mirror of
https://github.com/lukaszraczylo/claude-mnemonic.git
synced 2026-06-09 23:59:40 +00:00
fix: bound SQLite WAL growth and prevent worker hangs (#49)
The worker's SQLite WAL could grow unbounded (observed 19MB) and wedge the DB, hanging Claude Code on every prompt. No checkpoint ever truncated the WAL (only PASSIVE auto-checkpoint, which cannot reclaim the file), the connection-scoped pragmas were set via a single Exec so only one pooled connection received them (e.g. busy_timeout=0 on the rest), and the maintenance service that would optimize/checkpoint was never wired up. - Register a sqlite3 ConnectHook driver so all pragmas (busy_timeout, journal_mode, synchronous, cache_size, foreign_keys, journal_size_limit) apply to every pooled connection; enable safe connection recycling. - Add Store.Checkpoint (TRUNCATE), checkpoint-on-Close, and a periodic size-gated checkpoint loop with configurable interval/threshold. - Wire up the previously-dead maintenance service; make trigger_maintenance actually run DB maintenance instead of only recalculating scores. - Harden the user-prompt hook to honor its deadline and fail open so a slow worker can never stall a prompt. - Add regression tests for WAL truncation, checkpoint-on-close, and per-connection pragmas.
This commit is contained in:
@@ -91,18 +91,21 @@ func handleUserPrompt(ctx *hooks.HookContext, input *Input) (string, error) {
|
||||
observationCount int
|
||||
)
|
||||
|
||||
// Start search in background
|
||||
// Start search in background. Pass the deadline context so a wedged worker
|
||||
// aborts the request at the deadline instead of blocking for the full
|
||||
// hookClient timeout (10s). Errors are ignored -- fail open with no memory.
|
||||
wg.Add(1)
|
||||
go func() {
|
||||
defer wg.Done()
|
||||
searchResult, _ = hooks.GET(ctx.Port, searchURL)
|
||||
searchResult, _ = hooks.GETWithContext(deadline, ctx.Port, searchURL)
|
||||
}()
|
||||
|
||||
// Start session init in parallel (with observationCount=0; approximate is fine)
|
||||
// Start session init in parallel (with observationCount=0; approximate is fine).
|
||||
// Deadline context guards this call too.
|
||||
wg.Add(1)
|
||||
go func() {
|
||||
defer wg.Done()
|
||||
initResult, initErr = hooks.POST(ctx.Port, "/api/sessions/init", map[string]interface{}{
|
||||
initResult, initErr = hooks.POSTWithContextResult(deadline, ctx.Port, "/api/sessions/init", map[string]interface{}{
|
||||
"claudeSessionId": ctx.SessionID,
|
||||
"project": ctx.Project,
|
||||
"prompt": input.Prompt,
|
||||
@@ -113,7 +116,8 @@ func handleUserPrompt(ctx *hooks.HookContext, input *Input) (string, error) {
|
||||
// Wait for both to complete
|
||||
wg.Wait()
|
||||
|
||||
// Check deadline after network calls
|
||||
// Check deadline after network calls -- if exceeded, fail open (proceed with
|
||||
// no injected memory) rather than blocking or erroring the user's prompt.
|
||||
if deadline.Err() != nil {
|
||||
return "", nil
|
||||
}
|
||||
@@ -173,9 +177,11 @@ func handleUserPrompt(ctx *hooks.HookContext, input *Input) (string, error) {
|
||||
contextToInject = contextBuilder
|
||||
}
|
||||
|
||||
// Check session init result
|
||||
// Check session init result. A session-init failure must never block the
|
||||
// prompt: degrade gracefully and still inject any memory we found.
|
||||
if initErr != nil {
|
||||
return "", initErr
|
||||
fmt.Fprintf(os.Stderr, "[user-prompt] Session init failed: %v\n", initErr)
|
||||
return contextToInject, nil
|
||||
}
|
||||
if initResult == nil {
|
||||
return contextToInject, nil // Non-JSON response from worker, skip session init
|
||||
@@ -201,13 +207,15 @@ func handleUserPrompt(ctx *hooks.HookContext, input *Input) (string, error) {
|
||||
|
||||
fmt.Fprintf(os.Stderr, "[user-prompt] Session %d, prompt #%d\n", sessionID, promptNumber)
|
||||
|
||||
// Start SDK agent (depends on session init result, so kept sequential)
|
||||
_, err := hooks.POST(ctx.Port, fmt.Sprintf("/sessions/%d/init", sessionID), map[string]interface{}{
|
||||
// Start SDK agent (depends on session init result, so kept sequential).
|
||||
// Deadline-guarded so a wedged worker cannot stall past the hook budget.
|
||||
// Failure here must not block the prompt: degrade gracefully, still inject memory.
|
||||
if _, err := hooks.POSTWithContextResult(deadline, ctx.Port, fmt.Sprintf("/sessions/%d/init", sessionID), map[string]interface{}{
|
||||
"userPrompt": input.Prompt,
|
||||
"promptNumber": promptNumber,
|
||||
})
|
||||
if err != nil {
|
||||
return "", err
|
||||
}); err != nil {
|
||||
fmt.Fprintf(os.Stderr, "[user-prompt] SDK agent init failed: %v\n", err)
|
||||
return contextToInject, nil
|
||||
}
|
||||
|
||||
// Return context if we found relevant observations
|
||||
|
||||
+51
-38
@@ -39,44 +39,46 @@ var CriticalConcepts = []string{
|
||||
// Config holds the application configuration.
|
||||
// Field order optimized for memory alignment (fieldalignment).
|
||||
type Config struct {
|
||||
ContextFullField string `json:"context_full_field"`
|
||||
DBPath string `json:"db_path"`
|
||||
Model string `json:"model"`
|
||||
ClaudeCodePath string `json:"claude_code_path"`
|
||||
EmbeddingModel string `json:"embedding_model"`
|
||||
VectorStorageStrategy string `json:"vector_storage_strategy"`
|
||||
ContextObsConcepts []string `json:"context_obs_concepts"`
|
||||
ContextObsTypes []string `json:"context_obs_types"`
|
||||
ContextFullCount int `json:"context_full_count"`
|
||||
GraphBranchFactor int `json:"graph_branch_factor"`
|
||||
GraphEdgeWeight float64 `json:"graph_edge_weight"`
|
||||
ContextRelevanceThreshold float64 `json:"context_relevance_threshold"`
|
||||
RerankingCandidates int `json:"reranking_candidates"`
|
||||
WorkerPort int `json:"worker_port"`
|
||||
DeduplicationThreshold float64 `json:"deduplication_threshold"`
|
||||
RerankingMinImprovement float64 `json:"reranking_min_improvement"`
|
||||
ContextObservations int `json:"context_observations"`
|
||||
ContextMaxPromptResults int `json:"context_max_prompt_results"`
|
||||
ContextSessionCount int `json:"context_session_count"`
|
||||
MaxConns int `json:"max_conns"`
|
||||
RerankingAlpha float64 `json:"reranking_alpha"`
|
||||
GraphMaxHops int `json:"graph_max_hops"`
|
||||
RerankingResults int `json:"reranking_results"`
|
||||
GraphRebuildIntervalMin int `json:"graph_rebuild_interval_min"`
|
||||
HubThreshold int `json:"hub_threshold"`
|
||||
ObservationRetentionDays int `json:"observation_retention_days"`
|
||||
MaintenanceIntervalHours int `json:"maintenance_interval_hours"`
|
||||
ContextMaxTokensStartup int `json:"context_max_tokens_startup"`
|
||||
ContextMaxTokensPrompt int `json:"context_max_tokens_prompt"`
|
||||
ContextShowWorkTokens bool `json:"context_show_work_tokens"`
|
||||
ContextShowReadTokens bool `json:"context_show_read_tokens"`
|
||||
RerankingPureMode bool `json:"reranking_pure_mode"`
|
||||
GraphEnabled bool `json:"graph_enabled"`
|
||||
DeduplicationEnabled bool `json:"deduplication_enabled"`
|
||||
MaintenanceEnabled bool `json:"maintenance_enabled"`
|
||||
RerankingEnabled bool `json:"reranking_enabled"`
|
||||
ContextShowLastSummary bool `json:"context_show_last_summary"`
|
||||
CleanupStaleObservations bool `json:"cleanup_stale_observations"`
|
||||
ContextFullField string `json:"context_full_field"`
|
||||
DBPath string `json:"db_path"`
|
||||
Model string `json:"model"`
|
||||
ClaudeCodePath string `json:"claude_code_path"`
|
||||
EmbeddingModel string `json:"embedding_model"`
|
||||
VectorStorageStrategy string `json:"vector_storage_strategy"`
|
||||
ContextObsConcepts []string `json:"context_obs_concepts"`
|
||||
ContextObsTypes []string `json:"context_obs_types"`
|
||||
ContextFullCount int `json:"context_full_count"`
|
||||
GraphBranchFactor int `json:"graph_branch_factor"`
|
||||
GraphEdgeWeight float64 `json:"graph_edge_weight"`
|
||||
ContextRelevanceThreshold float64 `json:"context_relevance_threshold"`
|
||||
RerankingCandidates int `json:"reranking_candidates"`
|
||||
WorkerPort int `json:"worker_port"`
|
||||
DeduplicationThreshold float64 `json:"deduplication_threshold"`
|
||||
RerankingMinImprovement float64 `json:"reranking_min_improvement"`
|
||||
ContextObservations int `json:"context_observations"`
|
||||
ContextMaxPromptResults int `json:"context_max_prompt_results"`
|
||||
ContextSessionCount int `json:"context_session_count"`
|
||||
MaxConns int `json:"max_conns"`
|
||||
RerankingAlpha float64 `json:"reranking_alpha"`
|
||||
GraphMaxHops int `json:"graph_max_hops"`
|
||||
RerankingResults int `json:"reranking_results"`
|
||||
GraphRebuildIntervalMin int `json:"graph_rebuild_interval_min"`
|
||||
HubThreshold int `json:"hub_threshold"`
|
||||
ObservationRetentionDays int `json:"observation_retention_days"`
|
||||
MaintenanceIntervalHours int `json:"maintenance_interval_hours"`
|
||||
WALCheckpointIntervalSeconds int `json:"wal_checkpoint_interval_seconds"`
|
||||
WALCheckpointThresholdBytes int64 `json:"wal_checkpoint_threshold_bytes"`
|
||||
ContextMaxTokensStartup int `json:"context_max_tokens_startup"`
|
||||
ContextMaxTokensPrompt int `json:"context_max_tokens_prompt"`
|
||||
ContextShowWorkTokens bool `json:"context_show_work_tokens"`
|
||||
ContextShowReadTokens bool `json:"context_show_read_tokens"`
|
||||
RerankingPureMode bool `json:"reranking_pure_mode"`
|
||||
GraphEnabled bool `json:"graph_enabled"`
|
||||
DeduplicationEnabled bool `json:"deduplication_enabled"`
|
||||
MaintenanceEnabled bool `json:"maintenance_enabled"`
|
||||
RerankingEnabled bool `json:"reranking_enabled"`
|
||||
ContextShowLastSummary bool `json:"context_show_last_summary"`
|
||||
CleanupStaleObservations bool `json:"cleanup_stale_observations"`
|
||||
}
|
||||
|
||||
var (
|
||||
@@ -181,6 +183,10 @@ func Default() *Config {
|
||||
MaintenanceIntervalHours: 6, // Run every 6 hours
|
||||
ObservationRetentionDays: 0, // 0 = no age-based deletion (keep all)
|
||||
CleanupStaleObservations: false, // Don't auto-cleanup stale observations
|
||||
// WAL checkpoint loop tunables (issue #49). Defaults mirror the worker constants:
|
||||
// check the WAL every 60s and TRUNCATE-checkpoint once it reaches 4 MiB.
|
||||
WALCheckpointIntervalSeconds: 60,
|
||||
WALCheckpointThresholdBytes: 4 << 20, // 4 MiB
|
||||
}
|
||||
}
|
||||
|
||||
@@ -284,6 +290,13 @@ func Load() (*Config, error) {
|
||||
if v, ok := settings["CLAUDE_MNEMONIC_CONTEXT_MAX_TOKENS_PROMPT"].(float64); ok && v > 0 {
|
||||
cfg.ContextMaxTokensPrompt = int(v)
|
||||
}
|
||||
// WAL checkpoint loop tunables (issue #49)
|
||||
if v, ok := settings["CLAUDE_MNEMONIC_WAL_CHECKPOINT_INTERVAL_SECONDS"].(float64); ok && v > 0 {
|
||||
cfg.WALCheckpointIntervalSeconds = int(v)
|
||||
}
|
||||
if v, ok := settings["CLAUDE_MNEMONIC_WAL_CHECKPOINT_THRESHOLD_BYTES"].(float64); ok && v > 0 {
|
||||
cfg.WALCheckpointThresholdBytes = int64(v)
|
||||
}
|
||||
// Deduplication settings
|
||||
if v, ok := settings["CLAUDE_MNEMONIC_DEDUP_ENABLED"].(bool); ok {
|
||||
cfg.DeduplicationEnabled = v
|
||||
|
||||
+145
-38
@@ -5,18 +5,80 @@ import (
|
||||
"context"
|
||||
"database/sql"
|
||||
"fmt"
|
||||
"os"
|
||||
"slices"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
sqlite_vec "github.com/asg017/sqlite-vec-go-bindings/cgo"
|
||||
_ "github.com/mattn/go-sqlite3" // Import SQLite driver with FTS5 support
|
||||
sqlite3 "github.com/mattn/go-sqlite3" // SQLite driver with FTS5 support
|
||||
"github.com/rs/zerolog/log"
|
||||
"gorm.io/driver/sqlite"
|
||||
"gorm.io/gorm"
|
||||
"gorm.io/gorm/logger"
|
||||
)
|
||||
|
||||
// driverName is the name of the custom mattn/go-sqlite3 driver registered with a
|
||||
// ConnectHook that applies ALL connection pragmas (correctness + best-effort) to EVERY
|
||||
// pooled connection at open time. The stock "sqlite3" driver has no hook, so pragmas set
|
||||
// via a single post-open Exec only reach one arbitrary pooled connection (issue #49, F6).
|
||||
const driverName = "sqlite3_mnemonic"
|
||||
|
||||
// registerDriverOnce guards driver registration so it runs exactly once per process.
|
||||
// database/sql panics with "sql: Register called twice" on a duplicate name, and NewStore
|
||||
// may be called multiple times (e.g. after a config-change reinitialization).
|
||||
var registerDriverOnce sync.Once
|
||||
|
||||
// correctnessPragmas MUST succeed on every connection: getting any of them wrong changes
|
||||
// transactional/locking semantics, not just performance. A failure here aborts the open.
|
||||
var correctnessPragmas = []string{
|
||||
"PRAGMA foreign_keys=ON",
|
||||
"PRAGMA journal_mode=WAL",
|
||||
"PRAGMA synchronous=NORMAL",
|
||||
"PRAGMA busy_timeout=5000",
|
||||
"PRAGMA cache_size=-64000",
|
||||
}
|
||||
|
||||
// bestEffortPragmas are per-connection or database-wide optimizations. A failure is logged
|
||||
// and tolerated: the connection is still correct, just less tuned. (page_size only takes
|
||||
// effect on an empty database / next VACUUM, but applying it per-connection is harmless.)
|
||||
var bestEffortPragmas = []string{
|
||||
"PRAGMA temp_store=MEMORY", // Store temp tables in memory
|
||||
"PRAGMA mmap_size=268435456", // 256MB memory-mapped I/O
|
||||
"PRAGMA page_size=4096", // 4KB pages (optimal for most systems)
|
||||
"PRAGMA wal_autocheckpoint=1000", // Auto-checkpoint (PASSIVE) every 1000 WAL frames
|
||||
"PRAGMA journal_size_limit=8388608", // Backstop: cap -wal at 8MiB (issue #49)
|
||||
}
|
||||
|
||||
// connectHook applies all pragmas to a freshly opened connection. mattn/go-sqlite3 calls
|
||||
// it at the very end of Open, after DSN params and extensions, so it is authoritative.
|
||||
func connectHook(c *sqlite3.SQLiteConn) error {
|
||||
for _, pragma := range correctnessPragmas {
|
||||
if _, err := c.Exec(pragma, nil); err != nil {
|
||||
return fmt.Errorf("apply correctness pragma %q: %w", pragma, err)
|
||||
}
|
||||
}
|
||||
for _, pragma := range bestEffortPragmas {
|
||||
if _, err := c.Exec(pragma, nil); err != nil {
|
||||
log.Warn().Str("pragma", pragma).Err(err).Msg("Failed to set pragma (non-fatal)")
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// registerDriver registers the custom driver once. sqlite_vec.Auto() registers the vec
|
||||
// extension globally via sqlite3_auto_extension, which applies to connections from any
|
||||
// sqlite3-based driver, so the new driver still gets vec + FTS5.
|
||||
func registerDriver() {
|
||||
registerDriverOnce.Do(func() {
|
||||
sql.Register(driverName, &sqlite3.SQLiteDriver{
|
||||
ConnectHook: func(c *sqlite3.SQLiteConn) error {
|
||||
return connectHook(c)
|
||||
},
|
||||
})
|
||||
})
|
||||
}
|
||||
|
||||
// Store represents the GORM database connection with sqlite-vec support.
|
||||
type Store struct {
|
||||
healthCacheTime time.Time
|
||||
@@ -24,6 +86,7 @@ type Store struct {
|
||||
sqlDB *sql.DB
|
||||
metrics *PoolMetrics
|
||||
cachedHealth *HealthInfo
|
||||
path string
|
||||
healthCacheTTL time.Duration
|
||||
healthCacheMu sync.RWMutex
|
||||
}
|
||||
@@ -36,22 +99,32 @@ type Config struct {
|
||||
}
|
||||
|
||||
// NewStore creates a new Store with WAL mode enabled and sqlite-vec registered.
|
||||
// CRITICAL: WAL mode and foreign keys are enabled via pragmas for concurrent reads.
|
||||
// CRITICAL: all connection pragmas (WAL, foreign_keys, busy_timeout, etc.) are applied to
|
||||
// EVERY pooled connection via a driver ConnectHook (see registerDriver), so the pool is
|
||||
// uniformly configured and connections may be recycled safely (issue #49, F6).
|
||||
func NewStore(cfg Config) (*Store, error) {
|
||||
// 1. Register sqlite-vec extension (must be done before opening database)
|
||||
// 1. Register sqlite-vec extension (must be done before opening database).
|
||||
// sqlite_vec.Auto() uses sqlite3_auto_extension, which is global to all sqlite3-based
|
||||
// drivers, so connections from our custom driver also get the vec virtual table.
|
||||
sqlite_vec.Auto()
|
||||
|
||||
// 2. Build connection string (foreign keys enabled in DSN)
|
||||
// Use sqlite3 driver (mattn/go-sqlite3) which has FTS5 support
|
||||
// 2. Register the custom driver whose ConnectHook applies ALL pragmas to EVERY pooled
|
||||
// connection (issue #49, F6). Without this, pragmas set via a single post-open
|
||||
// sqlDB.Exec reach only one arbitrary pooled connection. The hook is authoritative.
|
||||
registerDriver()
|
||||
|
||||
// 3. Build a minimal DSN. _foreign_keys is kept as belt-and-suspenders (the hook sets
|
||||
// it too); all other pragmas are applied per-connection by the ConnectHook, so they no
|
||||
// longer need to live in the DSN.
|
||||
dsn := cfg.Path + "?_foreign_keys=ON"
|
||||
|
||||
// 3. Open raw database connection with mattn/go-sqlite3 (has FTS5 support)
|
||||
sqlDB, err := sql.Open("sqlite3", dsn)
|
||||
// 4. Open raw database connection with the custom driver (FTS5 + per-connection pragmas).
|
||||
sqlDB, err := sql.Open(driverName, dsn)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("open database: %w", err)
|
||||
}
|
||||
|
||||
// 4. Wrap with GORM using existing connection
|
||||
// 5. Wrap with GORM using existing connection
|
||||
db, err := gorm.Open(sqlite.Dialector{
|
||||
Conn: sqlDB,
|
||||
}, &gorm.Config{
|
||||
@@ -66,16 +139,25 @@ func NewStore(cfg Config) (*Store, error) {
|
||||
return nil, fmt.Errorf("open gorm: %w", err)
|
||||
}
|
||||
|
||||
// 5. Configure connection pool (same settings as current implementation)
|
||||
// 6. Configure connection pool.
|
||||
maxConns := cfg.MaxConns
|
||||
if maxConns <= 0 {
|
||||
maxConns = 4
|
||||
}
|
||||
sqlDB.SetMaxOpenConns(maxConns)
|
||||
sqlDB.SetMaxIdleConns(maxConns)
|
||||
sqlDB.SetConnMaxLifetime(0) // Never expire (SQLite connections are cheap)
|
||||
// Finite recycling (issue #49): previously SetConnMaxLifetime(0) meant connections
|
||||
// NEVER recycled, so a long-lived read connection could pin an old WAL read-mark for the
|
||||
// whole process lifetime and block TRUNCATE checkpoints from reclaiming the -wal file.
|
||||
// Recycling is safe now because the ConnectHook reapplies every correctness pragma on
|
||||
// each new connection — a recycled connection comes back fully configured, not with
|
||||
// defaults. 1h lifetime bounds read-mark staleness without churning the pool; 30m idle
|
||||
// time reclaims connections that sit unused (e.g. between sessions) so the pool shrinks
|
||||
// back to one warm connection during quiet periods, dropping their WAL read-marks.
|
||||
sqlDB.SetConnMaxLifetime(1 * time.Hour)
|
||||
sqlDB.SetConnMaxIdleTime(30 * time.Minute)
|
||||
|
||||
// 6. Verify connection
|
||||
// 7. Verify connection
|
||||
if err := sqlDB.Ping(); err != nil {
|
||||
return nil, fmt.Errorf("ping database: %w", err)
|
||||
}
|
||||
@@ -83,37 +165,18 @@ func NewStore(cfg Config) (*Store, error) {
|
||||
store := &Store{
|
||||
DB: db,
|
||||
sqlDB: sqlDB,
|
||||
path: cfg.Path,
|
||||
metrics: NewPoolMetrics(100), // Track last 100 latency samples
|
||||
healthCacheTTL: 5 * time.Second, // Cache health checks for 5 seconds
|
||||
}
|
||||
|
||||
// 7. Run migrations FIRST (before PRAGMA commands)
|
||||
// 8. Run migrations. All pragmas (correctness + best-effort) are applied per-connection
|
||||
// by the ConnectHook at open time, so there is no post-open pragma loop here anymore:
|
||||
// such a loop only ever reached one arbitrary pooled connection (issue #49, F6).
|
||||
if err := runMigrations(db, sqlDB); err != nil {
|
||||
return nil, fmt.Errorf("run migrations: %w", err)
|
||||
}
|
||||
|
||||
// 8. CRITICAL: Set WAL mode and other performance pragmas
|
||||
// Use raw sqlDB to avoid GORM transaction issues
|
||||
pragmas := []string{
|
||||
"PRAGMA journal_mode=WAL",
|
||||
"PRAGMA synchronous=NORMAL",
|
||||
"PRAGMA cache_size=-64000", // 64MB cache (negative = KB)
|
||||
"PRAGMA temp_store=MEMORY", // Store temp tables in memory
|
||||
"PRAGMA mmap_size=268435456", // 256MB memory-mapped I/O
|
||||
"PRAGMA page_size=4096", // 4KB pages (optimal for most systems)
|
||||
"PRAGMA wal_autocheckpoint=1000", // Explicit default; checkpoint every 1000 WAL frames
|
||||
}
|
||||
for _, pragma := range pragmas {
|
||||
if _, err := sqlDB.Exec(pragma); err != nil {
|
||||
log.Warn().Str("pragma", pragma).Err(err).Msg("Failed to set pragma (non-fatal)")
|
||||
}
|
||||
}
|
||||
// Set busy timeout to 5 seconds to handle concurrent writes
|
||||
// This allows SQLite to retry when database is locked instead of failing immediately
|
||||
if _, err := sqlDB.Exec("PRAGMA busy_timeout=5000"); err != nil {
|
||||
return nil, fmt.Errorf("set busy timeout: %w", err)
|
||||
}
|
||||
|
||||
// 9. Warm the connection pool
|
||||
store.WarmPool(maxConns)
|
||||
|
||||
@@ -148,11 +211,55 @@ func (s *Store) WarmPool(numConns int) {
|
||||
log.Debug().Int("connections", numConns).Msg("Connection pool warmed")
|
||||
}
|
||||
|
||||
// Close closes the database connection.
|
||||
// Close checkpoints the WAL (TRUNCATE) before closing the connection. Checkpointing on
|
||||
// shutdown prevents the WAL file from persisting in a large, dirty state across restarts
|
||||
// and config-change reinitializations, which otherwise leaves a multi-megabyte -wal file
|
||||
// on disk (issue #49). The checkpoint is best-effort: a failure is logged, not fatal.
|
||||
func (s *Store) Close() error {
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second)
|
||||
defer cancel()
|
||||
if err := s.Checkpoint(ctx); err != nil {
|
||||
log.Warn().Err(err).Msg("WAL checkpoint on close failed (non-fatal)")
|
||||
}
|
||||
return s.sqlDB.Close()
|
||||
}
|
||||
|
||||
// Checkpoint runs a TRUNCATE WAL checkpoint: it flushes WAL frames into the main
|
||||
// database file and shrinks the -wal file back to zero. Unlike a PASSIVE checkpoint
|
||||
// (which never truncates the file and is all SQLite's auto-checkpoint ever performs), a
|
||||
// TRUNCATE checkpoint reclaims disk and is the mechanism that bounds WAL growth.
|
||||
// It waits up to the connection busy_timeout for the write lock and returns an error
|
||||
// rather than blocking indefinitely.
|
||||
func (s *Store) Checkpoint(ctx context.Context) error {
|
||||
if _, err := s.sqlDB.ExecContext(ctx, "PRAGMA wal_checkpoint(TRUNCATE)"); err != nil {
|
||||
return fmt.Errorf("wal checkpoint (truncate): %w", err)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// WALSize returns the size in bytes of the SQLite WAL sidecar file (<db>-wal), or 0 if
|
||||
// it does not exist or cannot be stat'd. Used to decide when a checkpoint is worthwhile.
|
||||
func (s *Store) WALSize() int64 {
|
||||
if s.path == "" {
|
||||
return 0
|
||||
}
|
||||
info, err := os.Stat(s.path + "-wal")
|
||||
if err != nil {
|
||||
return 0
|
||||
}
|
||||
return info.Size()
|
||||
}
|
||||
|
||||
// CheckpointIfLarge performs a TRUNCATE checkpoint only when the WAL file has grown to or
|
||||
// beyond threshold bytes. Returns true if a checkpoint was performed. This keeps the
|
||||
// periodic checkpoint cheap: it does no work while the WAL is small.
|
||||
func (s *Store) CheckpointIfLarge(ctx context.Context, threshold int64) (bool, error) {
|
||||
if s.WALSize() < threshold {
|
||||
return false, nil
|
||||
}
|
||||
return true, s.Checkpoint(ctx)
|
||||
}
|
||||
|
||||
// Ping verifies the database connection is alive.
|
||||
func (s *Store) Ping() error {
|
||||
return s.sqlDB.Ping()
|
||||
@@ -193,8 +300,9 @@ func (s *Store) Optimize(ctx context.Context) error {
|
||||
log.Warn().Err(err).Msg("PRAGMA optimize failed (non-fatal)")
|
||||
}
|
||||
|
||||
// Passive WAL checkpoint — doesn't block readers/writers
|
||||
if _, err := s.sqlDB.ExecContext(ctx, "PRAGMA wal_checkpoint(PASSIVE)"); err != nil {
|
||||
// TRUNCATE WAL checkpoint — reclaims the -wal file during low-activity optimization.
|
||||
// (PASSIVE never shrinks the file, so it cannot bound WAL growth — see issue #49.)
|
||||
if err := s.Checkpoint(ctx); err != nil {
|
||||
log.Warn().Err(err).Msg("WAL checkpoint failed (non-fatal)")
|
||||
}
|
||||
|
||||
@@ -519,7 +627,6 @@ func (s *Store) ExecWithTimeout(ctx context.Context, timeout time.Duration, quer
|
||||
return nil
|
||||
}
|
||||
|
||||
|
||||
// TransactionWithTimeout wraps a transaction function with timeout handling.
|
||||
// The transaction is automatically rolled back if the context times out.
|
||||
func (s *Store) TransactionWithTimeout(ctx context.Context, timeout time.Duration, fn func(*gorm.DB) error) error {
|
||||
|
||||
@@ -5,9 +5,12 @@ package gorm
|
||||
|
||||
import (
|
||||
"context"
|
||||
"database/sql"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"strings"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"gorm.io/gorm/logger"
|
||||
)
|
||||
@@ -239,3 +242,321 @@ func TestOptimize_RespectsContextCancellation(t *testing.T) {
|
||||
t.Error("expected error with cancelled context, got nil")
|
||||
}
|
||||
}
|
||||
|
||||
// growWAL inserts sizeable rows to push the SQLite WAL well past a few hundred KB so
|
||||
// checkpoint behaviour can be observed. Returns the WAL file size after the inserts.
|
||||
func growWAL(t *testing.T, store *Store, rows int) int64 {
|
||||
t.Helper()
|
||||
bigTitle := strings.Repeat("x", 2048)
|
||||
for i := 0; i < rows; i++ {
|
||||
_, err := store.GetRawDB().Exec(
|
||||
"INSERT INTO observations (sdk_session_id, title, scope, project, type, created_at, created_at_epoch) "+
|
||||
"VALUES (?, ?, 'project', '/tmp/test', 'decision', '2026-01-01T00:00:00Z', 1735689600)",
|
||||
"sess", bigTitle)
|
||||
if err != nil {
|
||||
t.Fatalf("insert row %d: %v", i, err)
|
||||
}
|
||||
}
|
||||
return store.WALSize()
|
||||
}
|
||||
|
||||
func countObservations(t *testing.T, store *Store) int64 {
|
||||
t.Helper()
|
||||
var n int64
|
||||
if err := store.GetRawDB().QueryRow("SELECT COUNT(*) FROM observations").Scan(&n); err != nil {
|
||||
t.Fatalf("count observations: %v", err)
|
||||
}
|
||||
return n
|
||||
}
|
||||
|
||||
// TestCheckpoint_TruncateShrinksWAL verifies Checkpoint() performs a TRUNCATE checkpoint
|
||||
// that actually reclaims the -wal file. This is the load-bearing fix for issue #49: a
|
||||
// PASSIVE checkpoint drains frames but never shrinks the file, so reverting Checkpoint to
|
||||
// PASSIVE would leave the WAL grown and fail this test.
|
||||
func TestCheckpoint_TruncateShrinksWAL(t *testing.T) {
|
||||
tmpDir, err := os.MkdirTemp("", "gorm_checkpoint_truncate_*")
|
||||
if err != nil {
|
||||
t.Fatalf("create temp dir: %v", err)
|
||||
}
|
||||
defer os.RemoveAll(tmpDir)
|
||||
|
||||
dbPath := filepath.Join(tmpDir, "test.db")
|
||||
store, err := NewStore(Config{Path: dbPath, MaxConns: 2, LogLevel: logger.Silent})
|
||||
if err != nil {
|
||||
t.Fatalf("NewStore failed: %v", err)
|
||||
}
|
||||
defer store.Close()
|
||||
|
||||
walBefore := growWAL(t, store, 1000)
|
||||
if walBefore < 64*1024 {
|
||||
t.Fatalf("expected WAL to grow above 64KiB, got %d bytes", walBefore)
|
||||
}
|
||||
|
||||
if err := store.Checkpoint(context.Background()); err != nil {
|
||||
t.Fatalf("Checkpoint failed: %v", err)
|
||||
}
|
||||
|
||||
walAfter := store.WALSize()
|
||||
if walAfter >= walBefore {
|
||||
t.Errorf("expected WAL to shrink after TRUNCATE checkpoint: before=%d after=%d", walBefore, walAfter)
|
||||
}
|
||||
if walAfter > 64*1024 {
|
||||
t.Errorf("expected WAL truncated to near-zero, got %d bytes", walAfter)
|
||||
}
|
||||
}
|
||||
|
||||
// TestCheckpointIfLarge_GatesOnThreshold verifies the size-gated periodic checkpoint used
|
||||
// by the worker's walCheckpointLoop: a no-op below the threshold, a truncating checkpoint
|
||||
// at/above it.
|
||||
func TestCheckpointIfLarge_GatesOnThreshold(t *testing.T) {
|
||||
tmpDir, err := os.MkdirTemp("", "gorm_checkpoint_gated_*")
|
||||
if err != nil {
|
||||
t.Fatalf("create temp dir: %v", err)
|
||||
}
|
||||
defer os.RemoveAll(tmpDir)
|
||||
|
||||
dbPath := filepath.Join(tmpDir, "test.db")
|
||||
store, err := NewStore(Config{Path: dbPath, MaxConns: 2, LogLevel: logger.Silent})
|
||||
if err != nil {
|
||||
t.Fatalf("NewStore failed: %v", err)
|
||||
}
|
||||
defer store.Close()
|
||||
|
||||
// Below an enormous threshold -> no checkpoint.
|
||||
done, err := store.CheckpointIfLarge(context.Background(), 1<<30) // 1 GiB
|
||||
if err != nil {
|
||||
t.Fatalf("CheckpointIfLarge (small) failed: %v", err)
|
||||
}
|
||||
if done {
|
||||
t.Errorf("expected no checkpoint below threshold, but one was performed")
|
||||
}
|
||||
|
||||
// Grow the WAL, then a low threshold triggers a truncating checkpoint.
|
||||
walBefore := growWAL(t, store, 1000)
|
||||
if walBefore < 64*1024 {
|
||||
t.Fatalf("expected WAL to grow above 64KiB, got %d bytes", walBefore)
|
||||
}
|
||||
|
||||
done, err = store.CheckpointIfLarge(context.Background(), 64*1024)
|
||||
if err != nil {
|
||||
t.Fatalf("CheckpointIfLarge (large) failed: %v", err)
|
||||
}
|
||||
if !done {
|
||||
t.Errorf("expected checkpoint above threshold, but none was performed")
|
||||
}
|
||||
if walAfter := store.WALSize(); walAfter >= walBefore {
|
||||
t.Errorf("expected WAL to shrink after gated checkpoint: before=%d after=%d", walBefore, walAfter)
|
||||
}
|
||||
}
|
||||
|
||||
// TestClose_CheckpointsWAL verifies Close() reclaims the WAL and leaves the data intact on
|
||||
// the next open (issue #49: shutdown must not leave a large dirty WAL on disk).
|
||||
func TestClose_CheckpointsWAL(t *testing.T) {
|
||||
tmpDir, err := os.MkdirTemp("", "gorm_close_checkpoint_*")
|
||||
if err != nil {
|
||||
t.Fatalf("create temp dir: %v", err)
|
||||
}
|
||||
defer os.RemoveAll(tmpDir)
|
||||
|
||||
dbPath := filepath.Join(tmpDir, "test.db")
|
||||
store, err := NewStore(Config{Path: dbPath, MaxConns: 2, LogLevel: logger.Silent})
|
||||
if err != nil {
|
||||
t.Fatalf("NewStore failed: %v", err)
|
||||
}
|
||||
|
||||
if walBefore := growWAL(t, store, 800); walBefore == 0 {
|
||||
t.Fatalf("expected WAL to grow before close, got 0")
|
||||
}
|
||||
count := countObservations(t, store)
|
||||
if count < 800 {
|
||||
t.Fatalf("expected >=800 observations before close, got %d", count)
|
||||
}
|
||||
|
||||
if err := store.Close(); err != nil {
|
||||
t.Fatalf("Close failed: %v", err)
|
||||
}
|
||||
|
||||
// The -wal file must not persist large on disk after a clean shutdown.
|
||||
if info, statErr := os.Stat(dbPath + "-wal"); statErr == nil && info.Size() > 64*1024 {
|
||||
t.Errorf("expected WAL reclaimed on close, -wal still %d bytes", info.Size())
|
||||
}
|
||||
|
||||
// Reopen and verify data survived the checkpoint.
|
||||
store2, err := NewStore(Config{Path: dbPath, MaxConns: 2, LogLevel: logger.Silent})
|
||||
if err != nil {
|
||||
t.Fatalf("reopen NewStore failed: %v", err)
|
||||
}
|
||||
defer store2.Close()
|
||||
|
||||
if count2 := countObservations(t, store2); count2 != count {
|
||||
t.Errorf("expected %d observations after reopen, got %d", count, count2)
|
||||
}
|
||||
}
|
||||
|
||||
// TestBusyTimeoutAppliedToAllConnections verifies the issue #49 DSN fix: busy_timeout is
|
||||
// applied to EVERY pooled connection (not just one arbitrary connection as happened when
|
||||
// it was set via a single post-open sqlDB.Exec).
|
||||
func TestBusyTimeoutAppliedToAllConnections(t *testing.T) {
|
||||
tmpDir, err := os.MkdirTemp("", "gorm_busy_timeout_*")
|
||||
if err != nil {
|
||||
t.Fatalf("create temp dir: %v", err)
|
||||
}
|
||||
defer os.RemoveAll(tmpDir)
|
||||
|
||||
dbPath := filepath.Join(tmpDir, "test.db")
|
||||
const maxConns = 4
|
||||
store, err := NewStore(Config{Path: dbPath, MaxConns: maxConns, LogLevel: logger.Silent})
|
||||
if err != nil {
|
||||
t.Fatalf("NewStore failed: %v", err)
|
||||
}
|
||||
defer store.Close()
|
||||
|
||||
// Pin all connections concurrently so each distinct connection is inspected, then
|
||||
// assert every one reports busy_timeout=5000.
|
||||
raw := store.GetRawDB()
|
||||
conns := make([]*sql.Conn, 0, maxConns)
|
||||
defer func() {
|
||||
for _, c := range conns {
|
||||
_ = c.Close()
|
||||
}
|
||||
}()
|
||||
for i := 0; i < maxConns; i++ {
|
||||
c, err := raw.Conn(context.Background())
|
||||
if err != nil {
|
||||
t.Fatalf("acquire conn %d: %v", i, err)
|
||||
}
|
||||
conns = append(conns, c)
|
||||
}
|
||||
for i, c := range conns {
|
||||
var timeout int
|
||||
if err := c.QueryRowContext(context.Background(), "PRAGMA busy_timeout").Scan(&timeout); err != nil {
|
||||
t.Fatalf("query busy_timeout on conn %d: %v", i, err)
|
||||
}
|
||||
if timeout != 5000 {
|
||||
t.Errorf("conn %d: expected busy_timeout=5000, got %d", i, timeout)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// TestAllPragmasAppliedToAllConnections verifies the issue #49 (F6) ConnectHook fix: not
|
||||
// just busy_timeout but the full pragma set — including the best-effort pragmas that used
|
||||
// to be set via a single post-open sqlDB.Exec (journal_size_limit, temp_store,
|
||||
// wal_autocheckpoint) — is applied to EVERY pooled connection. It pins all connections so
|
||||
// each distinct connection is inspected, then asserts each reports the expected value.
|
||||
func TestAllPragmasAppliedToAllConnections(t *testing.T) {
|
||||
tmpDir, err := os.MkdirTemp("", "gorm_all_pragmas_*")
|
||||
if err != nil {
|
||||
t.Fatalf("create temp dir: %v", err)
|
||||
}
|
||||
defer os.RemoveAll(tmpDir)
|
||||
|
||||
dbPath := filepath.Join(tmpDir, "test.db")
|
||||
const maxConns = 4
|
||||
store, err := NewStore(Config{Path: dbPath, MaxConns: maxConns, LogLevel: logger.Silent})
|
||||
if err != nil {
|
||||
t.Fatalf("NewStore failed: %v", err)
|
||||
}
|
||||
defer store.Close()
|
||||
|
||||
// Expected per-connection pragma values. temp_store=MEMORY reports as 2; the others are
|
||||
// numeric. foreign_keys/journal_mode/synchronous are covered by TestNewStore and the
|
||||
// busy_timeout test; here we focus on the previously single-connection pragmas.
|
||||
checks := []struct {
|
||||
name string
|
||||
want int64
|
||||
}{
|
||||
{"busy_timeout", 5000},
|
||||
{"journal_size_limit", 8388608},
|
||||
{"temp_store", 2}, // 2 == MEMORY
|
||||
{"wal_autocheckpoint", 1000},
|
||||
{"foreign_keys", 1},
|
||||
}
|
||||
|
||||
raw := store.GetRawDB()
|
||||
conns := make([]*sql.Conn, 0, maxConns)
|
||||
defer func() {
|
||||
for _, c := range conns {
|
||||
_ = c.Close()
|
||||
}
|
||||
}()
|
||||
for i := 0; i < maxConns; i++ {
|
||||
c, err := raw.Conn(context.Background())
|
||||
if err != nil {
|
||||
t.Fatalf("acquire conn %d: %v", i, err)
|
||||
}
|
||||
conns = append(conns, c)
|
||||
}
|
||||
|
||||
for i, c := range conns {
|
||||
for _, chk := range checks {
|
||||
var got int64
|
||||
query := "PRAGMA " + chk.name
|
||||
if err := c.QueryRowContext(context.Background(), query).Scan(&got); err != nil {
|
||||
t.Fatalf("conn %d: query %q: %v", i, chk.name, err)
|
||||
}
|
||||
if got != chk.want {
|
||||
t.Errorf("conn %d: %s = %d, want %d", i, chk.name, got, chk.want)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// TestRecycledConnectionRetainsPragmas verifies that recycling a connection (which now
|
||||
// happens because SetConnMaxLifetime is finite, not 0) does NOT drop the correctness
|
||||
// pragmas: the ConnectHook reapplies them on every new connection. We force recycling by
|
||||
// setting a near-zero max lifetime so the next acquisition opens a fresh connection, then
|
||||
// assert the new connection still reports the safe values rather than SQLite defaults
|
||||
// (busy_timeout would default to 0 and journal_mode to "delete" without the hook).
|
||||
func TestRecycledConnectionRetainsPragmas(t *testing.T) {
|
||||
tmpDir, err := os.MkdirTemp("", "gorm_recycle_pragmas_*")
|
||||
if err != nil {
|
||||
t.Fatalf("create temp dir: %v", err)
|
||||
}
|
||||
defer os.RemoveAll(tmpDir)
|
||||
|
||||
dbPath := filepath.Join(tmpDir, "test.db")
|
||||
store, err := NewStore(Config{Path: dbPath, MaxConns: 2, LogLevel: logger.Silent})
|
||||
if err != nil {
|
||||
t.Fatalf("NewStore failed: %v", err)
|
||||
}
|
||||
defer store.Close()
|
||||
|
||||
raw := store.GetRawDB()
|
||||
|
||||
// Force aggressive recycling: any connection older than 1ns is expired on next use, so
|
||||
// database/sql opens a brand-new connection (running the ConnectHook again).
|
||||
raw.SetConnMaxLifetime(time.Nanosecond)
|
||||
time.Sleep(5 * time.Millisecond)
|
||||
|
||||
// Acquire a connection that is necessarily freshly opened (all prior ones are expired),
|
||||
// and verify the hook reapplied the correctness pragmas.
|
||||
conn, err := raw.Conn(context.Background())
|
||||
if err != nil {
|
||||
t.Fatalf("acquire recycled conn: %v", err)
|
||||
}
|
||||
defer conn.Close()
|
||||
|
||||
var busyTimeout int
|
||||
if err := conn.QueryRowContext(context.Background(), "PRAGMA busy_timeout").Scan(&busyTimeout); err != nil {
|
||||
t.Fatalf("query busy_timeout: %v", err)
|
||||
}
|
||||
if busyTimeout != 5000 {
|
||||
t.Errorf("recycled conn: busy_timeout = %d, want 5000 (hook did not reapply)", busyTimeout)
|
||||
}
|
||||
|
||||
var journalMode string
|
||||
if err := conn.QueryRowContext(context.Background(), "PRAGMA journal_mode").Scan(&journalMode); err != nil {
|
||||
t.Fatalf("query journal_mode: %v", err)
|
||||
}
|
||||
if journalMode != "wal" {
|
||||
t.Errorf("recycled conn: journal_mode = %q, want \"wal\" (hook did not reapply)", journalMode)
|
||||
}
|
||||
|
||||
var foreignKeys int
|
||||
if err := conn.QueryRowContext(context.Background(), "PRAGMA foreign_keys").Scan(&foreignKeys); err != nil {
|
||||
t.Fatalf("query foreign_keys: %v", err)
|
||||
}
|
||||
if foreignKeys != 1 {
|
||||
t.Errorf("recycled conn: foreign_keys = %d, want 1 (hook did not reapply)", foreignKeys)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -83,8 +83,20 @@ func (s *Service) Start(ctx context.Context) {
|
||||
Bool("cleanup_stale", s.config.CleanupStaleObservations).
|
||||
Msg("Starting maintenance scheduler")
|
||||
|
||||
// Initial run after 5 minutes (allow system to stabilize)
|
||||
time.Sleep(5 * time.Minute)
|
||||
// Initial run after 5 minutes (allow system to stabilize).
|
||||
// Use a cancellable timer so shutdown (ctx cancel / Stop) is not blocked for up to 5m.
|
||||
initialDelay := time.NewTimer(5 * time.Minute)
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
initialDelay.Stop()
|
||||
s.log.Info().Msg("Maintenance shutting down before initial run (context cancellation)")
|
||||
return
|
||||
case <-s.stopCh:
|
||||
initialDelay.Stop()
|
||||
s.log.Info().Msg("Maintenance shutting down before initial run (stop signal)")
|
||||
return
|
||||
case <-initialDelay.C:
|
||||
}
|
||||
s.runMaintenance(ctx)
|
||||
|
||||
ticker := time.NewTicker(interval)
|
||||
@@ -288,7 +300,13 @@ func (s *Service) Stats() map[string]any {
|
||||
}
|
||||
}
|
||||
|
||||
// RunNow triggers an immediate maintenance run.
|
||||
// RunNow triggers an immediate maintenance run in the background.
|
||||
func (s *Service) RunNow(ctx context.Context) {
|
||||
go s.runMaintenance(ctx)
|
||||
}
|
||||
|
||||
// RunNowSync triggers an immediate maintenance run and blocks until it completes.
|
||||
// Use this when the caller needs to report a synchronous result (e.g. an HTTP handler).
|
||||
func (s *Service) RunNowSync(ctx context.Context) {
|
||||
s.runMaintenance(ctx)
|
||||
}
|
||||
|
||||
@@ -534,7 +534,7 @@ func (s *Server) handleToolsList(req *Request) *Response {
|
||||
},
|
||||
{
|
||||
Name: "trigger_maintenance",
|
||||
Description: "Trigger an immediate maintenance run (cleanup old observations, optimize database).",
|
||||
Description: "Trigger an immediate database maintenance run: optimize/checkpoint the database, clean up old prompts, apply any configured observation retention/stale cleanup, and recalculate importance scores.",
|
||||
InputSchema: map[string]any{
|
||||
"type": "object",
|
||||
"properties": map[string]any{},
|
||||
@@ -917,7 +917,7 @@ func (s *Server) callTool(ctx context.Context, name string, args json.RawMessage
|
||||
"project": s.project,
|
||||
})
|
||||
case "trigger_maintenance":
|
||||
return s.proxyPostRaw(ctx, "/api/scoring/recalculate", nil)
|
||||
return s.proxyPostRaw(ctx, "/api/maintenance/run", nil)
|
||||
case "analyze_observation_importance":
|
||||
return s.handleAnalyzeImportanceProxy(ctx, args)
|
||||
case "analyze_search_patterns":
|
||||
|
||||
@@ -314,6 +314,49 @@ func (s *Service) handleTriggerRecalculation(w http.ResponseWriter, r *http.Requ
|
||||
writeJSON(w, map[string]string{"status": "recalculation triggered"})
|
||||
}
|
||||
|
||||
// handleRunMaintenance triggers an immediate, synchronous database maintenance run
|
||||
// (Optimize/TRUNCATE checkpoint + prompt cleanup + any enabled retention/stale cleanup)
|
||||
// and also kicks off an importance-score recalculation in the background so the behavior
|
||||
// of the previous trigger_maintenance tool is preserved (issue #49).
|
||||
func (s *Service) handleRunMaintenance(w http.ResponseWriter, r *http.Request) {
|
||||
// initMu.RLock held by requireReady middleware
|
||||
maintSvc := s.maintenanceSvc
|
||||
recalculator := s.recalculator
|
||||
|
||||
if maintSvc == nil {
|
||||
http.Error(w, "maintenance service not available", http.StatusServiceUnavailable)
|
||||
return
|
||||
}
|
||||
|
||||
// Run maintenance synchronously with an independent, bounded context so the caller
|
||||
// receives a real completion status. Use context.Background so an HTTP client timeout
|
||||
// does not abort an in-progress DB maintenance pass.
|
||||
mctx, cancel := context.WithTimeout(context.Background(), 60*time.Second)
|
||||
defer cancel()
|
||||
maintSvc.RunNowSync(mctx)
|
||||
|
||||
// Preserve prior trigger_maintenance behavior: also recalculate importance scores.
|
||||
recalcTriggered := false
|
||||
if recalculator != nil {
|
||||
recalcTriggered = true
|
||||
s.wg.Add(1)
|
||||
go func() {
|
||||
defer s.wg.Done()
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 60*time.Second)
|
||||
defer cancel()
|
||||
if err := recalculator.RecalculateNow(ctx); err != nil {
|
||||
log.Error().Err(err).Msg("Background recalculation during maintenance failed")
|
||||
}
|
||||
}()
|
||||
}
|
||||
|
||||
writeJSON(w, map[string]any{
|
||||
"status": "maintenance completed",
|
||||
"recalc_triggered": recalcTriggered,
|
||||
"maintenance_stats": maintSvc.Stats(),
|
||||
})
|
||||
}
|
||||
|
||||
// parseIntParam parses an integer query parameter with a default value.
|
||||
func parseIntParam(r *http.Request, name string, defaultVal int) int {
|
||||
if val := r.URL.Query().Get(name); val != "" {
|
||||
|
||||
@@ -16,6 +16,7 @@ import (
|
||||
"github.com/lukaszraczylo/claude-mnemonic/internal/config"
|
||||
"github.com/lukaszraczylo/claude-mnemonic/internal/db/gorm"
|
||||
"github.com/lukaszraczylo/claude-mnemonic/internal/embedding"
|
||||
"github.com/lukaszraczylo/claude-mnemonic/internal/maintenance"
|
||||
"github.com/lukaszraczylo/claude-mnemonic/internal/pattern"
|
||||
"github.com/lukaszraczylo/claude-mnemonic/internal/reranking"
|
||||
"github.com/lukaszraczylo/claude-mnemonic/internal/scoring"
|
||||
@@ -44,6 +45,14 @@ const (
|
||||
// QueueProcessInterval is how often the background queue processor runs.
|
||||
QueueProcessInterval = 2 * time.Second
|
||||
|
||||
// WALCheckpointInterval is how often the worker checks whether the SQLite WAL needs a
|
||||
// TRUNCATE checkpoint to reclaim disk and prevent unbounded growth (issue #49).
|
||||
WALCheckpointInterval = 60 * time.Second
|
||||
|
||||
// WALCheckpointThreshold is the WAL file size at or above which the periodic check
|
||||
// performs a TRUNCATE checkpoint. Keeps the steady-state WAL bounded to a few MB.
|
||||
WALCheckpointThreshold = 4 << 20 // 4 MiB
|
||||
|
||||
// reinitializationDrainDelay is the delay after marking the service as not ready
|
||||
// to allow in-flight requests to complete before reinitializing.
|
||||
reinitializationDrainDelay = 200 * time.Millisecond
|
||||
@@ -121,6 +130,7 @@ type Service struct {
|
||||
patternStore *gorm.PatternStore
|
||||
relationStore *gorm.RelationStore
|
||||
patternDetector *pattern.Detector
|
||||
maintenanceSvc *maintenance.Service
|
||||
sessionManager *session.Manager
|
||||
sseBroadcaster *sse.Broadcaster
|
||||
processor *sdk.Processor
|
||||
@@ -570,6 +580,34 @@ func (s *Service) initializeAsync() {
|
||||
go s.processQueue()
|
||||
}
|
||||
|
||||
// Start periodic WAL checkpoint loop to bound SQLite WAL file growth (issue #49).
|
||||
s.wg.Add(1)
|
||||
go s.walCheckpointLoop()
|
||||
|
||||
// Start the scheduled maintenance service (issue #49: was dead code, never instantiated).
|
||||
// vectorCleanupFn mirrors the observation store's cleanup hook so age/stale deletions done
|
||||
// directly via GORM still remove their vectors from sqlite-vec.
|
||||
var vectorCleanupFn func(ctx context.Context, deletedIDs []int64)
|
||||
if vectorSync != nil {
|
||||
vectorCleanupFn = func(ctx context.Context, deletedIDs []int64) {
|
||||
if err := retryWithBackoff(ctx, VectorSyncMaxRetries, VectorSyncInitialBackoff, func() error {
|
||||
return vectorSync.DeleteObservations(ctx, deletedIDs)
|
||||
}); err != nil {
|
||||
log.Warn().Err(err).Ints64("ids", deletedIDs).Msg("Failed to delete observations from sqlite-vec during maintenance")
|
||||
}
|
||||
}
|
||||
}
|
||||
maintSvc := maintenance.NewService(store, observationStore, summaryStore, promptStore, vectorCleanupFn, s.config, log.Logger)
|
||||
s.initMu.Lock()
|
||||
s.maintenanceSvc = maintSvc
|
||||
s.initMu.Unlock()
|
||||
s.wg.Add(1)
|
||||
go func() {
|
||||
defer s.wg.Done()
|
||||
maintSvc.Start(s.ctx)
|
||||
}()
|
||||
log.Info().Msg("Maintenance scheduler started")
|
||||
|
||||
// Start file watchers for auto-recreation on deletion
|
||||
s.startWatchers()
|
||||
|
||||
@@ -1290,6 +1328,9 @@ func (s *Service) setupRoutes() {
|
||||
r.Put("/api/scoring/concepts/{concept}", s.handleUpdateConceptWeight)
|
||||
r.Post("/api/scoring/recalculate", s.handleTriggerRecalculation)
|
||||
|
||||
// Maintenance: run an immediate synchronous DB maintenance pass (issue #49)
|
||||
r.Post("/api/maintenance/run", s.handleRunMaintenance)
|
||||
|
||||
// Context injection
|
||||
r.Get("/api/context/count", s.handleContextCount)
|
||||
r.Get("/api/context/inject", s.handleContextInject)
|
||||
@@ -1621,6 +1662,52 @@ func (s *Service) processQueue() {
|
||||
}
|
||||
}
|
||||
|
||||
// walCheckpointLoop periodically checkpoints the SQLite WAL so it cannot grow unbounded
|
||||
// during long-lived sessions. SQLite's internal auto-checkpoint is PASSIVE and never
|
||||
// shrinks the -wal file; under sustained writes with overlapping readers it can leave the
|
||||
// WAL growing without limit (issue #49). This loop performs a TRUNCATE checkpoint whenever
|
||||
// the WAL has grown to WALCheckpointThreshold, and does nothing while it is small.
|
||||
func (s *Service) walCheckpointLoop() {
|
||||
defer s.wg.Done()
|
||||
|
||||
// Tunable via config; fall back to the package constants when unset/<=0 (issue #49).
|
||||
interval := WALCheckpointInterval
|
||||
if s.config != nil && s.config.WALCheckpointIntervalSeconds > 0 {
|
||||
interval = time.Duration(s.config.WALCheckpointIntervalSeconds) * time.Second
|
||||
}
|
||||
threshold := int64(WALCheckpointThreshold)
|
||||
if s.config != nil && s.config.WALCheckpointThresholdBytes > 0 {
|
||||
threshold = s.config.WALCheckpointThresholdBytes
|
||||
}
|
||||
|
||||
ticker := time.NewTicker(interval)
|
||||
defer ticker.Stop()
|
||||
|
||||
for {
|
||||
select {
|
||||
case <-s.ctx.Done():
|
||||
return
|
||||
case <-ticker.C:
|
||||
s.initMu.RLock()
|
||||
store := s.store
|
||||
s.initMu.RUnlock()
|
||||
if store == nil {
|
||||
continue
|
||||
}
|
||||
|
||||
ctx, cancel := context.WithTimeout(s.ctx, 15*time.Second)
|
||||
done, err := store.CheckpointIfLarge(ctx, threshold)
|
||||
cancel()
|
||||
switch {
|
||||
case err != nil:
|
||||
log.Warn().Err(err).Msg("Periodic WAL checkpoint failed (non-fatal)")
|
||||
case done:
|
||||
log.Debug().Msg("Periodic WAL checkpoint (TRUNCATE) completed")
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// processAllSessions processes pending messages for all active sessions.
|
||||
// Messages are processed in parallel using goroutines, with concurrency
|
||||
// limited by a channel-based semaphore.
|
||||
@@ -1748,6 +1835,9 @@ func (s *Service) Shutdown(ctx context.Context) error {
|
||||
if s.patternDetector != nil {
|
||||
s.patternDetector.Stop()
|
||||
}
|
||||
if s.maintenanceSvc != nil {
|
||||
s.maintenanceSvc.Stop()
|
||||
}
|
||||
|
||||
// Phase 4: Shutdown sessions (flush pending work)
|
||||
log.Debug().Msg("Phase 4: Shutting down sessions...")
|
||||
|
||||
@@ -564,6 +564,44 @@ func POSTWithContext(ctx context.Context, port int, path string, body interface{
|
||||
return nil
|
||||
}
|
||||
|
||||
// POSTWithContextResult sends a POST request using the provided context and
|
||||
// decodes the JSON response body, mirroring POST but honoring ctx for
|
||||
// cancellation/deadline. Used on the prompt critical path so a wedged worker
|
||||
// aborts at the hook deadline instead of blocking for the full client timeout.
|
||||
// A non-JSON body is returned as (nil, nil), matching POST's behavior.
|
||||
func POSTWithContextResult(ctx context.Context, port int, path string, body interface{}) (map[string]interface{}, error) {
|
||||
jsonBody, err := json.Marshal(body)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
req, err := http.NewRequestWithContext(ctx, http.MethodPost,
|
||||
fmt.Sprintf("http://127.0.0.1:%d%s", port, path),
|
||||
bytes.NewReader(jsonBody))
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
|
||||
resp, err := hookClient.Do(req)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer func() { _ = resp.Body.Close() }()
|
||||
|
||||
if resp.StatusCode >= 400 {
|
||||
return nil, fmt.Errorf("request failed: %s", resp.Status)
|
||||
}
|
||||
|
||||
var result map[string]interface{}
|
||||
if err := json.NewDecoder(resp.Body).Decode(&result); err != nil {
|
||||
// Not all endpoints return JSON
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
return result, nil
|
||||
}
|
||||
|
||||
// GET sends a GET request to the worker.
|
||||
func GET(port int, path string) (map[string]interface{}, error) {
|
||||
resp, err := hookClient.Get(fmt.Sprintf("http://127.0.0.1:%d%s", port, path))
|
||||
@@ -584,6 +622,35 @@ func GET(port int, path string) (map[string]interface{}, error) {
|
||||
return result, nil
|
||||
}
|
||||
|
||||
// GETWithContext sends a GET request using the provided context and decodes the
|
||||
// JSON response body, mirroring GET but honoring ctx for cancellation/deadline.
|
||||
// Used on the prompt critical path so a wedged worker aborts at the hook
|
||||
// deadline instead of blocking for the full client timeout.
|
||||
func GETWithContext(ctx context.Context, port int, path string) (map[string]interface{}, error) {
|
||||
req, err := http.NewRequestWithContext(ctx, http.MethodGet,
|
||||
fmt.Sprintf("http://127.0.0.1:%d%s", port, path), nil)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
resp, err := hookClient.Do(req)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer func() { _ = resp.Body.Close() }()
|
||||
|
||||
if resp.StatusCode >= 400 {
|
||||
return nil, fmt.Errorf("request failed: %s", resp.Status)
|
||||
}
|
||||
|
||||
var result map[string]interface{}
|
||||
if err := json.NewDecoder(resp.Body).Decode(&result); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return result, nil
|
||||
}
|
||||
|
||||
// versionsCompatible checks if two versions are compatible for dev builds.
|
||||
// Returns true if both versions share the same base version (ignoring -dirty, -dev, commit suffixes).
|
||||
// This prevents unnecessary restarts during development.
|
||||
|
||||
@@ -952,6 +952,203 @@ func TestGET_Timeout(t *testing.T) {
|
||||
require.Error(t, err)
|
||||
}
|
||||
|
||||
// TestGETWithContext tests GETWithContext with a mock server.
|
||||
func TestGETWithContext(t *testing.T) {
|
||||
tests := []struct {
|
||||
serverHandler func(w http.ResponseWriter, r *http.Request)
|
||||
expectedResult map[string]interface{}
|
||||
name string
|
||||
expectError bool
|
||||
}{
|
||||
{
|
||||
name: "successful GET with JSON response",
|
||||
serverHandler: func(w http.ResponseWriter, r *http.Request) {
|
||||
assert.Equal(t, http.MethodGet, r.Method)
|
||||
w.WriteHeader(http.StatusOK)
|
||||
_ = json.NewEncoder(w).Encode(map[string]interface{}{"data": "test"})
|
||||
},
|
||||
expectError: false,
|
||||
expectedResult: map[string]interface{}{"data": "test"},
|
||||
},
|
||||
{
|
||||
name: "GET with 404 error",
|
||||
serverHandler: func(w http.ResponseWriter, r *http.Request) {
|
||||
w.WriteHeader(http.StatusNotFound)
|
||||
},
|
||||
expectError: true,
|
||||
},
|
||||
{
|
||||
name: "GET with invalid JSON",
|
||||
serverHandler: func(w http.ResponseWriter, r *http.Request) {
|
||||
w.WriteHeader(http.StatusOK)
|
||||
_, _ = w.Write([]byte("not valid json"))
|
||||
},
|
||||
expectError: true,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
server := httptest.NewServer(http.HandlerFunc(tt.serverHandler))
|
||||
defer server.Close()
|
||||
|
||||
var port int
|
||||
_, err := fmt.Sscanf(server.URL, "http://127.0.0.1:%d", &port)
|
||||
require.NoError(t, err)
|
||||
|
||||
result, err := GETWithContext(context.Background(), port, "/test")
|
||||
|
||||
if tt.expectError {
|
||||
assert.Error(t, err)
|
||||
} else {
|
||||
assert.NoError(t, err)
|
||||
if tt.expectedResult != nil {
|
||||
assert.Equal(t, tt.expectedResult["data"], result["data"])
|
||||
}
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// TestGETWithContext_Timeout verifies the context deadline aborts a slow server
|
||||
// well before the hookClient timeout, so a wedged worker cannot stall the prompt.
|
||||
func TestGETWithContext_Timeout(t *testing.T) {
|
||||
// Server that blocks longer than the context deadline.
|
||||
blockUntil := make(chan struct{})
|
||||
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
<-blockUntil // never closed during the test -> server hangs
|
||||
}))
|
||||
defer server.Close()
|
||||
defer close(blockUntil)
|
||||
|
||||
var port int
|
||||
_, err := fmt.Sscanf(server.URL, "http://127.0.0.1:%d", &port)
|
||||
require.NoError(t, err)
|
||||
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 100*time.Millisecond)
|
||||
defer cancel()
|
||||
|
||||
start := time.Now()
|
||||
_, err = GETWithContext(ctx, port, "/test")
|
||||
elapsed := time.Since(start)
|
||||
|
||||
require.Error(t, err)
|
||||
// Should abort near the 100ms deadline, far below hookClient's 10s timeout.
|
||||
assert.Less(t, elapsed, 2*time.Second, "context deadline must abort the request quickly")
|
||||
}
|
||||
|
||||
// TestGETWithContext_CancelledContext verifies an already-cancelled context
|
||||
// returns immediately without making a real request.
|
||||
func TestGETWithContext_CancelledContext(t *testing.T) {
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
cancel() // cancel immediately
|
||||
|
||||
start := time.Now()
|
||||
_, err := GETWithContext(ctx, 99994, "/test")
|
||||
elapsed := time.Since(start)
|
||||
|
||||
require.Error(t, err)
|
||||
assert.Less(t, elapsed, 1*time.Second, "cancelled context should return immediately")
|
||||
}
|
||||
|
||||
// TestPOSTWithContextResult tests POSTWithContextResult with a mock server.
|
||||
func TestPOSTWithContextResult(t *testing.T) {
|
||||
tests := []struct {
|
||||
body interface{}
|
||||
serverHandler func(w http.ResponseWriter, r *http.Request)
|
||||
expectedResult map[string]interface{}
|
||||
name string
|
||||
expectError bool
|
||||
}{
|
||||
{
|
||||
name: "successful POST with JSON response",
|
||||
serverHandler: func(w http.ResponseWriter, r *http.Request) {
|
||||
assert.Equal(t, http.MethodPost, r.Method)
|
||||
assert.Equal(t, "application/json", r.Header.Get("Content-Type"))
|
||||
w.WriteHeader(http.StatusOK)
|
||||
_ = json.NewEncoder(w).Encode(map[string]interface{}{"status": "ok"})
|
||||
},
|
||||
body: map[string]string{"key": "value"},
|
||||
expectError: false,
|
||||
expectedResult: map[string]interface{}{"status": "ok"},
|
||||
},
|
||||
{
|
||||
name: "POST with 400 error",
|
||||
serverHandler: func(w http.ResponseWriter, r *http.Request) {
|
||||
w.WriteHeader(http.StatusBadRequest)
|
||||
},
|
||||
body: map[string]string{"key": "value"},
|
||||
expectError: true,
|
||||
},
|
||||
{
|
||||
name: "POST with non-JSON response returns nil",
|
||||
serverHandler: func(w http.ResponseWriter, r *http.Request) {
|
||||
w.WriteHeader(http.StatusOK)
|
||||
_, _ = w.Write([]byte("not json"))
|
||||
},
|
||||
body: map[string]string{"key": "value"},
|
||||
expectError: false,
|
||||
expectedResult: nil,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
server := httptest.NewServer(http.HandlerFunc(tt.serverHandler))
|
||||
defer server.Close()
|
||||
|
||||
var port int
|
||||
_, err := fmt.Sscanf(server.URL, "http://127.0.0.1:%d", &port)
|
||||
require.NoError(t, err)
|
||||
|
||||
result, err := POSTWithContextResult(context.Background(), port, "/test", tt.body)
|
||||
|
||||
if tt.expectError {
|
||||
assert.Error(t, err)
|
||||
} else {
|
||||
assert.NoError(t, err)
|
||||
if tt.expectedResult != nil {
|
||||
assert.Equal(t, tt.expectedResult["status"], result["status"])
|
||||
} else {
|
||||
assert.Nil(t, result)
|
||||
}
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// TestPOSTWithContextResult_MarshalError tests POSTWithContextResult with an unmarshalable body.
|
||||
func TestPOSTWithContextResult_MarshalError(t *testing.T) {
|
||||
badValue := make(chan int)
|
||||
_, err := POSTWithContextResult(context.Background(), 99999, "/test", badValue)
|
||||
require.Error(t, err)
|
||||
}
|
||||
|
||||
// TestPOSTWithContextResult_Timeout verifies the context deadline aborts a slow
|
||||
// server before the hookClient timeout.
|
||||
func TestPOSTWithContextResult_Timeout(t *testing.T) {
|
||||
blockUntil := make(chan struct{})
|
||||
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
<-blockUntil
|
||||
}))
|
||||
defer server.Close()
|
||||
defer close(blockUntil)
|
||||
|
||||
var port int
|
||||
_, err := fmt.Sscanf(server.URL, "http://127.0.0.1:%d", &port)
|
||||
require.NoError(t, err)
|
||||
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 100*time.Millisecond)
|
||||
defer cancel()
|
||||
|
||||
start := time.Now()
|
||||
_, err = POSTWithContextResult(ctx, port, "/test", map[string]string{"k": "v"})
|
||||
elapsed := time.Since(start)
|
||||
|
||||
require.Error(t, err)
|
||||
assert.Less(t, elapsed, 2*time.Second, "context deadline must abort the request quickly")
|
||||
}
|
||||
|
||||
// TestIsWorkerRunning_Timeout tests IsWorkerRunning with timeout.
|
||||
func TestIsWorkerRunning_Timeout(t *testing.T) {
|
||||
// Non-existent port should quickly return false
|
||||
|
||||
Reference in New Issue
Block a user