diff --git a/cmd/hooks/user-prompt/main.go b/cmd/hooks/user-prompt/main.go index bc8736c..1197ca3 100644 --- a/cmd/hooks/user-prompt/main.go +++ b/cmd/hooks/user-prompt/main.go @@ -91,18 +91,21 @@ func handleUserPrompt(ctx *hooks.HookContext, input *Input) (string, error) { observationCount int ) - // Start search in background + // Start search in background. Pass the deadline context so a wedged worker + // aborts the request at the deadline instead of blocking for the full + // hookClient timeout (10s). Errors are ignored -- fail open with no memory. wg.Add(1) go func() { defer wg.Done() - searchResult, _ = hooks.GET(ctx.Port, searchURL) + searchResult, _ = hooks.GETWithContext(deadline, ctx.Port, searchURL) }() - // Start session init in parallel (with observationCount=0; approximate is fine) + // Start session init in parallel (with observationCount=0; approximate is fine). + // Deadline context guards this call too. wg.Add(1) go func() { defer wg.Done() - initResult, initErr = hooks.POST(ctx.Port, "/api/sessions/init", map[string]interface{}{ + initResult, initErr = hooks.POSTWithContextResult(deadline, ctx.Port, "/api/sessions/init", map[string]interface{}{ "claudeSessionId": ctx.SessionID, "project": ctx.Project, "prompt": input.Prompt, @@ -113,7 +116,8 @@ func handleUserPrompt(ctx *hooks.HookContext, input *Input) (string, error) { // Wait for both to complete wg.Wait() - // Check deadline after network calls + // Check deadline after network calls -- if exceeded, fail open (proceed with + // no injected memory) rather than blocking or erroring the user's prompt. if deadline.Err() != nil { return "", nil } @@ -173,9 +177,11 @@ func handleUserPrompt(ctx *hooks.HookContext, input *Input) (string, error) { contextToInject = contextBuilder } - // Check session init result + // Check session init result. A session-init failure must never block the + // prompt: degrade gracefully and still inject any memory we found. if initErr != nil { - return "", initErr + fmt.Fprintf(os.Stderr, "[user-prompt] Session init failed: %v\n", initErr) + return contextToInject, nil } if initResult == nil { return contextToInject, nil // Non-JSON response from worker, skip session init @@ -201,13 +207,15 @@ func handleUserPrompt(ctx *hooks.HookContext, input *Input) (string, error) { fmt.Fprintf(os.Stderr, "[user-prompt] Session %d, prompt #%d\n", sessionID, promptNumber) - // Start SDK agent (depends on session init result, so kept sequential) - _, err := hooks.POST(ctx.Port, fmt.Sprintf("/sessions/%d/init", sessionID), map[string]interface{}{ + // Start SDK agent (depends on session init result, so kept sequential). + // Deadline-guarded so a wedged worker cannot stall past the hook budget. + // Failure here must not block the prompt: degrade gracefully, still inject memory. + if _, err := hooks.POSTWithContextResult(deadline, ctx.Port, fmt.Sprintf("/sessions/%d/init", sessionID), map[string]interface{}{ "userPrompt": input.Prompt, "promptNumber": promptNumber, - }) - if err != nil { - return "", err + }); err != nil { + fmt.Fprintf(os.Stderr, "[user-prompt] SDK agent init failed: %v\n", err) + return contextToInject, nil } // Return context if we found relevant observations diff --git a/internal/config/config.go b/internal/config/config.go index f2abde8..0a19000 100644 --- a/internal/config/config.go +++ b/internal/config/config.go @@ -39,44 +39,46 @@ var CriticalConcepts = []string{ // Config holds the application configuration. // Field order optimized for memory alignment (fieldalignment). type Config struct { - ContextFullField string `json:"context_full_field"` - DBPath string `json:"db_path"` - Model string `json:"model"` - ClaudeCodePath string `json:"claude_code_path"` - EmbeddingModel string `json:"embedding_model"` - VectorStorageStrategy string `json:"vector_storage_strategy"` - ContextObsConcepts []string `json:"context_obs_concepts"` - ContextObsTypes []string `json:"context_obs_types"` - ContextFullCount int `json:"context_full_count"` - GraphBranchFactor int `json:"graph_branch_factor"` - GraphEdgeWeight float64 `json:"graph_edge_weight"` - ContextRelevanceThreshold float64 `json:"context_relevance_threshold"` - RerankingCandidates int `json:"reranking_candidates"` - WorkerPort int `json:"worker_port"` - DeduplicationThreshold float64 `json:"deduplication_threshold"` - RerankingMinImprovement float64 `json:"reranking_min_improvement"` - ContextObservations int `json:"context_observations"` - ContextMaxPromptResults int `json:"context_max_prompt_results"` - ContextSessionCount int `json:"context_session_count"` - MaxConns int `json:"max_conns"` - RerankingAlpha float64 `json:"reranking_alpha"` - GraphMaxHops int `json:"graph_max_hops"` - RerankingResults int `json:"reranking_results"` - GraphRebuildIntervalMin int `json:"graph_rebuild_interval_min"` - HubThreshold int `json:"hub_threshold"` - ObservationRetentionDays int `json:"observation_retention_days"` - MaintenanceIntervalHours int `json:"maintenance_interval_hours"` - ContextMaxTokensStartup int `json:"context_max_tokens_startup"` - ContextMaxTokensPrompt int `json:"context_max_tokens_prompt"` - ContextShowWorkTokens bool `json:"context_show_work_tokens"` - ContextShowReadTokens bool `json:"context_show_read_tokens"` - RerankingPureMode bool `json:"reranking_pure_mode"` - GraphEnabled bool `json:"graph_enabled"` - DeduplicationEnabled bool `json:"deduplication_enabled"` - MaintenanceEnabled bool `json:"maintenance_enabled"` - RerankingEnabled bool `json:"reranking_enabled"` - ContextShowLastSummary bool `json:"context_show_last_summary"` - CleanupStaleObservations bool `json:"cleanup_stale_observations"` + ContextFullField string `json:"context_full_field"` + DBPath string `json:"db_path"` + Model string `json:"model"` + ClaudeCodePath string `json:"claude_code_path"` + EmbeddingModel string `json:"embedding_model"` + VectorStorageStrategy string `json:"vector_storage_strategy"` + ContextObsConcepts []string `json:"context_obs_concepts"` + ContextObsTypes []string `json:"context_obs_types"` + ContextFullCount int `json:"context_full_count"` + GraphBranchFactor int `json:"graph_branch_factor"` + GraphEdgeWeight float64 `json:"graph_edge_weight"` + ContextRelevanceThreshold float64 `json:"context_relevance_threshold"` + RerankingCandidates int `json:"reranking_candidates"` + WorkerPort int `json:"worker_port"` + DeduplicationThreshold float64 `json:"deduplication_threshold"` + RerankingMinImprovement float64 `json:"reranking_min_improvement"` + ContextObservations int `json:"context_observations"` + ContextMaxPromptResults int `json:"context_max_prompt_results"` + ContextSessionCount int `json:"context_session_count"` + MaxConns int `json:"max_conns"` + RerankingAlpha float64 `json:"reranking_alpha"` + GraphMaxHops int `json:"graph_max_hops"` + RerankingResults int `json:"reranking_results"` + GraphRebuildIntervalMin int `json:"graph_rebuild_interval_min"` + HubThreshold int `json:"hub_threshold"` + ObservationRetentionDays int `json:"observation_retention_days"` + MaintenanceIntervalHours int `json:"maintenance_interval_hours"` + WALCheckpointIntervalSeconds int `json:"wal_checkpoint_interval_seconds"` + WALCheckpointThresholdBytes int64 `json:"wal_checkpoint_threshold_bytes"` + ContextMaxTokensStartup int `json:"context_max_tokens_startup"` + ContextMaxTokensPrompt int `json:"context_max_tokens_prompt"` + ContextShowWorkTokens bool `json:"context_show_work_tokens"` + ContextShowReadTokens bool `json:"context_show_read_tokens"` + RerankingPureMode bool `json:"reranking_pure_mode"` + GraphEnabled bool `json:"graph_enabled"` + DeduplicationEnabled bool `json:"deduplication_enabled"` + MaintenanceEnabled bool `json:"maintenance_enabled"` + RerankingEnabled bool `json:"reranking_enabled"` + ContextShowLastSummary bool `json:"context_show_last_summary"` + CleanupStaleObservations bool `json:"cleanup_stale_observations"` } var ( @@ -181,6 +183,10 @@ func Default() *Config { MaintenanceIntervalHours: 6, // Run every 6 hours ObservationRetentionDays: 0, // 0 = no age-based deletion (keep all) CleanupStaleObservations: false, // Don't auto-cleanup stale observations + // WAL checkpoint loop tunables (issue #49). Defaults mirror the worker constants: + // check the WAL every 60s and TRUNCATE-checkpoint once it reaches 4 MiB. + WALCheckpointIntervalSeconds: 60, + WALCheckpointThresholdBytes: 4 << 20, // 4 MiB } } @@ -284,6 +290,13 @@ func Load() (*Config, error) { if v, ok := settings["CLAUDE_MNEMONIC_CONTEXT_MAX_TOKENS_PROMPT"].(float64); ok && v > 0 { cfg.ContextMaxTokensPrompt = int(v) } + // WAL checkpoint loop tunables (issue #49) + if v, ok := settings["CLAUDE_MNEMONIC_WAL_CHECKPOINT_INTERVAL_SECONDS"].(float64); ok && v > 0 { + cfg.WALCheckpointIntervalSeconds = int(v) + } + if v, ok := settings["CLAUDE_MNEMONIC_WAL_CHECKPOINT_THRESHOLD_BYTES"].(float64); ok && v > 0 { + cfg.WALCheckpointThresholdBytes = int64(v) + } // Deduplication settings if v, ok := settings["CLAUDE_MNEMONIC_DEDUP_ENABLED"].(bool); ok { cfg.DeduplicationEnabled = v diff --git a/internal/db/gorm/store.go b/internal/db/gorm/store.go index e3511f9..11130a8 100644 --- a/internal/db/gorm/store.go +++ b/internal/db/gorm/store.go @@ -5,18 +5,80 @@ import ( "context" "database/sql" "fmt" + "os" "slices" "sync" "time" sqlite_vec "github.com/asg017/sqlite-vec-go-bindings/cgo" - _ "github.com/mattn/go-sqlite3" // Import SQLite driver with FTS5 support + sqlite3 "github.com/mattn/go-sqlite3" // SQLite driver with FTS5 support "github.com/rs/zerolog/log" "gorm.io/driver/sqlite" "gorm.io/gorm" "gorm.io/gorm/logger" ) +// driverName is the name of the custom mattn/go-sqlite3 driver registered with a +// ConnectHook that applies ALL connection pragmas (correctness + best-effort) to EVERY +// pooled connection at open time. The stock "sqlite3" driver has no hook, so pragmas set +// via a single post-open Exec only reach one arbitrary pooled connection (issue #49, F6). +const driverName = "sqlite3_mnemonic" + +// registerDriverOnce guards driver registration so it runs exactly once per process. +// database/sql panics with "sql: Register called twice" on a duplicate name, and NewStore +// may be called multiple times (e.g. after a config-change reinitialization). +var registerDriverOnce sync.Once + +// correctnessPragmas MUST succeed on every connection: getting any of them wrong changes +// transactional/locking semantics, not just performance. A failure here aborts the open. +var correctnessPragmas = []string{ + "PRAGMA foreign_keys=ON", + "PRAGMA journal_mode=WAL", + "PRAGMA synchronous=NORMAL", + "PRAGMA busy_timeout=5000", + "PRAGMA cache_size=-64000", +} + +// bestEffortPragmas are per-connection or database-wide optimizations. A failure is logged +// and tolerated: the connection is still correct, just less tuned. (page_size only takes +// effect on an empty database / next VACUUM, but applying it per-connection is harmless.) +var bestEffortPragmas = []string{ + "PRAGMA temp_store=MEMORY", // Store temp tables in memory + "PRAGMA mmap_size=268435456", // 256MB memory-mapped I/O + "PRAGMA page_size=4096", // 4KB pages (optimal for most systems) + "PRAGMA wal_autocheckpoint=1000", // Auto-checkpoint (PASSIVE) every 1000 WAL frames + "PRAGMA journal_size_limit=8388608", // Backstop: cap -wal at 8MiB (issue #49) +} + +// connectHook applies all pragmas to a freshly opened connection. mattn/go-sqlite3 calls +// it at the very end of Open, after DSN params and extensions, so it is authoritative. +func connectHook(c *sqlite3.SQLiteConn) error { + for _, pragma := range correctnessPragmas { + if _, err := c.Exec(pragma, nil); err != nil { + return fmt.Errorf("apply correctness pragma %q: %w", pragma, err) + } + } + for _, pragma := range bestEffortPragmas { + if _, err := c.Exec(pragma, nil); err != nil { + log.Warn().Str("pragma", pragma).Err(err).Msg("Failed to set pragma (non-fatal)") + } + } + return nil +} + +// registerDriver registers the custom driver once. sqlite_vec.Auto() registers the vec +// extension globally via sqlite3_auto_extension, which applies to connections from any +// sqlite3-based driver, so the new driver still gets vec + FTS5. +func registerDriver() { + registerDriverOnce.Do(func() { + sql.Register(driverName, &sqlite3.SQLiteDriver{ + ConnectHook: func(c *sqlite3.SQLiteConn) error { + return connectHook(c) + }, + }) + }) +} + // Store represents the GORM database connection with sqlite-vec support. type Store struct { healthCacheTime time.Time @@ -24,6 +86,7 @@ type Store struct { sqlDB *sql.DB metrics *PoolMetrics cachedHealth *HealthInfo + path string healthCacheTTL time.Duration healthCacheMu sync.RWMutex } @@ -36,22 +99,32 @@ type Config struct { } // NewStore creates a new Store with WAL mode enabled and sqlite-vec registered. -// CRITICAL: WAL mode and foreign keys are enabled via pragmas for concurrent reads. +// CRITICAL: all connection pragmas (WAL, foreign_keys, busy_timeout, etc.) are applied to +// EVERY pooled connection via a driver ConnectHook (see registerDriver), so the pool is +// uniformly configured and connections may be recycled safely (issue #49, F6). func NewStore(cfg Config) (*Store, error) { - // 1. Register sqlite-vec extension (must be done before opening database) + // 1. Register sqlite-vec extension (must be done before opening database). + // sqlite_vec.Auto() uses sqlite3_auto_extension, which is global to all sqlite3-based + // drivers, so connections from our custom driver also get the vec virtual table. sqlite_vec.Auto() - // 2. Build connection string (foreign keys enabled in DSN) - // Use sqlite3 driver (mattn/go-sqlite3) which has FTS5 support + // 2. Register the custom driver whose ConnectHook applies ALL pragmas to EVERY pooled + // connection (issue #49, F6). Without this, pragmas set via a single post-open + // sqlDB.Exec reach only one arbitrary pooled connection. The hook is authoritative. + registerDriver() + + // 3. Build a minimal DSN. _foreign_keys is kept as belt-and-suspenders (the hook sets + // it too); all other pragmas are applied per-connection by the ConnectHook, so they no + // longer need to live in the DSN. dsn := cfg.Path + "?_foreign_keys=ON" - // 3. Open raw database connection with mattn/go-sqlite3 (has FTS5 support) - sqlDB, err := sql.Open("sqlite3", dsn) + // 4. Open raw database connection with the custom driver (FTS5 + per-connection pragmas). + sqlDB, err := sql.Open(driverName, dsn) if err != nil { return nil, fmt.Errorf("open database: %w", err) } - // 4. Wrap with GORM using existing connection + // 5. Wrap with GORM using existing connection db, err := gorm.Open(sqlite.Dialector{ Conn: sqlDB, }, &gorm.Config{ @@ -66,16 +139,25 @@ func NewStore(cfg Config) (*Store, error) { return nil, fmt.Errorf("open gorm: %w", err) } - // 5. Configure connection pool (same settings as current implementation) + // 6. Configure connection pool. maxConns := cfg.MaxConns if maxConns <= 0 { maxConns = 4 } sqlDB.SetMaxOpenConns(maxConns) sqlDB.SetMaxIdleConns(maxConns) - sqlDB.SetConnMaxLifetime(0) // Never expire (SQLite connections are cheap) + // Finite recycling (issue #49): previously SetConnMaxLifetime(0) meant connections + // NEVER recycled, so a long-lived read connection could pin an old WAL read-mark for the + // whole process lifetime and block TRUNCATE checkpoints from reclaiming the -wal file. + // Recycling is safe now because the ConnectHook reapplies every correctness pragma on + // each new connection — a recycled connection comes back fully configured, not with + // defaults. 1h lifetime bounds read-mark staleness without churning the pool; 30m idle + // time reclaims connections that sit unused (e.g. between sessions) so the pool shrinks + // back to one warm connection during quiet periods, dropping their WAL read-marks. + sqlDB.SetConnMaxLifetime(1 * time.Hour) + sqlDB.SetConnMaxIdleTime(30 * time.Minute) - // 6. Verify connection + // 7. Verify connection if err := sqlDB.Ping(); err != nil { return nil, fmt.Errorf("ping database: %w", err) } @@ -83,37 +165,18 @@ func NewStore(cfg Config) (*Store, error) { store := &Store{ DB: db, sqlDB: sqlDB, + path: cfg.Path, metrics: NewPoolMetrics(100), // Track last 100 latency samples healthCacheTTL: 5 * time.Second, // Cache health checks for 5 seconds } - // 7. Run migrations FIRST (before PRAGMA commands) + // 8. Run migrations. All pragmas (correctness + best-effort) are applied per-connection + // by the ConnectHook at open time, so there is no post-open pragma loop here anymore: + // such a loop only ever reached one arbitrary pooled connection (issue #49, F6). if err := runMigrations(db, sqlDB); err != nil { return nil, fmt.Errorf("run migrations: %w", err) } - // 8. CRITICAL: Set WAL mode and other performance pragmas - // Use raw sqlDB to avoid GORM transaction issues - pragmas := []string{ - "PRAGMA journal_mode=WAL", - "PRAGMA synchronous=NORMAL", - "PRAGMA cache_size=-64000", // 64MB cache (negative = KB) - "PRAGMA temp_store=MEMORY", // Store temp tables in memory - "PRAGMA mmap_size=268435456", // 256MB memory-mapped I/O - "PRAGMA page_size=4096", // 4KB pages (optimal for most systems) - "PRAGMA wal_autocheckpoint=1000", // Explicit default; checkpoint every 1000 WAL frames - } - for _, pragma := range pragmas { - if _, err := sqlDB.Exec(pragma); err != nil { - log.Warn().Str("pragma", pragma).Err(err).Msg("Failed to set pragma (non-fatal)") - } - } - // Set busy timeout to 5 seconds to handle concurrent writes - // This allows SQLite to retry when database is locked instead of failing immediately - if _, err := sqlDB.Exec("PRAGMA busy_timeout=5000"); err != nil { - return nil, fmt.Errorf("set busy timeout: %w", err) - } - // 9. Warm the connection pool store.WarmPool(maxConns) @@ -148,11 +211,55 @@ func (s *Store) WarmPool(numConns int) { log.Debug().Int("connections", numConns).Msg("Connection pool warmed") } -// Close closes the database connection. +// Close checkpoints the WAL (TRUNCATE) before closing the connection. Checkpointing on +// shutdown prevents the WAL file from persisting in a large, dirty state across restarts +// and config-change reinitializations, which otherwise leaves a multi-megabyte -wal file +// on disk (issue #49). The checkpoint is best-effort: a failure is logged, not fatal. func (s *Store) Close() error { + ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second) + defer cancel() + if err := s.Checkpoint(ctx); err != nil { + log.Warn().Err(err).Msg("WAL checkpoint on close failed (non-fatal)") + } return s.sqlDB.Close() } +// Checkpoint runs a TRUNCATE WAL checkpoint: it flushes WAL frames into the main +// database file and shrinks the -wal file back to zero. Unlike a PASSIVE checkpoint +// (which never truncates the file and is all SQLite's auto-checkpoint ever performs), a +// TRUNCATE checkpoint reclaims disk and is the mechanism that bounds WAL growth. +// It waits up to the connection busy_timeout for the write lock and returns an error +// rather than blocking indefinitely. +func (s *Store) Checkpoint(ctx context.Context) error { + if _, err := s.sqlDB.ExecContext(ctx, "PRAGMA wal_checkpoint(TRUNCATE)"); err != nil { + return fmt.Errorf("wal checkpoint (truncate): %w", err) + } + return nil +} + +// WALSize returns the size in bytes of the SQLite WAL sidecar file (-wal), or 0 if +// it does not exist or cannot be stat'd. Used to decide when a checkpoint is worthwhile. +func (s *Store) WALSize() int64 { + if s.path == "" { + return 0 + } + info, err := os.Stat(s.path + "-wal") + if err != nil { + return 0 + } + return info.Size() +} + +// CheckpointIfLarge performs a TRUNCATE checkpoint only when the WAL file has grown to or +// beyond threshold bytes. Returns true if a checkpoint was performed. This keeps the +// periodic checkpoint cheap: it does no work while the WAL is small. +func (s *Store) CheckpointIfLarge(ctx context.Context, threshold int64) (bool, error) { + if s.WALSize() < threshold { + return false, nil + } + return true, s.Checkpoint(ctx) +} + // Ping verifies the database connection is alive. func (s *Store) Ping() error { return s.sqlDB.Ping() @@ -193,8 +300,9 @@ func (s *Store) Optimize(ctx context.Context) error { log.Warn().Err(err).Msg("PRAGMA optimize failed (non-fatal)") } - // Passive WAL checkpoint — doesn't block readers/writers - if _, err := s.sqlDB.ExecContext(ctx, "PRAGMA wal_checkpoint(PASSIVE)"); err != nil { + // TRUNCATE WAL checkpoint — reclaims the -wal file during low-activity optimization. + // (PASSIVE never shrinks the file, so it cannot bound WAL growth — see issue #49.) + if err := s.Checkpoint(ctx); err != nil { log.Warn().Err(err).Msg("WAL checkpoint failed (non-fatal)") } @@ -519,7 +627,6 @@ func (s *Store) ExecWithTimeout(ctx context.Context, timeout time.Duration, quer return nil } - // TransactionWithTimeout wraps a transaction function with timeout handling. // The transaction is automatically rolled back if the context times out. func (s *Store) TransactionWithTimeout(ctx context.Context, timeout time.Duration, fn func(*gorm.DB) error) error { diff --git a/internal/db/gorm/store_test.go b/internal/db/gorm/store_test.go index 9dfe80f..8bc7353 100644 --- a/internal/db/gorm/store_test.go +++ b/internal/db/gorm/store_test.go @@ -5,9 +5,12 @@ package gorm import ( "context" + "database/sql" "os" "path/filepath" + "strings" "testing" + "time" "gorm.io/gorm/logger" ) @@ -239,3 +242,321 @@ func TestOptimize_RespectsContextCancellation(t *testing.T) { t.Error("expected error with cancelled context, got nil") } } + +// growWAL inserts sizeable rows to push the SQLite WAL well past a few hundred KB so +// checkpoint behaviour can be observed. Returns the WAL file size after the inserts. +func growWAL(t *testing.T, store *Store, rows int) int64 { + t.Helper() + bigTitle := strings.Repeat("x", 2048) + for i := 0; i < rows; i++ { + _, err := store.GetRawDB().Exec( + "INSERT INTO observations (sdk_session_id, title, scope, project, type, created_at, created_at_epoch) "+ + "VALUES (?, ?, 'project', '/tmp/test', 'decision', '2026-01-01T00:00:00Z', 1735689600)", + "sess", bigTitle) + if err != nil { + t.Fatalf("insert row %d: %v", i, err) + } + } + return store.WALSize() +} + +func countObservations(t *testing.T, store *Store) int64 { + t.Helper() + var n int64 + if err := store.GetRawDB().QueryRow("SELECT COUNT(*) FROM observations").Scan(&n); err != nil { + t.Fatalf("count observations: %v", err) + } + return n +} + +// TestCheckpoint_TruncateShrinksWAL verifies Checkpoint() performs a TRUNCATE checkpoint +// that actually reclaims the -wal file. This is the load-bearing fix for issue #49: a +// PASSIVE checkpoint drains frames but never shrinks the file, so reverting Checkpoint to +// PASSIVE would leave the WAL grown and fail this test. +func TestCheckpoint_TruncateShrinksWAL(t *testing.T) { + tmpDir, err := os.MkdirTemp("", "gorm_checkpoint_truncate_*") + if err != nil { + t.Fatalf("create temp dir: %v", err) + } + defer os.RemoveAll(tmpDir) + + dbPath := filepath.Join(tmpDir, "test.db") + store, err := NewStore(Config{Path: dbPath, MaxConns: 2, LogLevel: logger.Silent}) + if err != nil { + t.Fatalf("NewStore failed: %v", err) + } + defer store.Close() + + walBefore := growWAL(t, store, 1000) + if walBefore < 64*1024 { + t.Fatalf("expected WAL to grow above 64KiB, got %d bytes", walBefore) + } + + if err := store.Checkpoint(context.Background()); err != nil { + t.Fatalf("Checkpoint failed: %v", err) + } + + walAfter := store.WALSize() + if walAfter >= walBefore { + t.Errorf("expected WAL to shrink after TRUNCATE checkpoint: before=%d after=%d", walBefore, walAfter) + } + if walAfter > 64*1024 { + t.Errorf("expected WAL truncated to near-zero, got %d bytes", walAfter) + } +} + +// TestCheckpointIfLarge_GatesOnThreshold verifies the size-gated periodic checkpoint used +// by the worker's walCheckpointLoop: a no-op below the threshold, a truncating checkpoint +// at/above it. +func TestCheckpointIfLarge_GatesOnThreshold(t *testing.T) { + tmpDir, err := os.MkdirTemp("", "gorm_checkpoint_gated_*") + if err != nil { + t.Fatalf("create temp dir: %v", err) + } + defer os.RemoveAll(tmpDir) + + dbPath := filepath.Join(tmpDir, "test.db") + store, err := NewStore(Config{Path: dbPath, MaxConns: 2, LogLevel: logger.Silent}) + if err != nil { + t.Fatalf("NewStore failed: %v", err) + } + defer store.Close() + + // Below an enormous threshold -> no checkpoint. + done, err := store.CheckpointIfLarge(context.Background(), 1<<30) // 1 GiB + if err != nil { + t.Fatalf("CheckpointIfLarge (small) failed: %v", err) + } + if done { + t.Errorf("expected no checkpoint below threshold, but one was performed") + } + + // Grow the WAL, then a low threshold triggers a truncating checkpoint. + walBefore := growWAL(t, store, 1000) + if walBefore < 64*1024 { + t.Fatalf("expected WAL to grow above 64KiB, got %d bytes", walBefore) + } + + done, err = store.CheckpointIfLarge(context.Background(), 64*1024) + if err != nil { + t.Fatalf("CheckpointIfLarge (large) failed: %v", err) + } + if !done { + t.Errorf("expected checkpoint above threshold, but none was performed") + } + if walAfter := store.WALSize(); walAfter >= walBefore { + t.Errorf("expected WAL to shrink after gated checkpoint: before=%d after=%d", walBefore, walAfter) + } +} + +// TestClose_CheckpointsWAL verifies Close() reclaims the WAL and leaves the data intact on +// the next open (issue #49: shutdown must not leave a large dirty WAL on disk). +func TestClose_CheckpointsWAL(t *testing.T) { + tmpDir, err := os.MkdirTemp("", "gorm_close_checkpoint_*") + if err != nil { + t.Fatalf("create temp dir: %v", err) + } + defer os.RemoveAll(tmpDir) + + dbPath := filepath.Join(tmpDir, "test.db") + store, err := NewStore(Config{Path: dbPath, MaxConns: 2, LogLevel: logger.Silent}) + if err != nil { + t.Fatalf("NewStore failed: %v", err) + } + + if walBefore := growWAL(t, store, 800); walBefore == 0 { + t.Fatalf("expected WAL to grow before close, got 0") + } + count := countObservations(t, store) + if count < 800 { + t.Fatalf("expected >=800 observations before close, got %d", count) + } + + if err := store.Close(); err != nil { + t.Fatalf("Close failed: %v", err) + } + + // The -wal file must not persist large on disk after a clean shutdown. + if info, statErr := os.Stat(dbPath + "-wal"); statErr == nil && info.Size() > 64*1024 { + t.Errorf("expected WAL reclaimed on close, -wal still %d bytes", info.Size()) + } + + // Reopen and verify data survived the checkpoint. + store2, err := NewStore(Config{Path: dbPath, MaxConns: 2, LogLevel: logger.Silent}) + if err != nil { + t.Fatalf("reopen NewStore failed: %v", err) + } + defer store2.Close() + + if count2 := countObservations(t, store2); count2 != count { + t.Errorf("expected %d observations after reopen, got %d", count, count2) + } +} + +// TestBusyTimeoutAppliedToAllConnections verifies the issue #49 DSN fix: busy_timeout is +// applied to EVERY pooled connection (not just one arbitrary connection as happened when +// it was set via a single post-open sqlDB.Exec). +func TestBusyTimeoutAppliedToAllConnections(t *testing.T) { + tmpDir, err := os.MkdirTemp("", "gorm_busy_timeout_*") + if err != nil { + t.Fatalf("create temp dir: %v", err) + } + defer os.RemoveAll(tmpDir) + + dbPath := filepath.Join(tmpDir, "test.db") + const maxConns = 4 + store, err := NewStore(Config{Path: dbPath, MaxConns: maxConns, LogLevel: logger.Silent}) + if err != nil { + t.Fatalf("NewStore failed: %v", err) + } + defer store.Close() + + // Pin all connections concurrently so each distinct connection is inspected, then + // assert every one reports busy_timeout=5000. + raw := store.GetRawDB() + conns := make([]*sql.Conn, 0, maxConns) + defer func() { + for _, c := range conns { + _ = c.Close() + } + }() + for i := 0; i < maxConns; i++ { + c, err := raw.Conn(context.Background()) + if err != nil { + t.Fatalf("acquire conn %d: %v", i, err) + } + conns = append(conns, c) + } + for i, c := range conns { + var timeout int + if err := c.QueryRowContext(context.Background(), "PRAGMA busy_timeout").Scan(&timeout); err != nil { + t.Fatalf("query busy_timeout on conn %d: %v", i, err) + } + if timeout != 5000 { + t.Errorf("conn %d: expected busy_timeout=5000, got %d", i, timeout) + } + } +} + +// TestAllPragmasAppliedToAllConnections verifies the issue #49 (F6) ConnectHook fix: not +// just busy_timeout but the full pragma set — including the best-effort pragmas that used +// to be set via a single post-open sqlDB.Exec (journal_size_limit, temp_store, +// wal_autocheckpoint) — is applied to EVERY pooled connection. It pins all connections so +// each distinct connection is inspected, then asserts each reports the expected value. +func TestAllPragmasAppliedToAllConnections(t *testing.T) { + tmpDir, err := os.MkdirTemp("", "gorm_all_pragmas_*") + if err != nil { + t.Fatalf("create temp dir: %v", err) + } + defer os.RemoveAll(tmpDir) + + dbPath := filepath.Join(tmpDir, "test.db") + const maxConns = 4 + store, err := NewStore(Config{Path: dbPath, MaxConns: maxConns, LogLevel: logger.Silent}) + if err != nil { + t.Fatalf("NewStore failed: %v", err) + } + defer store.Close() + + // Expected per-connection pragma values. temp_store=MEMORY reports as 2; the others are + // numeric. foreign_keys/journal_mode/synchronous are covered by TestNewStore and the + // busy_timeout test; here we focus on the previously single-connection pragmas. + checks := []struct { + name string + want int64 + }{ + {"busy_timeout", 5000}, + {"journal_size_limit", 8388608}, + {"temp_store", 2}, // 2 == MEMORY + {"wal_autocheckpoint", 1000}, + {"foreign_keys", 1}, + } + + raw := store.GetRawDB() + conns := make([]*sql.Conn, 0, maxConns) + defer func() { + for _, c := range conns { + _ = c.Close() + } + }() + for i := 0; i < maxConns; i++ { + c, err := raw.Conn(context.Background()) + if err != nil { + t.Fatalf("acquire conn %d: %v", i, err) + } + conns = append(conns, c) + } + + for i, c := range conns { + for _, chk := range checks { + var got int64 + query := "PRAGMA " + chk.name + if err := c.QueryRowContext(context.Background(), query).Scan(&got); err != nil { + t.Fatalf("conn %d: query %q: %v", i, chk.name, err) + } + if got != chk.want { + t.Errorf("conn %d: %s = %d, want %d", i, chk.name, got, chk.want) + } + } + } +} + +// TestRecycledConnectionRetainsPragmas verifies that recycling a connection (which now +// happens because SetConnMaxLifetime is finite, not 0) does NOT drop the correctness +// pragmas: the ConnectHook reapplies them on every new connection. We force recycling by +// setting a near-zero max lifetime so the next acquisition opens a fresh connection, then +// assert the new connection still reports the safe values rather than SQLite defaults +// (busy_timeout would default to 0 and journal_mode to "delete" without the hook). +func TestRecycledConnectionRetainsPragmas(t *testing.T) { + tmpDir, err := os.MkdirTemp("", "gorm_recycle_pragmas_*") + if err != nil { + t.Fatalf("create temp dir: %v", err) + } + defer os.RemoveAll(tmpDir) + + dbPath := filepath.Join(tmpDir, "test.db") + store, err := NewStore(Config{Path: dbPath, MaxConns: 2, LogLevel: logger.Silent}) + if err != nil { + t.Fatalf("NewStore failed: %v", err) + } + defer store.Close() + + raw := store.GetRawDB() + + // Force aggressive recycling: any connection older than 1ns is expired on next use, so + // database/sql opens a brand-new connection (running the ConnectHook again). + raw.SetConnMaxLifetime(time.Nanosecond) + time.Sleep(5 * time.Millisecond) + + // Acquire a connection that is necessarily freshly opened (all prior ones are expired), + // and verify the hook reapplied the correctness pragmas. + conn, err := raw.Conn(context.Background()) + if err != nil { + t.Fatalf("acquire recycled conn: %v", err) + } + defer conn.Close() + + var busyTimeout int + if err := conn.QueryRowContext(context.Background(), "PRAGMA busy_timeout").Scan(&busyTimeout); err != nil { + t.Fatalf("query busy_timeout: %v", err) + } + if busyTimeout != 5000 { + t.Errorf("recycled conn: busy_timeout = %d, want 5000 (hook did not reapply)", busyTimeout) + } + + var journalMode string + if err := conn.QueryRowContext(context.Background(), "PRAGMA journal_mode").Scan(&journalMode); err != nil { + t.Fatalf("query journal_mode: %v", err) + } + if journalMode != "wal" { + t.Errorf("recycled conn: journal_mode = %q, want \"wal\" (hook did not reapply)", journalMode) + } + + var foreignKeys int + if err := conn.QueryRowContext(context.Background(), "PRAGMA foreign_keys").Scan(&foreignKeys); err != nil { + t.Fatalf("query foreign_keys: %v", err) + } + if foreignKeys != 1 { + t.Errorf("recycled conn: foreign_keys = %d, want 1 (hook did not reapply)", foreignKeys) + } +} diff --git a/internal/maintenance/service.go b/internal/maintenance/service.go index 9be756d..74da8d0 100644 --- a/internal/maintenance/service.go +++ b/internal/maintenance/service.go @@ -83,8 +83,20 @@ func (s *Service) Start(ctx context.Context) { Bool("cleanup_stale", s.config.CleanupStaleObservations). Msg("Starting maintenance scheduler") - // Initial run after 5 minutes (allow system to stabilize) - time.Sleep(5 * time.Minute) + // Initial run after 5 minutes (allow system to stabilize). + // Use a cancellable timer so shutdown (ctx cancel / Stop) is not blocked for up to 5m. + initialDelay := time.NewTimer(5 * time.Minute) + select { + case <-ctx.Done(): + initialDelay.Stop() + s.log.Info().Msg("Maintenance shutting down before initial run (context cancellation)") + return + case <-s.stopCh: + initialDelay.Stop() + s.log.Info().Msg("Maintenance shutting down before initial run (stop signal)") + return + case <-initialDelay.C: + } s.runMaintenance(ctx) ticker := time.NewTicker(interval) @@ -288,7 +300,13 @@ func (s *Service) Stats() map[string]any { } } -// RunNow triggers an immediate maintenance run. +// RunNow triggers an immediate maintenance run in the background. func (s *Service) RunNow(ctx context.Context) { go s.runMaintenance(ctx) } + +// RunNowSync triggers an immediate maintenance run and blocks until it completes. +// Use this when the caller needs to report a synchronous result (e.g. an HTTP handler). +func (s *Service) RunNowSync(ctx context.Context) { + s.runMaintenance(ctx) +} diff --git a/internal/mcp/server.go b/internal/mcp/server.go index b7c105b..e26ba4d 100644 --- a/internal/mcp/server.go +++ b/internal/mcp/server.go @@ -534,7 +534,7 @@ func (s *Server) handleToolsList(req *Request) *Response { }, { Name: "trigger_maintenance", - Description: "Trigger an immediate maintenance run (cleanup old observations, optimize database).", + Description: "Trigger an immediate database maintenance run: optimize/checkpoint the database, clean up old prompts, apply any configured observation retention/stale cleanup, and recalculate importance scores.", InputSchema: map[string]any{ "type": "object", "properties": map[string]any{}, @@ -917,7 +917,7 @@ func (s *Server) callTool(ctx context.Context, name string, args json.RawMessage "project": s.project, }) case "trigger_maintenance": - return s.proxyPostRaw(ctx, "/api/scoring/recalculate", nil) + return s.proxyPostRaw(ctx, "/api/maintenance/run", nil) case "analyze_observation_importance": return s.handleAnalyzeImportanceProxy(ctx, args) case "analyze_search_patterns": diff --git a/internal/worker/handlers_scoring.go b/internal/worker/handlers_scoring.go index 0090056..259d8a5 100644 --- a/internal/worker/handlers_scoring.go +++ b/internal/worker/handlers_scoring.go @@ -314,6 +314,49 @@ func (s *Service) handleTriggerRecalculation(w http.ResponseWriter, r *http.Requ writeJSON(w, map[string]string{"status": "recalculation triggered"}) } +// handleRunMaintenance triggers an immediate, synchronous database maintenance run +// (Optimize/TRUNCATE checkpoint + prompt cleanup + any enabled retention/stale cleanup) +// and also kicks off an importance-score recalculation in the background so the behavior +// of the previous trigger_maintenance tool is preserved (issue #49). +func (s *Service) handleRunMaintenance(w http.ResponseWriter, r *http.Request) { + // initMu.RLock held by requireReady middleware + maintSvc := s.maintenanceSvc + recalculator := s.recalculator + + if maintSvc == nil { + http.Error(w, "maintenance service not available", http.StatusServiceUnavailable) + return + } + + // Run maintenance synchronously with an independent, bounded context so the caller + // receives a real completion status. Use context.Background so an HTTP client timeout + // does not abort an in-progress DB maintenance pass. + mctx, cancel := context.WithTimeout(context.Background(), 60*time.Second) + defer cancel() + maintSvc.RunNowSync(mctx) + + // Preserve prior trigger_maintenance behavior: also recalculate importance scores. + recalcTriggered := false + if recalculator != nil { + recalcTriggered = true + s.wg.Add(1) + go func() { + defer s.wg.Done() + ctx, cancel := context.WithTimeout(context.Background(), 60*time.Second) + defer cancel() + if err := recalculator.RecalculateNow(ctx); err != nil { + log.Error().Err(err).Msg("Background recalculation during maintenance failed") + } + }() + } + + writeJSON(w, map[string]any{ + "status": "maintenance completed", + "recalc_triggered": recalcTriggered, + "maintenance_stats": maintSvc.Stats(), + }) +} + // parseIntParam parses an integer query parameter with a default value. func parseIntParam(r *http.Request, name string, defaultVal int) int { if val := r.URL.Query().Get(name); val != "" { diff --git a/internal/worker/service.go b/internal/worker/service.go index c854d25..875b0d7 100644 --- a/internal/worker/service.go +++ b/internal/worker/service.go @@ -16,6 +16,7 @@ import ( "github.com/lukaszraczylo/claude-mnemonic/internal/config" "github.com/lukaszraczylo/claude-mnemonic/internal/db/gorm" "github.com/lukaszraczylo/claude-mnemonic/internal/embedding" + "github.com/lukaszraczylo/claude-mnemonic/internal/maintenance" "github.com/lukaszraczylo/claude-mnemonic/internal/pattern" "github.com/lukaszraczylo/claude-mnemonic/internal/reranking" "github.com/lukaszraczylo/claude-mnemonic/internal/scoring" @@ -44,6 +45,14 @@ const ( // QueueProcessInterval is how often the background queue processor runs. QueueProcessInterval = 2 * time.Second + // WALCheckpointInterval is how often the worker checks whether the SQLite WAL needs a + // TRUNCATE checkpoint to reclaim disk and prevent unbounded growth (issue #49). + WALCheckpointInterval = 60 * time.Second + + // WALCheckpointThreshold is the WAL file size at or above which the periodic check + // performs a TRUNCATE checkpoint. Keeps the steady-state WAL bounded to a few MB. + WALCheckpointThreshold = 4 << 20 // 4 MiB + // reinitializationDrainDelay is the delay after marking the service as not ready // to allow in-flight requests to complete before reinitializing. reinitializationDrainDelay = 200 * time.Millisecond @@ -121,6 +130,7 @@ type Service struct { patternStore *gorm.PatternStore relationStore *gorm.RelationStore patternDetector *pattern.Detector + maintenanceSvc *maintenance.Service sessionManager *session.Manager sseBroadcaster *sse.Broadcaster processor *sdk.Processor @@ -570,6 +580,34 @@ func (s *Service) initializeAsync() { go s.processQueue() } + // Start periodic WAL checkpoint loop to bound SQLite WAL file growth (issue #49). + s.wg.Add(1) + go s.walCheckpointLoop() + + // Start the scheduled maintenance service (issue #49: was dead code, never instantiated). + // vectorCleanupFn mirrors the observation store's cleanup hook so age/stale deletions done + // directly via GORM still remove their vectors from sqlite-vec. + var vectorCleanupFn func(ctx context.Context, deletedIDs []int64) + if vectorSync != nil { + vectorCleanupFn = func(ctx context.Context, deletedIDs []int64) { + if err := retryWithBackoff(ctx, VectorSyncMaxRetries, VectorSyncInitialBackoff, func() error { + return vectorSync.DeleteObservations(ctx, deletedIDs) + }); err != nil { + log.Warn().Err(err).Ints64("ids", deletedIDs).Msg("Failed to delete observations from sqlite-vec during maintenance") + } + } + } + maintSvc := maintenance.NewService(store, observationStore, summaryStore, promptStore, vectorCleanupFn, s.config, log.Logger) + s.initMu.Lock() + s.maintenanceSvc = maintSvc + s.initMu.Unlock() + s.wg.Add(1) + go func() { + defer s.wg.Done() + maintSvc.Start(s.ctx) + }() + log.Info().Msg("Maintenance scheduler started") + // Start file watchers for auto-recreation on deletion s.startWatchers() @@ -1290,6 +1328,9 @@ func (s *Service) setupRoutes() { r.Put("/api/scoring/concepts/{concept}", s.handleUpdateConceptWeight) r.Post("/api/scoring/recalculate", s.handleTriggerRecalculation) + // Maintenance: run an immediate synchronous DB maintenance pass (issue #49) + r.Post("/api/maintenance/run", s.handleRunMaintenance) + // Context injection r.Get("/api/context/count", s.handleContextCount) r.Get("/api/context/inject", s.handleContextInject) @@ -1621,6 +1662,52 @@ func (s *Service) processQueue() { } } +// walCheckpointLoop periodically checkpoints the SQLite WAL so it cannot grow unbounded +// during long-lived sessions. SQLite's internal auto-checkpoint is PASSIVE and never +// shrinks the -wal file; under sustained writes with overlapping readers it can leave the +// WAL growing without limit (issue #49). This loop performs a TRUNCATE checkpoint whenever +// the WAL has grown to WALCheckpointThreshold, and does nothing while it is small. +func (s *Service) walCheckpointLoop() { + defer s.wg.Done() + + // Tunable via config; fall back to the package constants when unset/<=0 (issue #49). + interval := WALCheckpointInterval + if s.config != nil && s.config.WALCheckpointIntervalSeconds > 0 { + interval = time.Duration(s.config.WALCheckpointIntervalSeconds) * time.Second + } + threshold := int64(WALCheckpointThreshold) + if s.config != nil && s.config.WALCheckpointThresholdBytes > 0 { + threshold = s.config.WALCheckpointThresholdBytes + } + + ticker := time.NewTicker(interval) + defer ticker.Stop() + + for { + select { + case <-s.ctx.Done(): + return + case <-ticker.C: + s.initMu.RLock() + store := s.store + s.initMu.RUnlock() + if store == nil { + continue + } + + ctx, cancel := context.WithTimeout(s.ctx, 15*time.Second) + done, err := store.CheckpointIfLarge(ctx, threshold) + cancel() + switch { + case err != nil: + log.Warn().Err(err).Msg("Periodic WAL checkpoint failed (non-fatal)") + case done: + log.Debug().Msg("Periodic WAL checkpoint (TRUNCATE) completed") + } + } + } +} + // processAllSessions processes pending messages for all active sessions. // Messages are processed in parallel using goroutines, with concurrency // limited by a channel-based semaphore. @@ -1748,6 +1835,9 @@ func (s *Service) Shutdown(ctx context.Context) error { if s.patternDetector != nil { s.patternDetector.Stop() } + if s.maintenanceSvc != nil { + s.maintenanceSvc.Stop() + } // Phase 4: Shutdown sessions (flush pending work) log.Debug().Msg("Phase 4: Shutting down sessions...") diff --git a/pkg/hooks/worker.go b/pkg/hooks/worker.go index 0d92147..374a188 100644 --- a/pkg/hooks/worker.go +++ b/pkg/hooks/worker.go @@ -564,6 +564,44 @@ func POSTWithContext(ctx context.Context, port int, path string, body interface{ return nil } +// POSTWithContextResult sends a POST request using the provided context and +// decodes the JSON response body, mirroring POST but honoring ctx for +// cancellation/deadline. Used on the prompt critical path so a wedged worker +// aborts at the hook deadline instead of blocking for the full client timeout. +// A non-JSON body is returned as (nil, nil), matching POST's behavior. +func POSTWithContextResult(ctx context.Context, port int, path string, body interface{}) (map[string]interface{}, error) { + jsonBody, err := json.Marshal(body) + if err != nil { + return nil, err + } + + req, err := http.NewRequestWithContext(ctx, http.MethodPost, + fmt.Sprintf("http://127.0.0.1:%d%s", port, path), + bytes.NewReader(jsonBody)) + if err != nil { + return nil, err + } + req.Header.Set("Content-Type", "application/json") + + resp, err := hookClient.Do(req) + if err != nil { + return nil, err + } + defer func() { _ = resp.Body.Close() }() + + if resp.StatusCode >= 400 { + return nil, fmt.Errorf("request failed: %s", resp.Status) + } + + var result map[string]interface{} + if err := json.NewDecoder(resp.Body).Decode(&result); err != nil { + // Not all endpoints return JSON + return nil, nil + } + + return result, nil +} + // GET sends a GET request to the worker. func GET(port int, path string) (map[string]interface{}, error) { resp, err := hookClient.Get(fmt.Sprintf("http://127.0.0.1:%d%s", port, path)) @@ -584,6 +622,35 @@ func GET(port int, path string) (map[string]interface{}, error) { return result, nil } +// GETWithContext sends a GET request using the provided context and decodes the +// JSON response body, mirroring GET but honoring ctx for cancellation/deadline. +// Used on the prompt critical path so a wedged worker aborts at the hook +// deadline instead of blocking for the full client timeout. +func GETWithContext(ctx context.Context, port int, path string) (map[string]interface{}, error) { + req, err := http.NewRequestWithContext(ctx, http.MethodGet, + fmt.Sprintf("http://127.0.0.1:%d%s", port, path), nil) + if err != nil { + return nil, err + } + + resp, err := hookClient.Do(req) + if err != nil { + return nil, err + } + defer func() { _ = resp.Body.Close() }() + + if resp.StatusCode >= 400 { + return nil, fmt.Errorf("request failed: %s", resp.Status) + } + + var result map[string]interface{} + if err := json.NewDecoder(resp.Body).Decode(&result); err != nil { + return nil, err + } + + return result, nil +} + // versionsCompatible checks if two versions are compatible for dev builds. // Returns true if both versions share the same base version (ignoring -dirty, -dev, commit suffixes). // This prevents unnecessary restarts during development. diff --git a/pkg/hooks/worker_test.go b/pkg/hooks/worker_test.go index 8adee07..0b6065c 100644 --- a/pkg/hooks/worker_test.go +++ b/pkg/hooks/worker_test.go @@ -952,6 +952,203 @@ func TestGET_Timeout(t *testing.T) { require.Error(t, err) } +// TestGETWithContext tests GETWithContext with a mock server. +func TestGETWithContext(t *testing.T) { + tests := []struct { + serverHandler func(w http.ResponseWriter, r *http.Request) + expectedResult map[string]interface{} + name string + expectError bool + }{ + { + name: "successful GET with JSON response", + serverHandler: func(w http.ResponseWriter, r *http.Request) { + assert.Equal(t, http.MethodGet, r.Method) + w.WriteHeader(http.StatusOK) + _ = json.NewEncoder(w).Encode(map[string]interface{}{"data": "test"}) + }, + expectError: false, + expectedResult: map[string]interface{}{"data": "test"}, + }, + { + name: "GET with 404 error", + serverHandler: func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(http.StatusNotFound) + }, + expectError: true, + }, + { + name: "GET with invalid JSON", + serverHandler: func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(http.StatusOK) + _, _ = w.Write([]byte("not valid json")) + }, + expectError: true, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + server := httptest.NewServer(http.HandlerFunc(tt.serverHandler)) + defer server.Close() + + var port int + _, err := fmt.Sscanf(server.URL, "http://127.0.0.1:%d", &port) + require.NoError(t, err) + + result, err := GETWithContext(context.Background(), port, "/test") + + if tt.expectError { + assert.Error(t, err) + } else { + assert.NoError(t, err) + if tt.expectedResult != nil { + assert.Equal(t, tt.expectedResult["data"], result["data"]) + } + } + }) + } +} + +// TestGETWithContext_Timeout verifies the context deadline aborts a slow server +// well before the hookClient timeout, so a wedged worker cannot stall the prompt. +func TestGETWithContext_Timeout(t *testing.T) { + // Server that blocks longer than the context deadline. + blockUntil := make(chan struct{}) + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + <-blockUntil // never closed during the test -> server hangs + })) + defer server.Close() + defer close(blockUntil) + + var port int + _, err := fmt.Sscanf(server.URL, "http://127.0.0.1:%d", &port) + require.NoError(t, err) + + ctx, cancel := context.WithTimeout(context.Background(), 100*time.Millisecond) + defer cancel() + + start := time.Now() + _, err = GETWithContext(ctx, port, "/test") + elapsed := time.Since(start) + + require.Error(t, err) + // Should abort near the 100ms deadline, far below hookClient's 10s timeout. + assert.Less(t, elapsed, 2*time.Second, "context deadline must abort the request quickly") +} + +// TestGETWithContext_CancelledContext verifies an already-cancelled context +// returns immediately without making a real request. +func TestGETWithContext_CancelledContext(t *testing.T) { + ctx, cancel := context.WithCancel(context.Background()) + cancel() // cancel immediately + + start := time.Now() + _, err := GETWithContext(ctx, 99994, "/test") + elapsed := time.Since(start) + + require.Error(t, err) + assert.Less(t, elapsed, 1*time.Second, "cancelled context should return immediately") +} + +// TestPOSTWithContextResult tests POSTWithContextResult with a mock server. +func TestPOSTWithContextResult(t *testing.T) { + tests := []struct { + body interface{} + serverHandler func(w http.ResponseWriter, r *http.Request) + expectedResult map[string]interface{} + name string + expectError bool + }{ + { + name: "successful POST with JSON response", + serverHandler: func(w http.ResponseWriter, r *http.Request) { + assert.Equal(t, http.MethodPost, r.Method) + assert.Equal(t, "application/json", r.Header.Get("Content-Type")) + w.WriteHeader(http.StatusOK) + _ = json.NewEncoder(w).Encode(map[string]interface{}{"status": "ok"}) + }, + body: map[string]string{"key": "value"}, + expectError: false, + expectedResult: map[string]interface{}{"status": "ok"}, + }, + { + name: "POST with 400 error", + serverHandler: func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(http.StatusBadRequest) + }, + body: map[string]string{"key": "value"}, + expectError: true, + }, + { + name: "POST with non-JSON response returns nil", + serverHandler: func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(http.StatusOK) + _, _ = w.Write([]byte("not json")) + }, + body: map[string]string{"key": "value"}, + expectError: false, + expectedResult: nil, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + server := httptest.NewServer(http.HandlerFunc(tt.serverHandler)) + defer server.Close() + + var port int + _, err := fmt.Sscanf(server.URL, "http://127.0.0.1:%d", &port) + require.NoError(t, err) + + result, err := POSTWithContextResult(context.Background(), port, "/test", tt.body) + + if tt.expectError { + assert.Error(t, err) + } else { + assert.NoError(t, err) + if tt.expectedResult != nil { + assert.Equal(t, tt.expectedResult["status"], result["status"]) + } else { + assert.Nil(t, result) + } + } + }) + } +} + +// TestPOSTWithContextResult_MarshalError tests POSTWithContextResult with an unmarshalable body. +func TestPOSTWithContextResult_MarshalError(t *testing.T) { + badValue := make(chan int) + _, err := POSTWithContextResult(context.Background(), 99999, "/test", badValue) + require.Error(t, err) +} + +// TestPOSTWithContextResult_Timeout verifies the context deadline aborts a slow +// server before the hookClient timeout. +func TestPOSTWithContextResult_Timeout(t *testing.T) { + blockUntil := make(chan struct{}) + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + <-blockUntil + })) + defer server.Close() + defer close(blockUntil) + + var port int + _, err := fmt.Sscanf(server.URL, "http://127.0.0.1:%d", &port) + require.NoError(t, err) + + ctx, cancel := context.WithTimeout(context.Background(), 100*time.Millisecond) + defer cancel() + + start := time.Now() + _, err = POSTWithContextResult(ctx, port, "/test", map[string]string{"k": "v"}) + elapsed := time.Since(start) + + require.Error(t, err) + assert.Less(t, elapsed, 2*time.Second, "context deadline must abort the request quickly") +} + // TestIsWorkerRunning_Timeout tests IsWorkerRunning with timeout. func TestIsWorkerRunning_Timeout(t *testing.T) { // Non-existent port should quickly return false