mirror of
https://github.com/lukaszraczylo/claude-mnemonic.git
synced 2026-06-05 23:03:55 +00:00
Make things 'betterer' across the board (#23)
* Make things 'betterer' across the board * fix: reorganize struct fields and config parameters for consistency - [x] Reorder Config struct fields alphabetically and by related functionality - [x] Reorganize Observation model fields with archival fields grouped together - [x] Reorder ObservationStore fields to group related members - [x] Reorder Store struct fields with health check caching grouped - [x] Reorganize HealthInfo and PoolMetrics struct field order - [x] Reorder maintenance Service struct fields logically - [x] Reorganize MCP server handler parameter structs alphabetically - [x] Reorder pattern detector candidate tracking fields - [x] Reorganize search Manager struct fields by functionality - [x] Reorder vector Client struct fields with mutex protections grouped - [x] Reorganize handler request/response struct fields - [x] Update handlers_test.go to expect wrapped response format - [x] Reorder middleware TokenAuth and rate limiter fields - [x] Reorganize Service struct fields with grouped functionality - [x] Fix RateLimiter field ordering for clarity - [x] Reorder CircuitBreaker metrics fields * fix(security): improve JSON output safety and path traversal protection - [x] Replace unsafe JSON string formatting with proper json.Marshal in export handler - [x] Remove escapeJSONString helper function in favor of standard JSON marshaling - [x] Add safeResolvePath function to validate paths and prevent directory traversal - [x] Apply path traversal validation in captureFileMtimes operations - [x] Cap result slice capacity in getRecentSearchQueries to prevent DoS via excessive allocation * fix(sdk): improve path traversal protection and allocation safety - [x] Enhance safeResolvePath with stricter validation using filepath.Rel - [x] Reject paths containing ".." after cleaning to prevent traversal - [x] Validate absolute paths are within cwd when cwd is specified - [x] Apply safeResolvePath validation to GetFileContent for consistency - [x] Add comprehensive test coverage for path traversal protection - [x] Fix allocation safety in getRecentSearchQueries by using constant capacity
This commit is contained in:
+54
-17
@@ -9,27 +9,13 @@ import (
|
||||
"time"
|
||||
|
||||
"gorm.io/gorm"
|
||||
"gorm.io/gorm/clause"
|
||||
)
|
||||
|
||||
// EnsureSessionExists creates a session if it doesn't exist.
|
||||
// Uses INSERT OR IGNORE pattern for atomic idempotent creation (single query instead of COUNT + INSERT).
|
||||
// This is shared between stores to avoid duplication.
|
||||
func EnsureSessionExists(ctx context.Context, db *gorm.DB, sdkSessionID, project string) error {
|
||||
// Check if session exists
|
||||
var count int64
|
||||
err := db.WithContext(ctx).
|
||||
Model(&SDKSession{}).
|
||||
Where("sdk_session_id = ?", sdkSessionID).
|
||||
Count(&count).Error
|
||||
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if count > 0 {
|
||||
return nil // Session exists
|
||||
}
|
||||
|
||||
// Auto-create session
|
||||
now := time.Now()
|
||||
session := &SDKSession{
|
||||
ClaudeSessionID: sdkSessionID,
|
||||
@@ -41,7 +27,14 @@ func EnsureSessionExists(ctx context.Context, db *gorm.DB, sdkSessionID, project
|
||||
PromptCounter: 0,
|
||||
}
|
||||
|
||||
return db.WithContext(ctx).Create(session).Error
|
||||
// Use INSERT OR IGNORE - single query, atomic operation
|
||||
// If session already exists (conflict on sdk_session_id), do nothing
|
||||
return db.WithContext(ctx).
|
||||
Clauses(clause.OnConflict{
|
||||
Columns: []clause.Column{{Name: "sdk_session_id"}},
|
||||
DoNothing: true,
|
||||
}).
|
||||
Create(session).Error
|
||||
}
|
||||
|
||||
// sqlNullString creates a sql.NullString from a string.
|
||||
@@ -52,8 +45,13 @@ func sqlNullString(s string) sql.NullString {
|
||||
return sql.NullString{String: s, Valid: true}
|
||||
}
|
||||
|
||||
// MaxPaginationLimit is the maximum allowed limit for pagination queries.
|
||||
// This protects against resource exhaustion from excessively large requests.
|
||||
const MaxPaginationLimit = 1000
|
||||
|
||||
// ParseLimitParam parses the "limit" query parameter from an HTTP request.
|
||||
// Returns defaultLimit if the parameter is missing or invalid.
|
||||
// Note: This does NOT enforce a maximum limit. Use ParseLimitParamWithMax for that.
|
||||
func ParseLimitParam(r *http.Request, defaultLimit int) int {
|
||||
if l := r.URL.Query().Get("limit"); l != "" {
|
||||
if parsed, err := strconv.Atoi(l); err == nil && parsed > 0 {
|
||||
@@ -62,3 +60,42 @@ func ParseLimitParam(r *http.Request, defaultLimit int) int {
|
||||
}
|
||||
return defaultLimit
|
||||
}
|
||||
|
||||
// ParseLimitParamWithMax parses the "limit" query parameter with a maximum cap.
|
||||
// Returns min(parsed, maxLimit) or defaultLimit if missing/invalid.
|
||||
// If maxLimit is 0, uses MaxPaginationLimit (1000).
|
||||
func ParseLimitParamWithMax(r *http.Request, defaultLimit, maxLimit int) int {
|
||||
if maxLimit <= 0 {
|
||||
maxLimit = MaxPaginationLimit
|
||||
}
|
||||
limit := ParseLimitParam(r, defaultLimit)
|
||||
if limit > maxLimit {
|
||||
return maxLimit
|
||||
}
|
||||
return limit
|
||||
}
|
||||
|
||||
// ParseOffsetParam parses the "offset" query parameter from an HTTP request.
|
||||
// Returns 0 if the parameter is missing or invalid.
|
||||
func ParseOffsetParam(r *http.Request) int {
|
||||
if o := r.URL.Query().Get("offset"); o != "" {
|
||||
if parsed, err := strconv.Atoi(o); err == nil && parsed >= 0 {
|
||||
return parsed
|
||||
}
|
||||
}
|
||||
return 0
|
||||
}
|
||||
|
||||
// PaginationParams holds pagination parameters.
|
||||
type PaginationParams struct {
|
||||
Limit int
|
||||
Offset int
|
||||
}
|
||||
|
||||
// ParsePaginationParams parses both limit and offset from an HTTP request.
|
||||
func ParsePaginationParams(r *http.Request, defaultLimit int) PaginationParams {
|
||||
return PaginationParams{
|
||||
Limit: ParseLimitParam(r, defaultLimit),
|
||||
Offset: ParseOffsetParam(r),
|
||||
}
|
||||
}
|
||||
|
||||
@@ -322,6 +322,244 @@ func runMigrations(db *gorm.DB, sqlDB *sql.DB) error {
|
||||
return tx.Migrator().DropTable("observation_relations")
|
||||
},
|
||||
},
|
||||
|
||||
// Migration 012: Query optimization indexes
|
||||
// Adds covering and composite indexes for common query patterns
|
||||
{
|
||||
ID: "012_query_optimization_indexes",
|
||||
Migrate: func(tx *gorm.DB) error {
|
||||
sqls := []string{
|
||||
// Composite index for observation queries by project + scope + importance
|
||||
// Covers the common pattern: WHERE project = ? OR scope = 'global' ORDER BY importance_score DESC
|
||||
`CREATE INDEX IF NOT EXISTS idx_observations_project_scope_importance
|
||||
ON observations(project, scope, importance_score DESC, created_at_epoch DESC)`,
|
||||
|
||||
// Covering index for observation retrieval (includes most used columns)
|
||||
// Allows index-only scans for listing queries
|
||||
`CREATE INDEX IF NOT EXISTS idx_observations_project_covering
|
||||
ON observations(project, scope, is_superseded, importance_score DESC)
|
||||
WHERE is_superseded = 0 OR is_superseded IS NULL`,
|
||||
|
||||
// Index for session summary lookups
|
||||
`CREATE INDEX IF NOT EXISTS idx_summaries_project_importance
|
||||
ON session_summaries(project, importance_score DESC, created_at_epoch DESC)`,
|
||||
|
||||
// Index for prompt retrieval by session
|
||||
`CREATE INDEX IF NOT EXISTS idx_prompts_session_number
|
||||
ON user_prompts(claude_session_id, prompt_number)`,
|
||||
|
||||
// Index for pattern queries by frequency
|
||||
`CREATE INDEX IF NOT EXISTS idx_patterns_frequency
|
||||
ON patterns(frequency DESC, last_seen_at_epoch DESC)
|
||||
WHERE is_deprecated = 0`,
|
||||
}
|
||||
for _, s := range sqls {
|
||||
if err := tx.Exec(s).Error; err != nil {
|
||||
// Non-fatal: index may already exist or fail for benign reasons
|
||||
continue
|
||||
}
|
||||
}
|
||||
return nil
|
||||
},
|
||||
Rollback: func(tx *gorm.DB) error {
|
||||
sqls := []string{
|
||||
"DROP INDEX IF EXISTS idx_observations_project_scope_importance",
|
||||
"DROP INDEX IF EXISTS idx_observations_project_covering",
|
||||
"DROP INDEX IF EXISTS idx_summaries_project_importance",
|
||||
"DROP INDEX IF EXISTS idx_prompts_session_number",
|
||||
"DROP INDEX IF EXISTS idx_patterns_frequency",
|
||||
}
|
||||
for _, s := range sqls {
|
||||
if err := tx.Exec(s).Error; err != nil {
|
||||
continue
|
||||
}
|
||||
}
|
||||
return nil
|
||||
},
|
||||
},
|
||||
|
||||
// Migration 013: Add archival columns to observations
|
||||
{
|
||||
ID: "013_observation_archival",
|
||||
Migrate: func(tx *gorm.DB) error {
|
||||
sqls := []string{
|
||||
// Add archival columns
|
||||
`ALTER TABLE observations ADD COLUMN is_archived INTEGER DEFAULT 0`,
|
||||
`ALTER TABLE observations ADD COLUMN archived_at_epoch INTEGER`,
|
||||
`ALTER TABLE observations ADD COLUMN archived_reason TEXT`,
|
||||
// Index for archived observations
|
||||
`CREATE INDEX IF NOT EXISTS idx_observations_archived ON observations(is_archived)`,
|
||||
// Composite index for filtering active (non-archived) observations
|
||||
`CREATE INDEX IF NOT EXISTS idx_observations_active
|
||||
ON observations(project, is_archived, is_superseded, importance_score DESC)
|
||||
WHERE (is_archived = 0 OR is_archived IS NULL)`,
|
||||
}
|
||||
for _, s := range sqls {
|
||||
if err := tx.Exec(s).Error; err != nil {
|
||||
// Non-fatal: column may already exist
|
||||
continue
|
||||
}
|
||||
}
|
||||
return nil
|
||||
},
|
||||
Rollback: func(tx *gorm.DB) error {
|
||||
// SQLite doesn't support DROP COLUMN in older versions
|
||||
// but for newer versions we can try
|
||||
sqls := []string{
|
||||
"DROP INDEX IF EXISTS idx_observations_active",
|
||||
"DROP INDEX IF EXISTS idx_observations_archived",
|
||||
}
|
||||
for _, s := range sqls {
|
||||
_ = tx.Exec(s).Error
|
||||
}
|
||||
return nil
|
||||
},
|
||||
},
|
||||
|
||||
// Migration 014: Add performance-critical indexes for common query patterns
|
||||
{
|
||||
ID: "014_performance_indexes",
|
||||
Migrate: func(tx *gorm.DB) error {
|
||||
sqls := []string{
|
||||
// Index for batch ID lookups (IN queries)
|
||||
`CREATE INDEX IF NOT EXISTS idx_observations_id_covering
|
||||
ON observations(id, project, scope, importance_score)`,
|
||||
|
||||
// Index for vector search result fetching
|
||||
`CREATE INDEX IF NOT EXISTS idx_vectors_doc_type_project
|
||||
ON vectors(doc_type, project, scope)`,
|
||||
|
||||
// Index for session summaries by project
|
||||
`CREATE INDEX IF NOT EXISTS idx_summaries_project_created
|
||||
ON session_summaries(project, created_at_epoch DESC)`,
|
||||
|
||||
// Index for user prompts by session
|
||||
`CREATE INDEX IF NOT EXISTS idx_prompts_session_created
|
||||
ON user_prompts(claude_session_id, created_at_epoch DESC)`,
|
||||
|
||||
// Index for patterns by type and project
|
||||
`CREATE INDEX IF NOT EXISTS idx_patterns_type_project
|
||||
ON patterns(type, project, frequency DESC)
|
||||
WHERE is_deprecated = 0`,
|
||||
|
||||
// Index for observation relations
|
||||
`CREATE INDEX IF NOT EXISTS idx_relations_source_type
|
||||
ON observation_relations(source_observation_id, relation_type)`,
|
||||
`CREATE INDEX IF NOT EXISTS idx_relations_target_type
|
||||
ON observation_relations(target_observation_id, relation_type)`,
|
||||
}
|
||||
for _, s := range sqls {
|
||||
if err := tx.Exec(s).Error; err != nil {
|
||||
// Non-fatal: index may already exist
|
||||
continue
|
||||
}
|
||||
}
|
||||
return nil
|
||||
},
|
||||
Rollback: func(tx *gorm.DB) error {
|
||||
sqls := []string{
|
||||
"DROP INDEX IF EXISTS idx_observations_id_covering",
|
||||
"DROP INDEX IF EXISTS idx_vectors_doc_type_project",
|
||||
"DROP INDEX IF EXISTS idx_summaries_project_created",
|
||||
"DROP INDEX IF EXISTS idx_prompts_session_created",
|
||||
"DROP INDEX IF EXISTS idx_patterns_type_project",
|
||||
"DROP INDEX IF EXISTS idx_relations_source_type",
|
||||
"DROP INDEX IF EXISTS idx_relations_target_type",
|
||||
}
|
||||
for _, s := range sqls {
|
||||
_ = tx.Exec(s).Error
|
||||
}
|
||||
return nil
|
||||
},
|
||||
},
|
||||
|
||||
// Migration 015: Add optimized composite indexes for common query patterns
|
||||
{
|
||||
ID: "015_optimized_composite_indexes",
|
||||
Migrate: func(tx *gorm.DB) error {
|
||||
sqls := []string{
|
||||
// Composite index for GetRecentObservations with project+scope filtering
|
||||
// Covers: WHERE (project = ? OR scope = 'global') ORDER BY importance_score DESC, created_at_epoch DESC
|
||||
`CREATE INDEX IF NOT EXISTS idx_observations_project_scope_created
|
||||
ON observations(project, scope, created_at_epoch DESC, importance_score DESC)`,
|
||||
|
||||
// Index for scope='global' queries (common pattern in search)
|
||||
`CREATE INDEX IF NOT EXISTS idx_observations_global_scope
|
||||
ON observations(scope, importance_score DESC, created_at_epoch DESC)
|
||||
WHERE scope = 'global'`,
|
||||
|
||||
// Index for vector search result deduplication by observation
|
||||
`CREATE INDEX IF NOT EXISTS idx_vectors_observation_lookup
|
||||
ON vectors(doc_type, sqlite_id, project)
|
||||
WHERE doc_type = 'observation'`,
|
||||
|
||||
// Index for FTS search result ordering
|
||||
`CREATE INDEX IF NOT EXISTS idx_observations_fts_ordering
|
||||
ON observations(project, importance_score DESC)
|
||||
WHERE (is_archived = 0 OR is_archived IS NULL) AND (is_superseded = 0 OR is_superseded IS NULL)`,
|
||||
}
|
||||
for _, s := range sqls {
|
||||
if err := tx.Exec(s).Error; err != nil {
|
||||
// Non-fatal: index may already exist
|
||||
continue
|
||||
}
|
||||
}
|
||||
return nil
|
||||
},
|
||||
Rollback: func(tx *gorm.DB) error {
|
||||
sqls := []string{
|
||||
"DROP INDEX IF EXISTS idx_observations_project_scope_created",
|
||||
"DROP INDEX IF EXISTS idx_observations_global_scope",
|
||||
"DROP INDEX IF EXISTS idx_vectors_observation_lookup",
|
||||
"DROP INDEX IF EXISTS idx_observations_fts_ordering",
|
||||
}
|
||||
for _, s := range sqls {
|
||||
_ = tx.Exec(s).Error
|
||||
}
|
||||
return nil
|
||||
},
|
||||
},
|
||||
|
||||
// Migration 016: Add covering indexes for relation joins and active observations
|
||||
{
|
||||
ID: "016_relation_and_active_indexes",
|
||||
Migrate: func(tx *gorm.DB) error {
|
||||
sqls := []string{
|
||||
// Covering index for observation relation joins (common JOIN patterns)
|
||||
// Speeds up queries like: JOIN observation_relations ON source_id = obs.id WHERE relation_type = ?
|
||||
`CREATE INDEX IF NOT EXISTS idx_relations_source_type_target
|
||||
ON observation_relations(source_observation_id, relation_type, target_observation_id)`,
|
||||
|
||||
// Covering index for reverse relation lookups
|
||||
`CREATE INDEX IF NOT EXISTS idx_relations_target_type_source
|
||||
ON observation_relations(target_observation_id, relation_type, source_observation_id)`,
|
||||
|
||||
// Partial index for active (non-archived, non-superseded) observations
|
||||
// Optimizes activeObservationFilter queries
|
||||
`CREATE INDEX IF NOT EXISTS idx_observations_active
|
||||
ON observations(project, importance_score DESC, created_at_epoch DESC)
|
||||
WHERE (is_archived = 0 OR is_archived IS NULL) AND (is_superseded = 0 OR is_superseded IS NULL)`,
|
||||
}
|
||||
for _, s := range sqls {
|
||||
if err := tx.Exec(s).Error; err != nil {
|
||||
// Non-fatal: index may already exist
|
||||
continue
|
||||
}
|
||||
}
|
||||
return nil
|
||||
},
|
||||
Rollback: func(tx *gorm.DB) error {
|
||||
sqls := []string{
|
||||
"DROP INDEX IF EXISTS idx_relations_source_type_target",
|
||||
"DROP INDEX IF EXISTS idx_relations_target_type_source",
|
||||
"DROP INDEX IF EXISTS idx_observations_active",
|
||||
}
|
||||
for _, s := range sqls {
|
||||
_ = tx.Exec(s).Error
|
||||
}
|
||||
return nil
|
||||
},
|
||||
},
|
||||
})
|
||||
|
||||
if err := m.Migrate(); err != nil {
|
||||
|
||||
@@ -45,30 +45,34 @@ func (s *SDKSession) BeforeCreate(tx *gorm.DB) error {
|
||||
}
|
||||
|
||||
// Observation represents a stored observation (learning).
|
||||
// Field order optimized for memory alignment (fieldalignment).
|
||||
type Observation struct {
|
||||
FileMtimes models.JSONInt64Map `gorm:"type:text"`
|
||||
SDKSessionID string `gorm:"index;not null"`
|
||||
Project string `gorm:"index;not null"`
|
||||
Project string `gorm:"index:idx_observations_project;index:idx_observations_project_created,priority:1;not null"`
|
||||
Scope models.ObservationScope `gorm:"type:text;default:'project';check:scope IN ('project', 'global');index:idx_observations_scope;index:idx_observations_project_scope,priority:2"`
|
||||
Type models.ObservationType `gorm:"type:text;check:type IN ('decision', 'bugfix', 'feature', 'refactor', 'discovery', 'change');index;not null"`
|
||||
CreatedAt string `gorm:"not null"`
|
||||
Title sql.NullString `gorm:"type:text"`
|
||||
Facts models.JSONStringArray `gorm:"type:text"`
|
||||
Narrative sql.NullString `gorm:"type:text"`
|
||||
Concepts models.JSONStringArray `gorm:"type:text"`
|
||||
FilesRead models.JSONStringArray `gorm:"type:text"`
|
||||
FilesModified models.JSONStringArray `gorm:"type:text"`
|
||||
Subtitle sql.NullString `gorm:"type:text"`
|
||||
Facts models.JSONStringArray `gorm:"type:text"`
|
||||
LastRetrievedAt sql.NullInt64 `gorm:"column:last_retrieved_at_epoch"`
|
||||
PromptNumber sql.NullInt64
|
||||
Title sql.NullString `gorm:"type:text"`
|
||||
ArchivedReason sql.NullString
|
||||
ScoreUpdatedAt sql.NullInt64 `gorm:"column:score_updated_at_epoch;index:idx_observations_score_updated"`
|
||||
PromptNumber sql.NullInt64
|
||||
ArchivedAt sql.NullInt64 `gorm:"column:archived_at_epoch"`
|
||||
LastRetrievedAt sql.NullInt64 `gorm:"column:last_retrieved_at_epoch"`
|
||||
ID int64 `gorm:"primaryKey;autoIncrement"`
|
||||
ImportanceScore float64 `gorm:"type:real;default:1.0;index:idx_observations_importance,priority:1,sort:desc"`
|
||||
UserFeedback int `gorm:"default:0"`
|
||||
RetrievalCount int `gorm:"default:0"`
|
||||
CreatedAtEpoch int64 `gorm:"index:idx_observations_created,sort:desc;not null"`
|
||||
CreatedAtEpoch int64 `gorm:"index:idx_observations_created,sort:desc;index:idx_observations_project_created,priority:2,sort:desc;not null"`
|
||||
DiscoveryTokens int64 `gorm:"default:0"`
|
||||
IsSuperseded int `gorm:"default:0;index:idx_observations_superseded,priority:1"`
|
||||
IsSuperseded int `gorm:"default:0;index:idx_observations_superseded;index:idx_observations_active,priority:2"`
|
||||
IsArchived int `gorm:"default:0;index:idx_observations_archived;index:idx_observations_active,priority:1"`
|
||||
}
|
||||
|
||||
func (Observation) TableName() string { return "observations" }
|
||||
|
||||
@@ -5,39 +5,126 @@ import (
|
||||
"context"
|
||||
"database/sql"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"strings"
|
||||
"sync"
|
||||
"sync/atomic"
|
||||
"time"
|
||||
|
||||
"gorm.io/gorm"
|
||||
|
||||
"github.com/lukaszraczylo/claude-mnemonic/pkg/models"
|
||||
"github.com/rs/zerolog/log"
|
||||
)
|
||||
|
||||
// MaxObservationsPerProject is the maximum number of observations to keep per project.
|
||||
const MaxObservationsPerProject = 100
|
||||
|
||||
// cleanupQueueSize is the buffer size for the cleanup queue.
|
||||
const cleanupQueueSize = 100
|
||||
|
||||
// commonWords is a package-level set for O(1) lookup of stop words.
|
||||
// Created once at init time to avoid repeated map allocations.
|
||||
var commonWords = map[string]struct{}{
|
||||
"the": {}, "and": {}, "or": {}, "but": {}, "in": {},
|
||||
"on": {}, "at": {}, "to": {}, "for": {}, "of": {},
|
||||
"with": {}, "by": {}, "from": {}, "as": {}, "is": {},
|
||||
"was": {}, "are": {}, "were": {}, "be": {}, "been": {},
|
||||
"being": {}, "have": {}, "has": {}, "had": {}, "do": {},
|
||||
"does": {}, "did": {}, "will": {}, "would": {}, "should": {},
|
||||
"could": {}, "may": {}, "might": {}, "must": {}, "can": {},
|
||||
}
|
||||
|
||||
// CleanupFunc is a callback for when observations are cleaned up.
|
||||
// Receives the IDs of deleted observations for downstream cleanup (e.g., vector DB).
|
||||
type CleanupFunc func(ctx context.Context, deletedIDs []int64)
|
||||
|
||||
// ObservationStore provides observation-related database operations using GORM.
|
||||
type ObservationStore struct {
|
||||
db *gorm.DB
|
||||
rawDB *sql.DB
|
||||
cleanupFunc CleanupFunc
|
||||
conflictStore interface{} // Placeholder for ConflictStore (Phase 4)
|
||||
relationStore interface{} // Placeholder for RelationStore (Phase 4)
|
||||
conflictStore any
|
||||
relationStore any
|
||||
db *gorm.DB
|
||||
rawDB *sql.DB
|
||||
cleanupFunc CleanupFunc
|
||||
cleanupQueue chan string
|
||||
stopCleanup chan struct{}
|
||||
cleanupWg sync.WaitGroup
|
||||
cleanupOnce sync.Once
|
||||
cleanupStarted atomic.Bool
|
||||
}
|
||||
|
||||
// NewObservationStore creates a new observation store.
|
||||
// The conflictStore and relationStore parameters are optional (can be nil) and will be used in Phase 4.
|
||||
func NewObservationStore(store *Store, cleanupFunc CleanupFunc, conflictStore, relationStore interface{}) *ObservationStore {
|
||||
return &ObservationStore{
|
||||
func NewObservationStore(store *Store, cleanupFunc CleanupFunc, conflictStore, relationStore any) *ObservationStore {
|
||||
s := &ObservationStore{
|
||||
db: store.DB,
|
||||
rawDB: store.GetRawDB(),
|
||||
cleanupFunc: cleanupFunc,
|
||||
conflictStore: conflictStore,
|
||||
relationStore: relationStore,
|
||||
cleanupQueue: make(chan string, cleanupQueueSize),
|
||||
stopCleanup: make(chan struct{}),
|
||||
}
|
||||
// Start the cleanup worker
|
||||
s.startCleanupWorker()
|
||||
return s
|
||||
}
|
||||
|
||||
// startCleanupWorker starts the background cleanup worker.
|
||||
func (s *ObservationStore) startCleanupWorker() {
|
||||
s.cleanupOnce.Do(func() {
|
||||
s.cleanupStarted.Store(true)
|
||||
s.cleanupWg.Add(1)
|
||||
go s.cleanupWorker()
|
||||
})
|
||||
}
|
||||
|
||||
// cleanupWorker processes cleanup requests from the queue.
|
||||
func (s *ObservationStore) cleanupWorker() {
|
||||
defer s.cleanupWg.Done()
|
||||
|
||||
for {
|
||||
select {
|
||||
case <-s.stopCleanup:
|
||||
// Drain remaining items in queue before exiting
|
||||
for {
|
||||
select {
|
||||
case project := <-s.cleanupQueue:
|
||||
s.processCleanup(project)
|
||||
default:
|
||||
return
|
||||
}
|
||||
}
|
||||
case project := <-s.cleanupQueue:
|
||||
s.processCleanup(project)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// processCleanup performs the actual cleanup for a project.
|
||||
func (s *ObservationStore) processCleanup(project string) {
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second)
|
||||
defer cancel()
|
||||
|
||||
deletedIDs, err := s.CleanupOldObservations(ctx, project)
|
||||
if err != nil {
|
||||
log.Warn().Err(err).Str("project", project).Msg("Failed to cleanup old observations")
|
||||
return
|
||||
}
|
||||
|
||||
if len(deletedIDs) > 0 && s.cleanupFunc != nil {
|
||||
s.cleanupFunc(ctx, deletedIDs)
|
||||
log.Debug().Str("project", project).Int("count", len(deletedIDs)).Msg("Cleaned up old observations")
|
||||
}
|
||||
}
|
||||
|
||||
// Close stops the cleanup worker and waits for it to finish.
|
||||
// Safe to call even if the worker was never started.
|
||||
func (s *ObservationStore) Close() {
|
||||
// Only stop if worker was actually started to avoid deadlock
|
||||
if s.cleanupStarted.Load() {
|
||||
close(s.stopCleanup)
|
||||
s.cleanupWg.Wait()
|
||||
}
|
||||
}
|
||||
|
||||
@@ -86,16 +173,15 @@ func (s *ObservationStore) StoreObservation(ctx context.Context, sdkSessionID, p
|
||||
return 0, 0, err
|
||||
}
|
||||
|
||||
// Cleanup old observations beyond the limit for this project (async to not block handler)
|
||||
// Queue cleanup of old observations beyond the limit for this project (async to not block handler)
|
||||
if project != "" {
|
||||
go func(proj string) {
|
||||
cleanupCtx, cancel := context.WithTimeout(context.Background(), 10*time.Second)
|
||||
defer cancel()
|
||||
deletedIDs, _ := s.CleanupOldObservations(cleanupCtx, proj)
|
||||
if len(deletedIDs) > 0 && s.cleanupFunc != nil {
|
||||
s.cleanupFunc(cleanupCtx, deletedIDs)
|
||||
}
|
||||
}(project)
|
||||
select {
|
||||
case s.cleanupQueue <- project:
|
||||
// Successfully queued for cleanup
|
||||
default:
|
||||
// Queue is full, log a warning instead of silently dropping
|
||||
log.Warn().Str("project", project).Msg("Cleanup queue full, skipping cleanup for this observation")
|
||||
}
|
||||
}
|
||||
|
||||
// Note: Conflict and relation detection intentionally omitted for now
|
||||
@@ -104,6 +190,89 @@ func (s *ObservationStore) StoreObservation(ctx context.Context, sdkSessionID, p
|
||||
return dbObs.ID, nowEpoch, nil
|
||||
}
|
||||
|
||||
// ObservationUpdate contains fields that can be updated on an observation.
|
||||
// Only non-nil fields will be updated.
|
||||
type ObservationUpdate struct {
|
||||
Title *string // New title
|
||||
Subtitle *string // New subtitle
|
||||
Narrative *string // New narrative
|
||||
Facts *[]string // New facts (replaces existing)
|
||||
Concepts *[]string // New concepts (replaces existing)
|
||||
FilesRead *[]string // New files read (replaces existing)
|
||||
FilesModified *[]string // New files modified (replaces existing)
|
||||
Scope *string // New scope (project or global)
|
||||
}
|
||||
|
||||
// UpdateObservation updates an existing observation with the provided fields.
|
||||
// Only non-nil fields in the update struct will be modified.
|
||||
// Returns the updated observation or an error.
|
||||
func (s *ObservationStore) UpdateObservation(ctx context.Context, id int64, update *ObservationUpdate) (*models.Observation, error) {
|
||||
if update == nil {
|
||||
return nil, fmt.Errorf("update cannot be nil")
|
||||
}
|
||||
|
||||
// First, verify the observation exists
|
||||
var dbObs Observation
|
||||
if err := s.db.WithContext(ctx).First(&dbObs, id).Error; err != nil {
|
||||
if err == gorm.ErrRecordNotFound {
|
||||
return nil, fmt.Errorf("observation not found: %d", id)
|
||||
}
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// Build update map with only provided fields
|
||||
updates := make(map[string]any)
|
||||
|
||||
if update.Title != nil {
|
||||
updates["title"] = sql.NullString{String: *update.Title, Valid: true}
|
||||
}
|
||||
if update.Subtitle != nil {
|
||||
updates["subtitle"] = sql.NullString{String: *update.Subtitle, Valid: true}
|
||||
}
|
||||
if update.Narrative != nil {
|
||||
updates["narrative"] = sql.NullString{String: *update.Narrative, Valid: true}
|
||||
}
|
||||
if update.Facts != nil {
|
||||
factsJSON, _ := json.Marshal(*update.Facts)
|
||||
updates["facts"] = string(factsJSON)
|
||||
}
|
||||
if update.Concepts != nil {
|
||||
conceptsJSON, _ := json.Marshal(*update.Concepts)
|
||||
updates["concepts"] = string(conceptsJSON)
|
||||
}
|
||||
if update.FilesRead != nil {
|
||||
filesReadJSON, _ := json.Marshal(*update.FilesRead)
|
||||
updates["files_read"] = string(filesReadJSON)
|
||||
}
|
||||
if update.FilesModified != nil {
|
||||
filesModifiedJSON, _ := json.Marshal(*update.FilesModified)
|
||||
updates["files_modified"] = string(filesModifiedJSON)
|
||||
}
|
||||
if update.Scope != nil {
|
||||
updates["scope"] = sql.NullString{String: *update.Scope, Valid: true}
|
||||
}
|
||||
|
||||
// Add updated_at timestamp
|
||||
updates["updated_at_epoch"] = sql.NullInt64{Int64: time.Now().Unix(), Valid: true}
|
||||
|
||||
if len(updates) == 0 {
|
||||
// Nothing to update, just return existing observation
|
||||
return toModelObservation(&dbObs), nil
|
||||
}
|
||||
|
||||
// Perform the update
|
||||
if err := s.db.WithContext(ctx).Model(&Observation{}).Where("id = ?", id).Updates(updates).Error; err != nil {
|
||||
return nil, fmt.Errorf("failed to update observation: %w", err)
|
||||
}
|
||||
|
||||
// Fetch the updated observation
|
||||
if err := s.db.WithContext(ctx).First(&dbObs, id).Error; err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return toModelObservation(&dbObs), nil
|
||||
}
|
||||
|
||||
// GetObservationByID retrieves an observation by its ID.
|
||||
func (s *ObservationStore) GetObservationByID(ctx context.Context, id int64) (*models.Observation, error) {
|
||||
var dbObs Observation
|
||||
@@ -134,6 +303,8 @@ func (s *ObservationStore) GetObservationsByIDs(ctx context.Context, ids []int64
|
||||
query = query.Order("created_at_epoch DESC")
|
||||
case "importance":
|
||||
query = query.Order("importance_score DESC, created_at_epoch DESC")
|
||||
case "score_desc":
|
||||
query = query.Order("importance_score DESC, created_at_epoch DESC")
|
||||
default:
|
||||
// Default: importance first, then recency
|
||||
query = query.Order("COALESCE(importance_score, 1.0) DESC, created_at_epoch DESC")
|
||||
@@ -152,6 +323,60 @@ func (s *ObservationStore) GetObservationsByIDs(ctx context.Context, ids []int64
|
||||
return toModelObservations(dbObservations), nil
|
||||
}
|
||||
|
||||
// GetObservationsByIDsPreserveOrder retrieves observations by IDs, preserving the input order.
|
||||
// This is useful when the caller has already sorted/ranked the IDs (e.g., by vector similarity).
|
||||
func (s *ObservationStore) GetObservationsByIDsPreserveOrder(ctx context.Context, ids []int64) ([]*models.Observation, error) {
|
||||
if len(ids) == 0 {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
// Fetch all observations in a single query
|
||||
var dbObservations []Observation
|
||||
err := s.db.WithContext(ctx).Where("id IN ?", ids).Find(&dbObservations).Error
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// Build ID -> observation map for O(1) lookups
|
||||
obsMap := make(map[int64]*Observation, len(dbObservations))
|
||||
for i := range dbObservations {
|
||||
obsMap[int64(dbObservations[i].ID)] = &dbObservations[i]
|
||||
}
|
||||
|
||||
// Reconstruct in original order
|
||||
result := make([]*models.Observation, 0, len(ids))
|
||||
for _, id := range ids {
|
||||
if obs, ok := obsMap[id]; ok {
|
||||
result = append(result, toModelObservation(obs))
|
||||
}
|
||||
}
|
||||
|
||||
return result, nil
|
||||
}
|
||||
|
||||
// BatchGetObservationsWithScores retrieves observations with associated scores.
|
||||
// Returns a map of ID -> observation for efficient lookup.
|
||||
func (s *ObservationStore) BatchGetObservationsWithScores(ctx context.Context, ids []int64) (map[int64]*models.Observation, error) {
|
||||
if len(ids) == 0 {
|
||||
return make(map[int64]*models.Observation), nil
|
||||
}
|
||||
|
||||
// Fetch all observations in a single query
|
||||
var dbObservations []Observation
|
||||
err := s.db.WithContext(ctx).Where("id IN ?", ids).Find(&dbObservations).Error
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// Build result map
|
||||
result := make(map[int64]*models.Observation, len(dbObservations))
|
||||
for i := range dbObservations {
|
||||
result[int64(dbObservations[i].ID)] = toModelObservation(&dbObservations[i])
|
||||
}
|
||||
|
||||
return result, nil
|
||||
}
|
||||
|
||||
// GetRecentObservations retrieves recent observations for a project.
|
||||
// This includes project-scoped observations for the specified project AND global observations.
|
||||
// Results are ordered by importance_score DESC, then created_at_epoch DESC.
|
||||
@@ -169,13 +394,13 @@ func (s *ObservationStore) GetRecentObservations(ctx context.Context, project st
|
||||
return toModelObservations(dbObservations), nil
|
||||
}
|
||||
|
||||
// GetActiveObservations retrieves recent non-superseded observations for a project.
|
||||
// This excludes observations that have been marked as superseded by newer ones.
|
||||
// GetActiveObservations retrieves recent non-superseded, non-archived observations for a project.
|
||||
// This excludes observations that have been marked as superseded or archived.
|
||||
// Results are ordered by importance_score DESC, then created_at_epoch DESC.
|
||||
func (s *ObservationStore) GetActiveObservations(ctx context.Context, project string, limit int) ([]*models.Observation, error) {
|
||||
var dbObservations []Observation
|
||||
err := s.db.WithContext(ctx).
|
||||
Scopes(projectScopeFilter(project), notSupersededFilter(), importanceOrdering()).
|
||||
Scopes(projectScopeFilter(project), activeObservationFilter(), importanceOrdering()).
|
||||
Limit(limit).
|
||||
Find(&dbObservations).Error
|
||||
|
||||
@@ -245,7 +470,57 @@ func (s *ObservationStore) GetAllRecentObservations(ctx context.Context, limit i
|
||||
return toModelObservations(dbObservations), nil
|
||||
}
|
||||
|
||||
// GetAllRecentObservationsPaginated retrieves recent observations with pagination.
|
||||
func (s *ObservationStore) GetAllRecentObservationsPaginated(ctx context.Context, limit, offset int) ([]*models.Observation, int64, error) {
|
||||
var dbObservations []Observation
|
||||
var total int64
|
||||
|
||||
// Get total count
|
||||
if err := s.db.WithContext(ctx).Model(&Observation{}).Count(&total).Error; err != nil {
|
||||
return nil, 0, err
|
||||
}
|
||||
|
||||
// Get paginated results
|
||||
err := s.db.WithContext(ctx).
|
||||
Scopes(importanceOrdering()).
|
||||
Limit(limit).
|
||||
Offset(offset).
|
||||
Find(&dbObservations).Error
|
||||
|
||||
if err != nil {
|
||||
return nil, 0, err
|
||||
}
|
||||
|
||||
return toModelObservations(dbObservations), total, nil
|
||||
}
|
||||
|
||||
// GetObservationsByProjectStrictPaginated retrieves observations strictly from a project with pagination.
|
||||
func (s *ObservationStore) GetObservationsByProjectStrictPaginated(ctx context.Context, project string, limit, offset int) ([]*models.Observation, int64, error) {
|
||||
var dbObservations []Observation
|
||||
var total int64
|
||||
|
||||
// Get total count for project
|
||||
if err := s.db.WithContext(ctx).Model(&Observation{}).Where("project = ?", project).Count(&total).Error; err != nil {
|
||||
return nil, 0, err
|
||||
}
|
||||
|
||||
// Get paginated results
|
||||
err := s.db.WithContext(ctx).
|
||||
Where("project = ?", project).
|
||||
Scopes(importanceOrdering()).
|
||||
Limit(limit).
|
||||
Offset(offset).
|
||||
Find(&dbObservations).Error
|
||||
|
||||
if err != nil {
|
||||
return nil, 0, err
|
||||
}
|
||||
|
||||
return toModelObservations(dbObservations), total, nil
|
||||
}
|
||||
|
||||
// GetAllObservations retrieves all observations (for vector rebuild).
|
||||
// Note: For large datasets, prefer GetAllObservationsIterator to avoid memory issues.
|
||||
func (s *ObservationStore) GetAllObservations(ctx context.Context) ([]*models.Observation, error) {
|
||||
var dbObservations []Observation
|
||||
err := s.db.WithContext(ctx).
|
||||
@@ -259,6 +534,51 @@ func (s *ObservationStore) GetAllObservations(ctx context.Context) ([]*models.Ob
|
||||
return toModelObservations(dbObservations), nil
|
||||
}
|
||||
|
||||
// GetAllObservationsIterator returns observations in batches to avoid loading all into memory.
|
||||
// The callback is called for each batch. Return false from callback to stop iteration.
|
||||
// batchSize controls how many observations are loaded at once (default 500 if <= 0).
|
||||
func (s *ObservationStore) GetAllObservationsIterator(ctx context.Context, batchSize int, callback func([]*models.Observation) bool) error {
|
||||
if batchSize <= 0 {
|
||||
batchSize = 500
|
||||
}
|
||||
|
||||
var lastID int64 = 0
|
||||
for {
|
||||
// Check context cancellation
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
return ctx.Err()
|
||||
default:
|
||||
}
|
||||
|
||||
var dbObservations []Observation
|
||||
err := s.db.WithContext(ctx).
|
||||
Where("id > ?", lastID).
|
||||
Order("id ASC").
|
||||
Limit(batchSize).
|
||||
Find(&dbObservations).Error
|
||||
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if len(dbObservations) == 0 {
|
||||
break // No more observations
|
||||
}
|
||||
|
||||
// Update cursor for next batch
|
||||
lastID = dbObservations[len(dbObservations)-1].ID
|
||||
|
||||
// Convert and call callback
|
||||
observations := toModelObservations(dbObservations)
|
||||
if !callback(observations) {
|
||||
break // Callback requested stop
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// SearchObservationsFTS performs full-text search on observations using FTS5.
|
||||
// Falls back to LIKE search if FTS5 fails.
|
||||
func (s *ObservationStore) SearchObservationsFTS(ctx context.Context, query, project string, limit int) ([]*models.Observation, error) {
|
||||
@@ -314,14 +634,26 @@ func (s *ObservationStore) SearchObservationsFTS(ctx context.Context, query, pro
|
||||
}
|
||||
|
||||
// searchObservationsLike performs fallback LIKE search on observations using GORM.
|
||||
// Limits to 2 keywords to prevent expensive OR queries that SQLite optimizes poorly.
|
||||
// This is a fallback path when FTS returns no results, so we prioritize performance.
|
||||
func (s *ObservationStore) searchObservationsLike(ctx context.Context, keywords []string, project string, limit int) ([]*models.Observation, error) {
|
||||
if len(keywords) == 0 {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
// Limit keywords to prevent excessive OR conditions that hurt query planning.
|
||||
// SQLite performs significantly better with fewer LIKE conditions.
|
||||
// Using 2 instead of 5 reduces query complexity from O(15) to O(6) conditions
|
||||
// (each keyword creates 3 LIKE conditions for title, subtitle, narrative).
|
||||
const maxKeywords = 2
|
||||
if len(keywords) > maxKeywords {
|
||||
keywords = keywords[:maxKeywords]
|
||||
}
|
||||
|
||||
// Build LIKE conditions for each keyword
|
||||
var conditions []string
|
||||
var args []interface{}
|
||||
// Pre-allocate for efficiency: maxKeywords conditions × 3 args each + 1 project arg
|
||||
conditions := make([]string, 0, len(keywords))
|
||||
args := make([]any, 0, len(keywords)*3+1)
|
||||
|
||||
for _, kw := range keywords {
|
||||
pattern := "%" + kw + "%"
|
||||
@@ -358,6 +690,217 @@ func (s *ObservationStore) DeleteObservations(ctx context.Context, ids []int64)
|
||||
return result.RowsAffected, result.Error
|
||||
}
|
||||
|
||||
// DeleteObservation deletes a single observation by ID.
|
||||
func (s *ObservationStore) DeleteObservation(ctx context.Context, id int64) error {
|
||||
result := s.db.WithContext(ctx).Delete(&Observation{}, id)
|
||||
if result.Error != nil {
|
||||
return result.Error
|
||||
}
|
||||
if result.RowsAffected == 0 {
|
||||
return fmt.Errorf("observation %d not found", id)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// MarkAsSuperseded marks an observation as superseded (stale).
|
||||
func (s *ObservationStore) MarkAsSuperseded(ctx context.Context, id int64) error {
|
||||
result := s.db.WithContext(ctx).
|
||||
Model(&Observation{}).
|
||||
Where("id = ?", id).
|
||||
Update("is_superseded", 1)
|
||||
if result.Error != nil {
|
||||
return result.Error
|
||||
}
|
||||
if result.RowsAffected == 0 {
|
||||
return fmt.Errorf("observation %d not found", id)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// MarkAsSupersededBatch marks multiple observations as superseded in a single query.
|
||||
// Returns the number of observations updated and any error.
|
||||
func (s *ObservationStore) MarkAsSupersededBatch(ctx context.Context, ids []int64) (int64, error) {
|
||||
if len(ids) == 0 {
|
||||
return 0, nil
|
||||
}
|
||||
|
||||
result := s.db.WithContext(ctx).
|
||||
Model(&Observation{}).
|
||||
Where("id IN ?", ids).
|
||||
Update("is_superseded", 1)
|
||||
|
||||
return result.RowsAffected, result.Error
|
||||
}
|
||||
|
||||
// ArchiveObservation archives a single observation with an optional reason.
|
||||
func (s *ObservationStore) ArchiveObservation(ctx context.Context, id int64, reason string) error {
|
||||
updates := map[string]any{
|
||||
"is_archived": 1,
|
||||
"archived_at_epoch": time.Now().UnixMilli(),
|
||||
}
|
||||
if reason != "" {
|
||||
updates["archived_reason"] = reason
|
||||
}
|
||||
|
||||
result := s.db.WithContext(ctx).
|
||||
Model(&Observation{}).
|
||||
Where("id = ?", id).
|
||||
Updates(updates)
|
||||
if result.Error != nil {
|
||||
return result.Error
|
||||
}
|
||||
if result.RowsAffected == 0 {
|
||||
return fmt.Errorf("observation %d not found", id)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// UnarchiveObservation restores an archived observation.
|
||||
func (s *ObservationStore) UnarchiveObservation(ctx context.Context, id int64) error {
|
||||
result := s.db.WithContext(ctx).
|
||||
Model(&Observation{}).
|
||||
Where("id = ?", id).
|
||||
Updates(map[string]any{
|
||||
"is_archived": 0,
|
||||
"archived_at_epoch": nil,
|
||||
"archived_reason": nil,
|
||||
})
|
||||
if result.Error != nil {
|
||||
return result.Error
|
||||
}
|
||||
if result.RowsAffected == 0 {
|
||||
return fmt.Errorf("observation %d not found", id)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// ArchiveOldObservations archives observations older than the specified age.
|
||||
// Returns the count of archived observations and their IDs.
|
||||
func (s *ObservationStore) ArchiveOldObservations(ctx context.Context, project string, maxAgeDays int, reason string) ([]int64, error) {
|
||||
if maxAgeDays <= 0 {
|
||||
maxAgeDays = 90 // Default: archive observations older than 90 days
|
||||
}
|
||||
|
||||
cutoffEpoch := time.Now().AddDate(0, 0, -maxAgeDays).UnixMilli()
|
||||
if reason == "" {
|
||||
reason = fmt.Sprintf("auto-archived: older than %d days", maxAgeDays)
|
||||
}
|
||||
|
||||
var idsToArchive []int64
|
||||
|
||||
err := s.db.WithContext(ctx).Transaction(func(tx *gorm.DB) error {
|
||||
// Find observations to archive (not already archived, older than cutoff)
|
||||
query := tx.Model(&Observation{}).
|
||||
Where("created_at_epoch < ?", cutoffEpoch).
|
||||
Where("COALESCE(is_archived, 0) = 0")
|
||||
|
||||
if project != "" {
|
||||
query = query.Where("project = ?", project)
|
||||
}
|
||||
|
||||
if err := query.Pluck("id", &idsToArchive).Error; err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if len(idsToArchive) == 0 {
|
||||
return nil
|
||||
}
|
||||
|
||||
// Archive the observations
|
||||
now := time.Now().UnixMilli()
|
||||
return tx.Model(&Observation{}).
|
||||
Where("id IN ?", idsToArchive).
|
||||
Updates(map[string]any{
|
||||
"is_archived": 1,
|
||||
"archived_at_epoch": now,
|
||||
"archived_reason": reason,
|
||||
}).Error
|
||||
})
|
||||
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return idsToArchive, nil
|
||||
}
|
||||
|
||||
// GetArchivedObservations retrieves archived observations for a project.
|
||||
func (s *ObservationStore) GetArchivedObservations(ctx context.Context, project string, limit int) ([]*models.Observation, error) {
|
||||
var dbObservations []Observation
|
||||
query := s.db.WithContext(ctx).
|
||||
Where("COALESCE(is_archived, 0) = 1")
|
||||
|
||||
if project != "" {
|
||||
query = query.Where("project = ?", project)
|
||||
}
|
||||
|
||||
err := query.
|
||||
Order("archived_at_epoch DESC").
|
||||
Limit(limit).
|
||||
Find(&dbObservations).Error
|
||||
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return toModelObservations(dbObservations), nil
|
||||
}
|
||||
|
||||
// GetArchivalStats returns statistics about archived observations.
|
||||
// Optimized to use a single query instead of 4 separate queries.
|
||||
func (s *ObservationStore) GetArchivalStats(ctx context.Context, project string) (*ArchivalStats, error) {
|
||||
// Use a single query with conditional aggregation to get all stats at once.
|
||||
// This is much faster than 4 separate queries (saves 3 round trips).
|
||||
type statsResult struct {
|
||||
OldestEpoch *int64
|
||||
NewestEpoch *int64
|
||||
TotalCount int64
|
||||
ArchivedCount int64
|
||||
}
|
||||
|
||||
var result statsResult
|
||||
|
||||
query := s.db.WithContext(ctx).Model(&Observation{}).
|
||||
Select(`
|
||||
COUNT(*) as total_count,
|
||||
SUM(CASE WHEN COALESCE(is_archived, 0) = 1 THEN 1 ELSE 0 END) as archived_count,
|
||||
MIN(CASE WHEN COALESCE(is_archived, 0) = 1 THEN archived_at_epoch END) as oldest_epoch,
|
||||
MAX(CASE WHEN COALESCE(is_archived, 0) = 1 THEN archived_at_epoch END) as newest_epoch
|
||||
`)
|
||||
|
||||
if project != "" {
|
||||
query = query.Where("project = ?", project)
|
||||
}
|
||||
|
||||
if err := query.Scan(&result).Error; err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
stats := &ArchivalStats{
|
||||
TotalCount: result.TotalCount,
|
||||
ArchivedCount: result.ArchivedCount,
|
||||
ActiveCount: result.TotalCount - result.ArchivedCount,
|
||||
}
|
||||
|
||||
if result.OldestEpoch != nil {
|
||||
stats.OldestArchivedEpoch = *result.OldestEpoch
|
||||
}
|
||||
if result.NewestEpoch != nil {
|
||||
stats.NewestArchivedEpoch = *result.NewestEpoch
|
||||
}
|
||||
|
||||
return stats, nil
|
||||
}
|
||||
|
||||
// ArchivalStats contains statistics about archived observations.
|
||||
type ArchivalStats struct {
|
||||
TotalCount int64 `json:"total_count"`
|
||||
ActiveCount int64 `json:"active_count"`
|
||||
ArchivedCount int64 `json:"archived_count"`
|
||||
OldestArchivedEpoch int64 `json:"oldest_archived_epoch,omitempty"`
|
||||
NewestArchivedEpoch int64 `json:"newest_archived_epoch,omitempty"`
|
||||
}
|
||||
|
||||
// CleanupOldObservations removes observations beyond the limit for a project.
|
||||
// Returns the IDs of deleted observations.
|
||||
func (s *ObservationStore) CleanupOldObservations(ctx context.Context, project string) ([]int64, error) {
|
||||
@@ -418,10 +961,12 @@ func projectScopeFilter(project string) func(*gorm.DB) *gorm.DB {
|
||||
}
|
||||
}
|
||||
|
||||
// notSupersededFilter filters out superseded observations.
|
||||
func notSupersededFilter() func(*gorm.DB) *gorm.DB {
|
||||
// activeObservationFilter filters for active (non-archived, non-superseded) observations.
|
||||
// This is more efficient than chaining notSupersededFilter + notArchivedFilter
|
||||
// as it produces a single WHERE clause for the query optimizer.
|
||||
func activeObservationFilter() func(*gorm.DB) *gorm.DB {
|
||||
return func(db *gorm.DB) *gorm.DB {
|
||||
return db.Where("COALESCE(is_superseded, 0) = 0")
|
||||
return db.Where("COALESCE(is_archived, 0) = 0 AND COALESCE(is_superseded, 0) = 0")
|
||||
}
|
||||
}
|
||||
|
||||
@@ -437,23 +982,17 @@ func importanceOrdering() func(*gorm.DB) *gorm.DB {
|
||||
// ====================
|
||||
|
||||
// extractKeywords extracts keywords from a search query.
|
||||
// Uses package-level commonWords map for O(1) stop word filtering.
|
||||
func extractKeywords(query string) []string {
|
||||
words := strings.Fields(strings.ToLower(query))
|
||||
var keywords []string
|
||||
|
||||
commonWords := map[string]bool{
|
||||
"the": true, "and": true, "or": true, "but": true, "in": true,
|
||||
"on": true, "at": true, "to": true, "for": true, "of": true,
|
||||
"with": true, "by": true, "from": true, "as": true, "is": true,
|
||||
"was": true, "are": true, "were": true, "be": true, "been": true,
|
||||
"being": true, "have": true, "has": true, "had": true, "do": true,
|
||||
"does": true, "did": true, "will": true, "would": true, "should": true,
|
||||
"could": true, "may": true, "might": true, "must": true, "can": true,
|
||||
}
|
||||
keywords := make([]string, 0, len(words)) // Pre-allocate for typical case
|
||||
|
||||
for _, word := range words {
|
||||
// Skip short words and common words
|
||||
if len(word) <= 3 || commonWords[word] {
|
||||
// Skip short words and common stop words
|
||||
if len(word) <= 3 {
|
||||
continue
|
||||
}
|
||||
if _, isCommon := commonWords[word]; isCommon {
|
||||
continue
|
||||
}
|
||||
keywords = append(keywords, word)
|
||||
@@ -464,7 +1003,8 @@ func extractKeywords(query string) []string {
|
||||
|
||||
// scanObservationRows scans multiple observations from raw SQL rows.
|
||||
func scanObservationRows(rows *sql.Rows) ([]*models.Observation, error) {
|
||||
var observations []*models.Observation
|
||||
// Pre-allocate with reasonable initial capacity to avoid repeated slice growth
|
||||
observations := make([]*models.Observation, 0, 64)
|
||||
for rows.Next() {
|
||||
obs, err := scanObservation(rows)
|
||||
if err != nil {
|
||||
@@ -476,7 +1016,7 @@ func scanObservationRows(rows *sql.Rows) ([]*models.Observation, error) {
|
||||
}
|
||||
|
||||
// scanObservation scans a single observation from a row scanner.
|
||||
func scanObservation(scanner interface{ Scan(...interface{}) error }) (*models.Observation, error) {
|
||||
func scanObservation(scanner interface{ Scan(...any) error }) (*models.Observation, error) {
|
||||
var obs models.Observation
|
||||
var factsJSON, conceptsJSON, filesReadJSON, filesModifiedJSON, fileMtimesJSON []byte
|
||||
var isSuperseded int
|
||||
|
||||
@@ -3,6 +3,8 @@ package gorm
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"gorm.io/gorm"
|
||||
@@ -59,12 +61,23 @@ func (s *ObservationStore) UpdateImportanceScore(ctx context.Context, id int64,
|
||||
}
|
||||
|
||||
// UpdateImportanceScores bulk updates importance scores for multiple observations.
|
||||
// This is more efficient than individual updates for batch recalculation.
|
||||
// Uses a single SQL statement with CASE/WHEN for efficient batch updates.
|
||||
func (s *ObservationStore) UpdateImportanceScores(ctx context.Context, scores map[int64]float64) error {
|
||||
if len(scores) == 0 {
|
||||
return nil
|
||||
}
|
||||
|
||||
// For small batches, use simple individual updates
|
||||
if len(scores) <= 5 {
|
||||
return s.updateScoresIndividually(ctx, scores)
|
||||
}
|
||||
|
||||
// For larger batches, use CASE/WHEN SQL for single-query update
|
||||
return s.updateScoresBatch(ctx, scores)
|
||||
}
|
||||
|
||||
// updateScoresIndividually updates scores one at a time (efficient for small batches).
|
||||
func (s *ObservationStore) updateScoresIndividually(ctx context.Context, scores map[int64]float64) error {
|
||||
return s.db.WithContext(ctx).Transaction(func(tx *gorm.DB) error {
|
||||
now := time.Now().UnixMilli()
|
||||
|
||||
@@ -85,6 +98,36 @@ func (s *ObservationStore) UpdateImportanceScores(ctx context.Context, scores ma
|
||||
})
|
||||
}
|
||||
|
||||
// updateScoresBatch updates multiple scores in a single SQL statement using CASE/WHEN.
|
||||
// This is much more efficient for large batches (O(1) queries instead of O(n)).
|
||||
func (s *ObservationStore) updateScoresBatch(ctx context.Context, scores map[int64]float64) error {
|
||||
now := time.Now().UnixMilli()
|
||||
|
||||
// Build CASE/WHEN clause for importance_score
|
||||
// UPDATE observations SET
|
||||
// importance_score = CASE id WHEN 1 THEN 0.5 WHEN 2 THEN 0.8 ... END,
|
||||
// score_updated_at_epoch = ?
|
||||
// WHERE id IN (1, 2, ...)
|
||||
|
||||
ids := make([]int64, 0, len(scores))
|
||||
caseBuilder := strings.Builder{}
|
||||
caseBuilder.WriteString("CASE id ")
|
||||
|
||||
for id, score := range scores {
|
||||
ids = append(ids, id)
|
||||
caseBuilder.WriteString(fmt.Sprintf("WHEN %d THEN %f ", id, score))
|
||||
}
|
||||
caseBuilder.WriteString("END")
|
||||
|
||||
// Use raw SQL for the batch update
|
||||
sql := fmt.Sprintf(
|
||||
"UPDATE observations SET importance_score = %s, score_updated_at_epoch = ? WHERE id IN ?",
|
||||
caseBuilder.String(),
|
||||
)
|
||||
|
||||
return s.db.WithContext(ctx).Exec(sql, now, ids).Error
|
||||
}
|
||||
|
||||
// GetObservationsNeedingScoreUpdate returns observations that need their importance score recalculated.
|
||||
// Returns observations where score_updated_at_epoch is NULL or older than the threshold.
|
||||
func (s *ObservationStore) GetObservationsNeedingScoreUpdate(ctx context.Context, threshold time.Duration, limit int) ([]*models.Observation, error) {
|
||||
|
||||
@@ -119,26 +119,44 @@ func (s *SessionStore) FindAnySDKSession(ctx context.Context, claudeSessionID st
|
||||
}
|
||||
|
||||
// IncrementPromptCounter increments the prompt counter and returns the new value.
|
||||
// Uses a single SQL query with RETURNING clause for optimal performance.
|
||||
func (s *SessionStore) IncrementPromptCounter(ctx context.Context, id int64) (int, error) {
|
||||
// Atomic increment using GORM expression
|
||||
err := s.db.WithContext(ctx).
|
||||
Model(&SDKSession{}).
|
||||
Where("id = ?", id).
|
||||
Update("prompt_counter", gorm.Expr("COALESCE(prompt_counter, 0) + 1")).Error
|
||||
// Use raw SQL with RETURNING to get updated value in single query
|
||||
// SQLite supports RETURNING since version 3.35.0 (2021-03-12)
|
||||
var newCounter int
|
||||
err := s.db.WithContext(ctx).Raw(`
|
||||
UPDATE sdk_sessions
|
||||
SET prompt_counter = COALESCE(prompt_counter, 0) + 1
|
||||
WHERE id = ?
|
||||
RETURNING prompt_counter
|
||||
`, id).Scan(&newCounter).Error
|
||||
|
||||
if err != nil {
|
||||
// Fallback for older SQLite versions without RETURNING support
|
||||
if err.Error() == "near \"RETURNING\": syntax error" || newCounter == 0 {
|
||||
// Atomic increment
|
||||
updateErr := s.db.WithContext(ctx).
|
||||
Model(&SDKSession{}).
|
||||
Where("id = ?", id).
|
||||
Update("prompt_counter", gorm.Expr("COALESCE(prompt_counter, 0) + 1")).Error
|
||||
if updateErr != nil {
|
||||
return 0, updateErr
|
||||
}
|
||||
|
||||
// Fetch updated value
|
||||
var sess SDKSession
|
||||
fetchErr := s.db.WithContext(ctx).
|
||||
Select("prompt_counter").
|
||||
First(&sess, id).Error
|
||||
if fetchErr != nil {
|
||||
return 0, fetchErr
|
||||
}
|
||||
return sess.PromptCounter, nil
|
||||
}
|
||||
return 0, err
|
||||
}
|
||||
|
||||
// Fetch updated value
|
||||
var sess SDKSession
|
||||
err = s.db.WithContext(ctx).
|
||||
Select("prompt_counter").
|
||||
First(&sess, id).Error
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
|
||||
return sess.PromptCounter, nil
|
||||
return newCounter, nil
|
||||
}
|
||||
|
||||
// GetPromptCounter returns the current prompt counter for a session.
|
||||
|
||||
+426
-9
@@ -2,11 +2,16 @@
|
||||
package gorm
|
||||
|
||||
import (
|
||||
"context"
|
||||
"database/sql"
|
||||
"fmt"
|
||||
"slices"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
sqlite_vec "github.com/asg017/sqlite-vec-go-bindings/cgo"
|
||||
_ "github.com/mattn/go-sqlite3" // Import SQLite driver with FTS5 support
|
||||
"github.com/rs/zerolog/log"
|
||||
"gorm.io/driver/sqlite"
|
||||
"gorm.io/gorm"
|
||||
"gorm.io/gorm/logger"
|
||||
@@ -14,8 +19,13 @@ import (
|
||||
|
||||
// Store represents the GORM database connection with sqlite-vec support.
|
||||
type Store struct {
|
||||
DB *gorm.DB
|
||||
sqlDB *sql.DB // For FTS5 and sqlite-vec operations that require raw SQL
|
||||
healthCacheTime time.Time
|
||||
DB *gorm.DB
|
||||
sqlDB *sql.DB
|
||||
metrics *PoolMetrics
|
||||
cachedHealth *HealthInfo
|
||||
healthCacheTTL time.Duration
|
||||
healthCacheMu sync.RWMutex
|
||||
}
|
||||
|
||||
// Config holds database configuration.
|
||||
@@ -71,8 +81,10 @@ func NewStore(cfg Config) (*Store, error) {
|
||||
}
|
||||
|
||||
store := &Store{
|
||||
DB: db,
|
||||
sqlDB: sqlDB,
|
||||
DB: db,
|
||||
sqlDB: sqlDB,
|
||||
metrics: NewPoolMetrics(100), // Track last 100 latency samples
|
||||
healthCacheTTL: 5 * time.Second, // Cache health checks for 5 seconds
|
||||
}
|
||||
|
||||
// 7. Run migrations FIRST (before PRAGMA commands)
|
||||
@@ -80,13 +92,20 @@ func NewStore(cfg Config) (*Store, error) {
|
||||
return nil, fmt.Errorf("run migrations: %w", err)
|
||||
}
|
||||
|
||||
// 8. CRITICAL: Set WAL mode and synchronous mode via raw SQL
|
||||
// 8. CRITICAL: Set WAL mode and other performance pragmas
|
||||
// Use raw sqlDB to avoid GORM transaction issues
|
||||
if _, err := sqlDB.Exec("PRAGMA journal_mode=WAL"); err != nil {
|
||||
return nil, fmt.Errorf("set WAL mode: %w", err)
|
||||
pragmas := []string{
|
||||
"PRAGMA journal_mode=WAL",
|
||||
"PRAGMA synchronous=NORMAL",
|
||||
"PRAGMA cache_size=-64000", // 64MB cache (negative = KB)
|
||||
"PRAGMA temp_store=MEMORY", // Store temp tables in memory
|
||||
"PRAGMA mmap_size=268435456", // 256MB memory-mapped I/O
|
||||
"PRAGMA page_size=4096", // 4KB pages (optimal for most systems)
|
||||
}
|
||||
if _, err := sqlDB.Exec("PRAGMA synchronous=NORMAL"); err != nil {
|
||||
return nil, fmt.Errorf("set synchronous mode: %w", err)
|
||||
for _, pragma := range pragmas {
|
||||
if _, err := sqlDB.Exec(pragma); err != nil {
|
||||
log.Warn().Str("pragma", pragma).Err(err).Msg("Failed to set pragma (non-fatal)")
|
||||
}
|
||||
}
|
||||
// Set busy timeout to 5 seconds to handle concurrent writes
|
||||
// This allows SQLite to retry when database is locked instead of failing immediately
|
||||
@@ -94,9 +113,40 @@ func NewStore(cfg Config) (*Store, error) {
|
||||
return nil, fmt.Errorf("set busy timeout: %w", err)
|
||||
}
|
||||
|
||||
// 9. Warm the connection pool
|
||||
store.WarmPool(maxConns)
|
||||
|
||||
return store, nil
|
||||
}
|
||||
|
||||
// WarmPool pre-creates connections to avoid cold start latency.
|
||||
func (s *Store) WarmPool(numConns int) {
|
||||
if numConns <= 0 {
|
||||
numConns = 4
|
||||
}
|
||||
|
||||
var wg sync.WaitGroup
|
||||
for i := 0; i < numConns; i++ {
|
||||
wg.Add(1)
|
||||
go func() {
|
||||
defer wg.Done()
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 2*time.Second)
|
||||
defer cancel()
|
||||
|
||||
conn, err := s.sqlDB.Conn(ctx)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
// Execute a simple query to ensure the connection is fully initialized
|
||||
_ = conn.PingContext(ctx)
|
||||
// Return connection to pool (don't close it)
|
||||
_ = conn.Close()
|
||||
}()
|
||||
}
|
||||
wg.Wait()
|
||||
log.Debug().Int("connections", numConns).Msg("Connection pool warmed")
|
||||
}
|
||||
|
||||
// Close closes the database connection.
|
||||
func (s *Store) Close() error {
|
||||
return s.sqlDB.Close()
|
||||
@@ -120,3 +170,370 @@ func (s *Store) GetRawDB() *sql.DB {
|
||||
func (s *Store) GetDB() *gorm.DB {
|
||||
return s.DB
|
||||
}
|
||||
|
||||
// Stats returns database connection pool statistics.
|
||||
func (s *Store) Stats() sql.DBStats {
|
||||
return s.sqlDB.Stats()
|
||||
}
|
||||
|
||||
// Optimize runs VACUUM and ANALYZE to optimize the database.
|
||||
// Should be called periodically (e.g., daily) during low activity.
|
||||
func (s *Store) Optimize(ctx context.Context) error {
|
||||
log.Info().Msg("Starting database optimization")
|
||||
start := time.Now()
|
||||
|
||||
// ANALYZE updates statistics for query optimizer
|
||||
if _, err := s.sqlDB.ExecContext(ctx, "ANALYZE"); err != nil {
|
||||
return fmt.Errorf("analyze: %w", err)
|
||||
}
|
||||
|
||||
// PRAGMA optimize runs optimization based on query statistics
|
||||
if _, err := s.sqlDB.ExecContext(ctx, "PRAGMA optimize"); err != nil {
|
||||
log.Warn().Err(err).Msg("PRAGMA optimize failed (non-fatal)")
|
||||
}
|
||||
|
||||
log.Info().Dur("duration", time.Since(start)).Msg("Database optimization complete")
|
||||
return nil
|
||||
}
|
||||
|
||||
// HealthCheck performs a comprehensive health check with latency measurement.
|
||||
// Returns detailed health information including connection pool stats and query latency.
|
||||
// Results are cached for healthCacheTTL (default 5 seconds) to reduce database load
|
||||
// from frequent monitoring calls.
|
||||
func (s *Store) HealthCheck(ctx context.Context) *HealthInfo {
|
||||
// Fast path: check cache with read lock
|
||||
s.healthCacheMu.RLock()
|
||||
if s.cachedHealth != nil && time.Since(s.healthCacheTime) < s.healthCacheTTL {
|
||||
cached := s.cachedHealth
|
||||
s.healthCacheMu.RUnlock()
|
||||
return cached
|
||||
}
|
||||
s.healthCacheMu.RUnlock()
|
||||
|
||||
// Slow path: perform actual health check
|
||||
info := s.performHealthCheck(ctx)
|
||||
|
||||
// Cache the result
|
||||
s.healthCacheMu.Lock()
|
||||
s.cachedHealth = info
|
||||
s.healthCacheTime = time.Now()
|
||||
s.healthCacheMu.Unlock()
|
||||
|
||||
return info
|
||||
}
|
||||
|
||||
// HealthCheckForce performs a health check bypassing the cache.
|
||||
// Use this when you need real-time health data (e.g., debugging, alerting).
|
||||
func (s *Store) HealthCheckForce(ctx context.Context) *HealthInfo {
|
||||
info := s.performHealthCheck(ctx)
|
||||
|
||||
// Update the cache with fresh data
|
||||
s.healthCacheMu.Lock()
|
||||
s.cachedHealth = info
|
||||
s.healthCacheTime = time.Now()
|
||||
s.healthCacheMu.Unlock()
|
||||
|
||||
return info
|
||||
}
|
||||
|
||||
// performHealthCheck does the actual health check work.
|
||||
func (s *Store) performHealthCheck(ctx context.Context) *HealthInfo {
|
||||
info := &HealthInfo{
|
||||
Status: "healthy",
|
||||
Timestamp: time.Now(),
|
||||
}
|
||||
|
||||
// Check pool stats
|
||||
stats := s.sqlDB.Stats()
|
||||
info.PoolStats = PoolStats{
|
||||
OpenConnections: stats.OpenConnections,
|
||||
InUse: stats.InUse,
|
||||
Idle: stats.Idle,
|
||||
WaitCount: stats.WaitCount,
|
||||
WaitDuration: stats.WaitDuration,
|
||||
MaxIdleClosed: stats.MaxIdleClosed,
|
||||
MaxLifetimeClosed: stats.MaxLifetimeClosed,
|
||||
}
|
||||
|
||||
// Record pool stats for metrics tracking
|
||||
if s.metrics != nil {
|
||||
s.metrics.RecordPoolStats(stats)
|
||||
}
|
||||
|
||||
// Measure query latency with a simple SELECT
|
||||
start := time.Now()
|
||||
var dummy int
|
||||
err := s.sqlDB.QueryRowContext(ctx, "SELECT 1").Scan(&dummy)
|
||||
info.QueryLatency = time.Since(start)
|
||||
|
||||
// Record latency for historical tracking
|
||||
if s.metrics != nil {
|
||||
s.metrics.RecordLatency(info.QueryLatency)
|
||||
info.HistoricalMetrics = s.metrics.GetMetricsSummary()
|
||||
}
|
||||
|
||||
if err != nil {
|
||||
info.Status = "unhealthy"
|
||||
info.Error = err.Error()
|
||||
return info
|
||||
}
|
||||
|
||||
// Check for connection saturation (degraded if pool is heavily used)
|
||||
if stats.InUse > 0 && float64(stats.InUse)/float64(stats.OpenConnections) > 0.8 {
|
||||
info.Status = "degraded"
|
||||
info.Warning = "Connection pool heavily utilized"
|
||||
}
|
||||
|
||||
// Check for wait contention
|
||||
if stats.WaitCount > 100 && stats.WaitDuration > 100*time.Millisecond {
|
||||
info.Status = "degraded"
|
||||
info.Warning = "Connection pool contention detected"
|
||||
}
|
||||
|
||||
// Check query latency (warn if > 10ms for simple query)
|
||||
if info.QueryLatency > 10*time.Millisecond {
|
||||
if info.Status == "healthy" {
|
||||
info.Status = "degraded"
|
||||
}
|
||||
info.Warning = fmt.Sprintf("Slow query latency: %v", info.QueryLatency)
|
||||
}
|
||||
|
||||
// Check historical latency trend (degraded if P95 is high)
|
||||
if s.metrics != nil && info.HistoricalMetrics.P95Latency > 50*time.Millisecond {
|
||||
if info.Status == "healthy" {
|
||||
info.Status = "degraded"
|
||||
}
|
||||
info.Warning = fmt.Sprintf("High P95 latency: %v", info.HistoricalMetrics.P95Latency)
|
||||
}
|
||||
|
||||
return info
|
||||
}
|
||||
|
||||
// HealthInfo contains database health check results.
|
||||
type HealthInfo struct {
|
||||
Timestamp time.Time `json:"timestamp"`
|
||||
Status string `json:"status"`
|
||||
Error string `json:"error,omitempty"`
|
||||
Warning string `json:"warning,omitempty"`
|
||||
HistoricalMetrics MetricsSummary `json:"historical_metrics,omitempty"`
|
||||
PoolStats PoolStats `json:"pool_stats"`
|
||||
QueryLatency time.Duration `json:"query_latency_ns"`
|
||||
}
|
||||
|
||||
// PoolStats contains connection pool statistics.
|
||||
type PoolStats struct {
|
||||
OpenConnections int `json:"open_connections"`
|
||||
InUse int `json:"in_use"`
|
||||
Idle int `json:"idle"`
|
||||
WaitCount int64 `json:"wait_count"`
|
||||
WaitDuration time.Duration `json:"wait_duration_ns"`
|
||||
MaxIdleClosed int64 `json:"max_idle_closed"`
|
||||
MaxLifetimeClosed int64 `json:"max_lifetime_closed"`
|
||||
}
|
||||
|
||||
// QueryTimeout constants for different query types.
|
||||
const (
|
||||
// DefaultQueryTimeout is the default timeout for regular queries.
|
||||
DefaultQueryTimeout = 5 * time.Second
|
||||
// FastQueryTimeout is for queries that should be very fast (health checks, etc).
|
||||
FastQueryTimeout = 1 * time.Second
|
||||
// SlowQueryTimeout is for queries that may take longer (bulk operations, rebuilds).
|
||||
SlowQueryTimeout = 30 * time.Second
|
||||
)
|
||||
|
||||
// PoolMetrics tracks historical connection pool metrics with a sliding window.
|
||||
type PoolMetrics struct {
|
||||
lastSampleTime time.Time
|
||||
latencySamples []time.Duration
|
||||
latencyIdx int
|
||||
latencyCount int
|
||||
totalQueries int64
|
||||
totalWaitTime time.Duration
|
||||
peakInUse int
|
||||
peakWaitCount int64
|
||||
windowSize int
|
||||
mu sync.RWMutex
|
||||
}
|
||||
|
||||
// NewPoolMetrics creates a new pool metrics collector with the given window size.
|
||||
func NewPoolMetrics(windowSize int) *PoolMetrics {
|
||||
if windowSize <= 0 {
|
||||
windowSize = 100 // Default: track last 100 samples
|
||||
}
|
||||
return &PoolMetrics{
|
||||
latencySamples: make([]time.Duration, windowSize),
|
||||
windowSize: windowSize,
|
||||
lastSampleTime: time.Now(),
|
||||
}
|
||||
}
|
||||
|
||||
// RecordLatency records a query latency sample.
|
||||
func (m *PoolMetrics) RecordLatency(latency time.Duration) {
|
||||
m.mu.Lock()
|
||||
defer m.mu.Unlock()
|
||||
|
||||
m.latencySamples[m.latencyIdx] = latency
|
||||
m.latencyIdx = (m.latencyIdx + 1) % m.windowSize
|
||||
if m.latencyCount < m.windowSize {
|
||||
m.latencyCount++
|
||||
}
|
||||
m.totalQueries++
|
||||
m.lastSampleTime = time.Now()
|
||||
}
|
||||
|
||||
// RecordPoolStats records pool statistics for peak tracking.
|
||||
func (m *PoolMetrics) RecordPoolStats(stats sql.DBStats) {
|
||||
m.mu.Lock()
|
||||
defer m.mu.Unlock()
|
||||
|
||||
if stats.InUse > m.peakInUse {
|
||||
m.peakInUse = stats.InUse
|
||||
}
|
||||
if stats.WaitCount > m.peakWaitCount {
|
||||
m.peakWaitCount = stats.WaitCount
|
||||
}
|
||||
m.totalWaitTime += stats.WaitDuration
|
||||
}
|
||||
|
||||
// GetMetricsSummary returns a summary of collected metrics.
|
||||
func (m *PoolMetrics) GetMetricsSummary() MetricsSummary {
|
||||
m.mu.RLock()
|
||||
defer m.mu.RUnlock()
|
||||
|
||||
summary := MetricsSummary{
|
||||
TotalQueries: m.totalQueries,
|
||||
SampleCount: m.latencyCount,
|
||||
PeakInUse: m.peakInUse,
|
||||
PeakWaitCount: m.peakWaitCount,
|
||||
TotalWaitTime: m.totalWaitTime,
|
||||
LastSampleTime: m.lastSampleTime,
|
||||
}
|
||||
|
||||
if m.latencyCount == 0 {
|
||||
return summary
|
||||
}
|
||||
|
||||
// Calculate latency statistics
|
||||
var total time.Duration
|
||||
var min, max time.Duration = m.latencySamples[0], m.latencySamples[0]
|
||||
|
||||
for i := 0; i < m.latencyCount; i++ {
|
||||
sample := m.latencySamples[i]
|
||||
total += sample
|
||||
if sample < min {
|
||||
min = sample
|
||||
}
|
||||
if sample > max {
|
||||
max = sample
|
||||
}
|
||||
}
|
||||
|
||||
summary.AvgLatency = total / time.Duration(m.latencyCount)
|
||||
summary.MinLatency = min
|
||||
summary.MaxLatency = max
|
||||
|
||||
// Calculate P95 latency (approximate using sorted samples)
|
||||
if m.latencyCount >= 20 {
|
||||
// Copy samples for sorting
|
||||
samples := make([]time.Duration, m.latencyCount)
|
||||
copy(samples, m.latencySamples[:m.latencyCount])
|
||||
// Use slices.Sort for O(n log n) instead of O(n²) insertion sort
|
||||
slices.Sort(samples)
|
||||
p95Idx := int(float64(len(samples)) * 0.95)
|
||||
summary.P95Latency = samples[p95Idx]
|
||||
}
|
||||
|
||||
return summary
|
||||
}
|
||||
|
||||
// MetricsSummary contains aggregated pool metrics.
|
||||
type MetricsSummary struct {
|
||||
LastSampleTime time.Time `json:"last_sample_time"`
|
||||
TotalQueries int64 `json:"total_queries"`
|
||||
SampleCount int `json:"sample_count"`
|
||||
AvgLatency time.Duration `json:"avg_latency_ns"`
|
||||
MinLatency time.Duration `json:"min_latency_ns"`
|
||||
MaxLatency time.Duration `json:"max_latency_ns"`
|
||||
P95Latency time.Duration `json:"p95_latency_ns,omitempty"`
|
||||
PeakInUse int `json:"peak_in_use"`
|
||||
PeakWaitCount int64 `json:"peak_wait_count"`
|
||||
TotalWaitTime time.Duration `json:"total_wait_time_ns"`
|
||||
}
|
||||
|
||||
// GetMetrics returns the current metrics without performing a health check.
|
||||
func (s *Store) GetMetrics() MetricsSummary {
|
||||
if s.metrics == nil {
|
||||
return MetricsSummary{}
|
||||
}
|
||||
return s.metrics.GetMetricsSummary()
|
||||
}
|
||||
|
||||
// ResetMetrics resets the metrics collector (useful for testing or after major changes).
|
||||
func (s *Store) ResetMetrics() {
|
||||
if s.metrics != nil {
|
||||
s.metrics = NewPoolMetrics(s.metrics.windowSize)
|
||||
}
|
||||
}
|
||||
|
||||
// WithTimeout wraps a context with the given timeout and logs slow queries.
|
||||
// Returns the wrapped context and a cancel function that should be called when done.
|
||||
func (s *Store) WithTimeout(ctx context.Context, timeout time.Duration, operation string) (context.Context, context.CancelFunc) {
|
||||
timeoutCtx, cancel := context.WithTimeout(ctx, timeout)
|
||||
start := time.Now()
|
||||
|
||||
// Return wrapped cancel that logs if query was slow
|
||||
return timeoutCtx, func() {
|
||||
elapsed := time.Since(start)
|
||||
cancel()
|
||||
|
||||
// Log slow queries (> 100ms)
|
||||
if elapsed > 100*time.Millisecond {
|
||||
log.Warn().
|
||||
Str("operation", operation).
|
||||
Dur("elapsed", elapsed).
|
||||
Dur("timeout", timeout).
|
||||
Msg("Slow database operation")
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// ExecWithTimeout executes a raw SQL query with timeout.
|
||||
// Returns error if query takes longer than timeout.
|
||||
func (s *Store) ExecWithTimeout(ctx context.Context, timeout time.Duration, query string, args ...any) error {
|
||||
timeoutCtx, cancel := s.WithTimeout(ctx, timeout, "exec")
|
||||
defer cancel()
|
||||
|
||||
_, err := s.sqlDB.ExecContext(timeoutCtx, query, args...)
|
||||
if err != nil {
|
||||
if err == context.DeadlineExceeded {
|
||||
return fmt.Errorf("query timeout after %v: %s", timeout, query)
|
||||
}
|
||||
return err
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// QueryRowWithTimeout executes a row query with timeout.
|
||||
func (s *Store) QueryRowWithTimeout(ctx context.Context, timeout time.Duration, query string, args ...any) *sql.Row {
|
||||
timeoutCtx, cancel := s.WithTimeout(ctx, timeout, "query_row")
|
||||
// Note: cancel will be called when row.Scan() completes or errors
|
||||
_ = cancel // Caller must ensure proper cleanup
|
||||
return s.sqlDB.QueryRowContext(timeoutCtx, query, args...)
|
||||
}
|
||||
|
||||
// TransactionWithTimeout wraps a transaction function with timeout handling.
|
||||
// The transaction is automatically rolled back if the context times out.
|
||||
func (s *Store) TransactionWithTimeout(ctx context.Context, timeout time.Duration, fn func(*gorm.DB) error) error {
|
||||
timeoutCtx, cancel := s.WithTimeout(ctx, timeout, "transaction")
|
||||
defer cancel()
|
||||
|
||||
return s.DB.WithContext(timeoutCtx).Transaction(func(tx *gorm.DB) error {
|
||||
// Check context before proceeding
|
||||
select {
|
||||
case <-timeoutCtx.Done():
|
||||
return timeoutCtx.Err()
|
||||
default:
|
||||
}
|
||||
return fn(tx)
|
||||
})
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user