mirror of
https://github.com/lukaszraczylo/claude-mnemonic.git
synced 2026-06-09 23:59:40 +00:00
d04b60517a
* Make things 'betterer' across the board * fix: reorganize struct fields and config parameters for consistency - [x] Reorder Config struct fields alphabetically and by related functionality - [x] Reorganize Observation model fields with archival fields grouped together - [x] Reorder ObservationStore fields to group related members - [x] Reorder Store struct fields with health check caching grouped - [x] Reorganize HealthInfo and PoolMetrics struct field order - [x] Reorder maintenance Service struct fields logically - [x] Reorganize MCP server handler parameter structs alphabetically - [x] Reorder pattern detector candidate tracking fields - [x] Reorganize search Manager struct fields by functionality - [x] Reorder vector Client struct fields with mutex protections grouped - [x] Reorganize handler request/response struct fields - [x] Update handlers_test.go to expect wrapped response format - [x] Reorder middleware TokenAuth and rate limiter fields - [x] Reorganize Service struct fields with grouped functionality - [x] Fix RateLimiter field ordering for clarity - [x] Reorder CircuitBreaker metrics fields * fix(security): improve JSON output safety and path traversal protection - [x] Replace unsafe JSON string formatting with proper json.Marshal in export handler - [x] Remove escapeJSONString helper function in favor of standard JSON marshaling - [x] Add safeResolvePath function to validate paths and prevent directory traversal - [x] Apply path traversal validation in captureFileMtimes operations - [x] Cap result slice capacity in getRecentSearchQueries to prevent DoS via excessive allocation * fix(sdk): improve path traversal protection and allocation safety - [x] Enhance safeResolvePath with stricter validation using filepath.Rel - [x] Reject paths containing ".." after cleaning to prevent traversal - [x] Validate absolute paths are within cwd when cwd is specified - [x] Apply safeResolvePath validation to GetFileContent for consistency - [x] Add comprehensive test coverage for path traversal protection - [x] Fix allocation safety in getRecentSearchQueries by using constant capacity
540 lines
16 KiB
Go
540 lines
16 KiB
Go
// Package gorm provides GORM-based database operations for claude-mnemonic.
|
|
package gorm
|
|
|
|
import (
|
|
"context"
|
|
"database/sql"
|
|
"fmt"
|
|
"slices"
|
|
"sync"
|
|
"time"
|
|
|
|
sqlite_vec "github.com/asg017/sqlite-vec-go-bindings/cgo"
|
|
_ "github.com/mattn/go-sqlite3" // Import SQLite driver with FTS5 support
|
|
"github.com/rs/zerolog/log"
|
|
"gorm.io/driver/sqlite"
|
|
"gorm.io/gorm"
|
|
"gorm.io/gorm/logger"
|
|
)
|
|
|
|
// Store represents the GORM database connection with sqlite-vec support.
|
|
type Store struct {
|
|
healthCacheTime time.Time
|
|
DB *gorm.DB
|
|
sqlDB *sql.DB
|
|
metrics *PoolMetrics
|
|
cachedHealth *HealthInfo
|
|
healthCacheTTL time.Duration
|
|
healthCacheMu sync.RWMutex
|
|
}
|
|
|
|
// Config holds database configuration.
|
|
type Config struct {
|
|
Path string // Path to SQLite database file
|
|
MaxConns int // Maximum number of open connections (default: 4)
|
|
LogLevel logger.LogLevel // GORM log level (logger.Silent for production)
|
|
}
|
|
|
|
// NewStore creates a new Store with WAL mode enabled and sqlite-vec registered.
|
|
// CRITICAL: WAL mode and foreign keys are enabled via pragmas for concurrent reads.
|
|
func NewStore(cfg Config) (*Store, error) {
|
|
// 1. Register sqlite-vec extension (must be done before opening database)
|
|
sqlite_vec.Auto()
|
|
|
|
// 2. Build connection string (foreign keys enabled in DSN)
|
|
// Use sqlite3 driver (mattn/go-sqlite3) which has FTS5 support
|
|
dsn := cfg.Path + "?_foreign_keys=ON"
|
|
|
|
// 3. Open raw database connection with mattn/go-sqlite3 (has FTS5 support)
|
|
sqlDB, err := sql.Open("sqlite3", dsn)
|
|
if err != nil {
|
|
return nil, fmt.Errorf("open database: %w", err)
|
|
}
|
|
|
|
// 4. Wrap with GORM using existing connection
|
|
db, err := gorm.Open(sqlite.Dialector{
|
|
Conn: sqlDB,
|
|
}, &gorm.Config{
|
|
Logger: logger.Default.LogMode(cfg.LogLevel),
|
|
// PrepareStmt enables prepared statement caching for performance
|
|
PrepareStmt: true,
|
|
// Disable default timestamp fields (we manage created_at manually)
|
|
NowFunc: nil,
|
|
})
|
|
if err != nil {
|
|
_ = sqlDB.Close() // Explicitly ignore close error during cleanup
|
|
return nil, fmt.Errorf("open gorm: %w", err)
|
|
}
|
|
|
|
// 5. Configure connection pool (same settings as current implementation)
|
|
maxConns := cfg.MaxConns
|
|
if maxConns <= 0 {
|
|
maxConns = 4
|
|
}
|
|
sqlDB.SetMaxOpenConns(maxConns)
|
|
sqlDB.SetMaxIdleConns(maxConns)
|
|
sqlDB.SetConnMaxLifetime(0) // Never expire (SQLite connections are cheap)
|
|
|
|
// 6. Verify connection
|
|
if err := sqlDB.Ping(); err != nil {
|
|
return nil, fmt.Errorf("ping database: %w", err)
|
|
}
|
|
|
|
store := &Store{
|
|
DB: db,
|
|
sqlDB: sqlDB,
|
|
metrics: NewPoolMetrics(100), // Track last 100 latency samples
|
|
healthCacheTTL: 5 * time.Second, // Cache health checks for 5 seconds
|
|
}
|
|
|
|
// 7. Run migrations FIRST (before PRAGMA commands)
|
|
if err := runMigrations(db, sqlDB); err != nil {
|
|
return nil, fmt.Errorf("run migrations: %w", err)
|
|
}
|
|
|
|
// 8. CRITICAL: Set WAL mode and other performance pragmas
|
|
// Use raw sqlDB to avoid GORM transaction issues
|
|
pragmas := []string{
|
|
"PRAGMA journal_mode=WAL",
|
|
"PRAGMA synchronous=NORMAL",
|
|
"PRAGMA cache_size=-64000", // 64MB cache (negative = KB)
|
|
"PRAGMA temp_store=MEMORY", // Store temp tables in memory
|
|
"PRAGMA mmap_size=268435456", // 256MB memory-mapped I/O
|
|
"PRAGMA page_size=4096", // 4KB pages (optimal for most systems)
|
|
}
|
|
for _, pragma := range pragmas {
|
|
if _, err := sqlDB.Exec(pragma); err != nil {
|
|
log.Warn().Str("pragma", pragma).Err(err).Msg("Failed to set pragma (non-fatal)")
|
|
}
|
|
}
|
|
// Set busy timeout to 5 seconds to handle concurrent writes
|
|
// This allows SQLite to retry when database is locked instead of failing immediately
|
|
if _, err := sqlDB.Exec("PRAGMA busy_timeout=5000"); err != nil {
|
|
return nil, fmt.Errorf("set busy timeout: %w", err)
|
|
}
|
|
|
|
// 9. Warm the connection pool
|
|
store.WarmPool(maxConns)
|
|
|
|
return store, nil
|
|
}
|
|
|
|
// WarmPool pre-creates connections to avoid cold start latency.
|
|
func (s *Store) WarmPool(numConns int) {
|
|
if numConns <= 0 {
|
|
numConns = 4
|
|
}
|
|
|
|
var wg sync.WaitGroup
|
|
for i := 0; i < numConns; i++ {
|
|
wg.Add(1)
|
|
go func() {
|
|
defer wg.Done()
|
|
ctx, cancel := context.WithTimeout(context.Background(), 2*time.Second)
|
|
defer cancel()
|
|
|
|
conn, err := s.sqlDB.Conn(ctx)
|
|
if err != nil {
|
|
return
|
|
}
|
|
// Execute a simple query to ensure the connection is fully initialized
|
|
_ = conn.PingContext(ctx)
|
|
// Return connection to pool (don't close it)
|
|
_ = conn.Close()
|
|
}()
|
|
}
|
|
wg.Wait()
|
|
log.Debug().Int("connections", numConns).Msg("Connection pool warmed")
|
|
}
|
|
|
|
// Close closes the database connection.
|
|
func (s *Store) Close() error {
|
|
return s.sqlDB.Close()
|
|
}
|
|
|
|
// Ping verifies the database connection is alive.
|
|
func (s *Store) Ping() error {
|
|
return s.sqlDB.Ping()
|
|
}
|
|
|
|
// GetRawDB returns the underlying *sql.DB for operations GORM can't handle.
|
|
// Use this for:
|
|
// - FTS5 full-text search queries (MATCH operator)
|
|
// - sqlite-vec vector operations
|
|
// - Complex raw SQL queries
|
|
func (s *Store) GetRawDB() *sql.DB {
|
|
return s.sqlDB
|
|
}
|
|
|
|
// GetDB returns the GORM DB instance for standard queries.
|
|
func (s *Store) GetDB() *gorm.DB {
|
|
return s.DB
|
|
}
|
|
|
|
// Stats returns database connection pool statistics.
|
|
func (s *Store) Stats() sql.DBStats {
|
|
return s.sqlDB.Stats()
|
|
}
|
|
|
|
// Optimize runs VACUUM and ANALYZE to optimize the database.
|
|
// Should be called periodically (e.g., daily) during low activity.
|
|
func (s *Store) Optimize(ctx context.Context) error {
|
|
log.Info().Msg("Starting database optimization")
|
|
start := time.Now()
|
|
|
|
// ANALYZE updates statistics for query optimizer
|
|
if _, err := s.sqlDB.ExecContext(ctx, "ANALYZE"); err != nil {
|
|
return fmt.Errorf("analyze: %w", err)
|
|
}
|
|
|
|
// PRAGMA optimize runs optimization based on query statistics
|
|
if _, err := s.sqlDB.ExecContext(ctx, "PRAGMA optimize"); err != nil {
|
|
log.Warn().Err(err).Msg("PRAGMA optimize failed (non-fatal)")
|
|
}
|
|
|
|
log.Info().Dur("duration", time.Since(start)).Msg("Database optimization complete")
|
|
return nil
|
|
}
|
|
|
|
// HealthCheck performs a comprehensive health check with latency measurement.
|
|
// Returns detailed health information including connection pool stats and query latency.
|
|
// Results are cached for healthCacheTTL (default 5 seconds) to reduce database load
|
|
// from frequent monitoring calls.
|
|
func (s *Store) HealthCheck(ctx context.Context) *HealthInfo {
|
|
// Fast path: check cache with read lock
|
|
s.healthCacheMu.RLock()
|
|
if s.cachedHealth != nil && time.Since(s.healthCacheTime) < s.healthCacheTTL {
|
|
cached := s.cachedHealth
|
|
s.healthCacheMu.RUnlock()
|
|
return cached
|
|
}
|
|
s.healthCacheMu.RUnlock()
|
|
|
|
// Slow path: perform actual health check
|
|
info := s.performHealthCheck(ctx)
|
|
|
|
// Cache the result
|
|
s.healthCacheMu.Lock()
|
|
s.cachedHealth = info
|
|
s.healthCacheTime = time.Now()
|
|
s.healthCacheMu.Unlock()
|
|
|
|
return info
|
|
}
|
|
|
|
// HealthCheckForce performs a health check bypassing the cache.
|
|
// Use this when you need real-time health data (e.g., debugging, alerting).
|
|
func (s *Store) HealthCheckForce(ctx context.Context) *HealthInfo {
|
|
info := s.performHealthCheck(ctx)
|
|
|
|
// Update the cache with fresh data
|
|
s.healthCacheMu.Lock()
|
|
s.cachedHealth = info
|
|
s.healthCacheTime = time.Now()
|
|
s.healthCacheMu.Unlock()
|
|
|
|
return info
|
|
}
|
|
|
|
// performHealthCheck does the actual health check work.
|
|
func (s *Store) performHealthCheck(ctx context.Context) *HealthInfo {
|
|
info := &HealthInfo{
|
|
Status: "healthy",
|
|
Timestamp: time.Now(),
|
|
}
|
|
|
|
// Check pool stats
|
|
stats := s.sqlDB.Stats()
|
|
info.PoolStats = PoolStats{
|
|
OpenConnections: stats.OpenConnections,
|
|
InUse: stats.InUse,
|
|
Idle: stats.Idle,
|
|
WaitCount: stats.WaitCount,
|
|
WaitDuration: stats.WaitDuration,
|
|
MaxIdleClosed: stats.MaxIdleClosed,
|
|
MaxLifetimeClosed: stats.MaxLifetimeClosed,
|
|
}
|
|
|
|
// Record pool stats for metrics tracking
|
|
if s.metrics != nil {
|
|
s.metrics.RecordPoolStats(stats)
|
|
}
|
|
|
|
// Measure query latency with a simple SELECT
|
|
start := time.Now()
|
|
var dummy int
|
|
err := s.sqlDB.QueryRowContext(ctx, "SELECT 1").Scan(&dummy)
|
|
info.QueryLatency = time.Since(start)
|
|
|
|
// Record latency for historical tracking
|
|
if s.metrics != nil {
|
|
s.metrics.RecordLatency(info.QueryLatency)
|
|
info.HistoricalMetrics = s.metrics.GetMetricsSummary()
|
|
}
|
|
|
|
if err != nil {
|
|
info.Status = "unhealthy"
|
|
info.Error = err.Error()
|
|
return info
|
|
}
|
|
|
|
// Check for connection saturation (degraded if pool is heavily used)
|
|
if stats.InUse > 0 && float64(stats.InUse)/float64(stats.OpenConnections) > 0.8 {
|
|
info.Status = "degraded"
|
|
info.Warning = "Connection pool heavily utilized"
|
|
}
|
|
|
|
// Check for wait contention
|
|
if stats.WaitCount > 100 && stats.WaitDuration > 100*time.Millisecond {
|
|
info.Status = "degraded"
|
|
info.Warning = "Connection pool contention detected"
|
|
}
|
|
|
|
// Check query latency (warn if > 10ms for simple query)
|
|
if info.QueryLatency > 10*time.Millisecond {
|
|
if info.Status == "healthy" {
|
|
info.Status = "degraded"
|
|
}
|
|
info.Warning = fmt.Sprintf("Slow query latency: %v", info.QueryLatency)
|
|
}
|
|
|
|
// Check historical latency trend (degraded if P95 is high)
|
|
if s.metrics != nil && info.HistoricalMetrics.P95Latency > 50*time.Millisecond {
|
|
if info.Status == "healthy" {
|
|
info.Status = "degraded"
|
|
}
|
|
info.Warning = fmt.Sprintf("High P95 latency: %v", info.HistoricalMetrics.P95Latency)
|
|
}
|
|
|
|
return info
|
|
}
|
|
|
|
// HealthInfo contains database health check results.
|
|
type HealthInfo struct {
|
|
Timestamp time.Time `json:"timestamp"`
|
|
Status string `json:"status"`
|
|
Error string `json:"error,omitempty"`
|
|
Warning string `json:"warning,omitempty"`
|
|
HistoricalMetrics MetricsSummary `json:"historical_metrics,omitempty"`
|
|
PoolStats PoolStats `json:"pool_stats"`
|
|
QueryLatency time.Duration `json:"query_latency_ns"`
|
|
}
|
|
|
|
// PoolStats contains connection pool statistics.
|
|
type PoolStats struct {
|
|
OpenConnections int `json:"open_connections"`
|
|
InUse int `json:"in_use"`
|
|
Idle int `json:"idle"`
|
|
WaitCount int64 `json:"wait_count"`
|
|
WaitDuration time.Duration `json:"wait_duration_ns"`
|
|
MaxIdleClosed int64 `json:"max_idle_closed"`
|
|
MaxLifetimeClosed int64 `json:"max_lifetime_closed"`
|
|
}
|
|
|
|
// QueryTimeout constants for different query types.
|
|
const (
|
|
// DefaultQueryTimeout is the default timeout for regular queries.
|
|
DefaultQueryTimeout = 5 * time.Second
|
|
// FastQueryTimeout is for queries that should be very fast (health checks, etc).
|
|
FastQueryTimeout = 1 * time.Second
|
|
// SlowQueryTimeout is for queries that may take longer (bulk operations, rebuilds).
|
|
SlowQueryTimeout = 30 * time.Second
|
|
)
|
|
|
|
// PoolMetrics tracks historical connection pool metrics with a sliding window.
|
|
type PoolMetrics struct {
|
|
lastSampleTime time.Time
|
|
latencySamples []time.Duration
|
|
latencyIdx int
|
|
latencyCount int
|
|
totalQueries int64
|
|
totalWaitTime time.Duration
|
|
peakInUse int
|
|
peakWaitCount int64
|
|
windowSize int
|
|
mu sync.RWMutex
|
|
}
|
|
|
|
// NewPoolMetrics creates a new pool metrics collector with the given window size.
|
|
func NewPoolMetrics(windowSize int) *PoolMetrics {
|
|
if windowSize <= 0 {
|
|
windowSize = 100 // Default: track last 100 samples
|
|
}
|
|
return &PoolMetrics{
|
|
latencySamples: make([]time.Duration, windowSize),
|
|
windowSize: windowSize,
|
|
lastSampleTime: time.Now(),
|
|
}
|
|
}
|
|
|
|
// RecordLatency records a query latency sample.
|
|
func (m *PoolMetrics) RecordLatency(latency time.Duration) {
|
|
m.mu.Lock()
|
|
defer m.mu.Unlock()
|
|
|
|
m.latencySamples[m.latencyIdx] = latency
|
|
m.latencyIdx = (m.latencyIdx + 1) % m.windowSize
|
|
if m.latencyCount < m.windowSize {
|
|
m.latencyCount++
|
|
}
|
|
m.totalQueries++
|
|
m.lastSampleTime = time.Now()
|
|
}
|
|
|
|
// RecordPoolStats records pool statistics for peak tracking.
|
|
func (m *PoolMetrics) RecordPoolStats(stats sql.DBStats) {
|
|
m.mu.Lock()
|
|
defer m.mu.Unlock()
|
|
|
|
if stats.InUse > m.peakInUse {
|
|
m.peakInUse = stats.InUse
|
|
}
|
|
if stats.WaitCount > m.peakWaitCount {
|
|
m.peakWaitCount = stats.WaitCount
|
|
}
|
|
m.totalWaitTime += stats.WaitDuration
|
|
}
|
|
|
|
// GetMetricsSummary returns a summary of collected metrics.
|
|
func (m *PoolMetrics) GetMetricsSummary() MetricsSummary {
|
|
m.mu.RLock()
|
|
defer m.mu.RUnlock()
|
|
|
|
summary := MetricsSummary{
|
|
TotalQueries: m.totalQueries,
|
|
SampleCount: m.latencyCount,
|
|
PeakInUse: m.peakInUse,
|
|
PeakWaitCount: m.peakWaitCount,
|
|
TotalWaitTime: m.totalWaitTime,
|
|
LastSampleTime: m.lastSampleTime,
|
|
}
|
|
|
|
if m.latencyCount == 0 {
|
|
return summary
|
|
}
|
|
|
|
// Calculate latency statistics
|
|
var total time.Duration
|
|
var min, max time.Duration = m.latencySamples[0], m.latencySamples[0]
|
|
|
|
for i := 0; i < m.latencyCount; i++ {
|
|
sample := m.latencySamples[i]
|
|
total += sample
|
|
if sample < min {
|
|
min = sample
|
|
}
|
|
if sample > max {
|
|
max = sample
|
|
}
|
|
}
|
|
|
|
summary.AvgLatency = total / time.Duration(m.latencyCount)
|
|
summary.MinLatency = min
|
|
summary.MaxLatency = max
|
|
|
|
// Calculate P95 latency (approximate using sorted samples)
|
|
if m.latencyCount >= 20 {
|
|
// Copy samples for sorting
|
|
samples := make([]time.Duration, m.latencyCount)
|
|
copy(samples, m.latencySamples[:m.latencyCount])
|
|
// Use slices.Sort for O(n log n) instead of O(n²) insertion sort
|
|
slices.Sort(samples)
|
|
p95Idx := int(float64(len(samples)) * 0.95)
|
|
summary.P95Latency = samples[p95Idx]
|
|
}
|
|
|
|
return summary
|
|
}
|
|
|
|
// MetricsSummary contains aggregated pool metrics.
|
|
type MetricsSummary struct {
|
|
LastSampleTime time.Time `json:"last_sample_time"`
|
|
TotalQueries int64 `json:"total_queries"`
|
|
SampleCount int `json:"sample_count"`
|
|
AvgLatency time.Duration `json:"avg_latency_ns"`
|
|
MinLatency time.Duration `json:"min_latency_ns"`
|
|
MaxLatency time.Duration `json:"max_latency_ns"`
|
|
P95Latency time.Duration `json:"p95_latency_ns,omitempty"`
|
|
PeakInUse int `json:"peak_in_use"`
|
|
PeakWaitCount int64 `json:"peak_wait_count"`
|
|
TotalWaitTime time.Duration `json:"total_wait_time_ns"`
|
|
}
|
|
|
|
// GetMetrics returns the current metrics without performing a health check.
|
|
func (s *Store) GetMetrics() MetricsSummary {
|
|
if s.metrics == nil {
|
|
return MetricsSummary{}
|
|
}
|
|
return s.metrics.GetMetricsSummary()
|
|
}
|
|
|
|
// ResetMetrics resets the metrics collector (useful for testing or after major changes).
|
|
func (s *Store) ResetMetrics() {
|
|
if s.metrics != nil {
|
|
s.metrics = NewPoolMetrics(s.metrics.windowSize)
|
|
}
|
|
}
|
|
|
|
// WithTimeout wraps a context with the given timeout and logs slow queries.
|
|
// Returns the wrapped context and a cancel function that should be called when done.
|
|
func (s *Store) WithTimeout(ctx context.Context, timeout time.Duration, operation string) (context.Context, context.CancelFunc) {
|
|
timeoutCtx, cancel := context.WithTimeout(ctx, timeout)
|
|
start := time.Now()
|
|
|
|
// Return wrapped cancel that logs if query was slow
|
|
return timeoutCtx, func() {
|
|
elapsed := time.Since(start)
|
|
cancel()
|
|
|
|
// Log slow queries (> 100ms)
|
|
if elapsed > 100*time.Millisecond {
|
|
log.Warn().
|
|
Str("operation", operation).
|
|
Dur("elapsed", elapsed).
|
|
Dur("timeout", timeout).
|
|
Msg("Slow database operation")
|
|
}
|
|
}
|
|
}
|
|
|
|
// ExecWithTimeout executes a raw SQL query with timeout.
|
|
// Returns error if query takes longer than timeout.
|
|
func (s *Store) ExecWithTimeout(ctx context.Context, timeout time.Duration, query string, args ...any) error {
|
|
timeoutCtx, cancel := s.WithTimeout(ctx, timeout, "exec")
|
|
defer cancel()
|
|
|
|
_, err := s.sqlDB.ExecContext(timeoutCtx, query, args...)
|
|
if err != nil {
|
|
if err == context.DeadlineExceeded {
|
|
return fmt.Errorf("query timeout after %v: %s", timeout, query)
|
|
}
|
|
return err
|
|
}
|
|
return nil
|
|
}
|
|
|
|
// QueryRowWithTimeout executes a row query with timeout.
|
|
func (s *Store) QueryRowWithTimeout(ctx context.Context, timeout time.Duration, query string, args ...any) *sql.Row {
|
|
timeoutCtx, cancel := s.WithTimeout(ctx, timeout, "query_row")
|
|
// Note: cancel will be called when row.Scan() completes or errors
|
|
_ = cancel // Caller must ensure proper cleanup
|
|
return s.sqlDB.QueryRowContext(timeoutCtx, query, args...)
|
|
}
|
|
|
|
// TransactionWithTimeout wraps a transaction function with timeout handling.
|
|
// The transaction is automatically rolled back if the context times out.
|
|
func (s *Store) TransactionWithTimeout(ctx context.Context, timeout time.Duration, fn func(*gorm.DB) error) error {
|
|
timeoutCtx, cancel := s.WithTimeout(ctx, timeout, "transaction")
|
|
defer cancel()
|
|
|
|
return s.DB.WithContext(timeoutCtx).Transaction(func(tx *gorm.DB) error {
|
|
// Check context before proceeding
|
|
select {
|
|
case <-timeoutCtx.Done():
|
|
return timeoutCtx.Err()
|
|
default:
|
|
}
|
|
return fn(tx)
|
|
})
|
|
}
|