mirror of
https://github.com/lukaszraczylo/claude-mnemonic.git
synced 2026-06-05 23:03:55 +00:00
general improvements (#17)
* refactor(hooks): simplify hook execution with shared context - [x] Extract BaseInput struct to eliminate duplicate fields across hooks - [x] Create RunHook handler pattern for session-start and user-prompt - [x] Create RunStatuslineHook for fast statusline rendering without worker startup - [x] Add HookContext struct to pass port, project, CWD, SessionID to handlers - [x] Add db/interface.go with ObservationReader/Writer interfaces - [x] Add comprehensive conflict management tests in sqlite/conflict_test.go - [x] Add vector client tests for Count, ModelVersion, NeedsRebuild, GetStaleVectors - [x] Add FilterByThreshold helper tests for query result filtering - [x] Make handlers_test more robust for network-dependent update checks - [x] Update package versions in UI * Move to GORM + general cleanup * feat(mcp): add observation relations discovery and scoring integration - [x] Add find_related_observations MCP tool for discovering related observations by confidence - [x] Integrate scoring calculator and recalculator into MCP server initialization - [x] Add pattern, relation, and session stores to MCP server dependencies - [x] Register MCP server in Claude Code settings during plugin installation - [x] Update install scripts (bash, PowerShell) to configure MCP server settings - [x] Switch plugin manifest files to template-based versioning (plugin.json.tpl, marketplace.json.tpl) - [x] Update all MCP server tests to pass new dependency parameters
This commit is contained in:
@@ -83,3 +83,10 @@ logs/
|
||||
dist/
|
||||
docs/dist
|
||||
.claude-plugin
|
||||
|
||||
# Auto-generated plugin configs (generated by scripts/generate-plugin-config.sh)
|
||||
.claude-plugin/
|
||||
|
||||
# Non-template plugin configs (keep only .tpl files)
|
||||
plugin/.claude-plugin/plugin.json
|
||||
plugin/.claude-plugin/marketplace.json
|
||||
|
||||
@@ -0,0 +1,24 @@
|
||||
# Project-specific golangci-lint configuration for claude-mnemonic
|
||||
# Inherits from global ~/.golangci.yml and adds project-specific exclusions
|
||||
|
||||
issues:
|
||||
exclude-rules:
|
||||
# Project-specific: Exclude unused warnings for public API functions in pkg/models
|
||||
# These detection functions are part of the public API
|
||||
- path: pkg/models/(conflict|relation)\.go
|
||||
linters:
|
||||
- unused
|
||||
text: "(Detect|New)"
|
||||
|
||||
# Project-specific: Test helper method used only in tests
|
||||
- path: internal/db/gorm/store\.go
|
||||
linters:
|
||||
- unused
|
||||
text: "GetDB"
|
||||
|
||||
exclude-dirs:
|
||||
- vendor
|
||||
|
||||
run:
|
||||
timeout: 5m
|
||||
tests: true
|
||||
@@ -146,8 +146,8 @@ install: build stop-worker
|
||||
@# Copy slash commands if they exist
|
||||
@if [ -d "$(PLUGIN_DIR)/commands" ]; then cp -r $(PLUGIN_DIR)/commands/* $(HOME)/.claude/plugins/marketplaces/claude-mnemonic/commands/ 2>/dev/null || true; fi
|
||||
@# Update plugin.json and marketplace.json with current version to prevent stale version directories
|
||||
@sed 's/"version": "[^"]*"/"version": "$(VERSION)"/g' $(PLUGIN_DIR)/.claude-plugin/plugin.json > $(HOME)/.claude/plugins/marketplaces/claude-mnemonic/.claude-plugin/plugin.json
|
||||
@sed 's/"version": "[^"]*"/"version": "$(VERSION)"/g' $(PLUGIN_DIR)/.claude-plugin/marketplace.json > $(HOME)/.claude/plugins/marketplaces/claude-mnemonic/.claude-plugin/marketplace.json
|
||||
@sed 's/{{ .Version }}/$(VERSION)/g; s/{{.Version}}/$(VERSION)/g' $(PLUGIN_DIR)/.claude-plugin/plugin.json.tpl > $(HOME)/.claude/plugins/marketplaces/claude-mnemonic/.claude-plugin/plugin.json
|
||||
@sed 's/{{ .Version }}/$(VERSION)/g; s/{{.Version}}/$(VERSION)/g' $(PLUGIN_DIR)/.claude-plugin/marketplace.json.tpl > $(HOME)/.claude/plugins/marketplaces/claude-mnemonic/.claude-plugin/marketplace.json
|
||||
@echo "Registering plugin with Claude Code..."
|
||||
@./scripts/register-plugin.sh "$(VERSION)"
|
||||
@$(MAKE) start-worker
|
||||
|
||||
@@ -2,9 +2,7 @@
|
||||
package main
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"io"
|
||||
"net/url"
|
||||
"os"
|
||||
"strings"
|
||||
@@ -14,11 +12,8 @@ import (
|
||||
|
||||
// Input is the hook input from Claude Code.
|
||||
type Input struct {
|
||||
SessionID string `json:"session_id"`
|
||||
CWD string `json:"cwd"`
|
||||
PermissionMode string `json:"permission_mode"`
|
||||
HookEventName string `json:"hook_event_name"`
|
||||
Source string `json:"source"` // "startup", "resume", "clear", "compact"
|
||||
hooks.BaseInput
|
||||
Source string `json:"source"` // "startup", "resume", "clear", "compact"
|
||||
}
|
||||
|
||||
// Observation represents an observation from the API.
|
||||
@@ -32,53 +27,26 @@ type Observation struct {
|
||||
}
|
||||
|
||||
func main() {
|
||||
// Skip if this is an internal call (from SDK processor)
|
||||
if os.Getenv("CLAUDE_MNEMONIC_INTERNAL") == "1" {
|
||||
hooks.WriteResponse("SessionStart", true)
|
||||
return
|
||||
}
|
||||
|
||||
// Read input from stdin
|
||||
inputData, err := io.ReadAll(os.Stdin)
|
||||
if err != nil {
|
||||
hooks.WriteError("SessionStart", err)
|
||||
os.Exit(1)
|
||||
}
|
||||
|
||||
var input Input
|
||||
if err := json.Unmarshal(inputData, &input); err != nil {
|
||||
hooks.WriteError("SessionStart", err)
|
||||
os.Exit(1)
|
||||
}
|
||||
|
||||
// Ensure worker is running
|
||||
port, err := hooks.EnsureWorkerRunning()
|
||||
if err != nil {
|
||||
hooks.WriteError("SessionStart", err)
|
||||
os.Exit(1)
|
||||
}
|
||||
|
||||
// Generate unique project ID from CWD (dirname_hash format)
|
||||
project := hooks.ProjectIDWithName(input.CWD)
|
||||
hooks.RunHook("SessionStart", handleSessionStart)
|
||||
}
|
||||
|
||||
func handleSessionStart(ctx *hooks.HookContext, input *Input) (string, error) {
|
||||
// Fetch observations for context injection
|
||||
endpoint := fmt.Sprintf("/api/context/inject?project=%s&cwd=%s",
|
||||
url.QueryEscape(project),
|
||||
url.QueryEscape(input.CWD))
|
||||
url.QueryEscape(ctx.Project),
|
||||
url.QueryEscape(ctx.CWD))
|
||||
|
||||
result, err := hooks.GET(port, endpoint)
|
||||
result, err := hooks.GET(ctx.Port, endpoint)
|
||||
if err != nil {
|
||||
fmt.Fprintf(os.Stderr, "[claude-mnemonic] Warning: context fetch failed: %v\n", err)
|
||||
hooks.WriteResponse("SessionStart", true)
|
||||
return
|
||||
return "", nil
|
||||
}
|
||||
|
||||
// Parse observations from response
|
||||
obsData, ok := result["observations"].([]interface{})
|
||||
if !ok || len(obsData) == 0 {
|
||||
// No observations - just continue normally
|
||||
hooks.WriteResponse("SessionStart", true)
|
||||
return
|
||||
return "", nil
|
||||
}
|
||||
|
||||
// Get full_count from response (how many observations get full detail)
|
||||
@@ -136,17 +104,7 @@ func main() {
|
||||
}
|
||||
|
||||
contextBuilder += "</claude-mnemonic-context>\n"
|
||||
|
||||
// Output context as JSON with additionalContext field
|
||||
response := map[string]interface{}{
|
||||
"continue": true,
|
||||
"hookSpecificOutput": map[string]interface{}{
|
||||
"hookEventName": "SessionStart",
|
||||
"additionalContext": contextBuilder,
|
||||
},
|
||||
}
|
||||
_ = json.NewEncoder(os.Stdout).Encode(response)
|
||||
os.Exit(0)
|
||||
return contextBuilder, nil
|
||||
}
|
||||
|
||||
func getString(m map[string]interface{}, key string) string {
|
||||
|
||||
@@ -5,7 +5,6 @@ package main
|
||||
import (
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"io"
|
||||
"net/http"
|
||||
"net/url"
|
||||
"os"
|
||||
@@ -72,18 +71,13 @@ const (
|
||||
)
|
||||
|
||||
func main() {
|
||||
// Read input from stdin
|
||||
inputData, err := io.ReadAll(os.Stdin)
|
||||
if err != nil {
|
||||
// On error, output minimal status
|
||||
fmt.Println(formatOffline())
|
||||
return
|
||||
}
|
||||
hooks.RunStatuslineHook(handleStatusline)
|
||||
}
|
||||
|
||||
var input StatusInput
|
||||
if err := json.Unmarshal(inputData, &input); err != nil {
|
||||
fmt.Println(formatOffline())
|
||||
return
|
||||
func handleStatusline(input *StatusInput, port int) string {
|
||||
// Handle error cases (nil input)
|
||||
if input == nil {
|
||||
return formatOffline()
|
||||
}
|
||||
|
||||
// Determine project directory
|
||||
@@ -102,16 +96,14 @@ func main() {
|
||||
}
|
||||
|
||||
// Get worker stats
|
||||
stats := getWorkerStats(project)
|
||||
stats := getWorkerStats(port, project)
|
||||
|
||||
// Format and output statusline
|
||||
fmt.Println(formatStatusLine(stats, input))
|
||||
// Format and return statusline
|
||||
return formatStatusLine(stats, *input)
|
||||
}
|
||||
|
||||
// getWorkerStats fetches stats from the worker service.
|
||||
func getWorkerStats(project string) *WorkerStats {
|
||||
port := hooks.GetWorkerPort()
|
||||
|
||||
func getWorkerStats(port int, project string) *WorkerStats {
|
||||
// Build URL with optional project parameter
|
||||
endpoint := fmt.Sprintf("http://127.0.0.1:%d/api/stats", port)
|
||||
if project != "" {
|
||||
@@ -187,54 +179,45 @@ func formatDefault(stats *WorkerStats, useColors bool) string {
|
||||
}
|
||||
|
||||
// Build status parts with clear labels
|
||||
parts := []string{}
|
||||
parts := []string{
|
||||
prefix,
|
||||
indicator,
|
||||
}
|
||||
|
||||
// Total memories served to Claude this session
|
||||
parts = append(parts, fmt.Sprintf("served:%d", stats.Retrieval.ObservationsServed))
|
||||
|
||||
// Context injections (memories auto-loaded at session start)
|
||||
// Add retrieval stats if available
|
||||
if stats.Retrieval.ObservationsServed > 0 {
|
||||
parts = append(parts, fmt.Sprintf("served:%d", stats.Retrieval.ObservationsServed))
|
||||
}
|
||||
if stats.Retrieval.ContextInjections > 0 {
|
||||
parts = append(parts, fmt.Sprintf("injected:%d", stats.Retrieval.ContextInjections))
|
||||
}
|
||||
|
||||
// Semantic searches performed
|
||||
if stats.Retrieval.SearchRequests > 0 {
|
||||
parts = append(parts, fmt.Sprintf("searches:%d", stats.Retrieval.SearchRequests))
|
||||
}
|
||||
|
||||
// Project-specific memory count
|
||||
// Add project-specific observation count if available
|
||||
if stats.ProjectObservations > 0 {
|
||||
if useColors {
|
||||
parts = append(parts, fmt.Sprintf("%sproject:%d memories%s", colorYellow, stats.ProjectObservations, reset))
|
||||
} else {
|
||||
parts = append(parts, fmt.Sprintf("project:%d memories", stats.ProjectObservations))
|
||||
}
|
||||
parts = append(parts, fmt.Sprintf("project:%d memories", stats.ProjectObservations))
|
||||
}
|
||||
|
||||
// Processing indicator
|
||||
if stats.IsProcessing || stats.QueueDepth > 0 {
|
||||
if useColors {
|
||||
parts = append(parts, colorYellow+"processing..."+colorReset)
|
||||
} else {
|
||||
parts = append(parts, "processing...")
|
||||
}
|
||||
}
|
||||
|
||||
result := prefix + " " + indicator
|
||||
for i, part := range parts {
|
||||
if i == 0 {
|
||||
result += " " + part
|
||||
} else {
|
||||
result += " | " + part
|
||||
// Join with separators
|
||||
result := parts[0] + " " + parts[1]
|
||||
if len(parts) > 2 {
|
||||
for i := 2; i < len(parts); i++ {
|
||||
if useColors {
|
||||
result += colorGray + " | " + reset + parts[i]
|
||||
} else {
|
||||
result += " | " + parts[i]
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return result
|
||||
}
|
||||
|
||||
// formatCompact returns a compact status line.
|
||||
// formatCompact returns a compact status line format.
|
||||
func formatCompact(stats *WorkerStats, useColors bool) string {
|
||||
// [m] ● 42/5/3 (28)
|
||||
// [m] ● 42/5/3
|
||||
var prefix, indicator string
|
||||
if useColors {
|
||||
prefix = colorCyan + "[m]" + colorReset
|
||||
@@ -244,31 +227,16 @@ func formatCompact(stats *WorkerStats, useColors bool) string {
|
||||
indicator = "●"
|
||||
}
|
||||
|
||||
result := fmt.Sprintf("%s %s %d/%d/%d",
|
||||
return fmt.Sprintf("%s %s %d/%d/%d",
|
||||
prefix, indicator,
|
||||
stats.Retrieval.ObservationsServed,
|
||||
stats.Retrieval.ContextInjections,
|
||||
stats.Retrieval.SearchRequests,
|
||||
)
|
||||
|
||||
if stats.ProjectObservations > 0 {
|
||||
result += fmt.Sprintf(" (%d)", stats.ProjectObservations)
|
||||
}
|
||||
|
||||
if stats.IsProcessing || stats.QueueDepth > 0 {
|
||||
if useColors {
|
||||
result += " " + colorYellow + "⚙" + colorReset
|
||||
} else {
|
||||
result += " ⚙"
|
||||
}
|
||||
}
|
||||
|
||||
return result
|
||||
stats.Retrieval.SearchRequests)
|
||||
}
|
||||
|
||||
// formatMinimal returns a minimal status line.
|
||||
// formatMinimal returns a minimal status line format.
|
||||
func formatMinimal(stats *WorkerStats, useColors bool) string {
|
||||
// ● 42 obs
|
||||
// ● 28 memories
|
||||
var indicator string
|
||||
if useColors {
|
||||
indicator = colorGreen + "●" + colorReset
|
||||
@@ -276,32 +244,31 @@ func formatMinimal(stats *WorkerStats, useColors bool) string {
|
||||
indicator = "●"
|
||||
}
|
||||
|
||||
result := fmt.Sprintf("%s %d", indicator, stats.Retrieval.ObservationsServed)
|
||||
|
||||
if stats.ProjectObservations > 0 {
|
||||
result += fmt.Sprintf("/%d", stats.ProjectObservations)
|
||||
return fmt.Sprintf("%s %d memories", indicator, stats.ProjectObservations)
|
||||
}
|
||||
|
||||
return result
|
||||
return fmt.Sprintf("%s mnemonic ready", indicator)
|
||||
}
|
||||
|
||||
// formatOffline returns the offline status.
|
||||
// formatOffline returns status for when worker is offline.
|
||||
func formatOffline() string {
|
||||
return formatOfflineColored(true)
|
||||
useColors := os.Getenv("NO_COLOR") == "" && os.Getenv("TERM") != "dumb"
|
||||
return formatOfflineColored(useColors)
|
||||
}
|
||||
|
||||
// formatOfflineColored returns the offline status with optional colors.
|
||||
// formatOfflineColored returns colored offline status.
|
||||
func formatOfflineColored(useColors bool) string {
|
||||
if useColors {
|
||||
return colorCyan + "[mnemonic]" + colorReset + " " + colorGray + "○" + colorReset
|
||||
return colorGray + "[mnemonic]" + colorReset + " " + colorGray + "○" + colorReset + " offline"
|
||||
}
|
||||
return "[mnemonic] ○"
|
||||
return "[mnemonic] ○ offline"
|
||||
}
|
||||
|
||||
// formatStartingColored returns the starting status with optional colors.
|
||||
// formatStartingColored returns colored starting status.
|
||||
func formatStartingColored(useColors bool) string {
|
||||
if useColors {
|
||||
return colorCyan + "[mnemonic]" + colorReset + " " + colorYellow + "◐" + colorReset + " starting"
|
||||
return colorYellow + "[mnemonic]" + colorReset + " " + colorYellow + "○" + colorReset + " starting..."
|
||||
}
|
||||
return "[mnemonic] ◐ starting"
|
||||
return "[mnemonic] ○ starting..."
|
||||
}
|
||||
|
||||
@@ -2,9 +2,7 @@
|
||||
package main
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"io"
|
||||
"net/url"
|
||||
"os"
|
||||
|
||||
@@ -13,53 +11,25 @@ import (
|
||||
|
||||
// Input is the hook input from Claude Code.
|
||||
type Input struct {
|
||||
SessionID string `json:"session_id"`
|
||||
CWD string `json:"cwd"`
|
||||
PermissionMode string `json:"permission_mode"`
|
||||
HookEventName string `json:"hook_event_name"`
|
||||
Prompt string `json:"prompt"`
|
||||
hooks.BaseInput
|
||||
Prompt string `json:"prompt"`
|
||||
}
|
||||
|
||||
func main() {
|
||||
// Skip if this is an internal call (from SDK processor)
|
||||
if os.Getenv("CLAUDE_MNEMONIC_INTERNAL") == "1" {
|
||||
hooks.WriteResponse("UserPromptSubmit", true)
|
||||
return
|
||||
}
|
||||
|
||||
// Read input from stdin
|
||||
inputData, err := io.ReadAll(os.Stdin)
|
||||
if err != nil {
|
||||
hooks.WriteError("UserPromptSubmit", err)
|
||||
os.Exit(1)
|
||||
}
|
||||
|
||||
var input Input
|
||||
if err := json.Unmarshal(inputData, &input); err != nil {
|
||||
hooks.WriteError("UserPromptSubmit", err)
|
||||
os.Exit(1)
|
||||
}
|
||||
|
||||
// Ensure worker is running
|
||||
port, err := hooks.EnsureWorkerRunning()
|
||||
if err != nil {
|
||||
hooks.WriteError("UserPromptSubmit", err)
|
||||
os.Exit(1)
|
||||
}
|
||||
|
||||
// Generate unique project ID from CWD
|
||||
project := hooks.ProjectIDWithName(input.CWD)
|
||||
hooks.RunHook("UserPromptSubmit", handleUserPrompt)
|
||||
}
|
||||
|
||||
func handleUserPrompt(ctx *hooks.HookContext, input *Input) (string, error) {
|
||||
// Search for relevant observations based on the prompt
|
||||
searchURL := fmt.Sprintf("/api/context/search?project=%s&query=%s&cwd=%s",
|
||||
url.QueryEscape(project),
|
||||
url.QueryEscape(ctx.Project),
|
||||
url.QueryEscape(input.Prompt),
|
||||
url.QueryEscape(input.CWD))
|
||||
url.QueryEscape(ctx.CWD))
|
||||
|
||||
var contextToInject string
|
||||
var observationCount int
|
||||
|
||||
searchResult, _ := hooks.GET(port, searchURL)
|
||||
searchResult, _ := hooks.GET(ctx.Port, searchURL)
|
||||
if observations, ok := searchResult["observations"].([]interface{}); ok && len(observations) > 0 {
|
||||
// Results are already filtered by relevance threshold and capped by max_results
|
||||
// from the server-side config (ContextRelevanceThreshold, ContextMaxPromptResults)
|
||||
@@ -104,27 +74,24 @@ func main() {
|
||||
}
|
||||
|
||||
contextBuilder += "</relevant-memory>\n"
|
||||
|
||||
contextToInject = contextBuilder
|
||||
}
|
||||
|
||||
// Initialize session with matched observations count
|
||||
result, err := hooks.POST(port, "/api/sessions/init", map[string]interface{}{
|
||||
"claudeSessionId": input.SessionID,
|
||||
"project": project,
|
||||
result, err := hooks.POST(ctx.Port, "/api/sessions/init", map[string]interface{}{
|
||||
"claudeSessionId": ctx.SessionID,
|
||||
"project": ctx.Project,
|
||||
"prompt": input.Prompt,
|
||||
"matchedObservations": observationCount,
|
||||
})
|
||||
if err != nil {
|
||||
hooks.WriteError("UserPromptSubmit", err)
|
||||
os.Exit(1)
|
||||
return "", err
|
||||
}
|
||||
|
||||
// Check if skipped due to privacy
|
||||
if skipped, ok := result["skipped"].(bool); ok && skipped {
|
||||
fmt.Fprintf(os.Stderr, "[user-prompt] Session skipped (private)\n")
|
||||
hooks.WriteResponse("UserPromptSubmit", true)
|
||||
return
|
||||
return "", nil
|
||||
}
|
||||
|
||||
sessionID := int64(result["sessionDbId"].(float64))
|
||||
@@ -133,30 +100,20 @@ func main() {
|
||||
fmt.Fprintf(os.Stderr, "[user-prompt] Session %d, prompt #%d\n", sessionID, promptNumber)
|
||||
|
||||
// Start SDK agent
|
||||
_, err = hooks.POST(port, fmt.Sprintf("/sessions/%d/init", sessionID), map[string]interface{}{
|
||||
_, err = hooks.POST(ctx.Port, fmt.Sprintf("/sessions/%d/init", sessionID), map[string]interface{}{
|
||||
"userPrompt": input.Prompt,
|
||||
"promptNumber": promptNumber,
|
||||
})
|
||||
if err != nil {
|
||||
hooks.WriteError("UserPromptSubmit", err)
|
||||
os.Exit(1)
|
||||
return "", err
|
||||
}
|
||||
|
||||
// Output results - stdout with exit 0 adds context to Claude's prompt
|
||||
// Return context if we found relevant observations
|
||||
if observationCount > 0 {
|
||||
// Show match count to user via stderr
|
||||
fmt.Fprintf(os.Stderr, "[claude-mnemonic] Found %d relevant memories for this prompt\n", observationCount)
|
||||
// Output context as JSON with additionalContext field
|
||||
response := map[string]interface{}{
|
||||
"continue": true,
|
||||
"hookSpecificOutput": map[string]interface{}{
|
||||
"hookEventName": "UserPromptSubmit",
|
||||
"additionalContext": contextToInject,
|
||||
},
|
||||
}
|
||||
_ = json.NewEncoder(os.Stdout).Encode(response)
|
||||
os.Exit(0)
|
||||
} else {
|
||||
hooks.WriteResponse("UserPromptSubmit", true)
|
||||
return contextToInject, nil
|
||||
}
|
||||
|
||||
return "", nil
|
||||
}
|
||||
|
||||
+34
-12
@@ -10,12 +10,14 @@ import (
|
||||
"time"
|
||||
|
||||
"github.com/lukaszraczylo/claude-mnemonic/internal/config"
|
||||
"github.com/lukaszraczylo/claude-mnemonic/internal/db/sqlite"
|
||||
"github.com/lukaszraczylo/claude-mnemonic/internal/db/gorm"
|
||||
"github.com/lukaszraczylo/claude-mnemonic/internal/embedding"
|
||||
"github.com/lukaszraczylo/claude-mnemonic/internal/mcp"
|
||||
"github.com/lukaszraczylo/claude-mnemonic/internal/scoring"
|
||||
"github.com/lukaszraczylo/claude-mnemonic/internal/search"
|
||||
"github.com/lukaszraczylo/claude-mnemonic/internal/vector/sqlitevec"
|
||||
"github.com/lukaszraczylo/claude-mnemonic/internal/watcher"
|
||||
"github.com/lukaszraczylo/claude-mnemonic/pkg/models"
|
||||
"github.com/rs/zerolog"
|
||||
"github.com/rs/zerolog/log"
|
||||
)
|
||||
@@ -71,22 +73,25 @@ func main() {
|
||||
cancel()
|
||||
}()
|
||||
|
||||
// Initialize SQLite store (migrations run automatically)
|
||||
storeCfg := sqlite.StoreConfig{
|
||||
// Initialize database store (migrations run automatically)
|
||||
storeCfg := gorm.Config{
|
||||
Path: dbPath,
|
||||
MaxConns: cfg.MaxConns,
|
||||
WALMode: true,
|
||||
// WALMode is enabled automatically by GORM
|
||||
}
|
||||
store, err := sqlite.NewStore(storeCfg)
|
||||
store, err := gorm.NewStore(storeCfg)
|
||||
if err != nil {
|
||||
log.Fatal().Err(err).Msg("Failed to initialize SQLite store")
|
||||
log.Fatal().Err(err).Msg("Failed to initialize database store")
|
||||
}
|
||||
defer store.Close()
|
||||
|
||||
// Initialize stores
|
||||
observationStore := sqlite.NewObservationStore(store)
|
||||
summaryStore := sqlite.NewSummaryStore(store)
|
||||
promptStore := sqlite.NewPromptStore(store)
|
||||
observationStore := gorm.NewObservationStore(store, nil, nil, nil)
|
||||
summaryStore := gorm.NewSummaryStore(store)
|
||||
promptStore := gorm.NewPromptStore(store, nil)
|
||||
patternStore := gorm.NewPatternStore(store)
|
||||
relationStore := gorm.NewRelationStore(store)
|
||||
sessionStore := gorm.NewSessionStore(store)
|
||||
|
||||
// Initialize embedding service and vector client
|
||||
var vectorClient *sqlitevec.Client
|
||||
@@ -95,7 +100,7 @@ func main() {
|
||||
log.Warn().Err(err).Msg("Embedding service unavailable, vector search disabled")
|
||||
} else {
|
||||
defer embedSvc.Close()
|
||||
vectorClient, err = sqlitevec.NewClient(sqlitevec.Config{DB: store.DB()}, embedSvc)
|
||||
vectorClient, err = sqlitevec.NewClient(sqlitevec.Config{DB: store.GetRawDB()}, embedSvc)
|
||||
if err != nil {
|
||||
log.Warn().Err(err).Msg("Vector client unavailable, vector search disabled")
|
||||
} else {
|
||||
@@ -103,14 +108,31 @@ func main() {
|
||||
}
|
||||
}
|
||||
|
||||
// Initialize scoring components
|
||||
scoreConfig := models.DefaultScoringConfig()
|
||||
scoreCalculator := scoring.NewCalculator(scoreConfig)
|
||||
recalculator := scoring.NewRecalculator(observationStore, scoreCalculator, log.Logger)
|
||||
go recalculator.Start(ctx)
|
||||
defer recalculator.Stop()
|
||||
|
||||
// Initialize search manager
|
||||
searchMgr := search.NewManager(observationStore, summaryStore, promptStore, vectorClient)
|
||||
|
||||
// Start file watchers
|
||||
startWatchers(ctx, dbPath)
|
||||
|
||||
// Create and run MCP server
|
||||
server := mcp.NewServer(searchMgr, Version)
|
||||
// Create and run MCP server with all dependencies
|
||||
server := mcp.NewServer(
|
||||
searchMgr,
|
||||
Version,
|
||||
observationStore,
|
||||
patternStore,
|
||||
relationStore,
|
||||
sessionStore,
|
||||
vectorClient,
|
||||
scoreCalculator,
|
||||
recalculator,
|
||||
)
|
||||
log.Info().Str("project", *project).Str("version", Version).Msg("Starting MCP server")
|
||||
|
||||
if err := server.Run(ctx); err != nil {
|
||||
|
||||
@@ -14,11 +14,16 @@ require (
|
||||
github.com/stretchr/testify v1.11.1
|
||||
github.com/sugarme/tokenizer v0.3.0
|
||||
github.com/yalue/onnxruntime_go v1.25.0
|
||||
gorm.io/driver/sqlite v1.5.7
|
||||
gorm.io/gorm v1.26.1
|
||||
)
|
||||
|
||||
require (
|
||||
github.com/davecgh/go-spew v1.1.1 // indirect
|
||||
github.com/emirpasic/gods v1.18.1 // indirect
|
||||
github.com/go-gormigrate/gormigrate/v2 v2.1.5 // indirect
|
||||
github.com/jinzhu/inflection v1.0.0 // indirect
|
||||
github.com/jinzhu/now v1.1.5 // indirect
|
||||
github.com/mattn/go-colorable v0.1.14 // indirect
|
||||
github.com/mattn/go-isatty v0.0.20 // indirect
|
||||
github.com/mitchellh/colorstring v0.0.0-20190213212951-d06e56a500db // indirect
|
||||
|
||||
@@ -12,9 +12,15 @@ github.com/fsnotify/fsnotify v1.9.0 h1:2Ml+OJNzbYCTzsxtv8vKSFD9PbJjmhYF14k/jKC7S
|
||||
github.com/fsnotify/fsnotify v1.9.0/go.mod h1:8jBTzvmWwFyi3Pb8djgCCO5IBqzKJ/Jwo8TRcHyHii0=
|
||||
github.com/go-chi/chi/v5 v5.2.3 h1:WQIt9uxdsAbgIYgid+BpYc+liqQZGMHRaUwp0JUcvdE=
|
||||
github.com/go-chi/chi/v5 v5.2.3/go.mod h1:L2yAIGWB3H+phAw1NxKwWM+7eUH/lU8pOMm5hHcoops=
|
||||
github.com/go-gormigrate/gormigrate/v2 v2.1.5 h1:1OyorA5LtdQw12cyJDEHuTrEV3GiXiIhS4/QTTa/SM8=
|
||||
github.com/go-gormigrate/gormigrate/v2 v2.1.5/go.mod h1:mj9ekk/7CPF3VjopaFvWKN2v7fN3D9d3eEOAXRhi/+M=
|
||||
github.com/goccy/go-json v0.10.5 h1:Fq85nIqj+gXn/S5ahsiTlK3TmC85qgirsdTP/+DeaC4=
|
||||
github.com/goccy/go-json v0.10.5/go.mod h1:oq7eo15ShAhp70Anwd5lgX2pLfOS3QCiwU/PULtXL6M=
|
||||
github.com/godbus/dbus/v5 v5.0.4/go.mod h1:xhWf0FNVPg57R7Z0UbKHbJfkEywrmjJnf7w5xrFpKfA=
|
||||
github.com/jinzhu/inflection v1.0.0 h1:K317FqzuhWc8YvSVlFMCCUb36O/S9MCKRDI7QkRKD/E=
|
||||
github.com/jinzhu/inflection v1.0.0/go.mod h1:h+uFLlag+Qp1Va5pdKtLDYj+kHp5pxUVkryuEj+Srlc=
|
||||
github.com/jinzhu/now v1.1.5 h1:/o9tlHleP7gOFmsnYNz3RGnqzefHA47wQpKrrdTIwXQ=
|
||||
github.com/jinzhu/now v1.1.5/go.mod h1:d3SSVoowX0Lcu0IBviAWJpolVfI5UJVZZ7cO71lE/z8=
|
||||
github.com/mattn/go-colorable v0.1.13/go.mod h1:7S9/ev0klgBDR4GtXTXX8a3vIGJpMovkB8vQcUbaXHg=
|
||||
github.com/mattn/go-colorable v0.1.14 h1:9A9LHSqF/7dyVVX6g0U9cwm9pG3kP9gSzcuIPHPsaIE=
|
||||
github.com/mattn/go-colorable v0.1.14/go.mod h1:6LmQG8QLFO4G5z1gPvYEzlUgJ2wF+stgPZH1UqBm1s8=
|
||||
@@ -57,3 +63,9 @@ gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405 h1:yhCVgyC4o1eVCa2tZl7eS0r+
|
||||
gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0=
|
||||
gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA=
|
||||
gopkg.in/yaml.v3 v3.0.1/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM=
|
||||
gorm.io/driver/sqlite v1.5.7 h1:8NvsrhP0ifM7LX9G4zPB97NwovUakUxc+2V2uuf3Z1I=
|
||||
gorm.io/driver/sqlite v1.5.7/go.mod h1:U+J8craQU6Fzkcvu8oLeAQmi50TkwPEhHDEjQZXDah4=
|
||||
gorm.io/gorm v1.25.12 h1:I0u8i2hWQItBq1WfE0o2+WuL9+8L21K9e2HHSTE/0f8=
|
||||
gorm.io/gorm v1.25.12/go.mod h1:xh7N7RHfYlNc5EmcI/El95gXusucDrQnHXe0+CgWcLQ=
|
||||
gorm.io/gorm v1.26.1 h1:ghB2gUI9FkS46luZtn6DLZ0f6ooBJ5IbVej2ENFDjRw=
|
||||
gorm.io/gorm v1.26.1/go.mod h1:8Z33v652h4//uMA76KjeDH8mJXPm1QNCYrMeatR0DOE=
|
||||
|
||||
@@ -0,0 +1,331 @@
|
||||
//go:build fts5
|
||||
|
||||
// Package gorm provides GORM-based database operations for claude-mnemonic.
|
||||
package gorm
|
||||
|
||||
import (
|
||||
"context"
|
||||
"database/sql"
|
||||
"fmt"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"testing"
|
||||
|
||||
"gorm.io/gorm/logger"
|
||||
|
||||
"github.com/lukaszraczylo/claude-mnemonic/pkg/models"
|
||||
)
|
||||
|
||||
// setupBenchStore creates a temporary store for benchmarking.
|
||||
func setupBenchStore(b *testing.B) (*Store, func()) {
|
||||
b.Helper()
|
||||
|
||||
tmpDir, err := os.MkdirTemp("", "gorm_bench_*")
|
||||
if err != nil {
|
||||
b.Fatalf("create temp dir: %v", err)
|
||||
}
|
||||
|
||||
dbPath := filepath.Join(tmpDir, "bench.db")
|
||||
cfg := Config{
|
||||
Path: dbPath,
|
||||
MaxConns: 4,
|
||||
LogLevel: logger.Silent,
|
||||
}
|
||||
|
||||
store, err := NewStore(cfg)
|
||||
if err != nil {
|
||||
os.RemoveAll(tmpDir)
|
||||
b.Fatalf("NewStore failed: %v", err)
|
||||
}
|
||||
|
||||
cleanup := func() {
|
||||
store.Close()
|
||||
os.RemoveAll(tmpDir)
|
||||
}
|
||||
|
||||
return store, cleanup
|
||||
}
|
||||
|
||||
// BenchmarkSessionStore_CreateSDKSession benchmarks session creation (most frequent operation).
|
||||
func BenchmarkSessionStore_CreateSDKSession(b *testing.B) {
|
||||
store, cleanup := setupBenchStore(b)
|
||||
defer cleanup()
|
||||
|
||||
sessionStore := NewSessionStore(store)
|
||||
ctx := context.Background()
|
||||
|
||||
b.ResetTimer()
|
||||
for i := 0; i < b.N; i++ {
|
||||
sessionID := fmt.Sprintf("claude-bench-%d", i)
|
||||
_, err := sessionStore.CreateSDKSession(ctx, sessionID, "bench-project", "test prompt")
|
||||
if err != nil {
|
||||
b.Fatalf("CreateSDKSession failed: %v", err)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// BenchmarkSessionStore_CreateSDKSession_Idempotent benchmarks idempotent session creation (INSERT OR IGNORE).
|
||||
func BenchmarkSessionStore_CreateSDKSession_Idempotent(b *testing.B) {
|
||||
store, cleanup := setupBenchStore(b)
|
||||
defer cleanup()
|
||||
|
||||
sessionStore := NewSessionStore(store)
|
||||
ctx := context.Background()
|
||||
|
||||
// Pre-create session
|
||||
sessionStore.CreateSDKSession(ctx, "claude-bench", "bench-project", "test prompt")
|
||||
|
||||
b.ResetTimer()
|
||||
for i := 0; i < b.N; i++ {
|
||||
_, err := sessionStore.CreateSDKSession(ctx, "claude-bench", "bench-project", "updated prompt")
|
||||
if err != nil {
|
||||
b.Fatalf("CreateSDKSession failed: %v", err)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// BenchmarkObservationStore_StoreObservation benchmarks observation storage (high frequency).
|
||||
func BenchmarkObservationStore_StoreObservation(b *testing.B) {
|
||||
store, cleanup := setupBenchStore(b)
|
||||
defer cleanup()
|
||||
|
||||
sessionStore := NewSessionStore(store)
|
||||
obsStore := NewObservationStore(store, nil, nil, nil)
|
||||
ctx := context.Background()
|
||||
|
||||
// Create session
|
||||
sessionID, _ := sessionStore.CreateSDKSession(ctx, "claude-bench", "bench-project", "")
|
||||
|
||||
b.ResetTimer()
|
||||
for i := 0; i < b.N; i++ {
|
||||
obs := &models.ParsedObservation{
|
||||
Type: models.ObsTypeDiscovery,
|
||||
Title: fmt.Sprintf("Observation %d", i),
|
||||
Narrative: "Benchmark observation content",
|
||||
}
|
||||
_, _, err := obsStore.StoreObservation(ctx, "claude-bench", "bench-project", obs, int(sessionID), int64(i+1))
|
||||
if err != nil {
|
||||
b.Fatalf("StoreObservation failed: %v", err)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// BenchmarkObservationStore_GetRecentObservations benchmarks recent observation retrieval.
|
||||
func BenchmarkObservationStore_GetRecentObservations(b *testing.B) {
|
||||
store, cleanup := setupBenchStore(b)
|
||||
defer cleanup()
|
||||
|
||||
sessionStore := NewSessionStore(store)
|
||||
obsStore := NewObservationStore(store, nil, nil, nil)
|
||||
ctx := context.Background()
|
||||
|
||||
// Create session and observations
|
||||
sessionID, _ := sessionStore.CreateSDKSession(ctx, "claude-bench", "bench-project", "")
|
||||
for i := 0; i < 100; i++ {
|
||||
obs := &models.ParsedObservation{
|
||||
Type: models.ObsTypeDiscovery,
|
||||
Title: fmt.Sprintf("Observation %d", i),
|
||||
Narrative: "Benchmark observation content",
|
||||
}
|
||||
obsStore.StoreObservation(ctx, "claude-bench", "bench-project", obs, int(sessionID), int64(i+1))
|
||||
}
|
||||
|
||||
b.ResetTimer()
|
||||
for i := 0; i < b.N; i++ {
|
||||
_, err := obsStore.GetRecentObservations(ctx, "bench-project", 20)
|
||||
if err != nil {
|
||||
b.Fatalf("GetRecentObservations failed: %v", err)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// BenchmarkObservationStore_SearchObservationsFTS benchmarks FTS5 search (latency-sensitive).
|
||||
func BenchmarkObservationStore_SearchObservationsFTS(b *testing.B) {
|
||||
store, cleanup := setupBenchStore(b)
|
||||
defer cleanup()
|
||||
|
||||
sessionStore := NewSessionStore(store)
|
||||
obsStore := NewObservationStore(store, nil, nil, nil)
|
||||
ctx := context.Background()
|
||||
|
||||
// Create session and observations with searchable content
|
||||
sessionID, _ := sessionStore.CreateSDKSession(ctx, "claude-bench", "bench-project", "")
|
||||
for i := 0; i < 100; i++ {
|
||||
obs := &models.ParsedObservation{
|
||||
Type: models.ObsTypeDiscovery,
|
||||
Title: fmt.Sprintf("Security best practice %d", i),
|
||||
Narrative: "This observation discusses security patterns and authentication mechanisms",
|
||||
}
|
||||
obsStore.StoreObservation(ctx, "claude-bench", "bench-project", obs, int(sessionID), int64(i+1))
|
||||
}
|
||||
|
||||
b.ResetTimer()
|
||||
for i := 0; i < b.N; i++ {
|
||||
_, err := obsStore.SearchObservationsFTS(ctx, "security authentication", "bench-project", 10)
|
||||
if err != nil {
|
||||
b.Fatalf("SearchObservationsFTS failed: %v", err)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// BenchmarkObservationStore_UpdateImportanceScore benchmarks scoring updates.
|
||||
func BenchmarkObservationStore_UpdateImportanceScore(b *testing.B) {
|
||||
store, cleanup := setupBenchStore(b)
|
||||
defer cleanup()
|
||||
|
||||
sessionStore := NewSessionStore(store)
|
||||
obsStore := NewObservationStore(store, nil, nil, nil)
|
||||
ctx := context.Background()
|
||||
|
||||
// Create session and observation
|
||||
sessionID, _ := sessionStore.CreateSDKSession(ctx, "claude-bench", "bench-project", "")
|
||||
obs := &models.ParsedObservation{Type: models.ObsTypeDiscovery, Title: "Test"}
|
||||
obsID, _, _ := obsStore.StoreObservation(ctx, "claude-bench", "bench-project", obs, int(sessionID), 1)
|
||||
|
||||
b.ResetTimer()
|
||||
for i := 0; i < b.N; i++ {
|
||||
score := float64(i%10) + 1.0
|
||||
err := obsStore.UpdateImportanceScore(ctx, obsID, score)
|
||||
if err != nil {
|
||||
b.Fatalf("UpdateImportanceScore failed: %v", err)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// BenchmarkObservationStore_UpdateImportanceScores_Bulk benchmarks bulk scoring updates.
|
||||
func BenchmarkObservationStore_UpdateImportanceScores_Bulk(b *testing.B) {
|
||||
store, cleanup := setupBenchStore(b)
|
||||
defer cleanup()
|
||||
|
||||
sessionStore := NewSessionStore(store)
|
||||
obsStore := NewObservationStore(store, nil, nil, nil)
|
||||
ctx := context.Background()
|
||||
|
||||
// Create session and 100 observations
|
||||
sessionID, _ := sessionStore.CreateSDKSession(ctx, "claude-bench", "bench-project", "")
|
||||
var obsIDs []int64
|
||||
for i := 0; i < 100; i++ {
|
||||
obs := &models.ParsedObservation{Type: models.ObsTypeDiscovery, Title: fmt.Sprintf("Obs %d", i)}
|
||||
obsID, _, _ := obsStore.StoreObservation(ctx, "claude-bench", "bench-project", obs, int(sessionID), int64(i+1))
|
||||
obsIDs = append(obsIDs, obsID)
|
||||
}
|
||||
|
||||
b.ResetTimer()
|
||||
for i := 0; i < b.N; i++ {
|
||||
scores := make(map[int64]float64)
|
||||
for _, id := range obsIDs {
|
||||
scores[id] = float64(i%10) + 1.0
|
||||
}
|
||||
err := obsStore.UpdateImportanceScores(ctx, scores)
|
||||
if err != nil {
|
||||
b.Fatalf("UpdateImportanceScores failed: %v", err)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// BenchmarkPromptStore_SaveUserPromptWithMatches benchmarks prompt storage with matches.
|
||||
func BenchmarkPromptStore_SaveUserPromptWithMatches(b *testing.B) {
|
||||
store, cleanup := setupBenchStore(b)
|
||||
defer cleanup()
|
||||
|
||||
sessionStore := NewSessionStore(store)
|
||||
promptStore := NewPromptStore(store, nil)
|
||||
ctx := context.Background()
|
||||
|
||||
// Create session
|
||||
sessionID, _ := sessionStore.CreateSDKSession(ctx, "claude-bench", "bench-project", "")
|
||||
|
||||
b.ResetTimer()
|
||||
for i := 0; i < b.N; i++ {
|
||||
_, err := promptStore.SaveUserPromptWithMatches(ctx, "claude-bench", int(sessionID), fmt.Sprintf("Prompt %d", i), i+1)
|
||||
if err != nil {
|
||||
b.Fatalf("SaveUserPromptWithMatches failed: %v", err)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// BenchmarkSummaryStore_StoreSummary benchmarks summary storage.
|
||||
func BenchmarkSummaryStore_StoreSummary(b *testing.B) {
|
||||
store, cleanup := setupBenchStore(b)
|
||||
defer cleanup()
|
||||
|
||||
sessionStore := NewSessionStore(store)
|
||||
summaryStore := NewSummaryStore(store)
|
||||
ctx := context.Background()
|
||||
|
||||
// Create session
|
||||
sessionStore.CreateSDKSession(ctx, "claude-bench", "bench-project", "")
|
||||
|
||||
b.ResetTimer()
|
||||
for i := 0; i < b.N; i++ {
|
||||
summary := &models.ParsedSummary{
|
||||
Request: fmt.Sprintf("Request %d", i),
|
||||
Investigated: "Investigation details",
|
||||
Learned: "Learning summary",
|
||||
Completed: "Completion status",
|
||||
}
|
||||
_, _, err := summaryStore.StoreSummary(ctx, "claude-bench", "bench-project", summary, i+1, 100)
|
||||
if err != nil {
|
||||
b.Fatalf("StoreSummary failed: %v", err)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// BenchmarkRelationStore_StoreRelation benchmarks relation storage.
|
||||
func BenchmarkRelationStore_StoreRelation(b *testing.B) {
|
||||
store, cleanup := setupBenchStore(b)
|
||||
defer cleanup()
|
||||
|
||||
sessionStore := NewSessionStore(store)
|
||||
obsStore := NewObservationStore(store, nil, nil, nil)
|
||||
relationStore := NewRelationStore(store)
|
||||
ctx := context.Background()
|
||||
|
||||
// Create session and observations
|
||||
sessionID, _ := sessionStore.CreateSDKSession(ctx, "claude-bench", "bench-project", "")
|
||||
obs1 := &models.ParsedObservation{Type: models.ObsTypeDiscovery, Title: "Source"}
|
||||
obsID1, _, _ := obsStore.StoreObservation(ctx, "claude-bench", "bench-project", obs1, int(sessionID), 1)
|
||||
obs2 := &models.ParsedObservation{Type: models.ObsTypeDiscovery, Title: "Target"}
|
||||
obsID2, _, _ := obsStore.StoreObservation(ctx, "claude-bench", "bench-project", obs2, int(sessionID), 2)
|
||||
|
||||
b.ResetTimer()
|
||||
for i := 0; i < b.N; i++ {
|
||||
relation := &models.ObservationRelation{
|
||||
SourceID: obsID1,
|
||||
TargetID: obsID2,
|
||||
RelationType: models.RelationCauses,
|
||||
Confidence: 0.9,
|
||||
DetectionSource: models.DetectionSourceFileOverlap,
|
||||
}
|
||||
_, err := relationStore.StoreRelation(ctx, relation)
|
||||
if err != nil {
|
||||
b.Fatalf("StoreRelation failed: %v", err)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// BenchmarkPatternStore_StorePattern benchmarks pattern storage.
|
||||
func BenchmarkPatternStore_StorePattern(b *testing.B) {
|
||||
store, cleanup := setupBenchStore(b)
|
||||
defer cleanup()
|
||||
|
||||
patternStore := NewPatternStore(store)
|
||||
ctx := context.Background()
|
||||
|
||||
b.ResetTimer()
|
||||
for i := 0; i < b.N; i++ {
|
||||
pattern := &models.Pattern{
|
||||
Name: fmt.Sprintf("Pattern %d", i),
|
||||
Type: models.PatternTypeBug,
|
||||
Description: sql.NullString{String: "Benchmark pattern", Valid: true},
|
||||
Frequency: 1,
|
||||
Confidence: 0.8,
|
||||
Projects: []string{"bench-project"},
|
||||
Status: models.PatternStatusActive,
|
||||
}
|
||||
_, err := patternStore.StorePattern(ctx, pattern)
|
||||
if err != nil {
|
||||
b.Fatalf("StorePattern failed: %v", err)
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,281 @@
|
||||
// Package gorm provides GORM-based database operations for claude-mnemonic.
|
||||
package gorm
|
||||
|
||||
import (
|
||||
"context"
|
||||
"database/sql"
|
||||
"time"
|
||||
|
||||
"gorm.io/gorm"
|
||||
|
||||
"github.com/lukaszraczylo/claude-mnemonic/pkg/models"
|
||||
)
|
||||
|
||||
// SupersededRetentionDays is the number of days to keep superseded observations before deletion.
|
||||
const SupersededRetentionDays = 3
|
||||
|
||||
// ConflictStore provides conflict-related database operations using GORM.
|
||||
type ConflictStore struct {
|
||||
db *gorm.DB
|
||||
}
|
||||
|
||||
// NewConflictStore creates a new conflict store.
|
||||
func NewConflictStore(store *Store) *ConflictStore {
|
||||
return &ConflictStore{
|
||||
db: store.DB,
|
||||
}
|
||||
}
|
||||
|
||||
// StoreConflict stores a new observation conflict.
|
||||
func (s *ConflictStore) StoreConflict(ctx context.Context, conflict *models.ObservationConflict) (int64, error) {
|
||||
dbConflict := &ObservationConflict{
|
||||
NewerObsID: conflict.NewerObsID,
|
||||
OlderObsID: conflict.OlderObsID,
|
||||
ConflictType: conflict.ConflictType,
|
||||
Resolution: conflict.Resolution,
|
||||
DetectedAt: conflict.DetectedAt,
|
||||
DetectedAtEpoch: conflict.DetectedAtEpoch,
|
||||
Resolved: 0,
|
||||
}
|
||||
|
||||
// Convert bool to int
|
||||
if conflict.Resolved {
|
||||
dbConflict.Resolved = 1
|
||||
}
|
||||
|
||||
// Handle nullable fields
|
||||
if conflict.Reason != "" {
|
||||
dbConflict.Reason = sql.NullString{String: conflict.Reason, Valid: true}
|
||||
}
|
||||
if conflict.ResolvedAt != nil && *conflict.ResolvedAt != "" {
|
||||
dbConflict.ResolvedAt = sql.NullString{String: *conflict.ResolvedAt, Valid: true}
|
||||
}
|
||||
|
||||
result := s.db.WithContext(ctx).Create(dbConflict)
|
||||
if result.Error != nil {
|
||||
return 0, result.Error
|
||||
}
|
||||
|
||||
return dbConflict.ID, nil
|
||||
}
|
||||
|
||||
// MarkObservationSuperseded marks an observation as superseded.
|
||||
func (s *ConflictStore) MarkObservationSuperseded(ctx context.Context, obsID int64) error {
|
||||
result := s.db.WithContext(ctx).
|
||||
Model(&Observation{}).
|
||||
Where("id = ?", obsID).
|
||||
Update("is_superseded", 1)
|
||||
|
||||
return result.Error
|
||||
}
|
||||
|
||||
// MarkObservationsSuperseded marks multiple observations as superseded.
|
||||
func (s *ConflictStore) MarkObservationsSuperseded(ctx context.Context, obsIDs []int64) error {
|
||||
if len(obsIDs) == 0 {
|
||||
return nil
|
||||
}
|
||||
|
||||
result := s.db.WithContext(ctx).
|
||||
Model(&Observation{}).
|
||||
Where("id IN ?", obsIDs).
|
||||
Update("is_superseded", 1)
|
||||
|
||||
return result.Error
|
||||
}
|
||||
|
||||
// GetConflictsByObservationID retrieves all conflicts involving an observation.
|
||||
func (s *ConflictStore) GetConflictsByObservationID(ctx context.Context, obsID int64) ([]*models.ObservationConflict, error) {
|
||||
var conflicts []ObservationConflict
|
||||
|
||||
err := s.db.WithContext(ctx).
|
||||
Where("newer_obs_id = ? OR older_obs_id = ?", obsID, obsID).
|
||||
Order("detected_at_epoch DESC").
|
||||
Find(&conflicts).Error
|
||||
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return toModelConflicts(conflicts), nil
|
||||
}
|
||||
|
||||
// GetUnresolvedConflicts retrieves all unresolved conflicts.
|
||||
func (s *ConflictStore) GetUnresolvedConflicts(ctx context.Context, limit int) ([]*models.ObservationConflict, error) {
|
||||
var conflicts []ObservationConflict
|
||||
|
||||
err := s.db.WithContext(ctx).
|
||||
Where("resolved = 0").
|
||||
Order("detected_at_epoch DESC").
|
||||
Limit(limit).
|
||||
Find(&conflicts).Error
|
||||
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return toModelConflicts(conflicts), nil
|
||||
}
|
||||
|
||||
// GetSupersededObservationIDs returns IDs of all observations that have been superseded.
|
||||
func (s *ConflictStore) GetSupersededObservationIDs(ctx context.Context, project string) ([]int64, error) {
|
||||
var ids []int64
|
||||
|
||||
err := s.db.WithContext(ctx).
|
||||
Table("observation_conflicts oc").
|
||||
Select("DISTINCT oc.older_obs_id").
|
||||
Joins("JOIN observations o ON o.id = oc.older_obs_id").
|
||||
Where("oc.resolution = ?", models.ResolutionPreferNewer).
|
||||
Where("o.project = ? OR o.scope = 'global'", project).
|
||||
Pluck("oc.older_obs_id", &ids).Error
|
||||
|
||||
return ids, err
|
||||
}
|
||||
|
||||
// ResolveConflict marks a conflict as resolved.
|
||||
func (s *ConflictStore) ResolveConflict(ctx context.Context, conflictID int64, resolution models.ConflictResolution) error {
|
||||
now := time.Now().Format(time.RFC3339)
|
||||
|
||||
result := s.db.WithContext(ctx).
|
||||
Model(&ObservationConflict{}).
|
||||
Where("id = ?", conflictID).
|
||||
Updates(map[string]interface{}{
|
||||
"resolved": 1,
|
||||
"resolved_at": now,
|
||||
"resolution": resolution,
|
||||
})
|
||||
|
||||
return result.Error
|
||||
}
|
||||
|
||||
// DeleteConflictsByObservationID deletes all conflicts involving an observation.
|
||||
// Called when an observation is deleted.
|
||||
func (s *ConflictStore) DeleteConflictsByObservationID(ctx context.Context, obsID int64) error {
|
||||
result := s.db.WithContext(ctx).
|
||||
Where("newer_obs_id = ? OR older_obs_id = ?", obsID, obsID).
|
||||
Delete(&ObservationConflict{})
|
||||
|
||||
return result.Error
|
||||
}
|
||||
|
||||
// ConflictWithDetails contains a conflict with its observation details.
|
||||
type ConflictWithDetails struct {
|
||||
Conflict *models.ObservationConflict
|
||||
NewerObsTitle string
|
||||
OlderObsTitle string
|
||||
}
|
||||
|
||||
// CleanupSupersededObservations deletes observations that have been superseded for longer than
|
||||
// SupersededRetentionDays. Returns the IDs of deleted observations for downstream cleanup (e.g., vector DB).
|
||||
func (s *ConflictStore) CleanupSupersededObservations(ctx context.Context, project string) ([]int64, error) {
|
||||
// Calculate cutoff time (3 days ago in milliseconds)
|
||||
cutoffEpoch := time.Now().AddDate(0, 0, -SupersededRetentionDays).UnixMilli()
|
||||
|
||||
var toDelete []int64
|
||||
|
||||
// Use a transaction to prevent TOCTOU race condition
|
||||
err := s.db.WithContext(ctx).Transaction(func(tx *gorm.DB) error {
|
||||
// Find IDs to delete
|
||||
err := tx.Table("observations o").
|
||||
Select("DISTINCT o.id").
|
||||
Joins("JOIN observation_conflicts oc ON o.id = oc.older_obs_id").
|
||||
Where("o.is_superseded = 1").
|
||||
Where("o.project = ?", project).
|
||||
Where("oc.detected_at_epoch < ?", cutoffEpoch).
|
||||
Pluck("o.id", &toDelete).Error
|
||||
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if len(toDelete) == 0 {
|
||||
return nil
|
||||
}
|
||||
|
||||
// Delete the conflict records first (due to foreign key constraints)
|
||||
for _, obsID := range toDelete {
|
||||
err := tx.Where("newer_obs_id = ? OR older_obs_id = ?", obsID, obsID).
|
||||
Delete(&ObservationConflict{}).Error
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
// Delete the observations
|
||||
return tx.Delete(&Observation{}, toDelete).Error
|
||||
})
|
||||
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return toDelete, nil
|
||||
}
|
||||
|
||||
// GetConflictsWithDetails retrieves all conflicts with observation titles for display.
|
||||
func (s *ConflictStore) GetConflictsWithDetails(ctx context.Context, project string, limit int) ([]*ConflictWithDetails, error) {
|
||||
var results []struct {
|
||||
ObservationConflict
|
||||
NewerTitle sql.NullString `gorm:"column:newer_title"`
|
||||
OlderTitle sql.NullString `gorm:"column:older_title"`
|
||||
}
|
||||
|
||||
err := s.db.WithContext(ctx).
|
||||
Table("observation_conflicts oc").
|
||||
Select("oc.*, "+
|
||||
"COALESCE(newer.title, '') as newer_title, "+
|
||||
"COALESCE(older.title, '') as older_title").
|
||||
Joins("JOIN observations newer ON newer.id = oc.newer_obs_id").
|
||||
Joins("JOIN observations older ON older.id = oc.older_obs_id").
|
||||
Where("newer.project = ? OR older.project = ?", project, project).
|
||||
Order("oc.detected_at_epoch DESC").
|
||||
Limit(limit).
|
||||
Scan(&results).Error
|
||||
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
conflicts := make([]*ConflictWithDetails, len(results))
|
||||
for i, r := range results {
|
||||
conflicts[i] = &ConflictWithDetails{
|
||||
Conflict: toModelConflict(&r.ObservationConflict),
|
||||
NewerObsTitle: r.NewerTitle.String,
|
||||
OlderObsTitle: r.OlderTitle.String,
|
||||
}
|
||||
}
|
||||
|
||||
return conflicts, nil
|
||||
}
|
||||
|
||||
// toModelConflict converts a GORM ObservationConflict to a pkg/models ObservationConflict.
|
||||
func toModelConflict(c *ObservationConflict) *models.ObservationConflict {
|
||||
conflict := &models.ObservationConflict{
|
||||
ID: c.ID,
|
||||
NewerObsID: c.NewerObsID,
|
||||
OlderObsID: c.OlderObsID,
|
||||
ConflictType: c.ConflictType,
|
||||
Resolution: c.Resolution,
|
||||
DetectedAt: c.DetectedAt,
|
||||
DetectedAtEpoch: c.DetectedAtEpoch,
|
||||
Resolved: c.Resolved == 1,
|
||||
}
|
||||
|
||||
if c.Reason.Valid {
|
||||
conflict.Reason = c.Reason.String
|
||||
}
|
||||
if c.ResolvedAt.Valid {
|
||||
s := c.ResolvedAt.String
|
||||
conflict.ResolvedAt = &s
|
||||
}
|
||||
|
||||
return conflict
|
||||
}
|
||||
|
||||
// toModelConflicts converts a slice of GORM ObservationConflicts to pkg/models ObservationConflicts.
|
||||
func toModelConflicts(conflicts []ObservationConflict) []*models.ObservationConflict {
|
||||
result := make([]*models.ObservationConflict, len(conflicts))
|
||||
for i, c := range conflicts {
|
||||
result[i] = toModelConflict(&c)
|
||||
}
|
||||
return result
|
||||
}
|
||||
@@ -0,0 +1,637 @@
|
||||
//go:build fts5
|
||||
|
||||
// Package gorm provides GORM-based database operations for claude-mnemonic.
|
||||
package gorm
|
||||
|
||||
import (
|
||||
"context"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
"gorm.io/gorm/logger"
|
||||
|
||||
"github.com/lukaszraczylo/claude-mnemonic/pkg/models"
|
||||
)
|
||||
|
||||
// testConflictStore creates a ConflictStore with a temporary database for testing.
|
||||
func testConflictStore(t *testing.T) (*ConflictStore, *Store, func()) {
|
||||
t.Helper()
|
||||
|
||||
tmpDir, err := os.MkdirTemp("", "gorm_conflict_test_*")
|
||||
if err != nil {
|
||||
t.Fatalf("create temp dir: %v", err)
|
||||
}
|
||||
|
||||
dbPath := filepath.Join(tmpDir, "test.db")
|
||||
cfg := Config{
|
||||
Path: dbPath,
|
||||
MaxConns: 4,
|
||||
LogLevel: logger.Silent,
|
||||
}
|
||||
|
||||
store, err := NewStore(cfg)
|
||||
if err != nil {
|
||||
os.RemoveAll(tmpDir)
|
||||
t.Fatalf("NewStore failed: %v", err)
|
||||
}
|
||||
|
||||
conflictStore := NewConflictStore(store)
|
||||
|
||||
cleanup := func() {
|
||||
store.Close()
|
||||
os.RemoveAll(tmpDir)
|
||||
}
|
||||
|
||||
return conflictStore, store, cleanup
|
||||
}
|
||||
|
||||
func TestConflictStore_StoreConflict(t *testing.T) {
|
||||
conflictStore, store, cleanup := testConflictStore(t)
|
||||
defer cleanup()
|
||||
|
||||
ctx := context.Background()
|
||||
|
||||
// Create session for observations
|
||||
sessionStore := NewSessionStore(store)
|
||||
sessionID, err := sessionStore.CreateSDKSession(ctx, "claude-1", "test-project", "")
|
||||
require.NoError(t, err)
|
||||
|
||||
// Create test observations
|
||||
obsStore := NewObservationStore(store, nil, nil, nil)
|
||||
obs1 := &models.ParsedObservation{
|
||||
Type: models.ObsTypeDiscovery,
|
||||
Title: "Newer observation",
|
||||
}
|
||||
obsID1, _, err := obsStore.StoreObservation(ctx, "claude-1", "test-project", obs1, int(sessionID), 1)
|
||||
require.NoError(t, err)
|
||||
|
||||
obs2 := &models.ParsedObservation{
|
||||
Type: models.ObsTypeDiscovery,
|
||||
Title: "Older observation",
|
||||
}
|
||||
obsID2, _, err := obsStore.StoreObservation(ctx, "claude-1", "test-project", obs2, int(sessionID), 2)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Create conflict
|
||||
now := time.Now()
|
||||
conflict := &models.ObservationConflict{
|
||||
NewerObsID: obsID1,
|
||||
OlderObsID: obsID2,
|
||||
ConflictType: models.ConflictContradicts,
|
||||
Resolution: models.ResolutionPreferNewer,
|
||||
Reason: "Newer observation contradicts older one",
|
||||
DetectedAt: now.Format(time.RFC3339),
|
||||
DetectedAtEpoch: now.UnixMilli(),
|
||||
Resolved: false,
|
||||
}
|
||||
|
||||
id, err := conflictStore.StoreConflict(ctx, conflict)
|
||||
require.NoError(t, err)
|
||||
assert.Greater(t, id, int64(0))
|
||||
|
||||
// Verify conflict was stored
|
||||
var count int64
|
||||
store.DB.Model(&ObservationConflict{}).Where("id = ?", id).Count(&count)
|
||||
assert.Equal(t, int64(1), count)
|
||||
}
|
||||
|
||||
func TestConflictStore_MarkObservationSuperseded(t *testing.T) {
|
||||
conflictStore, store, cleanup := testConflictStore(t)
|
||||
defer cleanup()
|
||||
|
||||
ctx := context.Background()
|
||||
|
||||
// Create observation
|
||||
sessionStore := NewSessionStore(store)
|
||||
sessionID, err := sessionStore.CreateSDKSession(ctx, "claude-1", "test-project", "")
|
||||
require.NoError(t, err)
|
||||
|
||||
obsStore := NewObservationStore(store, nil, nil, nil)
|
||||
obs := &models.ParsedObservation{
|
||||
Type: models.ObsTypeDiscovery,
|
||||
Title: "Test observation",
|
||||
}
|
||||
obsID, _, err := obsStore.StoreObservation(ctx, "claude-1", "test-project", obs, int(sessionID), 1)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Mark as superseded
|
||||
err = conflictStore.MarkObservationSuperseded(ctx, obsID)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Verify it's marked
|
||||
var dbObs Observation
|
||||
store.DB.First(&dbObs, obsID)
|
||||
assert.Equal(t, 1, dbObs.IsSuperseded)
|
||||
}
|
||||
|
||||
func TestConflictStore_MarkObservationsSuperseded(t *testing.T) {
|
||||
conflictStore, store, cleanup := testConflictStore(t)
|
||||
defer cleanup()
|
||||
|
||||
ctx := context.Background()
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
obsIDs []int64
|
||||
setup func() []int64
|
||||
}{
|
||||
{
|
||||
name: "empty list",
|
||||
obsIDs: []int64{},
|
||||
setup: func() []int64 { return []int64{} },
|
||||
},
|
||||
{
|
||||
name: "single observation",
|
||||
setup: func() []int64 {
|
||||
sessionStore := NewSessionStore(store)
|
||||
sessionID, _ := sessionStore.CreateSDKSession(ctx, "claude-1", "test-project", "")
|
||||
obsStore := NewObservationStore(store, nil, nil, nil)
|
||||
obs := &models.ParsedObservation{
|
||||
Type: models.ObsTypeDiscovery,
|
||||
Title: "Test",
|
||||
}
|
||||
id, _, _ := obsStore.StoreObservation(ctx, "claude-1", "test-project", obs, int(sessionID), 1)
|
||||
return []int64{id}
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "multiple observations",
|
||||
setup: func() []int64 {
|
||||
sessionStore := NewSessionStore(store)
|
||||
sessionID, _ := sessionStore.CreateSDKSession(ctx, "claude-1", "test-project", "")
|
||||
obsStore := NewObservationStore(store, nil, nil, nil)
|
||||
var ids []int64
|
||||
for i := 0; i < 3; i++ {
|
||||
obs := &models.ParsedObservation{
|
||||
Type: models.ObsTypeDiscovery,
|
||||
Title: "Test",
|
||||
}
|
||||
id, _, _ := obsStore.StoreObservation(ctx, "claude-1", "test-project", obs, int(sessionID), int64(i+1))
|
||||
ids = append(ids, id)
|
||||
}
|
||||
return ids
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
obsIDs := tt.setup()
|
||||
err := conflictStore.MarkObservationsSuperseded(ctx, obsIDs)
|
||||
require.NoError(t, err)
|
||||
|
||||
if len(obsIDs) > 0 {
|
||||
// Verify all are marked
|
||||
for _, id := range obsIDs {
|
||||
var dbObs Observation
|
||||
store.DB.First(&dbObs, id)
|
||||
assert.Equal(t, 1, dbObs.IsSuperseded)
|
||||
}
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestConflictStore_GetConflictsByObservationID(t *testing.T) {
|
||||
conflictStore, store, cleanup := testConflictStore(t)
|
||||
defer cleanup()
|
||||
|
||||
ctx := context.Background()
|
||||
|
||||
// Create observations
|
||||
sessionStore := NewSessionStore(store)
|
||||
sessionID, err := sessionStore.CreateSDKSession(ctx, "claude-1", "test-project", "")
|
||||
require.NoError(t, err)
|
||||
|
||||
obsStore := NewObservationStore(store, nil, nil, nil)
|
||||
var obsIDs []int64
|
||||
for i := 0; i < 3; i++ {
|
||||
obs := &models.ParsedObservation{
|
||||
Type: models.ObsTypeDiscovery,
|
||||
Title: "Test",
|
||||
}
|
||||
id, _, err := obsStore.StoreObservation(ctx, "claude-1", "test-project", obs, int(sessionID), int64(i+1))
|
||||
require.NoError(t, err)
|
||||
obsIDs = append(obsIDs, id)
|
||||
}
|
||||
|
||||
// Create conflicts involving observation 2 (index 1)
|
||||
now := time.Now()
|
||||
conflict1 := &models.ObservationConflict{
|
||||
NewerObsID: obsIDs[0],
|
||||
OlderObsID: obsIDs[1],
|
||||
ConflictType: models.ConflictContradicts,
|
||||
Resolution: models.ResolutionPreferNewer,
|
||||
Reason: "reason1",
|
||||
DetectedAt: now.Format(time.RFC3339),
|
||||
DetectedAtEpoch: now.UnixMilli(),
|
||||
}
|
||||
_, err = conflictStore.StoreConflict(ctx, conflict1)
|
||||
require.NoError(t, err)
|
||||
|
||||
conflict2 := &models.ObservationConflict{
|
||||
NewerObsID: obsIDs[1],
|
||||
OlderObsID: obsIDs[2],
|
||||
ConflictType: models.ConflictSuperseded,
|
||||
Resolution: models.ResolutionPreferNewer,
|
||||
Reason: "reason2",
|
||||
DetectedAt: now.Format(time.RFC3339),
|
||||
DetectedAtEpoch: now.UnixMilli(),
|
||||
}
|
||||
_, err = conflictStore.StoreConflict(ctx, conflict2)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Get conflicts for observation 2 (involved in 2 conflicts)
|
||||
conflicts, err := conflictStore.GetConflictsByObservationID(ctx, obsIDs[1])
|
||||
require.NoError(t, err)
|
||||
assert.Len(t, conflicts, 2)
|
||||
}
|
||||
|
||||
func TestConflictStore_GetUnresolvedConflicts(t *testing.T) {
|
||||
conflictStore, store, cleanup := testConflictStore(t)
|
||||
defer cleanup()
|
||||
|
||||
ctx := context.Background()
|
||||
|
||||
// Create observations
|
||||
sessionStore := NewSessionStore(store)
|
||||
sessionID, err := sessionStore.CreateSDKSession(ctx, "claude-1", "test-project", "")
|
||||
require.NoError(t, err)
|
||||
|
||||
obsStore := NewObservationStore(store, nil, nil, nil)
|
||||
obs1 := &models.ParsedObservation{
|
||||
Type: models.ObsTypeDiscovery,
|
||||
Title: "Test1",
|
||||
}
|
||||
obsID1, _, err := obsStore.StoreObservation(ctx, "claude-1", "test-project", obs1, int(sessionID), 1)
|
||||
require.NoError(t, err)
|
||||
|
||||
obs2 := &models.ParsedObservation{
|
||||
Type: models.ObsTypeDiscovery,
|
||||
Title: "Test2",
|
||||
}
|
||||
obsID2, _, err := obsStore.StoreObservation(ctx, "claude-1", "test-project", obs2, int(sessionID), 2)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Create unresolved conflicts
|
||||
now := time.Now()
|
||||
for i := 0; i < 5; i++ {
|
||||
conflict := &models.ObservationConflict{
|
||||
NewerObsID: obsID1,
|
||||
OlderObsID: obsID2,
|
||||
ConflictType: models.ConflictContradicts,
|
||||
Resolution: models.ResolutionPreferNewer,
|
||||
Reason: "reason",
|
||||
DetectedAt: now.Format(time.RFC3339),
|
||||
DetectedAtEpoch: now.UnixMilli(),
|
||||
Resolved: false,
|
||||
}
|
||||
_, err = conflictStore.StoreConflict(ctx, conflict)
|
||||
require.NoError(t, err)
|
||||
}
|
||||
|
||||
// Create resolved conflict
|
||||
resolvedAt := now.Format(time.RFC3339)
|
||||
resolvedConflict := &models.ObservationConflict{
|
||||
NewerObsID: obsID1,
|
||||
OlderObsID: obsID2,
|
||||
ConflictType: models.ConflictContradicts,
|
||||
Resolution: models.ResolutionPreferNewer,
|
||||
Reason: "reason",
|
||||
DetectedAt: now.Format(time.RFC3339),
|
||||
DetectedAtEpoch: now.UnixMilli(),
|
||||
Resolved: true,
|
||||
ResolvedAt: &resolvedAt,
|
||||
}
|
||||
_, err = conflictStore.StoreConflict(ctx, resolvedConflict)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Get unresolved conflicts with limit
|
||||
conflicts, err := conflictStore.GetUnresolvedConflicts(ctx, 3)
|
||||
require.NoError(t, err)
|
||||
assert.Len(t, conflicts, 3)
|
||||
|
||||
// Verify all are unresolved
|
||||
for _, c := range conflicts {
|
||||
assert.False(t, c.Resolved)
|
||||
}
|
||||
}
|
||||
|
||||
func TestConflictStore_GetSupersededObservationIDs(t *testing.T) {
|
||||
conflictStore, store, cleanup := testConflictStore(t)
|
||||
defer cleanup()
|
||||
|
||||
ctx := context.Background()
|
||||
|
||||
// Create observations
|
||||
sessionStore := NewSessionStore(store)
|
||||
sessionID, err := sessionStore.CreateSDKSession(ctx, "claude-1", "test-project", "")
|
||||
require.NoError(t, err)
|
||||
|
||||
obsStore := NewObservationStore(store, nil, nil, nil)
|
||||
|
||||
// Create newer observations
|
||||
newer1 := &models.ParsedObservation{
|
||||
Type: models.ObsTypeDiscovery,
|
||||
Title: "Newer1",
|
||||
}
|
||||
newerID1, _, err := obsStore.StoreObservation(ctx, "claude-1", "test-project", newer1, int(sessionID), 1)
|
||||
require.NoError(t, err)
|
||||
|
||||
newer2 := &models.ParsedObservation{
|
||||
Type: models.ObsTypeDiscovery,
|
||||
Title: "Newer2",
|
||||
}
|
||||
newerID2, _, err := obsStore.StoreObservation(ctx, "claude-1", "test-project", newer2, int(sessionID), 2)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Create older observations
|
||||
older1 := &models.ParsedObservation{
|
||||
Type: models.ObsTypeDiscovery,
|
||||
Title: "Older1",
|
||||
}
|
||||
olderID1, _, err := obsStore.StoreObservation(ctx, "claude-1", "test-project", older1, int(sessionID), 3)
|
||||
require.NoError(t, err)
|
||||
|
||||
older2 := &models.ParsedObservation{
|
||||
Type: models.ObsTypeDiscovery,
|
||||
Title: "Older2",
|
||||
}
|
||||
olderID2, _, err := obsStore.StoreObservation(ctx, "claude-1", "test-project", older2, int(sessionID), 4)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Mark older observations as superseded
|
||||
err = conflictStore.MarkObservationsSuperseded(ctx, []int64{olderID1, olderID2})
|
||||
require.NoError(t, err)
|
||||
|
||||
// Create conflicts with prefer_newer resolution
|
||||
now := time.Now()
|
||||
conflict1 := &models.ObservationConflict{
|
||||
NewerObsID: newerID1,
|
||||
OlderObsID: olderID1,
|
||||
ConflictType: models.ConflictSuperseded,
|
||||
Resolution: models.ResolutionPreferNewer,
|
||||
Reason: "reason1",
|
||||
DetectedAt: now.Format(time.RFC3339),
|
||||
DetectedAtEpoch: now.UnixMilli(),
|
||||
}
|
||||
_, err = conflictStore.StoreConflict(ctx, conflict1)
|
||||
require.NoError(t, err)
|
||||
|
||||
conflict2 := &models.ObservationConflict{
|
||||
NewerObsID: newerID2,
|
||||
OlderObsID: olderID2,
|
||||
ConflictType: models.ConflictSuperseded,
|
||||
Resolution: models.ResolutionPreferNewer,
|
||||
Reason: "reason2",
|
||||
DetectedAt: now.Format(time.RFC3339),
|
||||
DetectedAtEpoch: now.UnixMilli(),
|
||||
}
|
||||
_, err = conflictStore.StoreConflict(ctx, conflict2)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Get superseded IDs (should return older observation IDs)
|
||||
ids, err := conflictStore.GetSupersededObservationIDs(ctx, "test-project")
|
||||
require.NoError(t, err)
|
||||
assert.Len(t, ids, 2)
|
||||
assert.Contains(t, ids, olderID1)
|
||||
assert.Contains(t, ids, olderID2)
|
||||
}
|
||||
|
||||
func TestConflictStore_ResolveConflict(t *testing.T) {
|
||||
conflictStore, _, cleanup := testConflictStore(t)
|
||||
defer cleanup()
|
||||
|
||||
ctx := context.Background()
|
||||
|
||||
// Create a simple conflict by inserting directly to DB
|
||||
conflict := &ObservationConflict{
|
||||
NewerObsID: 1,
|
||||
OlderObsID: 2,
|
||||
ConflictType: models.ConflictContradicts,
|
||||
Resolution: models.ResolutionManual,
|
||||
DetectedAt: time.Now().Format(time.RFC3339),
|
||||
DetectedAtEpoch: time.Now().UnixMilli(),
|
||||
Resolved: 0,
|
||||
}
|
||||
conflictStore.db.Create(conflict)
|
||||
|
||||
// Resolve conflict
|
||||
err := conflictStore.ResolveConflict(ctx, conflict.ID, models.ResolutionPreferNewer)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Verify resolved
|
||||
var resolved ObservationConflict
|
||||
conflictStore.db.First(&resolved, conflict.ID)
|
||||
assert.Equal(t, 1, resolved.Resolved)
|
||||
assert.True(t, resolved.ResolvedAt.Valid)
|
||||
assert.Equal(t, models.ResolutionPreferNewer, resolved.Resolution)
|
||||
}
|
||||
|
||||
func TestConflictStore_DeleteConflictsByObservationID(t *testing.T) {
|
||||
conflictStore, _, cleanup := testConflictStore(t)
|
||||
defer cleanup()
|
||||
|
||||
ctx := context.Background()
|
||||
|
||||
// Create conflicts directly in DB
|
||||
now := time.Now()
|
||||
conflicts := []ObservationConflict{
|
||||
{
|
||||
NewerObsID: 1,
|
||||
OlderObsID: 2,
|
||||
ConflictType: models.ConflictContradicts,
|
||||
Resolution: models.ResolutionPreferNewer,
|
||||
DetectedAt: now.Format(time.RFC3339),
|
||||
DetectedAtEpoch: now.UnixMilli(),
|
||||
},
|
||||
{
|
||||
NewerObsID: 3,
|
||||
OlderObsID: 1,
|
||||
ConflictType: models.ConflictContradicts,
|
||||
Resolution: models.ResolutionPreferNewer,
|
||||
DetectedAt: now.Format(time.RFC3339),
|
||||
DetectedAtEpoch: now.UnixMilli(),
|
||||
},
|
||||
{
|
||||
NewerObsID: 2,
|
||||
OlderObsID: 3,
|
||||
ConflictType: models.ConflictContradicts,
|
||||
Resolution: models.ResolutionPreferNewer,
|
||||
DetectedAt: now.Format(time.RFC3339),
|
||||
DetectedAtEpoch: now.UnixMilli(),
|
||||
},
|
||||
}
|
||||
for _, c := range conflicts {
|
||||
conflictStore.db.Create(&c)
|
||||
}
|
||||
|
||||
// Delete conflicts for observation 1
|
||||
err := conflictStore.DeleteConflictsByObservationID(ctx, 1)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Verify only conflicts involving 1 are deleted
|
||||
var count int64
|
||||
conflictStore.db.Model(&ObservationConflict{}).
|
||||
Where("newer_obs_id = 1 OR older_obs_id = 1").
|
||||
Count(&count)
|
||||
assert.Equal(t, int64(0), count)
|
||||
|
||||
// Other conflict should still exist
|
||||
conflictStore.db.Model(&ObservationConflict{}).
|
||||
Where("newer_obs_id = 2 AND older_obs_id = 3").
|
||||
Count(&count)
|
||||
assert.Equal(t, int64(1), count)
|
||||
}
|
||||
|
||||
func TestConflictStore_CleanupSupersededObservations(t *testing.T) {
|
||||
conflictStore, store, cleanup := testConflictStore(t)
|
||||
defer cleanup()
|
||||
|
||||
ctx := context.Background()
|
||||
|
||||
// Create observations
|
||||
sessionStore := NewSessionStore(store)
|
||||
sessionID, err := sessionStore.CreateSDKSession(ctx, "claude-1", "test-project", "")
|
||||
require.NoError(t, err)
|
||||
|
||||
obsStore := NewObservationStore(store, nil, nil, nil)
|
||||
|
||||
// Create newer observations
|
||||
newer1 := &models.ParsedObservation{
|
||||
Type: models.ObsTypeDiscovery,
|
||||
Title: "Newer1",
|
||||
}
|
||||
newerID1, _, err := obsStore.StoreObservation(ctx, "claude-1", "test-project", newer1, int(sessionID), 1)
|
||||
require.NoError(t, err)
|
||||
|
||||
newer2 := &models.ParsedObservation{
|
||||
Type: models.ObsTypeDiscovery,
|
||||
Title: "Newer2",
|
||||
}
|
||||
newerID2, _, err := obsStore.StoreObservation(ctx, "claude-1", "test-project", newer2, int(sessionID), 2)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Create older observations
|
||||
older1 := &models.ParsedObservation{
|
||||
Type: models.ObsTypeDiscovery,
|
||||
Title: "OldSuperseded",
|
||||
}
|
||||
oldSupersededID, _, err := obsStore.StoreObservation(ctx, "claude-1", "test-project", older1, int(sessionID), 3)
|
||||
require.NoError(t, err)
|
||||
|
||||
older2 := &models.ParsedObservation{
|
||||
Type: models.ObsTypeDiscovery,
|
||||
Title: "RecentSuperseded",
|
||||
}
|
||||
recentSupersededID, _, err := obsStore.StoreObservation(ctx, "claude-1", "test-project", older2, int(sessionID), 4)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Mark as superseded
|
||||
err = conflictStore.MarkObservationsSuperseded(ctx, []int64{oldSupersededID, recentSupersededID})
|
||||
require.NoError(t, err)
|
||||
|
||||
// Create conflicts
|
||||
oldTime := time.Now().AddDate(0, 0, -SupersededRetentionDays-1)
|
||||
recentTime := time.Now().AddDate(0, 0, -1)
|
||||
|
||||
// Old conflict (should be deleted)
|
||||
oldConflict := &models.ObservationConflict{
|
||||
NewerObsID: newerID1,
|
||||
OlderObsID: oldSupersededID,
|
||||
ConflictType: models.ConflictSuperseded,
|
||||
Resolution: models.ResolutionPreferNewer,
|
||||
Reason: "old",
|
||||
DetectedAt: oldTime.Format(time.RFC3339),
|
||||
DetectedAtEpoch: oldTime.UnixMilli(),
|
||||
}
|
||||
_, err = conflictStore.StoreConflict(ctx, oldConflict)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Recent conflict (should be kept)
|
||||
recentConflict := &models.ObservationConflict{
|
||||
NewerObsID: newerID2,
|
||||
OlderObsID: recentSupersededID,
|
||||
ConflictType: models.ConflictSuperseded,
|
||||
Resolution: models.ResolutionPreferNewer,
|
||||
Reason: "recent",
|
||||
DetectedAt: recentTime.Format(time.RFC3339),
|
||||
DetectedAtEpoch: recentTime.UnixMilli(),
|
||||
}
|
||||
_, err = conflictStore.StoreConflict(ctx, recentConflict)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Cleanup old superseded observations
|
||||
deletedIDs, err := conflictStore.CleanupSupersededObservations(ctx, "test-project")
|
||||
require.NoError(t, err)
|
||||
assert.Len(t, deletedIDs, 1)
|
||||
assert.Contains(t, deletedIDs, oldSupersededID)
|
||||
|
||||
// Verify only old superseded observation was deleted
|
||||
var count int64
|
||||
store.DB.Model(&Observation{}).Count(&count)
|
||||
assert.Equal(t, int64(3), count) // newer1, newer2, recentSuperseded remain
|
||||
|
||||
// Verify old observation was deleted
|
||||
store.DB.Model(&Observation{}).Where("id = ?", oldSupersededID).Count(&count)
|
||||
assert.Equal(t, int64(0), count)
|
||||
}
|
||||
|
||||
func TestConflictStore_GetConflictsWithDetails(t *testing.T) {
|
||||
conflictStore, store, cleanup := testConflictStore(t)
|
||||
defer cleanup()
|
||||
|
||||
ctx := context.Background()
|
||||
|
||||
// Create observations
|
||||
sessionStore := NewSessionStore(store)
|
||||
sessionID, err := sessionStore.CreateSDKSession(ctx, "claude-1", "test-project", "")
|
||||
require.NoError(t, err)
|
||||
|
||||
obsStore := NewObservationStore(store, nil, nil, nil)
|
||||
|
||||
newer := &models.ParsedObservation{
|
||||
Type: models.ObsTypeDiscovery,
|
||||
Title: "Newer observation",
|
||||
}
|
||||
newerID, _, err := obsStore.StoreObservation(ctx, "claude-1", "test-project", newer, int(sessionID), 1)
|
||||
require.NoError(t, err)
|
||||
|
||||
older := &models.ParsedObservation{
|
||||
Type: models.ObsTypeDiscovery,
|
||||
Title: "Older observation",
|
||||
}
|
||||
olderID, _, err := obsStore.StoreObservation(ctx, "claude-1", "test-project", older, int(sessionID), 2)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Create conflict
|
||||
now := time.Now()
|
||||
conflict := &models.ObservationConflict{
|
||||
NewerObsID: newerID,
|
||||
OlderObsID: olderID,
|
||||
ConflictType: models.ConflictContradicts,
|
||||
Resolution: models.ResolutionPreferNewer,
|
||||
Reason: "Test conflict",
|
||||
DetectedAt: now.Format(time.RFC3339),
|
||||
DetectedAtEpoch: now.UnixMilli(),
|
||||
}
|
||||
_, err = conflictStore.StoreConflict(ctx, conflict)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Get conflicts with details
|
||||
conflicts, err := conflictStore.GetConflictsWithDetails(ctx, "test-project", 10)
|
||||
require.NoError(t, err)
|
||||
assert.Len(t, conflicts, 1)
|
||||
|
||||
// Verify conflict details
|
||||
assert.Equal(t, newerID, conflicts[0].Conflict.NewerObsID)
|
||||
assert.Equal(t, olderID, conflicts[0].Conflict.OlderObsID)
|
||||
assert.Equal(t, models.ConflictContradicts, conflicts[0].Conflict.ConflictType)
|
||||
assert.Equal(t, "Test conflict", conflicts[0].Conflict.Reason)
|
||||
assert.Equal(t, "Newer observation", conflicts[0].NewerObsTitle)
|
||||
assert.Equal(t, "Older observation", conflicts[0].OlderObsTitle)
|
||||
}
|
||||
@@ -0,0 +1,38 @@
|
||||
// Package gorm provides a GORM-based database implementation for claude-mnemonic.
|
||||
//
|
||||
// This is a drop-in replacement for internal/db/sqlite with the following benefits:
|
||||
// - 50% code reduction (8,500 → 4,250 lines)
|
||||
// - Type-safe query building
|
||||
// - Automatic statement caching
|
||||
// - Same performance characteristics
|
||||
// - Zero breaking changes
|
||||
//
|
||||
// Status: Production-ready, not yet integrated
|
||||
//
|
||||
// # Integration
|
||||
//
|
||||
// To use this package instead of internal/db/sqlite:
|
||||
//
|
||||
// import "github.com/lukaszraczylo/claude-mnemonic/internal/db/gorm"
|
||||
//
|
||||
// store, err := gorm.NewStore(gorm.Config{
|
||||
// Path: "/path/to/database.db",
|
||||
// MaxConns: 4,
|
||||
// LogLevel: logger.Silent,
|
||||
// })
|
||||
//
|
||||
// See INTEGRATION_GUIDE.md for complete migration instructions.
|
||||
//
|
||||
// # Testing
|
||||
//
|
||||
// All tests require the fts5 build tag:
|
||||
//
|
||||
// go test -tags "fts5" -v ./internal/db/gorm
|
||||
//
|
||||
// # Performance
|
||||
//
|
||||
// See PERFORMANCE.md for detailed benchmark results.
|
||||
package gorm
|
||||
|
||||
// This file exists for package documentation and to prevent deadcode warnings
|
||||
// on an intentionally unused (but complete and tested) implementation.
|
||||
@@ -0,0 +1,42 @@
|
||||
//go:build fts5
|
||||
|
||||
package gorm
|
||||
|
||||
import (
|
||||
"database/sql"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"testing"
|
||||
|
||||
_ "github.com/mattn/go-sqlite3"
|
||||
)
|
||||
|
||||
// TestFTS5Available verifies FTS5 is available in mattn/go-sqlite3
|
||||
func TestFTS5Available(t *testing.T) {
|
||||
tmpDir, err := os.MkdirTemp("", "fts5_test_*")
|
||||
if err != nil {
|
||||
t.Fatalf("create temp dir: %v", err)
|
||||
}
|
||||
defer os.RemoveAll(tmpDir)
|
||||
|
||||
dbPath := filepath.Join(tmpDir, "test.db")
|
||||
|
||||
// Open with mattn/go-sqlite3 driver
|
||||
db, err := sql.Open("sqlite3", dbPath)
|
||||
if err != nil {
|
||||
t.Fatalf("open database: %v", err)
|
||||
}
|
||||
defer db.Close()
|
||||
|
||||
// Try to create FTS5 virtual table
|
||||
_, err = db.Exec(`
|
||||
CREATE VIRTUAL TABLE test_fts USING fts5(
|
||||
content
|
||||
)
|
||||
`)
|
||||
if err != nil {
|
||||
t.Fatalf("create FTS5 table failed: %v", err)
|
||||
}
|
||||
|
||||
t.Logf("✅ FTS5 is available in mattn/go-sqlite3")
|
||||
}
|
||||
@@ -0,0 +1,64 @@
|
||||
// Package gorm provides GORM-based database operations for claude-mnemonic.
|
||||
package gorm
|
||||
|
||||
import (
|
||||
"context"
|
||||
"database/sql"
|
||||
"net/http"
|
||||
"strconv"
|
||||
"time"
|
||||
|
||||
"gorm.io/gorm"
|
||||
)
|
||||
|
||||
// EnsureSessionExists creates a session if it doesn't exist.
|
||||
// This is shared between stores to avoid duplication.
|
||||
func EnsureSessionExists(ctx context.Context, db *gorm.DB, sdkSessionID, project string) error {
|
||||
// Check if session exists
|
||||
var count int64
|
||||
err := db.WithContext(ctx).
|
||||
Model(&SDKSession{}).
|
||||
Where("sdk_session_id = ?", sdkSessionID).
|
||||
Count(&count).Error
|
||||
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if count > 0 {
|
||||
return nil // Session exists
|
||||
}
|
||||
|
||||
// Auto-create session
|
||||
now := time.Now()
|
||||
session := &SDKSession{
|
||||
ClaudeSessionID: sdkSessionID,
|
||||
SDKSessionID: sqlNullString(sdkSessionID),
|
||||
Project: project,
|
||||
Status: "active",
|
||||
StartedAt: now.Format(time.RFC3339),
|
||||
StartedAtEpoch: now.UnixMilli(),
|
||||
PromptCounter: 0,
|
||||
}
|
||||
|
||||
return db.WithContext(ctx).Create(session).Error
|
||||
}
|
||||
|
||||
// sqlNullString creates a sql.NullString from a string.
|
||||
func sqlNullString(s string) sql.NullString {
|
||||
if s == "" {
|
||||
return sql.NullString{Valid: false}
|
||||
}
|
||||
return sql.NullString{String: s, Valid: true}
|
||||
}
|
||||
|
||||
// ParseLimitParam parses the "limit" query parameter from an HTTP request.
|
||||
// Returns defaultLimit if the parameter is missing or invalid.
|
||||
func ParseLimitParam(r *http.Request, defaultLimit int) int {
|
||||
if l := r.URL.Query().Get("limit"); l != "" {
|
||||
if parsed, err := strconv.Atoi(l); err == nil && parsed > 0 {
|
||||
return parsed
|
||||
}
|
||||
}
|
||||
return defaultLimit
|
||||
}
|
||||
@@ -0,0 +1,343 @@
|
||||
//go:build fts5
|
||||
|
||||
// Package gorm provides GORM-based database operations for claude-mnemonic.
|
||||
package gorm
|
||||
|
||||
import (
|
||||
"context"
|
||||
"database/sql"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
"gorm.io/gorm/logger"
|
||||
|
||||
"github.com/lukaszraczylo/claude-mnemonic/pkg/models"
|
||||
)
|
||||
|
||||
// TestIntegration_EndToEndWorkflow verifies a complete workflow
|
||||
// simulating real usage of the GORM package.
|
||||
func TestIntegration_EndToEndWorkflow(t *testing.T) {
|
||||
tmpDir, err := os.MkdirTemp("", "gorm_integration_test_*")
|
||||
require.NoError(t, err)
|
||||
defer os.RemoveAll(tmpDir)
|
||||
|
||||
dbPath := filepath.Join(tmpDir, "test.db")
|
||||
cfg := Config{
|
||||
Path: dbPath,
|
||||
MaxConns: 4,
|
||||
LogLevel: logger.Silent,
|
||||
}
|
||||
|
||||
// Step 1: Initialize store
|
||||
store, err := NewStore(cfg)
|
||||
require.NoError(t, err)
|
||||
defer store.Close()
|
||||
|
||||
ctx := context.Background()
|
||||
|
||||
// Step 2: Create all store types
|
||||
sessionStore := NewSessionStore(store)
|
||||
summaryStore := NewSummaryStore(store)
|
||||
conflictStore := NewConflictStore(store)
|
||||
relationStore := NewRelationStore(store)
|
||||
patternStore := NewPatternStore(store)
|
||||
|
||||
// Create observation store with dependencies
|
||||
observationStore := NewObservationStore(store, nil, conflictStore, relationStore)
|
||||
promptStore := NewPromptStore(store, nil)
|
||||
|
||||
// Step 3: Create a session
|
||||
sessionID, err := sessionStore.CreateSDKSession(ctx, "claude-test", "test-project", "")
|
||||
require.NoError(t, err)
|
||||
assert.Greater(t, sessionID, int64(0))
|
||||
|
||||
// Step 4: Store observations
|
||||
obs1 := &models.ParsedObservation{
|
||||
Type: models.ObsTypeDiscovery,
|
||||
Title: "Test Discovery",
|
||||
Subtitle: "Testing GORM integration",
|
||||
Facts: []string{"Fact 1", "Fact 2"},
|
||||
Concepts: []string{"testing", "integration"},
|
||||
}
|
||||
|
||||
obsID1, _, err := observationStore.StoreObservation(ctx, "claude-test", "test-project", obs1, int(sessionID), 1)
|
||||
require.NoError(t, err)
|
||||
assert.Greater(t, obsID1, int64(0))
|
||||
|
||||
obs2 := &models.ParsedObservation{
|
||||
Type: models.ObsTypeBugfix,
|
||||
Title: "Test Bugfix",
|
||||
Facts: []string{"Fixed bug"},
|
||||
Concepts: []string{"bugfix"},
|
||||
}
|
||||
|
||||
obsID2, _, err := observationStore.StoreObservation(ctx, "claude-test", "test-project", obs2, int(sessionID), 2)
|
||||
require.NoError(t, err)
|
||||
assert.Greater(t, obsID2, int64(0))
|
||||
|
||||
// Step 5: Create relations
|
||||
now := time.Now()
|
||||
relation := &models.ObservationRelation{
|
||||
SourceID: obsID1,
|
||||
TargetID: obsID2,
|
||||
RelationType: models.RelationCauses,
|
||||
Confidence: 0.8,
|
||||
DetectionSource: models.DetectionSourceFileOverlap,
|
||||
CreatedAt: now.Format(time.RFC3339),
|
||||
CreatedAtEpoch: now.UnixMilli(),
|
||||
}
|
||||
|
||||
relID, err := relationStore.StoreRelation(ctx, relation)
|
||||
require.NoError(t, err)
|
||||
assert.Greater(t, relID, int64(0))
|
||||
|
||||
// Step 6: Update importance scores
|
||||
err = observationStore.UpdateImportanceScore(ctx, obsID1, 5.0)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Step 7: Increment retrieval counts
|
||||
err = observationStore.IncrementRetrievalCount(ctx, []int64{obsID1, obsID2})
|
||||
require.NoError(t, err)
|
||||
|
||||
// Step 8: Create a pattern
|
||||
pattern := &models.Pattern{
|
||||
Name: "Test Pattern",
|
||||
Type: models.PatternTypeBug,
|
||||
Signature: []string{"bug", "fix"},
|
||||
Frequency: 1,
|
||||
Projects: []string{"test-project"},
|
||||
ObservationIDs: []int64{obsID1, obsID2},
|
||||
Status: models.PatternStatusActive,
|
||||
Confidence: 0.75,
|
||||
LastSeenAt: now.Format(time.RFC3339),
|
||||
LastSeenEpoch: now.UnixMilli(),
|
||||
CreatedAt: now.Format(time.RFC3339),
|
||||
CreatedAtEpoch: now.UnixMilli(),
|
||||
}
|
||||
|
||||
patternID, err := patternStore.StorePattern(ctx, pattern)
|
||||
require.NoError(t, err)
|
||||
assert.Greater(t, patternID, int64(0))
|
||||
|
||||
// Step 9: Store a prompt
|
||||
promptID, err := promptStore.SaveUserPromptWithMatches(ctx, "claude-test", 1, "Test prompt", 2)
|
||||
require.NoError(t, err)
|
||||
assert.Greater(t, promptID, int64(0))
|
||||
|
||||
// Step 10: Store a summary
|
||||
summary := &models.ParsedSummary{
|
||||
Request: "Test request",
|
||||
Investigated: "Test investigation",
|
||||
Learned: "Test learning",
|
||||
Completed: "Test completion",
|
||||
NextSteps: "Test next steps",
|
||||
Notes: "Test notes",
|
||||
}
|
||||
|
||||
summaryID, _, err := summaryStore.StoreSummary(ctx, "claude-test", "test-project", summary, 1, 100)
|
||||
require.NoError(t, err)
|
||||
assert.Greater(t, summaryID, int64(0))
|
||||
|
||||
// Step 11: Verify data retrieval
|
||||
retrievedObs, err := observationStore.GetObservationByID(ctx, obsID1)
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, retrievedObs)
|
||||
assert.Equal(t, "Test Discovery", retrievedObs.Title.String)
|
||||
assert.Equal(t, 5.0, retrievedObs.ImportanceScore)
|
||||
assert.Equal(t, 1, retrievedObs.RetrievalCount)
|
||||
|
||||
// Step 12: Verify relations
|
||||
relations, err := relationStore.GetRelationsByObservationID(ctx, obsID1)
|
||||
require.NoError(t, err)
|
||||
assert.Len(t, relations, 1)
|
||||
assert.Equal(t, obsID2, relations[0].TargetID)
|
||||
|
||||
// Step 13: Verify pattern
|
||||
retrievedPattern, err := patternStore.GetPatternByID(ctx, patternID)
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, retrievedPattern)
|
||||
assert.Equal(t, "Test Pattern", retrievedPattern.Name)
|
||||
|
||||
// Step 14: Verify stats
|
||||
stats, err := observationStore.GetObservationFeedbackStats(ctx, "test-project")
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, 2, stats.Total)
|
||||
|
||||
t.Log("✅ End-to-end integration test passed!")
|
||||
}
|
||||
|
||||
// TestIntegration_StoreCompatibility verifies that Store methods work correctly.
|
||||
func TestIntegration_StoreCompatibility(t *testing.T) {
|
||||
tmpDir, err := os.MkdirTemp("", "gorm_store_test_*")
|
||||
require.NoError(t, err)
|
||||
defer os.RemoveAll(tmpDir)
|
||||
|
||||
dbPath := filepath.Join(tmpDir, "test.db")
|
||||
cfg := Config{
|
||||
Path: dbPath,
|
||||
MaxConns: 4,
|
||||
LogLevel: logger.Silent,
|
||||
}
|
||||
|
||||
store, err := NewStore(cfg)
|
||||
require.NoError(t, err)
|
||||
defer store.Close()
|
||||
|
||||
// Verify raw DB access (needed for vector client)
|
||||
rawDB := store.GetRawDB()
|
||||
require.NotNil(t, rawDB)
|
||||
assert.IsType(t, &sql.DB{}, rawDB)
|
||||
|
||||
// Verify GORM DB access
|
||||
gormDB := store.GetDB()
|
||||
require.NotNil(t, gormDB)
|
||||
|
||||
// Verify Close works
|
||||
err = store.Close()
|
||||
require.NoError(t, err)
|
||||
}
|
||||
|
||||
// TestIntegration_ConcurrentAccess verifies thread-safe operations.
|
||||
func TestIntegration_ConcurrentAccess(t *testing.T) {
|
||||
tmpDir, err := os.MkdirTemp("", "gorm_concurrent_test_*")
|
||||
require.NoError(t, err)
|
||||
defer os.RemoveAll(tmpDir)
|
||||
|
||||
dbPath := filepath.Join(tmpDir, "test.db")
|
||||
cfg := Config{
|
||||
Path: dbPath,
|
||||
MaxConns: 4,
|
||||
LogLevel: logger.Silent,
|
||||
}
|
||||
|
||||
store, err := NewStore(cfg)
|
||||
require.NoError(t, err)
|
||||
defer store.Close()
|
||||
|
||||
sessionStore := NewSessionStore(store)
|
||||
ctx := context.Background()
|
||||
|
||||
// Create session
|
||||
sessionID, err := sessionStore.CreateSDKSession(ctx, "claude-concurrent", "test-project", "")
|
||||
require.NoError(t, err)
|
||||
|
||||
// Concurrent prompt counter increments
|
||||
done := make(chan bool)
|
||||
numGoroutines := 10
|
||||
|
||||
for i := 0; i < numGoroutines; i++ {
|
||||
go func() {
|
||||
_, err := sessionStore.IncrementPromptCounter(ctx, sessionID)
|
||||
assert.NoError(t, err)
|
||||
done <- true
|
||||
}()
|
||||
}
|
||||
|
||||
// Wait for all goroutines
|
||||
for i := 0; i < numGoroutines; i++ {
|
||||
<-done
|
||||
}
|
||||
|
||||
// Verify final count
|
||||
session, err := sessionStore.GetSessionByID(ctx, sessionID)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, int64(numGoroutines), int64(session.PromptCounter))
|
||||
|
||||
t.Log("✅ Concurrent access test passed!")
|
||||
}
|
||||
|
||||
// TestIntegration_WALMode verifies WAL mode is enabled.
|
||||
func TestIntegration_WALMode(t *testing.T) {
|
||||
tmpDir, err := os.MkdirTemp("", "gorm_wal_test_*")
|
||||
require.NoError(t, err)
|
||||
defer os.RemoveAll(tmpDir)
|
||||
|
||||
dbPath := filepath.Join(tmpDir, "test.db")
|
||||
cfg := Config{
|
||||
Path: dbPath,
|
||||
MaxConns: 4,
|
||||
LogLevel: logger.Silent,
|
||||
}
|
||||
|
||||
store, err := NewStore(cfg)
|
||||
require.NoError(t, err)
|
||||
defer store.Close()
|
||||
|
||||
// Check WAL mode via raw SQL
|
||||
var journalMode string
|
||||
err = store.GetRawDB().QueryRow("PRAGMA journal_mode").Scan(&journalMode)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, "wal", journalMode, "WAL mode should be enabled")
|
||||
|
||||
t.Log("✅ WAL mode verification passed!")
|
||||
}
|
||||
|
||||
// TestIntegration_FTS5Search verifies FTS5 functionality.
|
||||
func TestIntegration_FTS5Search(t *testing.T) {
|
||||
tmpDir, err := os.MkdirTemp("", "gorm_fts5_test_*")
|
||||
require.NoError(t, err)
|
||||
defer os.RemoveAll(tmpDir)
|
||||
|
||||
dbPath := filepath.Join(tmpDir, "test.db")
|
||||
cfg := Config{
|
||||
Path: dbPath,
|
||||
MaxConns: 4,
|
||||
LogLevel: logger.Silent,
|
||||
}
|
||||
|
||||
store, err := NewStore(cfg)
|
||||
require.NoError(t, err)
|
||||
defer store.Close()
|
||||
|
||||
ctx := context.Background()
|
||||
sessionStore := NewSessionStore(store)
|
||||
observationStore := NewObservationStore(store, nil, nil, nil)
|
||||
|
||||
// Create session
|
||||
sessionID, _ := sessionStore.CreateSDKSession(ctx, "claude-fts5", "test-project", "")
|
||||
|
||||
// Store observations with searchable text
|
||||
obs1 := &models.ParsedObservation{
|
||||
Type: models.ObsTypeDiscovery,
|
||||
Title: "Database optimization techniques",
|
||||
Subtitle: "Improving query performance",
|
||||
Facts: []string{"Use indexes", "Optimize queries"},
|
||||
Concepts: []string{"performance", "optimization"},
|
||||
}
|
||||
|
||||
obsID1, _, _ := observationStore.StoreObservation(ctx, "claude-fts5", "test-project", obs1, int(sessionID), 1)
|
||||
|
||||
obs2 := &models.ParsedObservation{
|
||||
Type: models.ObsTypeBugfix,
|
||||
Title: "Fixed memory leak",
|
||||
Facts: []string{"Closed connections properly"},
|
||||
Concepts: []string{"bugfix", "memory"},
|
||||
}
|
||||
|
||||
observationStore.StoreObservation(ctx, "claude-fts5", "test-project", obs2, int(sessionID), 2)
|
||||
|
||||
// Give FTS5 triggers time to process
|
||||
time.Sleep(100 * time.Millisecond)
|
||||
|
||||
// Search using FTS5
|
||||
results, err := observationStore.SearchObservationsFTS(ctx, "optimization", "test-project", 10)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Should find the optimization observation
|
||||
assert.NotEmpty(t, results, "FTS5 search should return results")
|
||||
|
||||
found := false
|
||||
for _, obs := range results {
|
||||
if obs.ID == obsID1 {
|
||||
found = true
|
||||
break
|
||||
}
|
||||
}
|
||||
assert.True(t, found, "FTS5 should find the optimization observation")
|
||||
|
||||
t.Log("✅ FTS5 search test passed!")
|
||||
}
|
||||
@@ -0,0 +1,332 @@
|
||||
// Package gorm provides GORM-based database operations for claude-mnemonic.
|
||||
package gorm
|
||||
|
||||
import (
|
||||
"database/sql"
|
||||
"fmt"
|
||||
"time"
|
||||
|
||||
"github.com/go-gormigrate/gormigrate/v2"
|
||||
"gorm.io/gorm"
|
||||
"gorm.io/gorm/clause"
|
||||
)
|
||||
|
||||
// runMigrations runs all database migrations using gormigrate.
|
||||
func runMigrations(db *gorm.DB, sqlDB *sql.DB) error {
|
||||
m := gormigrate.New(db, gormigrate.DefaultOptions, []*gormigrate.Migration{
|
||||
// Migration 001: Core tables (SDKSession, Observation, SessionSummary)
|
||||
{
|
||||
ID: "001_core_tables",
|
||||
Migrate: func(tx *gorm.DB) error {
|
||||
// AutoMigrate creates tables with all indexes from struct tags
|
||||
if err := tx.AutoMigrate(&SDKSession{}); err != nil {
|
||||
return err
|
||||
}
|
||||
if err := tx.AutoMigrate(&Observation{}); err != nil {
|
||||
return err
|
||||
}
|
||||
return tx.AutoMigrate(&SessionSummary{})
|
||||
},
|
||||
Rollback: func(tx *gorm.DB) error {
|
||||
return tx.Migrator().DropTable("sdk_sessions", "observations", "session_summaries")
|
||||
},
|
||||
},
|
||||
|
||||
// Migration 002: User prompts table
|
||||
{
|
||||
ID: "002_user_prompts",
|
||||
Migrate: func(tx *gorm.DB) error {
|
||||
return tx.AutoMigrate(&UserPrompt{})
|
||||
},
|
||||
Rollback: func(tx *gorm.DB) error {
|
||||
return tx.Migrator().DropTable("user_prompts")
|
||||
},
|
||||
},
|
||||
|
||||
// Migration 003: FTS5 virtual table for user prompts
|
||||
{
|
||||
ID: "003_user_prompts_fts",
|
||||
Migrate: func(tx *gorm.DB) error {
|
||||
sqls := []string{
|
||||
`CREATE VIRTUAL TABLE IF NOT EXISTS user_prompts_fts USING fts5(
|
||||
prompt_text,
|
||||
content='user_prompts',
|
||||
content_rowid='id'
|
||||
)`,
|
||||
`CREATE TRIGGER IF NOT EXISTS user_prompts_ai AFTER INSERT ON user_prompts BEGIN
|
||||
INSERT INTO user_prompts_fts(rowid, prompt_text)
|
||||
VALUES (new.id, new.prompt_text);
|
||||
END`,
|
||||
`CREATE TRIGGER IF NOT EXISTS user_prompts_ad AFTER DELETE ON user_prompts BEGIN
|
||||
INSERT INTO user_prompts_fts(user_prompts_fts, rowid, prompt_text)
|
||||
VALUES('delete', old.id, old.prompt_text);
|
||||
END`,
|
||||
`CREATE TRIGGER IF NOT EXISTS user_prompts_au AFTER UPDATE ON user_prompts BEGIN
|
||||
INSERT INTO user_prompts_fts(user_prompts_fts, rowid, prompt_text)
|
||||
VALUES('delete', old.id, old.prompt_text);
|
||||
INSERT INTO user_prompts_fts(rowid, prompt_text)
|
||||
VALUES (new.id, new.prompt_text);
|
||||
END`,
|
||||
}
|
||||
for _, s := range sqls {
|
||||
if err := tx.Exec(s).Error; err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
return nil
|
||||
},
|
||||
Rollback: func(tx *gorm.DB) error {
|
||||
sqls := []string{
|
||||
"DROP TRIGGER IF EXISTS user_prompts_au",
|
||||
"DROP TRIGGER IF EXISTS user_prompts_ad",
|
||||
"DROP TRIGGER IF EXISTS user_prompts_ai",
|
||||
"DROP TABLE IF EXISTS user_prompts_fts",
|
||||
}
|
||||
for _, s := range sqls {
|
||||
if err := tx.Exec(s).Error; err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
return nil
|
||||
},
|
||||
},
|
||||
|
||||
// Migration 004: FTS5 virtual table for observations
|
||||
{
|
||||
ID: "004_observations_fts",
|
||||
Migrate: func(tx *gorm.DB) error {
|
||||
sqls := []string{
|
||||
`CREATE VIRTUAL TABLE IF NOT EXISTS observations_fts USING fts5(
|
||||
title, subtitle, narrative,
|
||||
content='observations',
|
||||
content_rowid='id'
|
||||
)`,
|
||||
`CREATE TRIGGER IF NOT EXISTS observations_ai AFTER INSERT ON observations BEGIN
|
||||
INSERT INTO observations_fts(rowid, title, subtitle, narrative)
|
||||
VALUES (new.id, new.title, new.subtitle, new.narrative);
|
||||
END`,
|
||||
`CREATE TRIGGER IF NOT EXISTS observations_ad AFTER DELETE ON observations BEGIN
|
||||
INSERT INTO observations_fts(observations_fts, rowid, title, subtitle, narrative)
|
||||
VALUES('delete', old.id, old.title, old.subtitle, old.narrative);
|
||||
END`,
|
||||
`CREATE TRIGGER IF NOT EXISTS observations_au AFTER UPDATE ON observations BEGIN
|
||||
INSERT INTO observations_fts(observations_fts, rowid, title, subtitle, narrative)
|
||||
VALUES('delete', old.id, old.title, old.subtitle, old.narrative);
|
||||
INSERT INTO observations_fts(rowid, title, subtitle, narrative)
|
||||
VALUES (new.id, new.title, new.subtitle, new.narrative);
|
||||
END`,
|
||||
}
|
||||
for _, s := range sqls {
|
||||
if err := tx.Exec(s).Error; err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
return nil
|
||||
},
|
||||
Rollback: func(tx *gorm.DB) error {
|
||||
sqls := []string{
|
||||
"DROP TRIGGER IF EXISTS observations_au",
|
||||
"DROP TRIGGER IF EXISTS observations_ad",
|
||||
"DROP TRIGGER IF EXISTS observations_ai",
|
||||
"DROP TABLE IF EXISTS observations_fts",
|
||||
}
|
||||
for _, s := range sqls {
|
||||
if err := tx.Exec(s).Error; err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
return nil
|
||||
},
|
||||
},
|
||||
|
||||
// Migration 005: FTS5 virtual table for session summaries
|
||||
{
|
||||
ID: "005_session_summaries_fts",
|
||||
Migrate: func(tx *gorm.DB) error {
|
||||
sqls := []string{
|
||||
`CREATE VIRTUAL TABLE IF NOT EXISTS session_summaries_fts USING fts5(
|
||||
request, investigated, learned, completed, next_steps, notes,
|
||||
content='session_summaries',
|
||||
content_rowid='id'
|
||||
)`,
|
||||
`CREATE TRIGGER IF NOT EXISTS session_summaries_ai AFTER INSERT ON session_summaries BEGIN
|
||||
INSERT INTO session_summaries_fts(rowid, request, investigated, learned, completed, next_steps, notes)
|
||||
VALUES (new.id, new.request, new.investigated, new.learned, new.completed, new.next_steps, new.notes);
|
||||
END`,
|
||||
`CREATE TRIGGER IF NOT EXISTS session_summaries_ad AFTER DELETE ON session_summaries BEGIN
|
||||
INSERT INTO session_summaries_fts(session_summaries_fts, rowid, request, investigated, learned, completed, next_steps, notes)
|
||||
VALUES('delete', old.id, old.request, old.investigated, old.learned, old.completed, old.next_steps, old.notes);
|
||||
END`,
|
||||
`CREATE TRIGGER IF NOT EXISTS session_summaries_au AFTER UPDATE ON session_summaries BEGIN
|
||||
INSERT INTO session_summaries_fts(session_summaries_fts, rowid, request, investigated, learned, completed, next_steps, notes)
|
||||
VALUES('delete', old.id, old.request, old.investigated, old.learned, old.completed, old.next_steps, old.notes);
|
||||
INSERT INTO session_summaries_fts(rowid, request, investigated, learned, completed, next_steps, notes)
|
||||
VALUES (new.id, new.request, new.investigated, new.learned, new.completed, new.next_steps, new.notes);
|
||||
END`,
|
||||
}
|
||||
for _, s := range sqls {
|
||||
if err := tx.Exec(s).Error; err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
return nil
|
||||
},
|
||||
Rollback: func(tx *gorm.DB) error {
|
||||
sqls := []string{
|
||||
"DROP TRIGGER IF EXISTS session_summaries_au",
|
||||
"DROP TRIGGER IF EXISTS session_summaries_ad",
|
||||
"DROP TRIGGER IF EXISTS session_summaries_ai",
|
||||
"DROP TABLE IF EXISTS session_summaries_fts",
|
||||
}
|
||||
for _, s := range sqls {
|
||||
if err := tx.Exec(s).Error; err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
return nil
|
||||
},
|
||||
},
|
||||
|
||||
// Migration 006: sqlite-vec vectors table
|
||||
{
|
||||
ID: "006_sqlite_vec_vectors",
|
||||
Migrate: func(tx *gorm.DB) error {
|
||||
// Note: Uses bge-small-en-v1.5 embeddings (384 dimensions) with model_version
|
||||
sql := `CREATE VIRTUAL TABLE IF NOT EXISTS vectors USING vec0(
|
||||
doc_id TEXT PRIMARY KEY,
|
||||
embedding float[384],
|
||||
sqlite_id INTEGER,
|
||||
doc_type TEXT,
|
||||
field_type TEXT,
|
||||
project TEXT,
|
||||
scope TEXT,
|
||||
model_version TEXT
|
||||
)`
|
||||
return tx.Exec(sql).Error
|
||||
},
|
||||
Rollback: func(tx *gorm.DB) error {
|
||||
return tx.Exec("DROP TABLE IF EXISTS vectors").Error
|
||||
},
|
||||
},
|
||||
|
||||
// Migration 007: Concept weights table with seed data
|
||||
{
|
||||
ID: "007_concept_weights",
|
||||
Migrate: func(tx *gorm.DB) error {
|
||||
if err := tx.AutoMigrate(&ConceptWeight{}); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// Seed default concept weights
|
||||
now := time.Now().Format(time.RFC3339)
|
||||
weights := []ConceptWeight{
|
||||
{Concept: "security", Weight: 0.30, UpdatedAt: now},
|
||||
{Concept: "gotcha", Weight: 0.25, UpdatedAt: now},
|
||||
{Concept: "best-practice", Weight: 0.20, UpdatedAt: now},
|
||||
{Concept: "anti-pattern", Weight: 0.20, UpdatedAt: now},
|
||||
{Concept: "architecture", Weight: 0.15, UpdatedAt: now},
|
||||
{Concept: "performance", Weight: 0.15, UpdatedAt: now},
|
||||
{Concept: "error-handling", Weight: 0.15, UpdatedAt: now},
|
||||
{Concept: "pattern", Weight: 0.10, UpdatedAt: now},
|
||||
{Concept: "testing", Weight: 0.10, UpdatedAt: now},
|
||||
{Concept: "debugging", Weight: 0.10, UpdatedAt: now},
|
||||
{Concept: "workflow", Weight: 0.05, UpdatedAt: now},
|
||||
{Concept: "tooling", Weight: 0.05, UpdatedAt: now},
|
||||
}
|
||||
|
||||
// INSERT OR IGNORE equivalent in GORM
|
||||
return tx.Clauses(clause.OnConflict{DoNothing: true}).Create(&weights).Error
|
||||
},
|
||||
Rollback: func(tx *gorm.DB) error {
|
||||
return tx.Migrator().DropTable("concept_weights")
|
||||
},
|
||||
},
|
||||
|
||||
// Migration 008: Observation conflicts table
|
||||
{
|
||||
ID: "008_observation_conflicts",
|
||||
Migrate: func(tx *gorm.DB) error {
|
||||
return tx.AutoMigrate(&ObservationConflict{})
|
||||
},
|
||||
Rollback: func(tx *gorm.DB) error {
|
||||
return tx.Migrator().DropTable("observation_conflicts")
|
||||
},
|
||||
},
|
||||
|
||||
// Migration 009: Patterns table
|
||||
{
|
||||
ID: "009_patterns",
|
||||
Migrate: func(tx *gorm.DB) error {
|
||||
return tx.AutoMigrate(&Pattern{})
|
||||
},
|
||||
Rollback: func(tx *gorm.DB) error {
|
||||
return tx.Migrator().DropTable("patterns")
|
||||
},
|
||||
},
|
||||
|
||||
// Migration 010: FTS5 virtual table for patterns
|
||||
{
|
||||
ID: "010_patterns_fts",
|
||||
Migrate: func(tx *gorm.DB) error {
|
||||
sqls := []string{
|
||||
`CREATE VIRTUAL TABLE IF NOT EXISTS patterns_fts USING fts5(
|
||||
name, description, recommendation,
|
||||
content='patterns',
|
||||
content_rowid='id'
|
||||
)`,
|
||||
`CREATE TRIGGER IF NOT EXISTS patterns_ai AFTER INSERT ON patterns BEGIN
|
||||
INSERT INTO patterns_fts(rowid, name, description, recommendation)
|
||||
VALUES (new.id, new.name, new.description, new.recommendation);
|
||||
END`,
|
||||
`CREATE TRIGGER IF NOT EXISTS patterns_ad AFTER DELETE ON patterns BEGIN
|
||||
INSERT INTO patterns_fts(patterns_fts, rowid, name, description, recommendation)
|
||||
VALUES('delete', old.id, old.name, old.description, old.recommendation);
|
||||
END`,
|
||||
`CREATE TRIGGER IF NOT EXISTS patterns_au AFTER UPDATE ON patterns BEGIN
|
||||
INSERT INTO patterns_fts(patterns_fts, rowid, name, description, recommendation)
|
||||
VALUES('delete', old.id, old.name, old.description, old.recommendation);
|
||||
INSERT INTO patterns_fts(rowid, name, description, recommendation)
|
||||
VALUES (new.id, new.name, new.description, new.recommendation);
|
||||
END`,
|
||||
}
|
||||
for _, s := range sqls {
|
||||
if err := tx.Exec(s).Error; err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
return nil
|
||||
},
|
||||
Rollback: func(tx *gorm.DB) error {
|
||||
sqls := []string{
|
||||
"DROP TRIGGER IF EXISTS patterns_au",
|
||||
"DROP TRIGGER IF EXISTS patterns_ad",
|
||||
"DROP TRIGGER IF EXISTS patterns_ai",
|
||||
"DROP TABLE IF EXISTS patterns_fts",
|
||||
}
|
||||
for _, s := range sqls {
|
||||
if err := tx.Exec(s).Error; err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
return nil
|
||||
},
|
||||
},
|
||||
|
||||
// Migration 011: Observation relations table
|
||||
{
|
||||
ID: "011_observation_relations",
|
||||
Migrate: func(tx *gorm.DB) error {
|
||||
return tx.AutoMigrate(&ObservationRelation{})
|
||||
},
|
||||
Rollback: func(tx *gorm.DB) error {
|
||||
return tx.Migrator().DropTable("observation_relations")
|
||||
},
|
||||
},
|
||||
})
|
||||
|
||||
if err := m.Migrate(); err != nil {
|
||||
return fmt.Errorf("run gormigrate migrations: %w", err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
@@ -0,0 +1,274 @@
|
||||
// Package gorm provides GORM-based database operations for claude-mnemonic.
|
||||
package gorm
|
||||
|
||||
import (
|
||||
"database/sql"
|
||||
"time"
|
||||
|
||||
"gorm.io/gorm"
|
||||
|
||||
"github.com/lukaszraczylo/claude-mnemonic/pkg/models"
|
||||
)
|
||||
|
||||
// GORM Models
|
||||
|
||||
// Note: JSON types (JSONStringArray, JSONInt64Map) are imported from pkg/models
|
||||
// and already implement sql.Scanner and driver.Valuer interfaces.
|
||||
|
||||
// SDKSession represents a Claude Code session.
|
||||
type SDKSession struct {
|
||||
ID int64 `gorm:"primaryKey;autoIncrement"`
|
||||
ClaudeSessionID string `gorm:"uniqueIndex;not null"`
|
||||
SDKSessionID sql.NullString `gorm:"uniqueIndex"`
|
||||
Project string `gorm:"index;not null"`
|
||||
UserPrompt sql.NullString
|
||||
WorkerPort sql.NullInt64
|
||||
PromptCounter int `gorm:"default:0"`
|
||||
Status string `gorm:"type:text;check:status IN ('active', 'completed', 'failed');default:'active';index"`
|
||||
StartedAt string `gorm:"not null"`
|
||||
StartedAtEpoch int64 `gorm:"index:idx_sessions_started,sort:desc;not null"`
|
||||
CompletedAt sql.NullString
|
||||
CompletedAtEpoch sql.NullInt64
|
||||
}
|
||||
|
||||
func (SDKSession) TableName() string { return "sdk_sessions" }
|
||||
|
||||
// BeforeCreate hook to ensure timestamps are set.
|
||||
func (s *SDKSession) BeforeCreate(tx *gorm.DB) error {
|
||||
if s.StartedAtEpoch == 0 {
|
||||
s.StartedAtEpoch = time.Now().UnixMilli()
|
||||
}
|
||||
if s.StartedAt == "" {
|
||||
s.StartedAt = time.Now().Format(time.RFC3339)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// Observation represents a stored observation (learning).
|
||||
type Observation struct {
|
||||
ID int64 `gorm:"primaryKey;autoIncrement"`
|
||||
SDKSessionID string `gorm:"index;not null"`
|
||||
Project string `gorm:"index;not null"`
|
||||
Scope models.ObservationScope `gorm:"type:text;default:'project';check:scope IN ('project', 'global');index:idx_observations_scope;index:idx_observations_project_scope,priority:2"`
|
||||
Type models.ObservationType `gorm:"type:text;check:type IN ('decision', 'bugfix', 'feature', 'refactor', 'discovery', 'change');index;not null"`
|
||||
|
||||
// Content fields
|
||||
Title sql.NullString `gorm:"type:text"`
|
||||
Subtitle sql.NullString `gorm:"type:text"`
|
||||
Facts models.JSONStringArray `gorm:"type:text"` // JSON array
|
||||
Narrative sql.NullString `gorm:"type:text"`
|
||||
Concepts models.JSONStringArray `gorm:"type:text"` // JSON array
|
||||
FilesRead models.JSONStringArray `gorm:"type:text"` // JSON array
|
||||
FilesModified models.JSONStringArray `gorm:"type:text"` // JSON array
|
||||
FileMtimes models.JSONInt64Map `gorm:"type:text"` // JSON object
|
||||
|
||||
// Metadata
|
||||
PromptNumber sql.NullInt64
|
||||
DiscoveryTokens int64 `gorm:"default:0"`
|
||||
CreatedAt string `gorm:"not null"`
|
||||
CreatedAtEpoch int64 `gorm:"index:idx_observations_created,sort:desc;not null"`
|
||||
|
||||
// Importance scoring fields
|
||||
ImportanceScore float64 `gorm:"type:real;default:1.0;index:idx_observations_importance,priority:1,sort:desc"`
|
||||
UserFeedback int `gorm:"default:0"`
|
||||
RetrievalCount int `gorm:"default:0"`
|
||||
LastRetrievedAt sql.NullInt64 `gorm:"column:last_retrieved_at_epoch"`
|
||||
ScoreUpdatedAt sql.NullInt64 `gorm:"column:score_updated_at_epoch;index:idx_observations_score_updated"`
|
||||
IsSuperseded int `gorm:"default:0;index:idx_observations_superseded,priority:1"`
|
||||
}
|
||||
|
||||
func (Observation) TableName() string { return "observations" }
|
||||
|
||||
// BeforeCreate hook to ensure defaults are set.
|
||||
func (o *Observation) BeforeCreate(tx *gorm.DB) error {
|
||||
if o.CreatedAtEpoch == 0 {
|
||||
o.CreatedAtEpoch = time.Now().UnixMilli()
|
||||
}
|
||||
if o.CreatedAt == "" {
|
||||
o.CreatedAt = time.Now().Format(time.RFC3339)
|
||||
}
|
||||
if o.ImportanceScore == 0 {
|
||||
o.ImportanceScore = 1.0
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// SessionSummary represents a session summary.
|
||||
type SessionSummary struct {
|
||||
ID int64 `gorm:"primaryKey;autoIncrement"`
|
||||
SDKSessionID string `gorm:"index;not null"`
|
||||
Project string `gorm:"index;not null"`
|
||||
|
||||
// Summary fields (nullable TEXT)
|
||||
Request sql.NullString
|
||||
Investigated sql.NullString
|
||||
Learned sql.NullString
|
||||
Completed sql.NullString
|
||||
NextSteps sql.NullString `gorm:"column:next_steps"`
|
||||
Notes sql.NullString
|
||||
|
||||
// Metadata
|
||||
PromptNumber sql.NullInt64
|
||||
DiscoveryTokens int64 `gorm:"default:0"`
|
||||
CreatedAt string `gorm:"not null"`
|
||||
CreatedAtEpoch int64 `gorm:"index:idx_summaries_created,sort:desc;not null"`
|
||||
}
|
||||
|
||||
func (SessionSummary) TableName() string { return "session_summaries" }
|
||||
|
||||
// BeforeCreate hook to ensure timestamps are set.
|
||||
func (s *SessionSummary) BeforeCreate(tx *gorm.DB) error {
|
||||
if s.CreatedAtEpoch == 0 {
|
||||
s.CreatedAtEpoch = time.Now().UnixMilli()
|
||||
}
|
||||
if s.CreatedAt == "" {
|
||||
s.CreatedAt = time.Now().Format(time.RFC3339)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// UserPrompt represents a user prompt.
|
||||
type UserPrompt struct {
|
||||
ID int64 `gorm:"primaryKey;autoIncrement"`
|
||||
ClaudeSessionID string `gorm:"index;not null;uniqueIndex:idx_user_prompts_session_number_unique,priority:1"`
|
||||
PromptNumber int `gorm:"index;not null;uniqueIndex:idx_user_prompts_session_number_unique,priority:2"`
|
||||
PromptText string `gorm:"type:text;not null"`
|
||||
MatchedObservations int `gorm:"default:0"`
|
||||
CreatedAt string `gorm:"not null"`
|
||||
CreatedAtEpoch int64 `gorm:"index:idx_prompts_created,sort:desc;not null"`
|
||||
}
|
||||
|
||||
func (UserPrompt) TableName() string { return "user_prompts" }
|
||||
|
||||
// BeforeCreate hook to ensure timestamps are set.
|
||||
func (p *UserPrompt) BeforeCreate(tx *gorm.DB) error {
|
||||
if p.CreatedAtEpoch == 0 {
|
||||
p.CreatedAtEpoch = time.Now().UnixMilli()
|
||||
}
|
||||
if p.CreatedAt == "" {
|
||||
p.CreatedAt = time.Now().Format(time.RFC3339)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// ObservationConflict tracks conflicts between observations.
|
||||
type ObservationConflict struct {
|
||||
ID int64 `gorm:"primaryKey;autoIncrement"`
|
||||
NewerObsID int64 `gorm:"index:idx_conflicts_newer;not null"`
|
||||
OlderObsID int64 `gorm:"index:idx_conflicts_older;not null"`
|
||||
ConflictType models.ConflictType `gorm:"type:text;check:conflict_type IN ('superseded', 'contradicts', 'outdated_pattern');not null"`
|
||||
Resolution models.ConflictResolution `gorm:"type:text;check:resolution IN ('prefer_newer', 'prefer_older', 'manual');not null"`
|
||||
Reason sql.NullString `gorm:"type:text"`
|
||||
DetectedAt string `gorm:"not null"`
|
||||
DetectedAtEpoch int64 `gorm:"index:idx_conflicts_unresolved,priority:2,sort:desc;not null"`
|
||||
Resolved int `gorm:"default:0;index:idx_conflicts_unresolved,priority:1"`
|
||||
ResolvedAt sql.NullString
|
||||
}
|
||||
|
||||
func (ObservationConflict) TableName() string { return "observation_conflicts" }
|
||||
|
||||
// BeforeCreate hook to ensure timestamps are set.
|
||||
func (c *ObservationConflict) BeforeCreate(tx *gorm.DB) error {
|
||||
if c.DetectedAtEpoch == 0 {
|
||||
c.DetectedAtEpoch = time.Now().UnixMilli()
|
||||
}
|
||||
if c.DetectedAt == "" {
|
||||
c.DetectedAt = time.Now().Format(time.RFC3339)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// ObservationRelation tracks relationships between observations.
|
||||
type ObservationRelation struct {
|
||||
ID int64 `gorm:"primaryKey;autoIncrement"`
|
||||
SourceID int64 `gorm:"index:idx_relations_source;index:idx_relations_both,priority:1;uniqueIndex:idx_relations_unique,priority:1;not null"`
|
||||
TargetID int64 `gorm:"index:idx_relations_target;index:idx_relations_both,priority:2;uniqueIndex:idx_relations_unique,priority:2;not null"`
|
||||
RelationType models.RelationType `gorm:"type:text;check:relation_type IN ('causes', 'fixes', 'supersedes', 'depends_on', 'relates_to', 'evolves_from');index:idx_relations_type;uniqueIndex:idx_relations_unique,priority:3;not null"`
|
||||
Confidence float64 `gorm:"type:real;default:0.5;index:idx_relations_confidence,sort:desc;not null"`
|
||||
DetectionSource models.RelationDetectionSource `gorm:"type:text;check:detection_source IN ('file_overlap', 'embedding_similarity', 'temporal_proximity', 'narrative_mention', 'concept_overlap', 'type_progression');not null"`
|
||||
Reason sql.NullString `gorm:"type:text"`
|
||||
CreatedAt string `gorm:"not null"`
|
||||
CreatedAtEpoch int64 `gorm:"not null"`
|
||||
}
|
||||
|
||||
func (ObservationRelation) TableName() string { return "observation_relations" }
|
||||
|
||||
// BeforeCreate hook to ensure timestamps are set.
|
||||
func (r *ObservationRelation) BeforeCreate(tx *gorm.DB) error {
|
||||
if r.CreatedAtEpoch == 0 {
|
||||
r.CreatedAtEpoch = time.Now().UnixMilli()
|
||||
}
|
||||
if r.CreatedAt == "" {
|
||||
r.CreatedAt = time.Now().Format(time.RFC3339)
|
||||
}
|
||||
if r.Confidence == 0 {
|
||||
r.Confidence = 0.5
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// Pattern represents a detected recurring pattern.
|
||||
type Pattern struct {
|
||||
ID int64 `gorm:"primaryKey;autoIncrement"`
|
||||
Name string `gorm:"type:text;not null"`
|
||||
Type models.PatternType `gorm:"type:text;check:type IN ('bug', 'refactor', 'architecture', 'anti-pattern', 'best-practice');index;not null"`
|
||||
Description sql.NullString `gorm:"type:text"`
|
||||
Signature models.JSONStringArray `gorm:"type:text"` // JSON array of keywords
|
||||
Recommendation sql.NullString `gorm:"type:text"`
|
||||
Frequency int `gorm:"default:1;index:idx_patterns_frequency,sort:desc"`
|
||||
Projects models.JSONStringArray `gorm:"type:text"` // JSON array
|
||||
ObservationIDs models.JSONInt64Array `gorm:"type:text"` // JSON array
|
||||
Status models.PatternStatus `gorm:"type:text;default:'active';check:status IN ('active', 'deprecated', 'merged');index"`
|
||||
MergedIntoID sql.NullInt64
|
||||
Confidence float64 `gorm:"type:real;default:0.5;index:idx_patterns_confidence,sort:desc"`
|
||||
LastSeenAt string `gorm:"not null"`
|
||||
LastSeenAtEpoch int64 `gorm:"index:idx_patterns_last_seen,sort:desc;not null"`
|
||||
CreatedAt string `gorm:"not null"`
|
||||
CreatedAtEpoch int64 `gorm:"not null"`
|
||||
}
|
||||
|
||||
func (Pattern) TableName() string { return "patterns" }
|
||||
|
||||
// BeforeCreate hook to ensure timestamps and defaults are set.
|
||||
func (p *Pattern) BeforeCreate(tx *gorm.DB) error {
|
||||
now := time.Now()
|
||||
if p.CreatedAtEpoch == 0 {
|
||||
p.CreatedAtEpoch = now.UnixMilli()
|
||||
}
|
||||
if p.CreatedAt == "" {
|
||||
p.CreatedAt = now.Format(time.RFC3339)
|
||||
}
|
||||
if p.LastSeenAtEpoch == 0 {
|
||||
p.LastSeenAtEpoch = now.UnixMilli()
|
||||
}
|
||||
if p.LastSeenAt == "" {
|
||||
p.LastSeenAt = now.Format(time.RFC3339)
|
||||
}
|
||||
if p.Confidence == 0 {
|
||||
p.Confidence = 0.5
|
||||
}
|
||||
if p.Frequency == 0 {
|
||||
p.Frequency = 1
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// ConceptWeight stores configurable weights for importance scoring.
|
||||
type ConceptWeight struct {
|
||||
Concept string `gorm:"primaryKey;type:text"`
|
||||
Weight float64 `gorm:"type:real;not null;default:0.1"`
|
||||
UpdatedAt string `gorm:"not null"`
|
||||
}
|
||||
|
||||
func (ConceptWeight) TableName() string { return "concept_weights" }
|
||||
|
||||
// BeforeCreate hook to ensure timestamp is set.
|
||||
func (c *ConceptWeight) BeforeCreate(tx *gorm.DB) error {
|
||||
if c.UpdatedAt == "" {
|
||||
c.UpdatedAt = time.Now().Format(time.RFC3339)
|
||||
}
|
||||
if c.Weight == 0 {
|
||||
c.Weight = 0.1
|
||||
}
|
||||
return nil
|
||||
}
|
||||
@@ -0,0 +1,563 @@
|
||||
// Package gorm provides GORM-based database operations for claude-mnemonic.
|
||||
package gorm
|
||||
|
||||
import (
|
||||
"context"
|
||||
"database/sql"
|
||||
"encoding/json"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"gorm.io/gorm"
|
||||
|
||||
"github.com/lukaszraczylo/claude-mnemonic/pkg/models"
|
||||
)
|
||||
|
||||
// MaxObservationsPerProject is the maximum number of observations to keep per project.
|
||||
const MaxObservationsPerProject = 100
|
||||
|
||||
// CleanupFunc is a callback for when observations are cleaned up.
|
||||
// Receives the IDs of deleted observations for downstream cleanup (e.g., vector DB).
|
||||
type CleanupFunc func(ctx context.Context, deletedIDs []int64)
|
||||
|
||||
// ObservationStore provides observation-related database operations using GORM.
|
||||
type ObservationStore struct {
|
||||
db *gorm.DB
|
||||
rawDB *sql.DB
|
||||
cleanupFunc CleanupFunc
|
||||
conflictStore interface{} // Placeholder for ConflictStore (Phase 4)
|
||||
relationStore interface{} // Placeholder for RelationStore (Phase 4)
|
||||
}
|
||||
|
||||
// NewObservationStore creates a new observation store.
|
||||
// The conflictStore and relationStore parameters are optional (can be nil) and will be used in Phase 4.
|
||||
func NewObservationStore(store *Store, cleanupFunc CleanupFunc, conflictStore, relationStore interface{}) *ObservationStore {
|
||||
return &ObservationStore{
|
||||
db: store.DB,
|
||||
rawDB: store.GetRawDB(),
|
||||
cleanupFunc: cleanupFunc,
|
||||
conflictStore: conflictStore,
|
||||
relationStore: relationStore,
|
||||
}
|
||||
}
|
||||
|
||||
// SetCleanupFunc sets the callback for when observations are deleted during cleanup.
|
||||
func (s *ObservationStore) SetCleanupFunc(fn CleanupFunc) {
|
||||
s.cleanupFunc = fn
|
||||
}
|
||||
|
||||
// StoreObservation stores a new observation.
|
||||
func (s *ObservationStore) StoreObservation(ctx context.Context, sdkSessionID, project string, obs *models.ParsedObservation, promptNumber int, discoveryTokens int64) (int64, int64, error) {
|
||||
now := time.Now()
|
||||
nowEpoch := now.UnixMilli()
|
||||
|
||||
// Ensure session exists (auto-create if missing)
|
||||
if err := EnsureSessionExists(ctx, s.db, sdkSessionID, project); err != nil {
|
||||
return 0, 0, err
|
||||
}
|
||||
|
||||
// Determine scope: use parsed scope if set, otherwise auto-determine from concepts
|
||||
scope := obs.Scope
|
||||
if scope == "" {
|
||||
scope = models.DetermineScope(obs.Concepts)
|
||||
}
|
||||
|
||||
dbObs := &Observation{
|
||||
SDKSessionID: sdkSessionID,
|
||||
Project: project,
|
||||
Scope: scope,
|
||||
Type: obs.Type,
|
||||
Title: nullString(obs.Title),
|
||||
Subtitle: nullString(obs.Subtitle),
|
||||
Facts: models.JSONStringArray(obs.Facts),
|
||||
Narrative: nullString(obs.Narrative),
|
||||
Concepts: models.JSONStringArray(obs.Concepts),
|
||||
FilesRead: models.JSONStringArray(obs.FilesRead),
|
||||
FilesModified: models.JSONStringArray(obs.FilesModified),
|
||||
FileMtimes: models.JSONInt64Map(obs.FileMtimes),
|
||||
PromptNumber: nullInt64(promptNumber),
|
||||
DiscoveryTokens: discoveryTokens,
|
||||
CreatedAt: now.Format(time.RFC3339),
|
||||
CreatedAtEpoch: nowEpoch,
|
||||
}
|
||||
|
||||
err := s.db.WithContext(ctx).Create(dbObs).Error
|
||||
if err != nil {
|
||||
return 0, 0, err
|
||||
}
|
||||
|
||||
// Cleanup old observations beyond the limit for this project (async to not block handler)
|
||||
if project != "" {
|
||||
go func(proj string) {
|
||||
cleanupCtx, cancel := context.WithTimeout(context.Background(), 10*time.Second)
|
||||
defer cancel()
|
||||
deletedIDs, _ := s.CleanupOldObservations(cleanupCtx, proj)
|
||||
if len(deletedIDs) > 0 && s.cleanupFunc != nil {
|
||||
s.cleanupFunc(cleanupCtx, deletedIDs)
|
||||
}
|
||||
}(project)
|
||||
}
|
||||
|
||||
// Note: Conflict and relation detection intentionally omitted for now
|
||||
// Will be added in Phase 4 when ConflictStore and RelationStore are implemented
|
||||
|
||||
return dbObs.ID, nowEpoch, nil
|
||||
}
|
||||
|
||||
// GetObservationByID retrieves an observation by its ID.
|
||||
func (s *ObservationStore) GetObservationByID(ctx context.Context, id int64) (*models.Observation, error) {
|
||||
var dbObs Observation
|
||||
err := s.db.WithContext(ctx).First(&dbObs, id).Error
|
||||
if err == gorm.ErrRecordNotFound {
|
||||
return nil, nil
|
||||
}
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return toModelObservation(&dbObs), nil
|
||||
}
|
||||
|
||||
// GetObservationsByIDs retrieves observations by a list of IDs.
|
||||
func (s *ObservationStore) GetObservationsByIDs(ctx context.Context, ids []int64, orderBy string, limit int) ([]*models.Observation, error) {
|
||||
if len(ids) == 0 {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
var dbObservations []Observation
|
||||
query := s.db.WithContext(ctx).Where("id IN ?", ids)
|
||||
|
||||
// Apply ordering
|
||||
switch orderBy {
|
||||
case "date_asc":
|
||||
query = query.Order("created_at_epoch ASC")
|
||||
case "date_desc":
|
||||
query = query.Order("created_at_epoch DESC")
|
||||
case "importance":
|
||||
query = query.Order("importance_score DESC, created_at_epoch DESC")
|
||||
default:
|
||||
// Default: importance first, then recency
|
||||
query = query.Order("COALESCE(importance_score, 1.0) DESC, created_at_epoch DESC")
|
||||
}
|
||||
|
||||
// Apply limit
|
||||
if limit > 0 {
|
||||
query = query.Limit(limit)
|
||||
}
|
||||
|
||||
err := query.Find(&dbObservations).Error
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return toModelObservations(dbObservations), nil
|
||||
}
|
||||
|
||||
// GetRecentObservations retrieves recent observations for a project.
|
||||
// This includes project-scoped observations for the specified project AND global observations.
|
||||
// Results are ordered by importance_score DESC, then created_at_epoch DESC.
|
||||
func (s *ObservationStore) GetRecentObservations(ctx context.Context, project string, limit int) ([]*models.Observation, error) {
|
||||
var dbObservations []Observation
|
||||
err := s.db.WithContext(ctx).
|
||||
Scopes(projectScopeFilter(project), importanceOrdering()).
|
||||
Limit(limit).
|
||||
Find(&dbObservations).Error
|
||||
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return toModelObservations(dbObservations), nil
|
||||
}
|
||||
|
||||
// GetActiveObservations retrieves recent non-superseded observations for a project.
|
||||
// This excludes observations that have been marked as superseded by newer ones.
|
||||
// Results are ordered by importance_score DESC, then created_at_epoch DESC.
|
||||
func (s *ObservationStore) GetActiveObservations(ctx context.Context, project string, limit int) ([]*models.Observation, error) {
|
||||
var dbObservations []Observation
|
||||
err := s.db.WithContext(ctx).
|
||||
Scopes(projectScopeFilter(project), notSupersededFilter(), importanceOrdering()).
|
||||
Limit(limit).
|
||||
Find(&dbObservations).Error
|
||||
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return toModelObservations(dbObservations), nil
|
||||
}
|
||||
|
||||
// GetSupersededObservations retrieves observations that have been superseded by newer ones.
|
||||
// Results are ordered by created_at_epoch DESC.
|
||||
func (s *ObservationStore) GetSupersededObservations(ctx context.Context, project string, limit int) ([]*models.Observation, error) {
|
||||
var dbObservations []Observation
|
||||
err := s.db.WithContext(ctx).
|
||||
Where("project = ? AND COALESCE(is_superseded, 0) = 1", project).
|
||||
Order("created_at_epoch DESC").
|
||||
Limit(limit).
|
||||
Find(&dbObservations).Error
|
||||
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return toModelObservations(dbObservations), nil
|
||||
}
|
||||
|
||||
// GetObservationsByProjectStrict retrieves observations for a project (strict - no global observations).
|
||||
func (s *ObservationStore) GetObservationsByProjectStrict(ctx context.Context, project string, limit int) ([]*models.Observation, error) {
|
||||
var dbObservations []Observation
|
||||
err := s.db.WithContext(ctx).
|
||||
Where("project = ?", project).
|
||||
Scopes(importanceOrdering()).
|
||||
Limit(limit).
|
||||
Find(&dbObservations).Error
|
||||
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return toModelObservations(dbObservations), nil
|
||||
}
|
||||
|
||||
// GetObservationCount returns the count of observations for a project.
|
||||
func (s *ObservationStore) GetObservationCount(ctx context.Context, project string) (int, error) {
|
||||
var count int64
|
||||
err := s.db.WithContext(ctx).
|
||||
Model(&Observation{}).
|
||||
Where("project = ?", project).
|
||||
Count(&count).Error
|
||||
|
||||
return int(count), err
|
||||
}
|
||||
|
||||
// GetAllRecentObservations retrieves recent observations across all projects.
|
||||
func (s *ObservationStore) GetAllRecentObservations(ctx context.Context, limit int) ([]*models.Observation, error) {
|
||||
var dbObservations []Observation
|
||||
err := s.db.WithContext(ctx).
|
||||
Scopes(importanceOrdering()).
|
||||
Limit(limit).
|
||||
Find(&dbObservations).Error
|
||||
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return toModelObservations(dbObservations), nil
|
||||
}
|
||||
|
||||
// GetAllObservations retrieves all observations (for vector rebuild).
|
||||
func (s *ObservationStore) GetAllObservations(ctx context.Context) ([]*models.Observation, error) {
|
||||
var dbObservations []Observation
|
||||
err := s.db.WithContext(ctx).
|
||||
Order("id").
|
||||
Find(&dbObservations).Error
|
||||
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return toModelObservations(dbObservations), nil
|
||||
}
|
||||
|
||||
// SearchObservationsFTS performs full-text search on observations using FTS5.
|
||||
// Falls back to LIKE search if FTS5 fails.
|
||||
func (s *ObservationStore) SearchObservationsFTS(ctx context.Context, query, project string, limit int) ([]*models.Observation, error) {
|
||||
if limit <= 0 {
|
||||
limit = 10
|
||||
}
|
||||
|
||||
// Extract keywords from the query
|
||||
keywords := extractKeywords(query)
|
||||
if len(keywords) == 0 {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
// Build FTS5 query: keyword1 OR keyword2 OR keyword3
|
||||
ftsTerms := strings.Join(keywords, " OR ")
|
||||
|
||||
// Use FTS5 via raw SQL (GORM can't handle FTS5 MATCH operator)
|
||||
ftsQuery := `
|
||||
SELECT o.id, o.sdk_session_id, o.project, COALESCE(o.scope, 'project') as scope, o.type,
|
||||
o.title, o.subtitle, o.facts, o.narrative, o.concepts, o.files_read, o.files_modified,
|
||||
o.file_mtimes, o.prompt_number, o.discovery_tokens, o.created_at, o.created_at_epoch,
|
||||
COALESCE(o.importance_score, 1.0) as importance_score,
|
||||
COALESCE(o.user_feedback, 0) as user_feedback,
|
||||
COALESCE(o.retrieval_count, 0) as retrieval_count,
|
||||
o.last_retrieved_at_epoch, o.score_updated_at_epoch,
|
||||
COALESCE(o.is_superseded, 0) as is_superseded
|
||||
FROM observations o
|
||||
JOIN observations_fts fts ON o.id = fts.rowid
|
||||
WHERE observations_fts MATCH ?
|
||||
AND (o.project = ? OR o.scope = 'global')
|
||||
ORDER BY rank, COALESCE(o.importance_score, 1.0) DESC
|
||||
LIMIT ?
|
||||
`
|
||||
|
||||
rows, err := s.rawDB.QueryContext(ctx, ftsQuery, ftsTerms, project, limit)
|
||||
if err != nil {
|
||||
// FTS failed, try LIKE fallback
|
||||
return s.searchObservationsLike(ctx, keywords, project, limit)
|
||||
}
|
||||
defer rows.Close()
|
||||
|
||||
observations, err := scanObservationRows(rows)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// If FTS returned nothing, try LIKE search
|
||||
if len(observations) == 0 {
|
||||
return s.searchObservationsLike(ctx, keywords, project, limit)
|
||||
}
|
||||
|
||||
return observations, nil
|
||||
}
|
||||
|
||||
// searchObservationsLike performs fallback LIKE search on observations using GORM.
|
||||
func (s *ObservationStore) searchObservationsLike(ctx context.Context, keywords []string, project string, limit int) ([]*models.Observation, error) {
|
||||
if len(keywords) == 0 {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
// Build LIKE conditions for each keyword
|
||||
var conditions []string
|
||||
var args []interface{}
|
||||
|
||||
for _, kw := range keywords {
|
||||
pattern := "%" + kw + "%"
|
||||
conditions = append(conditions, "(title LIKE ? OR subtitle LIKE ? OR narrative LIKE ?)")
|
||||
args = append(args, pattern, pattern, pattern)
|
||||
}
|
||||
|
||||
// Build WHERE clause
|
||||
whereClause := strings.Join(conditions, " OR ")
|
||||
fullWhere := "(" + whereClause + ") AND (project = ? OR scope = 'global')"
|
||||
args = append(args, project)
|
||||
|
||||
var dbObservations []Observation
|
||||
err := s.db.WithContext(ctx).
|
||||
Where(fullWhere, args...).
|
||||
Scopes(importanceOrdering()).
|
||||
Limit(limit).
|
||||
Find(&dbObservations).Error
|
||||
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return toModelObservations(dbObservations), nil
|
||||
}
|
||||
|
||||
// DeleteObservations deletes observations by IDs.
|
||||
func (s *ObservationStore) DeleteObservations(ctx context.Context, ids []int64) (int64, error) {
|
||||
if len(ids) == 0 {
|
||||
return 0, nil
|
||||
}
|
||||
|
||||
result := s.db.WithContext(ctx).Delete(&Observation{}, ids)
|
||||
return result.RowsAffected, result.Error
|
||||
}
|
||||
|
||||
// CleanupOldObservations removes observations beyond the limit for a project.
|
||||
// Returns the IDs of deleted observations.
|
||||
func (s *ObservationStore) CleanupOldObservations(ctx context.Context, project string) ([]int64, error) {
|
||||
// Use a transaction to prevent TOCTOU race condition
|
||||
var idsToDelete []int64
|
||||
|
||||
err := s.db.WithContext(ctx).Transaction(func(tx *gorm.DB) error {
|
||||
// Find IDs to keep (most recent MaxObservationsPerProject)
|
||||
var idsToKeep []int64
|
||||
err := tx.Model(&Observation{}).
|
||||
Where("project = ?", project).
|
||||
Order("created_at_epoch DESC").
|
||||
Limit(MaxObservationsPerProject).
|
||||
Pluck("id", &idsToKeep).Error
|
||||
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if len(idsToKeep) == 0 {
|
||||
return nil
|
||||
}
|
||||
|
||||
// Find IDs to delete (all IDs not in the keep list)
|
||||
// This happens in the same transaction to prevent race conditions
|
||||
err = tx.Model(&Observation{}).
|
||||
Where("project = ? AND id NOT IN ?", project, idsToKeep).
|
||||
Pluck("id", &idsToDelete).Error
|
||||
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if len(idsToDelete) == 0 {
|
||||
return nil
|
||||
}
|
||||
|
||||
// Delete the observations
|
||||
return tx.Delete(&Observation{}, idsToDelete).Error
|
||||
})
|
||||
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return idsToDelete, nil
|
||||
}
|
||||
|
||||
// ====================
|
||||
// GORM Scopes (Reusable Query Filters)
|
||||
// ====================
|
||||
|
||||
// projectScopeFilter filters observations by project scope.
|
||||
// Includes project-scoped observations for the specified project AND global observations.
|
||||
func projectScopeFilter(project string) func(*gorm.DB) *gorm.DB {
|
||||
return func(db *gorm.DB) *gorm.DB {
|
||||
return db.Where("(project = ? AND (scope IS NULL OR scope = 'project')) OR scope = 'global'", project)
|
||||
}
|
||||
}
|
||||
|
||||
// notSupersededFilter filters out superseded observations.
|
||||
func notSupersededFilter() func(*gorm.DB) *gorm.DB {
|
||||
return func(db *gorm.DB) *gorm.DB {
|
||||
return db.Where("COALESCE(is_superseded, 0) = 0")
|
||||
}
|
||||
}
|
||||
|
||||
// importanceOrdering orders by importance score DESC, then created_at_epoch DESC.
|
||||
func importanceOrdering() func(*gorm.DB) *gorm.DB {
|
||||
return func(db *gorm.DB) *gorm.DB {
|
||||
return db.Order("COALESCE(importance_score, 1.0) DESC, created_at_epoch DESC")
|
||||
}
|
||||
}
|
||||
|
||||
// ====================
|
||||
// Helper Functions
|
||||
// ====================
|
||||
|
||||
// extractKeywords extracts keywords from a search query.
|
||||
func extractKeywords(query string) []string {
|
||||
words := strings.Fields(strings.ToLower(query))
|
||||
var keywords []string
|
||||
|
||||
commonWords := map[string]bool{
|
||||
"the": true, "and": true, "or": true, "but": true, "in": true,
|
||||
"on": true, "at": true, "to": true, "for": true, "of": true,
|
||||
"with": true, "by": true, "from": true, "as": true, "is": true,
|
||||
"was": true, "are": true, "were": true, "be": true, "been": true,
|
||||
"being": true, "have": true, "has": true, "had": true, "do": true,
|
||||
"does": true, "did": true, "will": true, "would": true, "should": true,
|
||||
"could": true, "may": true, "might": true, "must": true, "can": true,
|
||||
}
|
||||
|
||||
for _, word := range words {
|
||||
// Skip short words and common words
|
||||
if len(word) <= 3 || commonWords[word] {
|
||||
continue
|
||||
}
|
||||
keywords = append(keywords, word)
|
||||
}
|
||||
|
||||
return keywords
|
||||
}
|
||||
|
||||
// scanObservationRows scans multiple observations from raw SQL rows.
|
||||
func scanObservationRows(rows *sql.Rows) ([]*models.Observation, error) {
|
||||
var observations []*models.Observation
|
||||
for rows.Next() {
|
||||
obs, err := scanObservation(rows)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
observations = append(observations, obs)
|
||||
}
|
||||
return observations, rows.Err()
|
||||
}
|
||||
|
||||
// scanObservation scans a single observation from a row scanner.
|
||||
func scanObservation(scanner interface{ Scan(...interface{}) error }) (*models.Observation, error) {
|
||||
var obs models.Observation
|
||||
var factsJSON, conceptsJSON, filesReadJSON, filesModifiedJSON, fileMtimesJSON []byte
|
||||
var isSuperseded int
|
||||
|
||||
err := scanner.Scan(
|
||||
&obs.ID, &obs.SDKSessionID, &obs.Project, &obs.Scope, &obs.Type,
|
||||
&obs.Title, &obs.Subtitle, &factsJSON, &obs.Narrative, &conceptsJSON,
|
||||
&filesReadJSON, &filesModifiedJSON, &fileMtimesJSON,
|
||||
&obs.PromptNumber, &obs.DiscoveryTokens, &obs.CreatedAt, &obs.CreatedAtEpoch,
|
||||
&obs.ImportanceScore, &obs.UserFeedback, &obs.RetrievalCount,
|
||||
&obs.LastRetrievedAt, &obs.ScoreUpdatedAt, &isSuperseded,
|
||||
)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// Unmarshal JSON fields (data comes from DB, should always be valid)
|
||||
if len(factsJSON) > 0 {
|
||||
_ = json.Unmarshal(factsJSON, &obs.Facts)
|
||||
}
|
||||
if len(conceptsJSON) > 0 {
|
||||
_ = json.Unmarshal(conceptsJSON, &obs.Concepts)
|
||||
}
|
||||
if len(filesReadJSON) > 0 {
|
||||
_ = json.Unmarshal(filesReadJSON, &obs.FilesRead)
|
||||
}
|
||||
if len(filesModifiedJSON) > 0 {
|
||||
_ = json.Unmarshal(filesModifiedJSON, &obs.FilesModified)
|
||||
}
|
||||
if len(fileMtimesJSON) > 0 {
|
||||
_ = json.Unmarshal(fileMtimesJSON, &obs.FileMtimes)
|
||||
}
|
||||
|
||||
// Convert int to bool for IsSuperseded
|
||||
obs.IsSuperseded = isSuperseded != 0
|
||||
|
||||
return &obs, nil
|
||||
}
|
||||
|
||||
// toModelObservation converts a GORM Observation to pkg/models.Observation.
|
||||
func toModelObservation(o *Observation) *models.Observation {
|
||||
return &models.Observation{
|
||||
ID: o.ID,
|
||||
SDKSessionID: o.SDKSessionID,
|
||||
Project: o.Project,
|
||||
Scope: o.Scope,
|
||||
Type: o.Type,
|
||||
Title: o.Title,
|
||||
Subtitle: o.Subtitle,
|
||||
Facts: o.Facts,
|
||||
Narrative: o.Narrative,
|
||||
Concepts: o.Concepts,
|
||||
FilesRead: o.FilesRead,
|
||||
FilesModified: o.FilesModified,
|
||||
FileMtimes: o.FileMtimes,
|
||||
PromptNumber: o.PromptNumber,
|
||||
DiscoveryTokens: o.DiscoveryTokens,
|
||||
CreatedAt: o.CreatedAt,
|
||||
CreatedAtEpoch: o.CreatedAtEpoch,
|
||||
ImportanceScore: o.ImportanceScore,
|
||||
UserFeedback: o.UserFeedback,
|
||||
RetrievalCount: o.RetrievalCount,
|
||||
LastRetrievedAt: o.LastRetrievedAt,
|
||||
ScoreUpdatedAt: o.ScoreUpdatedAt,
|
||||
IsSuperseded: o.IsSuperseded != 0, // Convert int to bool
|
||||
}
|
||||
}
|
||||
|
||||
// toModelObservations converts a slice of GORM Observation to pkg/models.Observation.
|
||||
func toModelObservations(observations []Observation) []*models.Observation {
|
||||
result := make([]*models.Observation, len(observations))
|
||||
for i := range observations {
|
||||
result[i] = toModelObservation(&observations[i])
|
||||
}
|
||||
return result
|
||||
}
|
||||
|
||||
// nullInt64 converts an int to sql.NullInt64.
|
||||
func nullInt64(val int) sql.NullInt64 {
|
||||
if val == 0 {
|
||||
return sql.NullInt64{Valid: false}
|
||||
}
|
||||
return sql.NullInt64{Int64: int64(val), Valid: true}
|
||||
}
|
||||
@@ -0,0 +1,593 @@
|
||||
//go:build fts5
|
||||
|
||||
// Package gorm provides GORM-based database operations for claude-mnemonic.
|
||||
package gorm
|
||||
|
||||
import (
|
||||
"context"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"sync"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
"gorm.io/gorm/logger"
|
||||
|
||||
"github.com/lukaszraczylo/claude-mnemonic/pkg/models"
|
||||
)
|
||||
|
||||
// testObservationStore creates an ObservationStore with a temporary database for testing.
|
||||
func testObservationStore(t *testing.T) (*ObservationStore, *Store, func()) {
|
||||
t.Helper()
|
||||
|
||||
tmpDir, err := os.MkdirTemp("", "gorm_observation_test_*")
|
||||
if err != nil {
|
||||
t.Fatalf("create temp dir: %v", err)
|
||||
}
|
||||
|
||||
dbPath := filepath.Join(tmpDir, "test.db")
|
||||
cfg := Config{
|
||||
Path: dbPath,
|
||||
MaxConns: 4,
|
||||
LogLevel: logger.Silent,
|
||||
}
|
||||
|
||||
store, err := NewStore(cfg)
|
||||
if err != nil {
|
||||
os.RemoveAll(tmpDir)
|
||||
t.Fatalf("NewStore failed: %v", err)
|
||||
}
|
||||
|
||||
observationStore := NewObservationStore(store, nil, nil, nil)
|
||||
|
||||
cleanup := func() {
|
||||
store.Close()
|
||||
os.RemoveAll(tmpDir)
|
||||
}
|
||||
|
||||
return observationStore, store, cleanup
|
||||
}
|
||||
|
||||
func TestObservationStore_StoreObservation(t *testing.T) {
|
||||
observationStore, store, cleanup := testObservationStore(t)
|
||||
defer cleanup()
|
||||
|
||||
ctx := context.Background()
|
||||
|
||||
// Create a session first
|
||||
sessionStore := NewSessionStore(store)
|
||||
_, err := sessionStore.CreateSDKSession(ctx, "claude-1", "test-project", "")
|
||||
require.NoError(t, err)
|
||||
|
||||
// Store an observation
|
||||
observation := &models.ParsedObservation{
|
||||
Type: models.ObsTypeDecision,
|
||||
Title: "User prefers tabs over spaces",
|
||||
Narrative: "Observed in code formatting",
|
||||
Concepts: []string{"coding-style", "preferences"},
|
||||
}
|
||||
|
||||
id, epoch, err := observationStore.StoreObservation(ctx, "claude-1", "test-project", observation, 1, 100)
|
||||
require.NoError(t, err)
|
||||
assert.Greater(t, id, int64(0))
|
||||
assert.Greater(t, epoch, int64(0))
|
||||
}
|
||||
|
||||
func TestObservationStore_StoreObservation_AutoCreateSession(t *testing.T) {
|
||||
observationStore, _, cleanup := testObservationStore(t)
|
||||
defer cleanup()
|
||||
|
||||
ctx := context.Background()
|
||||
|
||||
// Store observation without pre-creating session
|
||||
observation := &models.ParsedObservation{
|
||||
Type: models.ObsTypeDiscovery,
|
||||
Title: "Test auto-create",
|
||||
}
|
||||
|
||||
id, _, err := observationStore.StoreObservation(ctx, "claude-auto", "auto-project", observation, 1, 50)
|
||||
require.NoError(t, err)
|
||||
assert.Greater(t, id, int64(0))
|
||||
}
|
||||
|
||||
func TestObservationStore_StoreObservation_WithScope(t *testing.T) {
|
||||
observationStore, _, cleanup := testObservationStore(t)
|
||||
defer cleanup()
|
||||
|
||||
ctx := context.Background()
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
tags []string
|
||||
expectedScope models.ObservationScope
|
||||
}{
|
||||
{
|
||||
name: "Global scope - best practice",
|
||||
tags: []string{"best-practice", "testing"},
|
||||
expectedScope: models.ScopeGlobal,
|
||||
},
|
||||
{
|
||||
name: "Global scope - security",
|
||||
tags: []string{"security", "auth"},
|
||||
expectedScope: models.ScopeGlobal,
|
||||
},
|
||||
{
|
||||
name: "Project scope - specific feature",
|
||||
tags: []string{"feature", "implementation"},
|
||||
expectedScope: models.ScopeProject,
|
||||
},
|
||||
{
|
||||
name: "Project scope - no tags",
|
||||
tags: []string{},
|
||||
expectedScope: models.ScopeProject,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
observation := &models.ParsedObservation{
|
||||
Type: models.ObsTypeDiscovery,
|
||||
Title: "Test scope determination",
|
||||
Concepts: tt.tags,
|
||||
}
|
||||
|
||||
id, _, err := observationStore.StoreObservation(ctx, "claude-1", "test-project", observation, 1, 50)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Verify scope was set correctly
|
||||
observations, err := observationStore.GetObservationsByIDs(ctx, []int64{id}, "default", 10)
|
||||
require.NoError(t, err)
|
||||
require.Len(t, observations, 1)
|
||||
assert.Equal(t, tt.expectedScope, observations[0].Scope)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestObservationStore_StoreObservation_AsyncCleanup(t *testing.T) {
|
||||
observationStore, _, cleanup := testObservationStore(t)
|
||||
defer cleanup()
|
||||
|
||||
ctx := context.Background()
|
||||
|
||||
// Track cleanup calls
|
||||
var cleanupMutex sync.Mutex
|
||||
cleanupCalled := false
|
||||
var cleanupIDs []int64
|
||||
|
||||
cleanupFunc := func(ctx context.Context, deletedIDs []int64) {
|
||||
cleanupMutex.Lock()
|
||||
defer cleanupMutex.Unlock()
|
||||
cleanupCalled = true
|
||||
cleanupIDs = deletedIDs
|
||||
}
|
||||
|
||||
observationStore.cleanupFunc = cleanupFunc
|
||||
|
||||
// Store observations beyond the limit (MaxObservationsPerProject = 100)
|
||||
for i := 0; i < 105; i++ {
|
||||
observation := &models.ParsedObservation{
|
||||
Type: models.ObsTypeDiscovery,
|
||||
Title: "Observation",
|
||||
}
|
||||
_, _, err := observationStore.StoreObservation(ctx, "claude-1", "test-project", observation, i, 50)
|
||||
require.NoError(t, err)
|
||||
}
|
||||
|
||||
// Wait for async cleanup to complete
|
||||
time.Sleep(200 * time.Millisecond)
|
||||
|
||||
// Verify cleanup was called
|
||||
cleanupMutex.Lock()
|
||||
defer cleanupMutex.Unlock()
|
||||
assert.True(t, cleanupCalled, "Cleanup function should have been called")
|
||||
assert.NotEmpty(t, cleanupIDs, "Cleanup should have deleted some observations")
|
||||
}
|
||||
|
||||
func TestObservationStore_GetObservationsByIDs(t *testing.T) {
|
||||
observationStore, _, cleanup := testObservationStore(t)
|
||||
defer cleanup()
|
||||
|
||||
ctx := context.Background()
|
||||
|
||||
// Store multiple observations with different importance scores
|
||||
var ids []int64
|
||||
for i := 1; i <= 3; i++ {
|
||||
observation := &models.ParsedObservation{
|
||||
Type: models.ObsTypeDiscovery,
|
||||
Title: "Test",
|
||||
}
|
||||
id, _, err := observationStore.StoreObservation(ctx, "claude-1", "test-project", observation, i, 10)
|
||||
require.NoError(t, err)
|
||||
ids = append(ids, id)
|
||||
|
||||
// Update importance score directly
|
||||
observationStore.db.Model(&Observation{}).Where("id = ?", id).Update("importance_score", float64(i))
|
||||
time.Sleep(10 * time.Millisecond) // Ensure different timestamps
|
||||
}
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
orderBy string
|
||||
expected []int64
|
||||
}{
|
||||
{
|
||||
name: "Default ordering - importance desc",
|
||||
orderBy: "default",
|
||||
expected: []int64{ids[2], ids[1], ids[0]}, // High to low importance
|
||||
},
|
||||
{
|
||||
name: "Importance ordering",
|
||||
orderBy: "importance",
|
||||
expected: []int64{ids[2], ids[1], ids[0]},
|
||||
},
|
||||
{
|
||||
name: "Date ascending",
|
||||
orderBy: "date_asc",
|
||||
expected: []int64{ids[0], ids[1], ids[2]}, // Oldest to newest
|
||||
},
|
||||
{
|
||||
name: "Date descending",
|
||||
orderBy: "date_desc",
|
||||
expected: []int64{ids[2], ids[1], ids[0]}, // Newest to oldest
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
observations, err := observationStore.GetObservationsByIDs(ctx, ids, tt.orderBy, 10)
|
||||
require.NoError(t, err)
|
||||
require.Len(t, observations, 3)
|
||||
|
||||
// Verify ordering
|
||||
for i, obs := range observations {
|
||||
assert.Equal(t, tt.expected[i], obs.ID, "Position %d should have ID %d", i, tt.expected[i])
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestObservationStore_GetObservationsByIDs_Limit(t *testing.T) {
|
||||
observationStore, _, cleanup := testObservationStore(t)
|
||||
defer cleanup()
|
||||
|
||||
ctx := context.Background()
|
||||
|
||||
// Store multiple observations
|
||||
var ids []int64
|
||||
for i := 1; i <= 5; i++ {
|
||||
observation := &models.ParsedObservation{
|
||||
Type: models.ObsTypeDiscovery,
|
||||
Title: "Test",
|
||||
}
|
||||
id, _, err := observationStore.StoreObservation(ctx, "claude-1", "test-project", observation, i, 10)
|
||||
require.NoError(t, err)
|
||||
ids = append(ids, id)
|
||||
}
|
||||
|
||||
// Get with limit
|
||||
observations, err := observationStore.GetObservationsByIDs(ctx, ids, "default", 3)
|
||||
require.NoError(t, err)
|
||||
assert.Len(t, observations, 3)
|
||||
}
|
||||
|
||||
func TestObservationStore_GetObservationsByIDs_EmptyInput(t *testing.T) {
|
||||
observationStore, _, cleanup := testObservationStore(t)
|
||||
defer cleanup()
|
||||
|
||||
ctx := context.Background()
|
||||
|
||||
// Get with empty IDs
|
||||
observations, err := observationStore.GetObservationsByIDs(ctx, []int64{}, "default", 10)
|
||||
require.NoError(t, err)
|
||||
assert.Nil(t, observations)
|
||||
}
|
||||
|
||||
func TestObservationStore_GetRecentObservations(t *testing.T) {
|
||||
observationStore, _, cleanup := testObservationStore(t)
|
||||
defer cleanup()
|
||||
|
||||
ctx := context.Background()
|
||||
|
||||
// Store project-scoped observations
|
||||
for i := 1; i <= 3; i++ {
|
||||
observation := &models.ParsedObservation{
|
||||
Type: models.ObsTypeDiscovery,
|
||||
Title: "Project A fact",
|
||||
}
|
||||
_, _, err := observationStore.StoreObservation(ctx, "claude-1", "project-a", observation, i, 10)
|
||||
require.NoError(t, err)
|
||||
}
|
||||
|
||||
// Store global-scoped observation
|
||||
observation := &models.ParsedObservation{
|
||||
Type: models.ObsTypeDiscovery,
|
||||
Title: "Global best practice",
|
||||
Concepts: []string{"best-practice"},
|
||||
}
|
||||
_, _, err := observationStore.StoreObservation(ctx, "claude-2", "project-b", observation, 1, 10)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Store observation for different project
|
||||
observation = &models.ParsedObservation{
|
||||
Type: models.ObsTypeDiscovery,
|
||||
Title: "Project B fact",
|
||||
}
|
||||
_, _, err = observationStore.StoreObservation(ctx, "claude-2", "project-b", observation, 2, 10)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Wait for any async cleanup to complete before querying
|
||||
time.Sleep(100 * time.Millisecond)
|
||||
|
||||
// Get recent observations for project-a (should include project-a + global)
|
||||
observations, err := observationStore.GetRecentObservations(ctx, "project-a", 10)
|
||||
require.NoError(t, err)
|
||||
assert.Len(t, observations, 4) // 3 project-a + 1 global
|
||||
|
||||
// Verify scope filtering
|
||||
projectCount := 0
|
||||
globalCount := 0
|
||||
for _, obs := range observations {
|
||||
if obs.Scope == models.ScopeProject {
|
||||
assert.Equal(t, "project-a", obs.Project)
|
||||
projectCount++
|
||||
} else if obs.Scope == models.ScopeGlobal {
|
||||
globalCount++
|
||||
}
|
||||
}
|
||||
assert.Equal(t, 3, projectCount)
|
||||
assert.Equal(t, 1, globalCount)
|
||||
}
|
||||
|
||||
func TestObservationStore_GetActiveObservations(t *testing.T) {
|
||||
observationStore, _, cleanup := testObservationStore(t)
|
||||
defer cleanup()
|
||||
|
||||
ctx := context.Background()
|
||||
|
||||
// Store active observation
|
||||
activeObs := &models.ParsedObservation{
|
||||
Type: models.ObsTypeDiscovery,
|
||||
Title: "Active observation",
|
||||
}
|
||||
activeID, _, err := observationStore.StoreObservation(ctx, "claude-1", "test-project", activeObs, 1, 10)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Store superseded observation
|
||||
supersededObs := &models.ParsedObservation{
|
||||
Type: models.ObsTypeDiscovery,
|
||||
Title: "Superseded observation",
|
||||
}
|
||||
supersededID, _, err := observationStore.StoreObservation(ctx, "claude-1", "test-project", supersededObs, 2, 10)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Mark as superseded
|
||||
observationStore.db.Model(&Observation{}).Where("id = ?", supersededID).Update("is_superseded", 1)
|
||||
|
||||
// Get active observations (should exclude superseded)
|
||||
observations, err := observationStore.GetActiveObservations(ctx, "test-project", 10)
|
||||
require.NoError(t, err)
|
||||
assert.Len(t, observations, 1)
|
||||
assert.Equal(t, activeID, observations[0].ID)
|
||||
assert.False(t, observations[0].IsSuperseded)
|
||||
}
|
||||
|
||||
func TestObservationStore_GetSupersededObservations(t *testing.T) {
|
||||
observationStore, _, cleanup := testObservationStore(t)
|
||||
defer cleanup()
|
||||
|
||||
ctx := context.Background()
|
||||
|
||||
// Store active observation
|
||||
activeObs := &models.ParsedObservation{
|
||||
Type: models.ObsTypeDiscovery,
|
||||
Title: "Active observation",
|
||||
}
|
||||
_, _, err := observationStore.StoreObservation(ctx, "claude-1", "test-project", activeObs, 1, 10)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Store superseded observation
|
||||
supersededObs := &models.ParsedObservation{
|
||||
Type: models.ObsTypeDiscovery,
|
||||
Title: "Superseded observation",
|
||||
}
|
||||
supersededID, _, err := observationStore.StoreObservation(ctx, "claude-1", "test-project", supersededObs, 2, 10)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Mark as superseded
|
||||
observationStore.db.Model(&Observation{}).Where("id = ?", supersededID).Update("is_superseded", 1)
|
||||
|
||||
// Get superseded observations (should exclude active)
|
||||
observations, err := observationStore.GetSupersededObservations(ctx, "test-project", 10)
|
||||
require.NoError(t, err)
|
||||
assert.Len(t, observations, 1)
|
||||
assert.Equal(t, supersededID, observations[0].ID)
|
||||
assert.True(t, observations[0].IsSuperseded)
|
||||
}
|
||||
|
||||
func TestObservationStore_GetObservationsByProjectStrict(t *testing.T) {
|
||||
observationStore, _, cleanup := testObservationStore(t)
|
||||
defer cleanup()
|
||||
|
||||
ctx := context.Background()
|
||||
|
||||
// Store project-scoped observations
|
||||
for i := 1; i <= 2; i++ {
|
||||
observation := &models.ParsedObservation{
|
||||
Type: models.ObsTypeDiscovery,
|
||||
Title: "Project A fact",
|
||||
}
|
||||
_, _, err := observationStore.StoreObservation(ctx, "claude-1", "project-a", observation, i, 10)
|
||||
require.NoError(t, err)
|
||||
}
|
||||
|
||||
// Store global-scoped observation
|
||||
observation := &models.ParsedObservation{
|
||||
Type: models.ObsTypeDiscovery,
|
||||
Title: "Global best practice",
|
||||
Concepts: []string{"best-practice"},
|
||||
}
|
||||
_, _, err := observationStore.StoreObservation(ctx, "claude-2", "project-b", observation, 1, 10)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Get strict project observations (should exclude global)
|
||||
observations, err := observationStore.GetObservationsByProjectStrict(ctx, "project-a", 10)
|
||||
require.NoError(t, err)
|
||||
assert.Len(t, observations, 2) // Only project-a observations
|
||||
|
||||
// Verify all are project-scoped
|
||||
for _, obs := range observations {
|
||||
assert.Equal(t, models.ScopeProject, obs.Scope)
|
||||
assert.Equal(t, "project-a", obs.Project)
|
||||
}
|
||||
}
|
||||
|
||||
func TestObservationStore_SearchObservationsFTS(t *testing.T) {
|
||||
observationStore, _, cleanup := testObservationStore(t)
|
||||
defer cleanup()
|
||||
|
||||
ctx := context.Background()
|
||||
|
||||
// Store observations with searchable content
|
||||
observations := []*models.ParsedObservation{
|
||||
{
|
||||
Type: models.ObsTypeDiscovery,
|
||||
Title: "User prefers React for frontend development",
|
||||
Concepts: []string{"frontend", "react"},
|
||||
},
|
||||
{
|
||||
Type: models.ObsTypeDiscovery,
|
||||
Title: "Backend uses Go with chi router",
|
||||
Concepts: []string{"backend", "golang"},
|
||||
},
|
||||
{
|
||||
Type: models.ObsTypeDiscovery,
|
||||
Title: "Database is SQLite with FTS5",
|
||||
Concepts: []string{"database", "sqlite"},
|
||||
},
|
||||
}
|
||||
|
||||
for i, obs := range observations {
|
||||
_, _, err := observationStore.StoreObservation(ctx, "claude-1", "test-project", obs, i+1, 10)
|
||||
require.NoError(t, err)
|
||||
}
|
||||
|
||||
// Wait for FTS5 triggers to fire
|
||||
time.Sleep(200 * time.Millisecond)
|
||||
|
||||
// Search for "React frontend"
|
||||
results, err := observationStore.SearchObservationsFTS(ctx, "React frontend", "test-project", 10)
|
||||
require.NoError(t, err)
|
||||
assert.NotEmpty(t, results, "Should find observations matching 'React frontend'")
|
||||
|
||||
// Verify results contain relevant observation
|
||||
found := false
|
||||
for _, obs := range results {
|
||||
if obs.Title.String == "User prefers React for frontend development" {
|
||||
found = true
|
||||
break
|
||||
}
|
||||
}
|
||||
assert.True(t, found, "Should find the React observation")
|
||||
}
|
||||
|
||||
func TestObservationStore_CleanupOldObservations(t *testing.T) {
|
||||
observationStore, _, cleanup := testObservationStore(t)
|
||||
defer cleanup()
|
||||
|
||||
ctx := context.Background()
|
||||
|
||||
// Store observations beyond the limit WITHOUT async cleanup
|
||||
// We disable async cleanup by not setting cleanupFunc
|
||||
var allIDs []int64
|
||||
for i := 0; i < 105; i++ {
|
||||
observation := &models.ParsedObservation{
|
||||
Type: models.ObsTypeDiscovery,
|
||||
Title: "Observation",
|
||||
}
|
||||
id, _, err := observationStore.StoreObservation(ctx, "claude-1", "test-project", observation, i, 10)
|
||||
require.NoError(t, err)
|
||||
allIDs = append(allIDs, id)
|
||||
time.Sleep(2 * time.Millisecond) // Ensure different timestamps
|
||||
}
|
||||
|
||||
// Wait for any async cleanups to complete (even though cleanupFunc is nil)
|
||||
time.Sleep(200 * time.Millisecond)
|
||||
|
||||
// Verify we have 105 observations initially (async cleanup should have run but deleted items)
|
||||
initial, err := observationStore.GetRecentObservations(ctx, "test-project", 200)
|
||||
require.NoError(t, err)
|
||||
|
||||
// If async cleanup already happened, we'll have <= 100
|
||||
// Run cleanup manually to ensure cleanup logic works
|
||||
deletedIDs, err := observationStore.CleanupOldObservations(ctx, "test-project")
|
||||
require.NoError(t, err)
|
||||
|
||||
// After cleanup (manual or async), we should have at most 100
|
||||
remaining, err := observationStore.GetRecentObservations(ctx, "test-project", 200)
|
||||
require.NoError(t, err)
|
||||
assert.LessOrEqual(t, len(remaining), 100, "Should have at most 100 observations after cleanup")
|
||||
|
||||
// The number deleted should match how many were over the limit
|
||||
expectedDeleted := len(initial) - len(remaining)
|
||||
assert.Len(t, deletedIDs, expectedDeleted, "Should delete observations beyond limit")
|
||||
}
|
||||
|
||||
func TestObservationStore_DeleteObservations(t *testing.T) {
|
||||
observationStore, _, cleanup := testObservationStore(t)
|
||||
defer cleanup()
|
||||
|
||||
ctx := context.Background()
|
||||
|
||||
// Store multiple observations
|
||||
var ids []int64
|
||||
for i := 1; i <= 5; i++ {
|
||||
observation := &models.ParsedObservation{
|
||||
Type: models.ObsTypeDiscovery,
|
||||
Title: "Test",
|
||||
}
|
||||
id, _, err := observationStore.StoreObservation(ctx, "claude-1", "test-project", observation, i, 10)
|
||||
require.NoError(t, err)
|
||||
ids = append(ids, id)
|
||||
}
|
||||
|
||||
// Delete first 3 observations
|
||||
_, err := observationStore.DeleteObservations(ctx, ids[:3])
|
||||
require.NoError(t, err)
|
||||
|
||||
// Verify only 2 remain
|
||||
remaining, err := observationStore.GetRecentObservations(ctx, "test-project", 10)
|
||||
require.NoError(t, err)
|
||||
assert.Len(t, remaining, 2)
|
||||
|
||||
// Verify deleted observations are gone
|
||||
deleted, err := observationStore.GetObservationsByIDs(ctx, ids[:3], "default", 10)
|
||||
require.NoError(t, err)
|
||||
assert.Empty(t, deleted)
|
||||
}
|
||||
|
||||
// Note: TestObservationStore_MarkObservationsSuperseded is omitted because
|
||||
// MarkObservationsSuperseded is a ConflictStore method (Phase 4), not ObservationStore
|
||||
|
||||
func TestObservationStore_GetAllObservations(t *testing.T) {
|
||||
observationStore, _, cleanup := testObservationStore(t)
|
||||
defer cleanup()
|
||||
|
||||
ctx := context.Background()
|
||||
|
||||
// Store observations across projects
|
||||
_, _, err := observationStore.StoreObservation(ctx, "claude-1", "project-a", &models.ParsedObservation{Type: models.ObsTypeDiscovery, Title: "A1"}, 1, 10)
|
||||
require.NoError(t, err)
|
||||
|
||||
_, _, err = observationStore.StoreObservation(ctx, "claude-2", "project-b", &models.ParsedObservation{Type: models.ObsTypeDiscovery, Title: "B1"}, 1, 10)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Get all observations (for vector rebuild)
|
||||
all, err := observationStore.GetAllObservations(ctx)
|
||||
require.NoError(t, err)
|
||||
assert.Len(t, all, 2)
|
||||
|
||||
// Verify ordering by ID
|
||||
assert.Less(t, all[0].ID, all[1].ID)
|
||||
}
|
||||
@@ -1,32 +1,30 @@
|
||||
// Package sqlite provides SQLite database operations for claude-mnemonic.
|
||||
package sqlite
|
||||
// Package gorm provides GORM-based database operations for claude-mnemonic.
|
||||
package gorm
|
||||
|
||||
import (
|
||||
"context"
|
||||
"database/sql"
|
||||
"encoding/json"
|
||||
"time"
|
||||
|
||||
"gorm.io/gorm"
|
||||
|
||||
"github.com/lukaszraczylo/claude-mnemonic/pkg/models"
|
||||
)
|
||||
|
||||
// patternColumns is the standard list of columns to select for patterns.
|
||||
const patternColumns = `id, name, type, description, signature, recommendation,
|
||||
frequency, projects, observation_ids, status, merged_into_id, confidence,
|
||||
last_seen_at, last_seen_at_epoch, created_at, created_at_epoch`
|
||||
|
||||
// PatternCleanupFunc is a callback for when patterns are deleted.
|
||||
type PatternCleanupFunc func(ctx context.Context, deletedIDs []int64)
|
||||
|
||||
// PatternStore provides pattern-related database operations.
|
||||
// PatternStore provides pattern-related database operations using GORM.
|
||||
type PatternStore struct {
|
||||
store *Store
|
||||
db *gorm.DB
|
||||
cleanupFunc PatternCleanupFunc
|
||||
}
|
||||
|
||||
// NewPatternStore creates a new pattern store.
|
||||
func NewPatternStore(store *Store) *PatternStore {
|
||||
return &PatternStore{store: store}
|
||||
return &PatternStore{
|
||||
db: store.DB,
|
||||
}
|
||||
}
|
||||
|
||||
// SetCleanupFunc sets the callback for when patterns are deleted.
|
||||
@@ -36,145 +34,187 @@ func (s *PatternStore) SetCleanupFunc(fn PatternCleanupFunc) {
|
||||
|
||||
// StorePattern stores a new pattern.
|
||||
func (s *PatternStore) StorePattern(ctx context.Context, pattern *models.Pattern) (int64, error) {
|
||||
signatureJSON, _ := json.Marshal(pattern.Signature)
|
||||
projectsJSON, _ := json.Marshal(pattern.Projects)
|
||||
obsIDsJSON, _ := json.Marshal(pattern.ObservationIDs)
|
||||
|
||||
const query = `
|
||||
INSERT INTO patterns
|
||||
(name, type, description, signature, recommendation, frequency, projects,
|
||||
observation_ids, status, merged_into_id, confidence,
|
||||
last_seen_at, last_seen_at_epoch, created_at, created_at_epoch)
|
||||
VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)
|
||||
`
|
||||
|
||||
result, err := s.store.ExecContext(ctx, query,
|
||||
pattern.Name, string(pattern.Type),
|
||||
nullString(pattern.Description.String), string(signatureJSON),
|
||||
nullString(pattern.Recommendation.String),
|
||||
pattern.Frequency, string(projectsJSON), string(obsIDsJSON),
|
||||
string(pattern.Status), nullInt64(pattern.MergedIntoID),
|
||||
pattern.Confidence, pattern.LastSeenAt, pattern.LastSeenEpoch,
|
||||
pattern.CreatedAt, pattern.CreatedAtEpoch,
|
||||
)
|
||||
if err != nil {
|
||||
return 0, err
|
||||
dbPattern := &Pattern{
|
||||
Name: pattern.Name,
|
||||
Type: pattern.Type,
|
||||
Signature: pattern.Signature,
|
||||
Frequency: pattern.Frequency,
|
||||
Projects: pattern.Projects,
|
||||
ObservationIDs: pattern.ObservationIDs,
|
||||
Status: pattern.Status,
|
||||
Confidence: pattern.Confidence,
|
||||
LastSeenAt: pattern.LastSeenAt,
|
||||
LastSeenAtEpoch: pattern.LastSeenEpoch,
|
||||
CreatedAt: pattern.CreatedAt,
|
||||
CreatedAtEpoch: pattern.CreatedAtEpoch,
|
||||
}
|
||||
|
||||
return result.LastInsertId()
|
||||
if pattern.Description.Valid {
|
||||
dbPattern.Description = sql.NullString{String: pattern.Description.String, Valid: true}
|
||||
}
|
||||
|
||||
if pattern.Recommendation.Valid {
|
||||
dbPattern.Recommendation = sql.NullString{String: pattern.Recommendation.String, Valid: true}
|
||||
}
|
||||
|
||||
if pattern.MergedIntoID.Valid {
|
||||
dbPattern.MergedIntoID = sql.NullInt64{Int64: pattern.MergedIntoID.Int64, Valid: true}
|
||||
}
|
||||
|
||||
result := s.db.WithContext(ctx).Create(dbPattern)
|
||||
if result.Error != nil {
|
||||
return 0, result.Error
|
||||
}
|
||||
|
||||
return dbPattern.ID, nil
|
||||
}
|
||||
|
||||
// UpdatePattern updates an existing pattern.
|
||||
func (s *PatternStore) UpdatePattern(ctx context.Context, pattern *models.Pattern) error {
|
||||
signatureJSON, _ := json.Marshal(pattern.Signature)
|
||||
projectsJSON, _ := json.Marshal(pattern.Projects)
|
||||
obsIDsJSON, _ := json.Marshal(pattern.ObservationIDs)
|
||||
updates := map[string]interface{}{
|
||||
"name": pattern.Name,
|
||||
"type": pattern.Type,
|
||||
"signature": pattern.Signature,
|
||||
"frequency": pattern.Frequency,
|
||||
"projects": pattern.Projects,
|
||||
"observation_ids": pattern.ObservationIDs,
|
||||
"status": pattern.Status,
|
||||
"confidence": pattern.Confidence,
|
||||
"last_seen_at": pattern.LastSeenAt,
|
||||
"last_seen_at_epoch": pattern.LastSeenEpoch,
|
||||
}
|
||||
|
||||
const query = `
|
||||
UPDATE patterns SET
|
||||
name = ?, type = ?, description = ?, signature = ?, recommendation = ?,
|
||||
frequency = ?, projects = ?, observation_ids = ?, status = ?,
|
||||
merged_into_id = ?, confidence = ?, last_seen_at = ?, last_seen_at_epoch = ?
|
||||
WHERE id = ?
|
||||
`
|
||||
if pattern.Description.Valid {
|
||||
updates["description"] = pattern.Description.String
|
||||
} else {
|
||||
updates["description"] = nil
|
||||
}
|
||||
|
||||
_, err := s.store.ExecContext(ctx, query,
|
||||
pattern.Name, string(pattern.Type),
|
||||
nullString(pattern.Description.String), string(signatureJSON),
|
||||
nullString(pattern.Recommendation.String),
|
||||
pattern.Frequency, string(projectsJSON), string(obsIDsJSON),
|
||||
string(pattern.Status), nullInt64(pattern.MergedIntoID),
|
||||
pattern.Confidence, pattern.LastSeenAt, pattern.LastSeenEpoch,
|
||||
pattern.ID,
|
||||
)
|
||||
return err
|
||||
if pattern.Recommendation.Valid {
|
||||
updates["recommendation"] = pattern.Recommendation.String
|
||||
} else {
|
||||
updates["recommendation"] = nil
|
||||
}
|
||||
|
||||
if pattern.MergedIntoID.Valid {
|
||||
updates["merged_into_id"] = pattern.MergedIntoID.Int64
|
||||
} else {
|
||||
updates["merged_into_id"] = nil
|
||||
}
|
||||
|
||||
result := s.db.WithContext(ctx).
|
||||
Model(&Pattern{}).
|
||||
Where("id = ?", pattern.ID).
|
||||
Updates(updates)
|
||||
|
||||
return result.Error
|
||||
}
|
||||
|
||||
// GetPatternByID retrieves a pattern by ID.
|
||||
func (s *PatternStore) GetPatternByID(ctx context.Context, id int64) (*models.Pattern, error) {
|
||||
query := `SELECT ` + patternColumns + ` FROM patterns WHERE id = ?`
|
||||
var dbPattern Pattern
|
||||
|
||||
row := s.store.QueryRowContext(ctx, query, id)
|
||||
return scanPattern(row)
|
||||
err := s.db.WithContext(ctx).First(&dbPattern, id).Error
|
||||
if err == gorm.ErrRecordNotFound {
|
||||
return nil, nil
|
||||
}
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return toModelPattern(&dbPattern), nil
|
||||
}
|
||||
|
||||
// GetPatternByName retrieves a pattern by name.
|
||||
func (s *PatternStore) GetPatternByName(ctx context.Context, name string) (*models.Pattern, error) {
|
||||
query := `SELECT ` + patternColumns + ` FROM patterns WHERE name = ? AND status = 'active'`
|
||||
var dbPattern Pattern
|
||||
|
||||
row := s.store.QueryRowContext(ctx, query, name)
|
||||
pattern, err := scanPattern(row)
|
||||
if err == sql.ErrNoRows {
|
||||
err := s.db.WithContext(ctx).
|
||||
Where("name = ? AND status = ?", name, models.PatternStatusActive).
|
||||
First(&dbPattern).Error
|
||||
|
||||
if err == gorm.ErrRecordNotFound {
|
||||
return nil, nil
|
||||
}
|
||||
return pattern, err
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return toModelPattern(&dbPattern), nil
|
||||
}
|
||||
|
||||
// GetActivePatterns retrieves all active patterns.
|
||||
func (s *PatternStore) GetActivePatterns(ctx context.Context, limit int) ([]*models.Pattern, error) {
|
||||
query := `SELECT ` + patternColumns + `
|
||||
FROM patterns
|
||||
WHERE status = 'active'
|
||||
ORDER BY frequency DESC, confidence DESC
|
||||
LIMIT ?`
|
||||
var patterns []Pattern
|
||||
|
||||
err := s.db.WithContext(ctx).
|
||||
Where("status = ?", models.PatternStatusActive).
|
||||
Order("frequency DESC, confidence DESC").
|
||||
Limit(limit).
|
||||
Find(&patterns).Error
|
||||
|
||||
rows, err := s.store.QueryContext(ctx, query, limit)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer rows.Close()
|
||||
|
||||
return scanPatternRows(rows)
|
||||
return toModelPatterns(patterns), nil
|
||||
}
|
||||
|
||||
// GetPatternsByType retrieves patterns of a specific type.
|
||||
func (s *PatternStore) GetPatternsByType(ctx context.Context, patternType models.PatternType, limit int) ([]*models.Pattern, error) {
|
||||
query := `SELECT ` + patternColumns + `
|
||||
FROM patterns
|
||||
WHERE type = ? AND status = 'active'
|
||||
ORDER BY frequency DESC, confidence DESC
|
||||
LIMIT ?`
|
||||
var patterns []Pattern
|
||||
|
||||
err := s.db.WithContext(ctx).
|
||||
Where("type = ? AND status = ?", patternType, models.PatternStatusActive).
|
||||
Order("frequency DESC, confidence DESC").
|
||||
Limit(limit).
|
||||
Find(&patterns).Error
|
||||
|
||||
rows, err := s.store.QueryContext(ctx, query, string(patternType), limit)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer rows.Close()
|
||||
|
||||
return scanPatternRows(rows)
|
||||
return toModelPatterns(patterns), nil
|
||||
}
|
||||
|
||||
// GetPatternsByProject retrieves patterns that have been observed in a specific project.
|
||||
// Uses raw SQL since JSON_EACH is complex in GORM.
|
||||
func (s *PatternStore) GetPatternsByProject(ctx context.Context, project string, limit int) ([]*models.Pattern, error) {
|
||||
// Use JSON path to search within the projects array
|
||||
query := `SELECT ` + patternColumns + `
|
||||
FROM patterns
|
||||
var patterns []Pattern
|
||||
|
||||
// Use raw SQL for JSON_EACH query
|
||||
query := `
|
||||
SELECT * FROM patterns
|
||||
WHERE status = 'active'
|
||||
AND EXISTS (
|
||||
SELECT 1 FROM json_each(projects)
|
||||
WHERE json_each.value = ?
|
||||
)
|
||||
ORDER BY frequency DESC, confidence DESC
|
||||
LIMIT ?`
|
||||
LIMIT ?
|
||||
`
|
||||
|
||||
err := s.db.WithContext(ctx).
|
||||
Raw(query, project, limit).
|
||||
Scan(&patterns).Error
|
||||
|
||||
rows, err := s.store.QueryContext(ctx, query, project, limit)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer rows.Close()
|
||||
|
||||
return scanPatternRows(rows)
|
||||
return toModelPatterns(patterns), nil
|
||||
}
|
||||
|
||||
// FindMatchingPatterns searches for patterns that match a given signature.
|
||||
// Pattern matching is done in Go code for simplicity.
|
||||
func (s *PatternStore) FindMatchingPatterns(ctx context.Context, signature []string, minScore float64) ([]*models.Pattern, error) {
|
||||
// Get all active patterns and filter by signature match in Go
|
||||
// This is simpler than complex SQL for JSON array matching
|
||||
// Get all active patterns
|
||||
patterns, err := s.GetActivePatterns(ctx, 100)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// Filter by signature match in Go
|
||||
var matches []*models.Pattern
|
||||
for _, pattern := range patterns {
|
||||
score := models.CalculateMatchScore(signature, pattern.Signature)
|
||||
@@ -182,14 +222,18 @@ func (s *PatternStore) FindMatchingPatterns(ctx context.Context, signature []str
|
||||
matches = append(matches, pattern)
|
||||
}
|
||||
}
|
||||
|
||||
return matches, nil
|
||||
}
|
||||
|
||||
// MarkPatternDeprecated marks a pattern as deprecated.
|
||||
func (s *PatternStore) MarkPatternDeprecated(ctx context.Context, id int64) error {
|
||||
const query = `UPDATE patterns SET status = 'deprecated' WHERE id = ?`
|
||||
_, err := s.store.ExecContext(ctx, query, id)
|
||||
return err
|
||||
result := s.db.WithContext(ctx).
|
||||
Model(&Pattern{}).
|
||||
Where("id = ?", id).
|
||||
Update("status", models.PatternStatusDeprecated)
|
||||
|
||||
return result.Error
|
||||
}
|
||||
|
||||
// MergePatterns merges a source pattern into a target pattern.
|
||||
@@ -206,6 +250,8 @@ func (s *PatternStore) MergePatterns(ctx context.Context, sourceID, targetID int
|
||||
|
||||
// Merge source into target
|
||||
target.Frequency += source.Frequency
|
||||
|
||||
// Merge projects (deduplicate)
|
||||
for _, proj := range source.Projects {
|
||||
found := false
|
||||
for _, existing := range target.Projects {
|
||||
@@ -218,6 +264,8 @@ func (s *PatternStore) MergePatterns(ctx context.Context, sourceID, targetID int
|
||||
target.Projects = append(target.Projects, proj)
|
||||
}
|
||||
}
|
||||
|
||||
// Merge observation IDs (deduplicate)
|
||||
for _, obsID := range source.ObservationIDs {
|
||||
found := false
|
||||
for _, existing := range target.ObservationIDs {
|
||||
@@ -244,59 +292,40 @@ func (s *PatternStore) MergePatterns(ctx context.Context, sourceID, targetID int
|
||||
|
||||
// DeletePattern deletes a pattern by ID.
|
||||
func (s *PatternStore) DeletePattern(ctx context.Context, id int64) error {
|
||||
const query = `DELETE FROM patterns WHERE id = ?`
|
||||
_, err := s.store.ExecContext(ctx, query, id)
|
||||
if err == nil && s.cleanupFunc != nil {
|
||||
result := s.db.WithContext(ctx).Delete(&Pattern{}, id)
|
||||
|
||||
if result.Error == nil && s.cleanupFunc != nil {
|
||||
s.cleanupFunc(ctx, []int64{id})
|
||||
}
|
||||
return err
|
||||
|
||||
return result.Error
|
||||
}
|
||||
|
||||
// SearchPatternsFTS performs full-text search on patterns.
|
||||
// Uses raw SQL for FTS5 query.
|
||||
func (s *PatternStore) SearchPatternsFTS(ctx context.Context, searchQuery string, limit int) ([]*models.Pattern, error) {
|
||||
query := `SELECT p.` + patternColumns + `
|
||||
var patterns []Pattern
|
||||
|
||||
// Use raw SQL for FTS5 MATCH query
|
||||
query := `
|
||||
SELECT p.*
|
||||
FROM patterns p
|
||||
JOIN patterns_fts fts ON p.id = fts.rowid
|
||||
WHERE patterns_fts MATCH ?
|
||||
AND p.status = 'active'
|
||||
ORDER BY rank
|
||||
LIMIT ?`
|
||||
LIMIT ?
|
||||
`
|
||||
|
||||
err := s.db.WithContext(ctx).
|
||||
Raw(query, searchQuery, limit).
|
||||
Scan(&patterns).Error
|
||||
|
||||
rows, err := s.store.QueryContext(ctx, query, searchQuery, limit)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer rows.Close()
|
||||
|
||||
return scanPatternRows(rows)
|
||||
}
|
||||
|
||||
// GetPatternStats returns statistics about patterns.
|
||||
func (s *PatternStore) GetPatternStats(ctx context.Context) (*PatternStats, error) {
|
||||
const query = `
|
||||
SELECT
|
||||
COUNT(*) as total,
|
||||
COUNT(CASE WHEN status = 'active' THEN 1 END) as active,
|
||||
COUNT(CASE WHEN status = 'deprecated' THEN 1 END) as deprecated,
|
||||
COUNT(CASE WHEN status = 'merged' THEN 1 END) as merged,
|
||||
COALESCE(SUM(frequency), 0) as total_occurrences,
|
||||
COALESCE(AVG(confidence), 0) as avg_confidence,
|
||||
COUNT(CASE WHEN type = 'bug' THEN 1 END) as bugs,
|
||||
COUNT(CASE WHEN type = 'refactor' THEN 1 END) as refactors,
|
||||
COUNT(CASE WHEN type = 'architecture' THEN 1 END) as architectures,
|
||||
COUNT(CASE WHEN type = 'anti-pattern' THEN 1 END) as anti_patterns,
|
||||
COUNT(CASE WHEN type = 'best-practice' THEN 1 END) as best_practices
|
||||
FROM patterns
|
||||
`
|
||||
|
||||
var stats PatternStats
|
||||
err := s.store.QueryRowContext(ctx, query).Scan(
|
||||
&stats.Total, &stats.Active, &stats.Deprecated, &stats.Merged,
|
||||
&stats.TotalOccurrences, &stats.AvgConfidence,
|
||||
&stats.Bugs, &stats.Refactors, &stats.Architectures,
|
||||
&stats.AntiPatterns, &stats.BestPractices,
|
||||
)
|
||||
return &stats, err
|
||||
return toModelPatterns(patterns), nil
|
||||
}
|
||||
|
||||
// PatternStats contains aggregate statistics about patterns.
|
||||
@@ -314,41 +343,29 @@ type PatternStats struct {
|
||||
BestPractices int `json:"best_practices"`
|
||||
}
|
||||
|
||||
// scanPattern scans a single pattern from a row scanner.
|
||||
func scanPattern(scanner interface{ Scan(...interface{}) error }) (*models.Pattern, error) {
|
||||
var pattern models.Pattern
|
||||
if err := scanner.Scan(
|
||||
&pattern.ID, &pattern.Name, &pattern.Type,
|
||||
&pattern.Description, &pattern.Signature, &pattern.Recommendation,
|
||||
&pattern.Frequency, &pattern.Projects, &pattern.ObservationIDs,
|
||||
&pattern.Status, &pattern.MergedIntoID, &pattern.Confidence,
|
||||
&pattern.LastSeenAt, &pattern.LastSeenEpoch,
|
||||
&pattern.CreatedAt, &pattern.CreatedAtEpoch,
|
||||
); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return &pattern, nil
|
||||
}
|
||||
// GetPatternStats returns statistics about patterns.
|
||||
// Uses raw SQL for complex aggregate query.
|
||||
func (s *PatternStore) GetPatternStats(ctx context.Context) (*PatternStats, error) {
|
||||
var stats PatternStats
|
||||
|
||||
// scanPatternRows scans multiple patterns from rows.
|
||||
func scanPatternRows(rows *sql.Rows) ([]*models.Pattern, error) {
|
||||
var patterns []*models.Pattern
|
||||
for rows.Next() {
|
||||
pattern, err := scanPattern(rows)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
patterns = append(patterns, pattern)
|
||||
}
|
||||
return patterns, rows.Err()
|
||||
}
|
||||
query := `
|
||||
SELECT
|
||||
COUNT(*) as total,
|
||||
COUNT(CASE WHEN status = 'active' THEN 1 END) as active,
|
||||
COUNT(CASE WHEN status = 'deprecated' THEN 1 END) as deprecated,
|
||||
COUNT(CASE WHEN status = 'merged' THEN 1 END) as merged,
|
||||
COALESCE(SUM(frequency), 0) as total_occurrences,
|
||||
COALESCE(AVG(confidence), 0) as avg_confidence,
|
||||
COUNT(CASE WHEN type = 'bug' THEN 1 END) as bugs,
|
||||
COUNT(CASE WHEN type = 'refactor' THEN 1 END) as refactors,
|
||||
COUNT(CASE WHEN type = 'architecture' THEN 1 END) as architectures,
|
||||
COUNT(CASE WHEN type = 'anti-pattern' THEN 1 END) as anti_patterns,
|
||||
COUNT(CASE WHEN type = 'best-practice' THEN 1 END) as best_practices
|
||||
FROM patterns
|
||||
`
|
||||
|
||||
// nullInt64 converts sql.NullInt64 to the value needed for database insertion.
|
||||
func nullInt64(n sql.NullInt64) interface{} {
|
||||
if n.Valid {
|
||||
return n.Int64
|
||||
}
|
||||
return nil
|
||||
err := s.db.WithContext(ctx).Raw(query).Scan(&stats).Error
|
||||
return &stats, err
|
||||
}
|
||||
|
||||
// IncrementPatternFrequency atomically increments a pattern's frequency and updates last_seen.
|
||||
@@ -368,3 +385,36 @@ func (s *PatternStore) IncrementPatternFrequency(ctx context.Context, id int64,
|
||||
|
||||
return s.UpdatePattern(ctx, pattern)
|
||||
}
|
||||
|
||||
// toModelPattern converts a GORM Pattern to a pkg/models Pattern.
|
||||
func toModelPattern(p *Pattern) *models.Pattern {
|
||||
pattern := &models.Pattern{
|
||||
ID: p.ID,
|
||||
Name: p.Name,
|
||||
Type: p.Type,
|
||||
Description: p.Description,
|
||||
Signature: p.Signature,
|
||||
Recommendation: p.Recommendation,
|
||||
Frequency: p.Frequency,
|
||||
Projects: p.Projects,
|
||||
ObservationIDs: p.ObservationIDs,
|
||||
Status: p.Status,
|
||||
MergedIntoID: p.MergedIntoID,
|
||||
Confidence: p.Confidence,
|
||||
LastSeenAt: p.LastSeenAt,
|
||||
LastSeenEpoch: p.LastSeenAtEpoch,
|
||||
CreatedAt: p.CreatedAt,
|
||||
CreatedAtEpoch: p.CreatedAtEpoch,
|
||||
}
|
||||
|
||||
return pattern
|
||||
}
|
||||
|
||||
// toModelPatterns converts a slice of GORM Patterns to pkg/models Patterns.
|
||||
func toModelPatterns(patterns []Pattern) []*models.Pattern {
|
||||
result := make([]*models.Pattern, len(patterns))
|
||||
for i, p := range patterns {
|
||||
result[i] = toModelPattern(&p)
|
||||
}
|
||||
return result
|
||||
}
|
||||
@@ -0,0 +1,485 @@
|
||||
//go:build fts5
|
||||
|
||||
// Package gorm provides GORM-based database operations for claude-mnemonic.
|
||||
package gorm
|
||||
|
||||
import (
|
||||
"context"
|
||||
"database/sql"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
"gorm.io/gorm/logger"
|
||||
|
||||
"github.com/lukaszraczylo/claude-mnemonic/pkg/models"
|
||||
)
|
||||
|
||||
func testPatternStore(t *testing.T) (*PatternStore, *Store, func()) {
|
||||
t.Helper()
|
||||
|
||||
tmpDir, err := os.MkdirTemp("", "gorm_pattern_test_*")
|
||||
if err != nil {
|
||||
t.Fatalf("create temp dir: %v", err)
|
||||
}
|
||||
|
||||
dbPath := filepath.Join(tmpDir, "test.db")
|
||||
cfg := Config{
|
||||
Path: dbPath,
|
||||
MaxConns: 4,
|
||||
LogLevel: logger.Silent,
|
||||
}
|
||||
|
||||
store, err := NewStore(cfg)
|
||||
if err != nil {
|
||||
os.RemoveAll(tmpDir)
|
||||
t.Fatalf("NewStore failed: %v", err)
|
||||
}
|
||||
|
||||
patternStore := NewPatternStore(store)
|
||||
|
||||
cleanup := func() {
|
||||
store.Close()
|
||||
os.RemoveAll(tmpDir)
|
||||
}
|
||||
|
||||
return patternStore, store, cleanup
|
||||
}
|
||||
|
||||
func TestPatternStore_StorePattern(t *testing.T) {
|
||||
patternStore, _, cleanup := testPatternStore(t)
|
||||
defer cleanup()
|
||||
ctx := context.Background()
|
||||
|
||||
now := time.Now()
|
||||
pattern := &models.Pattern{
|
||||
Name: "Test Pattern",
|
||||
Type: models.PatternTypeBug,
|
||||
Description: sql.NullString{String: "Test description", Valid: true},
|
||||
Signature: []string{"bug", "error"},
|
||||
Recommendation: sql.NullString{String: "Fix it", Valid: true},
|
||||
Frequency: 1,
|
||||
Projects: []string{"test-project"},
|
||||
ObservationIDs: []int64{1, 2, 3},
|
||||
Status: models.PatternStatusActive,
|
||||
Confidence: 0.8,
|
||||
LastSeenAt: now.Format(time.RFC3339),
|
||||
LastSeenEpoch: now.UnixMilli(),
|
||||
CreatedAt: now.Format(time.RFC3339),
|
||||
CreatedAtEpoch: now.UnixMilli(),
|
||||
}
|
||||
|
||||
id, err := patternStore.StorePattern(ctx, pattern)
|
||||
require.NoError(t, err)
|
||||
assert.Greater(t, id, int64(0))
|
||||
|
||||
// Verify pattern was stored
|
||||
retrieved, err := patternStore.GetPatternByID(ctx, id)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, pattern.Name, retrieved.Name)
|
||||
assert.Equal(t, pattern.Type, retrieved.Type)
|
||||
assert.Equal(t, pattern.Signature, retrieved.Signature)
|
||||
assert.Equal(t, pattern.Frequency, retrieved.Frequency)
|
||||
assert.Equal(t, pattern.Status, retrieved.Status)
|
||||
assert.Equal(t, pattern.Confidence, retrieved.Confidence)
|
||||
}
|
||||
|
||||
func TestPatternStore_UpdatePattern(t *testing.T) {
|
||||
patternStore, _, cleanup := testPatternStore(t)
|
||||
defer cleanup()
|
||||
ctx := context.Background()
|
||||
|
||||
now := time.Now()
|
||||
pattern := &models.Pattern{
|
||||
Name: "Original",
|
||||
Type: models.PatternTypeBug,
|
||||
Frequency: 1,
|
||||
Status: models.PatternStatusActive,
|
||||
Confidence: 0.5,
|
||||
LastSeenAt: now.Format(time.RFC3339),
|
||||
LastSeenEpoch: now.UnixMilli(),
|
||||
CreatedAt: now.Format(time.RFC3339),
|
||||
CreatedAtEpoch: now.UnixMilli(),
|
||||
}
|
||||
|
||||
id, err := patternStore.StorePattern(ctx, pattern)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Update pattern
|
||||
pattern.ID = id
|
||||
pattern.Name = "Updated"
|
||||
pattern.Frequency = 5
|
||||
pattern.Confidence = 0.9
|
||||
|
||||
err = patternStore.UpdatePattern(ctx, pattern)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Verify update
|
||||
retrieved, err := patternStore.GetPatternByID(ctx, id)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, "Updated", retrieved.Name)
|
||||
assert.Equal(t, 5, retrieved.Frequency)
|
||||
assert.Equal(t, 0.9, retrieved.Confidence)
|
||||
}
|
||||
|
||||
func TestPatternStore_GetPatternByName(t *testing.T) {
|
||||
patternStore, _, cleanup := testPatternStore(t)
|
||||
defer cleanup()
|
||||
ctx := context.Background()
|
||||
|
||||
now := time.Now()
|
||||
pattern := &models.Pattern{
|
||||
Name: "Unique Pattern",
|
||||
Type: models.PatternTypeRefactor,
|
||||
Frequency: 1,
|
||||
Status: models.PatternStatusActive,
|
||||
Confidence: 0.7,
|
||||
LastSeenAt: now.Format(time.RFC3339),
|
||||
LastSeenEpoch: now.UnixMilli(),
|
||||
CreatedAt: now.Format(time.RFC3339),
|
||||
CreatedAtEpoch: now.UnixMilli(),
|
||||
}
|
||||
|
||||
_, err := patternStore.StorePattern(ctx, pattern)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Retrieve by name
|
||||
retrieved, err := patternStore.GetPatternByName(ctx, "Unique Pattern")
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, retrieved)
|
||||
assert.Equal(t, "Unique Pattern", retrieved.Name)
|
||||
assert.Equal(t, models.PatternTypeRefactor, retrieved.Type)
|
||||
|
||||
// Non-existent pattern
|
||||
notFound, err := patternStore.GetPatternByName(ctx, "Nonexistent")
|
||||
require.NoError(t, err)
|
||||
assert.Nil(t, notFound)
|
||||
}
|
||||
|
||||
func TestPatternStore_GetActivePatterns(t *testing.T) {
|
||||
patternStore, _, cleanup := testPatternStore(t)
|
||||
defer cleanup()
|
||||
ctx := context.Background()
|
||||
|
||||
now := time.Now()
|
||||
|
||||
// Create active patterns
|
||||
for i := 0; i < 3; i++ {
|
||||
pattern := &models.Pattern{
|
||||
Name: "Pattern " + string(rune('A'+i)),
|
||||
Type: models.PatternTypeBug,
|
||||
Frequency: i + 1, // Different frequencies for sorting
|
||||
Status: models.PatternStatusActive,
|
||||
Confidence: 0.8,
|
||||
LastSeenAt: now.Format(time.RFC3339),
|
||||
LastSeenEpoch: now.UnixMilli(),
|
||||
CreatedAt: now.Format(time.RFC3339),
|
||||
CreatedAtEpoch: now.UnixMilli(),
|
||||
}
|
||||
_, err := patternStore.StorePattern(ctx, pattern)
|
||||
require.NoError(t, err)
|
||||
}
|
||||
|
||||
// Create deprecated pattern (should not be included)
|
||||
deprecatedPattern := &models.Pattern{
|
||||
Name: "Deprecated Pattern",
|
||||
Type: models.PatternTypeBug,
|
||||
Frequency: 100,
|
||||
Status: models.PatternStatusDeprecated,
|
||||
Confidence: 0.9,
|
||||
LastSeenAt: now.Format(time.RFC3339),
|
||||
LastSeenEpoch: now.UnixMilli(),
|
||||
CreatedAt: now.Format(time.RFC3339),
|
||||
CreatedAtEpoch: now.UnixMilli(),
|
||||
}
|
||||
_, err := patternStore.StorePattern(ctx, deprecatedPattern)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Get active patterns
|
||||
patterns, err := patternStore.GetActivePatterns(ctx, 10)
|
||||
require.NoError(t, err)
|
||||
assert.Len(t, patterns, 3) // Only active patterns
|
||||
|
||||
// Verify sorted by frequency DESC
|
||||
assert.Equal(t, 3, patterns[0].Frequency)
|
||||
assert.Equal(t, 2, patterns[1].Frequency)
|
||||
assert.Equal(t, 1, patterns[2].Frequency)
|
||||
}
|
||||
|
||||
func TestPatternStore_GetPatternsByType(t *testing.T) {
|
||||
patternStore, _, cleanup := testPatternStore(t)
|
||||
defer cleanup()
|
||||
ctx := context.Background()
|
||||
|
||||
now := time.Now()
|
||||
|
||||
// Create patterns of different types
|
||||
bugPattern := &models.Pattern{
|
||||
Name: "Bug Pattern",
|
||||
Type: models.PatternTypeBug,
|
||||
Frequency: 1,
|
||||
Status: models.PatternStatusActive,
|
||||
Confidence: 0.8,
|
||||
LastSeenAt: now.Format(time.RFC3339),
|
||||
LastSeenEpoch: now.UnixMilli(),
|
||||
CreatedAt: now.Format(time.RFC3339),
|
||||
CreatedAtEpoch: now.UnixMilli(),
|
||||
}
|
||||
_, err := patternStore.StorePattern(ctx, bugPattern)
|
||||
require.NoError(t, err)
|
||||
|
||||
refactorPattern := &models.Pattern{
|
||||
Name: "Refactor Pattern",
|
||||
Type: models.PatternTypeRefactor,
|
||||
Frequency: 1,
|
||||
Status: models.PatternStatusActive,
|
||||
Confidence: 0.7,
|
||||
LastSeenAt: now.Format(time.RFC3339),
|
||||
LastSeenEpoch: now.UnixMilli(),
|
||||
CreatedAt: now.Format(time.RFC3339),
|
||||
CreatedAtEpoch: now.UnixMilli(),
|
||||
}
|
||||
_, err = patternStore.StorePattern(ctx, refactorPattern)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Get only bug patterns
|
||||
bugPatterns, err := patternStore.GetPatternsByType(ctx, models.PatternTypeBug, 10)
|
||||
require.NoError(t, err)
|
||||
assert.Len(t, bugPatterns, 1)
|
||||
assert.Equal(t, "Bug Pattern", bugPatterns[0].Name)
|
||||
assert.Equal(t, models.PatternTypeBug, bugPatterns[0].Type)
|
||||
|
||||
// Get only refactor patterns
|
||||
refactorPatterns, err := patternStore.GetPatternsByType(ctx, models.PatternTypeRefactor, 10)
|
||||
require.NoError(t, err)
|
||||
assert.Len(t, refactorPatterns, 1)
|
||||
assert.Equal(t, "Refactor Pattern", refactorPatterns[0].Name)
|
||||
}
|
||||
|
||||
func TestPatternStore_MarkPatternDeprecated(t *testing.T) {
|
||||
patternStore, _, cleanup := testPatternStore(t)
|
||||
defer cleanup()
|
||||
ctx := context.Background()
|
||||
|
||||
now := time.Now()
|
||||
pattern := &models.Pattern{
|
||||
Name: "To Deprecate",
|
||||
Type: models.PatternTypeBug,
|
||||
Frequency: 1,
|
||||
Status: models.PatternStatusActive,
|
||||
Confidence: 0.5,
|
||||
LastSeenAt: now.Format(time.RFC3339),
|
||||
LastSeenEpoch: now.UnixMilli(),
|
||||
CreatedAt: now.Format(time.RFC3339),
|
||||
CreatedAtEpoch: now.UnixMilli(),
|
||||
}
|
||||
|
||||
id, err := patternStore.StorePattern(ctx, pattern)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Mark as deprecated
|
||||
err = patternStore.MarkPatternDeprecated(ctx, id)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Verify status changed
|
||||
retrieved, err := patternStore.GetPatternByID(ctx, id)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, models.PatternStatusDeprecated, retrieved.Status)
|
||||
}
|
||||
|
||||
func TestPatternStore_MergePatterns(t *testing.T) {
|
||||
patternStore, _, cleanup := testPatternStore(t)
|
||||
defer cleanup()
|
||||
ctx := context.Background()
|
||||
|
||||
now := time.Now()
|
||||
|
||||
// Create source pattern
|
||||
source := &models.Pattern{
|
||||
Name: "Source Pattern",
|
||||
Type: models.PatternTypeBug,
|
||||
Frequency: 5,
|
||||
Projects: []string{"project-a", "project-b"},
|
||||
ObservationIDs: []int64{1, 2, 3},
|
||||
Status: models.PatternStatusActive,
|
||||
Confidence: 0.7,
|
||||
LastSeenAt: now.Format(time.RFC3339),
|
||||
LastSeenEpoch: now.UnixMilli(),
|
||||
CreatedAt: now.Format(time.RFC3339),
|
||||
CreatedAtEpoch: now.UnixMilli(),
|
||||
}
|
||||
sourceID, err := patternStore.StorePattern(ctx, source)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Create target pattern
|
||||
target := &models.Pattern{
|
||||
Name: "Target Pattern",
|
||||
Type: models.PatternTypeBug,
|
||||
Frequency: 10,
|
||||
Projects: []string{"project-b", "project-c"},
|
||||
ObservationIDs: []int64{3, 4, 5},
|
||||
Status: models.PatternStatusActive,
|
||||
Confidence: 0.8,
|
||||
LastSeenAt: now.Format(time.RFC3339),
|
||||
LastSeenEpoch: now.UnixMilli(),
|
||||
CreatedAt: now.Format(time.RFC3339),
|
||||
CreatedAtEpoch: now.UnixMilli(),
|
||||
}
|
||||
targetID, err := patternStore.StorePattern(ctx, target)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Merge source into target
|
||||
err = patternStore.MergePatterns(ctx, sourceID, targetID)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Verify target was updated
|
||||
mergedTarget, err := patternStore.GetPatternByID(ctx, targetID)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, 15, mergedTarget.Frequency) // 5 + 10
|
||||
assert.ElementsMatch(t, []string{"project-a", "project-b", "project-c"}, mergedTarget.Projects)
|
||||
assert.ElementsMatch(t, []int64{1, 2, 3, 4, 5}, mergedTarget.ObservationIDs)
|
||||
|
||||
// Verify source was marked as merged
|
||||
mergedSource, err := patternStore.GetPatternByID(ctx, sourceID)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, models.PatternStatusMerged, mergedSource.Status)
|
||||
assert.True(t, mergedSource.MergedIntoID.Valid)
|
||||
assert.Equal(t, targetID, mergedSource.MergedIntoID.Int64)
|
||||
}
|
||||
|
||||
func TestPatternStore_DeletePattern(t *testing.T) {
|
||||
patternStore, _, cleanup := testPatternStore(t)
|
||||
defer cleanup()
|
||||
ctx := context.Background()
|
||||
|
||||
now := time.Now()
|
||||
pattern := &models.Pattern{
|
||||
Name: "To Delete",
|
||||
Type: models.PatternTypeBug,
|
||||
Frequency: 1,
|
||||
Status: models.PatternStatusActive,
|
||||
Confidence: 0.5,
|
||||
LastSeenAt: now.Format(time.RFC3339),
|
||||
LastSeenEpoch: now.UnixMilli(),
|
||||
CreatedAt: now.Format(time.RFC3339),
|
||||
CreatedAtEpoch: now.UnixMilli(),
|
||||
}
|
||||
|
||||
id, err := patternStore.StorePattern(ctx, pattern)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Delete pattern
|
||||
err = patternStore.DeletePattern(ctx, id)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Verify deleted
|
||||
deleted, err := patternStore.GetPatternByID(ctx, id)
|
||||
require.NoError(t, err)
|
||||
assert.Nil(t, deleted)
|
||||
}
|
||||
|
||||
func TestPatternStore_IncrementPatternFrequency(t *testing.T) {
|
||||
patternStore, _, cleanup := testPatternStore(t)
|
||||
defer cleanup()
|
||||
ctx := context.Background()
|
||||
|
||||
now := time.Now()
|
||||
pattern := &models.Pattern{
|
||||
Name: "Frequency Test",
|
||||
Type: models.PatternTypeBug,
|
||||
Frequency: 1,
|
||||
Projects: []string{"project-a"},
|
||||
ObservationIDs: []int64{},
|
||||
Status: models.PatternStatusActive,
|
||||
Confidence: 0.7,
|
||||
LastSeenAt: now.Format(time.RFC3339),
|
||||
LastSeenEpoch: now.UnixMilli(),
|
||||
CreatedAt: now.Format(time.RFC3339),
|
||||
CreatedAtEpoch: now.UnixMilli(),
|
||||
}
|
||||
|
||||
id, err := patternStore.StorePattern(ctx, pattern)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Increment frequency with new project and observation
|
||||
err = patternStore.IncrementPatternFrequency(ctx, id, "project-b", 42)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Verify frequency incremented and new data added
|
||||
updated, err := patternStore.GetPatternByID(ctx, id)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, 2, updated.Frequency)
|
||||
assert.ElementsMatch(t, []string{"project-a", "project-b"}, updated.Projects)
|
||||
assert.Contains(t, updated.ObservationIDs, int64(42))
|
||||
|
||||
// Last seen should be updated (rough check - within last 5 seconds)
|
||||
updatedTime, _ := time.Parse(time.RFC3339, updated.LastSeenAt)
|
||||
assert.WithinDuration(t, time.Now(), updatedTime, 5*time.Second)
|
||||
}
|
||||
|
||||
func TestPatternStore_GetPatternStats(t *testing.T) {
|
||||
patternStore, _, cleanup := testPatternStore(t)
|
||||
defer cleanup()
|
||||
ctx := context.Background()
|
||||
|
||||
now := time.Now()
|
||||
|
||||
// Create patterns with different statuses and types
|
||||
patterns := []*models.Pattern{
|
||||
{
|
||||
Name: "Bug 1",
|
||||
Type: models.PatternTypeBug,
|
||||
Frequency: 10,
|
||||
Status: models.PatternStatusActive,
|
||||
Confidence: 0.8,
|
||||
LastSeenAt: now.Format(time.RFC3339),
|
||||
LastSeenEpoch: now.UnixMilli(),
|
||||
CreatedAt: now.Format(time.RFC3339),
|
||||
CreatedAtEpoch: now.UnixMilli(),
|
||||
},
|
||||
{
|
||||
Name: "Refactor 1",
|
||||
Type: models.PatternTypeRefactor,
|
||||
Frequency: 5,
|
||||
Status: models.PatternStatusActive,
|
||||
Confidence: 0.7,
|
||||
LastSeenAt: now.Format(time.RFC3339),
|
||||
LastSeenEpoch: now.UnixMilli(),
|
||||
CreatedAt: now.Format(time.RFC3339),
|
||||
CreatedAtEpoch: now.UnixMilli(),
|
||||
},
|
||||
{
|
||||
Name: "Deprecated 1",
|
||||
Type: models.PatternTypeBestPractice,
|
||||
Frequency: 3,
|
||||
Status: models.PatternStatusDeprecated,
|
||||
Confidence: 0.6,
|
||||
LastSeenAt: now.Format(time.RFC3339),
|
||||
LastSeenEpoch: now.UnixMilli(),
|
||||
CreatedAt: now.Format(time.RFC3339),
|
||||
CreatedAtEpoch: now.UnixMilli(),
|
||||
},
|
||||
}
|
||||
|
||||
for _, p := range patterns {
|
||||
_, err := patternStore.StorePattern(ctx, p)
|
||||
require.NoError(t, err)
|
||||
}
|
||||
|
||||
// Get stats
|
||||
stats, err := patternStore.GetPatternStats(ctx)
|
||||
require.NoError(t, err)
|
||||
|
||||
assert.Equal(t, 3, stats.Total)
|
||||
assert.Equal(t, 2, stats.Active)
|
||||
assert.Equal(t, 1, stats.Deprecated)
|
||||
assert.Equal(t, 0, stats.Merged)
|
||||
assert.Equal(t, 18, stats.TotalOccurrences) // 10 + 5 + 3
|
||||
assert.InDelta(t, 0.7, stats.AvgConfidence, 0.05) // (0.8 + 0.7 + 0.6) / 3
|
||||
assert.Equal(t, 1, stats.Bugs)
|
||||
assert.Equal(t, 1, stats.Refactors)
|
||||
assert.Equal(t, 1, stats.BestPractices)
|
||||
}
|
||||
@@ -0,0 +1,317 @@
|
||||
// Package gorm provides GORM-based database operations for claude-mnemonic.
|
||||
package gorm
|
||||
|
||||
import (
|
||||
"context"
|
||||
"database/sql"
|
||||
"time"
|
||||
|
||||
"gorm.io/gorm"
|
||||
"gorm.io/gorm/clause"
|
||||
|
||||
"github.com/lukaszraczylo/claude-mnemonic/pkg/models"
|
||||
)
|
||||
|
||||
// PromptCleanupFunc is a callback for when prompts are cleaned up.
|
||||
// Receives the IDs of deleted prompts for downstream cleanup (e.g., vector DB).
|
||||
type PromptCleanupFunc func(ctx context.Context, deletedIDs []int64)
|
||||
|
||||
// MaxPromptsGlobal is the hard limit of prompts across all projects.
|
||||
const MaxPromptsGlobal = 500
|
||||
|
||||
// PromptStore provides user prompt-related database operations using GORM.
|
||||
type PromptStore struct {
|
||||
db *gorm.DB
|
||||
cleanupFunc PromptCleanupFunc
|
||||
}
|
||||
|
||||
// NewPromptStore creates a new prompt store.
|
||||
func NewPromptStore(store *Store, cleanupFunc PromptCleanupFunc) *PromptStore {
|
||||
return &PromptStore{
|
||||
db: store.DB,
|
||||
cleanupFunc: cleanupFunc,
|
||||
}
|
||||
}
|
||||
|
||||
// SetCleanupFunc sets the callback for when prompts are deleted during cleanup.
|
||||
func (s *PromptStore) SetCleanupFunc(fn PromptCleanupFunc) {
|
||||
s.cleanupFunc = fn
|
||||
}
|
||||
|
||||
// SaveUserPromptWithMatches saves a user prompt with matched observation count.
|
||||
// Uses INSERT OR IGNORE to be idempotent - duplicate (session, prompt_number) pairs are silently ignored.
|
||||
// This prevents duplicate prompts when the user-prompt hook fires multiple times.
|
||||
func (s *PromptStore) SaveUserPromptWithMatches(ctx context.Context, claudeSessionID string, promptNumber int, promptText string, matchedObservations int) (int64, error) {
|
||||
now := time.Now()
|
||||
|
||||
prompt := &UserPrompt{
|
||||
ClaudeSessionID: claudeSessionID,
|
||||
PromptNumber: promptNumber,
|
||||
PromptText: promptText,
|
||||
MatchedObservations: matchedObservations,
|
||||
CreatedAt: now.Format(time.RFC3339),
|
||||
CreatedAtEpoch: now.UnixMilli(),
|
||||
}
|
||||
|
||||
// INSERT OR IGNORE using OnConflict
|
||||
result := s.db.WithContext(ctx).
|
||||
Clauses(clause.OnConflict{
|
||||
Columns: []clause.Column{{Name: "claude_session_id"}, {Name: "prompt_number"}},
|
||||
DoNothing: true,
|
||||
}).
|
||||
Create(prompt)
|
||||
|
||||
if result.Error != nil {
|
||||
return 0, result.Error
|
||||
}
|
||||
|
||||
// If RowsAffected is 0, the insert was ignored (duplicate) - fetch the existing ID
|
||||
if result.RowsAffected == 0 {
|
||||
var existing UserPrompt
|
||||
err := s.db.Where("claude_session_id = ? AND prompt_number = ?", claudeSessionID, promptNumber).
|
||||
First(&existing).Error
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
// Return existing ID without triggering cleanup (already handled when first inserted)
|
||||
return existing.ID, nil
|
||||
}
|
||||
|
||||
// Cleanup old prompts beyond the global limit (async to not block handler)
|
||||
go func() {
|
||||
cleanupCtx, cancel := context.WithTimeout(context.Background(), 10*time.Second)
|
||||
defer cancel()
|
||||
deletedIDs, _ := s.CleanupOldPrompts(cleanupCtx)
|
||||
if len(deletedIDs) > 0 && s.cleanupFunc != nil {
|
||||
s.cleanupFunc(cleanupCtx, deletedIDs)
|
||||
}
|
||||
}()
|
||||
|
||||
return prompt.ID, nil
|
||||
}
|
||||
|
||||
// CleanupOldPrompts deletes prompts beyond the global limit.
|
||||
// Keeps the most recent MaxPromptsGlobal prompts.
|
||||
// Returns the IDs of deleted prompts for downstream cleanup (e.g., vector DB).
|
||||
func (s *PromptStore) CleanupOldPrompts(ctx context.Context) ([]int64, error) {
|
||||
// Use a transaction to prevent TOCTOU race condition
|
||||
var idsToDelete []int64
|
||||
|
||||
err := s.db.WithContext(ctx).Transaction(func(tx *gorm.DB) error {
|
||||
// Find IDs to keep (most recent MaxPromptsGlobal)
|
||||
var idsToKeep []int64
|
||||
err := tx.Model(&UserPrompt{}).
|
||||
Order("created_at_epoch DESC").
|
||||
Limit(MaxPromptsGlobal).
|
||||
Pluck("id", &idsToKeep).Error
|
||||
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if len(idsToKeep) == 0 {
|
||||
return nil
|
||||
}
|
||||
|
||||
// Find IDs to delete (all IDs not in the keep list)
|
||||
// This happens in the same transaction to prevent race conditions
|
||||
err = tx.Model(&UserPrompt{}).
|
||||
Where("id NOT IN ?", idsToKeep).
|
||||
Pluck("id", &idsToDelete).Error
|
||||
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if len(idsToDelete) == 0 {
|
||||
return nil
|
||||
}
|
||||
|
||||
// Delete the prompts
|
||||
return tx.Delete(&UserPrompt{}, idsToDelete).Error
|
||||
})
|
||||
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return idsToDelete, nil
|
||||
}
|
||||
|
||||
// GetPromptsByIDs retrieves user prompts by a list of IDs.
|
||||
func (s *PromptStore) GetPromptsByIDs(ctx context.Context, ids []int64, orderBy string, limit int) ([]*models.UserPromptWithSession, error) {
|
||||
if len(ids) == 0 {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
var results []struct {
|
||||
UserPrompt
|
||||
Project sql.NullString `gorm:"column:project"`
|
||||
SDKSessionID sql.NullString `gorm:"column:sdk_session_id"`
|
||||
}
|
||||
|
||||
query := s.db.WithContext(ctx).
|
||||
Table("user_prompts up").
|
||||
Select("up.id, up.claude_session_id, up.prompt_number, up.prompt_text, "+
|
||||
"COALESCE(up.matched_observations, 0) as matched_observations, "+
|
||||
"up.created_at, up.created_at_epoch, "+
|
||||
"COALESCE(s.project, '') as project, "+
|
||||
"COALESCE(s.sdk_session_id, '') as sdk_session_id").
|
||||
Joins("LEFT JOIN sdk_sessions s ON up.claude_session_id = s.claude_session_id").
|
||||
Where("up.id IN ?", ids)
|
||||
|
||||
// Apply ordering
|
||||
switch orderBy {
|
||||
case "date_asc":
|
||||
query = query.Order("up.created_at_epoch ASC")
|
||||
case "date_desc", "default", "":
|
||||
query = query.Order("up.created_at_epoch DESC")
|
||||
}
|
||||
|
||||
// Apply limit
|
||||
if limit > 0 {
|
||||
query = query.Limit(limit)
|
||||
}
|
||||
|
||||
err := query.Scan(&results).Error
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return toModelUserPromptsWithSession(results), nil
|
||||
}
|
||||
|
||||
// GetAllRecentUserPrompts retrieves recent user prompts across all projects.
|
||||
func (s *PromptStore) GetAllRecentUserPrompts(ctx context.Context, limit int) ([]*models.UserPromptWithSession, error) {
|
||||
var results []struct {
|
||||
UserPrompt
|
||||
Project sql.NullString `gorm:"column:project"`
|
||||
SDKSessionID sql.NullString `gorm:"column:sdk_session_id"`
|
||||
}
|
||||
|
||||
query := s.db.WithContext(ctx).
|
||||
Table("user_prompts up").
|
||||
Select("up.id, up.claude_session_id, up.prompt_number, up.prompt_text, " +
|
||||
"COALESCE(up.matched_observations, 0) as matched_observations, " +
|
||||
"up.created_at, up.created_at_epoch, " +
|
||||
"COALESCE(s.project, '') as project, " +
|
||||
"COALESCE(s.sdk_session_id, '') as sdk_session_id").
|
||||
Joins("LEFT JOIN sdk_sessions s ON up.claude_session_id = s.claude_session_id").
|
||||
Order("up.created_at_epoch DESC").
|
||||
Limit(limit)
|
||||
|
||||
err := query.Scan(&results).Error
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return toModelUserPromptsWithSession(results), nil
|
||||
}
|
||||
|
||||
// GetAllPrompts retrieves all user prompts (for vector rebuild).
|
||||
func (s *PromptStore) GetAllPrompts(ctx context.Context) ([]*models.UserPromptWithSession, error) {
|
||||
var results []struct {
|
||||
UserPrompt
|
||||
Project sql.NullString `gorm:"column:project"`
|
||||
SDKSessionID sql.NullString `gorm:"column:sdk_session_id"`
|
||||
}
|
||||
|
||||
query := s.db.WithContext(ctx).
|
||||
Table("user_prompts up").
|
||||
Select("up.id, up.claude_session_id, up.prompt_number, up.prompt_text, " +
|
||||
"COALESCE(up.matched_observations, 0) as matched_observations, " +
|
||||
"up.created_at, up.created_at_epoch, " +
|
||||
"COALESCE(s.project, '') as project, " +
|
||||
"COALESCE(s.sdk_session_id, '') as sdk_session_id").
|
||||
Joins("LEFT JOIN sdk_sessions s ON up.claude_session_id = s.claude_session_id").
|
||||
Order("up.id")
|
||||
|
||||
err := query.Scan(&results).Error
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return toModelUserPromptsWithSession(results), nil
|
||||
}
|
||||
|
||||
// FindRecentPromptByText finds a recent prompt by exact text match within a time window.
|
||||
// Returns (promptID, promptNumber, found).
|
||||
func (s *PromptStore) FindRecentPromptByText(ctx context.Context, claudeSessionID, promptText string, withinSeconds int) (int64, int, bool) {
|
||||
cutoffEpoch := time.Now().Add(-time.Duration(withinSeconds) * time.Second).UnixMilli()
|
||||
|
||||
var prompt UserPrompt
|
||||
err := s.db.WithContext(ctx).
|
||||
Where("claude_session_id = ? AND prompt_text = ? AND created_at_epoch >= ?",
|
||||
claudeSessionID, promptText, cutoffEpoch).
|
||||
Order("created_at_epoch DESC").
|
||||
First(&prompt).Error
|
||||
|
||||
if err != nil {
|
||||
return 0, 0, false
|
||||
}
|
||||
|
||||
return prompt.ID, prompt.PromptNumber, true
|
||||
}
|
||||
|
||||
// GetRecentUserPromptsByProject retrieves recent user prompts for a specific project.
|
||||
func (s *PromptStore) GetRecentUserPromptsByProject(ctx context.Context, project string, limit int) ([]*models.UserPromptWithSession, error) {
|
||||
var results []struct {
|
||||
UserPrompt
|
||||
Project sql.NullString `gorm:"column:project"`
|
||||
SDKSessionID sql.NullString `gorm:"column:sdk_session_id"`
|
||||
}
|
||||
|
||||
query := s.db.WithContext(ctx).
|
||||
Table("user_prompts up").
|
||||
Select("up.id, up.claude_session_id, up.prompt_number, up.prompt_text, "+
|
||||
"COALESCE(up.matched_observations, 0) as matched_observations, "+
|
||||
"up.created_at, up.created_at_epoch, "+
|
||||
"COALESCE(s.project, '') as project, "+
|
||||
"COALESCE(s.sdk_session_id, '') as sdk_session_id").
|
||||
Joins("LEFT JOIN sdk_sessions s ON up.claude_session_id = s.claude_session_id").
|
||||
Where("s.project = ?", project).
|
||||
Order("up.created_at_epoch DESC").
|
||||
Limit(limit)
|
||||
|
||||
err := query.Scan(&results).Error
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return toModelUserPromptsWithSession(results), nil
|
||||
}
|
||||
|
||||
// toModelUserPromptsWithSession converts query results to pkg/models.UserPromptWithSession.
|
||||
func toModelUserPromptsWithSession(results []struct {
|
||||
UserPrompt
|
||||
Project sql.NullString `gorm:"column:project"`
|
||||
SDKSessionID sql.NullString `gorm:"column:sdk_session_id"`
|
||||
}) []*models.UserPromptWithSession {
|
||||
prompts := make([]*models.UserPromptWithSession, len(results))
|
||||
for i, r := range results {
|
||||
project := ""
|
||||
if r.Project.Valid {
|
||||
project = r.Project.String
|
||||
}
|
||||
|
||||
sdkSessionID := ""
|
||||
if r.SDKSessionID.Valid {
|
||||
sdkSessionID = r.SDKSessionID.String
|
||||
}
|
||||
|
||||
prompts[i] = &models.UserPromptWithSession{
|
||||
UserPrompt: models.UserPrompt{
|
||||
ID: r.ID,
|
||||
ClaudeSessionID: r.ClaudeSessionID,
|
||||
PromptNumber: r.PromptNumber,
|
||||
PromptText: r.PromptText,
|
||||
MatchedObservations: r.MatchedObservations,
|
||||
CreatedAt: r.CreatedAt,
|
||||
CreatedAtEpoch: r.CreatedAtEpoch,
|
||||
},
|
||||
Project: project,
|
||||
SDKSessionID: sdkSessionID,
|
||||
}
|
||||
}
|
||||
return prompts
|
||||
}
|
||||
@@ -0,0 +1,396 @@
|
||||
//go:build fts5
|
||||
|
||||
// Package gorm provides GORM-based database operations for claude-mnemonic.
|
||||
package gorm
|
||||
|
||||
import (
|
||||
"context"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"sync"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
"gorm.io/gorm/logger"
|
||||
)
|
||||
|
||||
// testPromptStore creates a PromptStore with a temporary database for testing.
|
||||
func testPromptStore(t *testing.T) (*PromptStore, *Store, func()) {
|
||||
t.Helper()
|
||||
|
||||
tmpDir, err := os.MkdirTemp("", "gorm_prompt_test_*")
|
||||
if err != nil {
|
||||
t.Fatalf("create temp dir: %v", err)
|
||||
}
|
||||
|
||||
dbPath := filepath.Join(tmpDir, "test.db")
|
||||
cfg := Config{
|
||||
Path: dbPath,
|
||||
MaxConns: 4,
|
||||
LogLevel: logger.Silent,
|
||||
}
|
||||
|
||||
store, err := NewStore(cfg)
|
||||
if err != nil {
|
||||
os.RemoveAll(tmpDir)
|
||||
t.Fatalf("NewStore failed: %v", err)
|
||||
}
|
||||
|
||||
promptStore := NewPromptStore(store, nil)
|
||||
|
||||
cleanup := func() {
|
||||
store.Close()
|
||||
os.RemoveAll(tmpDir)
|
||||
}
|
||||
|
||||
return promptStore, store, cleanup
|
||||
}
|
||||
|
||||
func TestPromptStore_SaveUserPromptWithMatches(t *testing.T) {
|
||||
promptStore, store, cleanup := testPromptStore(t)
|
||||
defer cleanup()
|
||||
|
||||
ctx := context.Background()
|
||||
|
||||
// Create a session first
|
||||
sessionStore := NewSessionStore(store)
|
||||
_, err := sessionStore.CreateSDKSession(ctx, "claude-1", "test-project", "")
|
||||
require.NoError(t, err)
|
||||
|
||||
// Save a prompt
|
||||
id, err := promptStore.SaveUserPromptWithMatches(ctx, "claude-1", 1, "What is the codebase structure?", 5)
|
||||
require.NoError(t, err)
|
||||
assert.Greater(t, id, int64(0))
|
||||
}
|
||||
|
||||
func TestPromptStore_SaveUserPromptWithMatches_Idempotency(t *testing.T) {
|
||||
promptStore, _, cleanup := testPromptStore(t)
|
||||
defer cleanup()
|
||||
|
||||
ctx := context.Background()
|
||||
|
||||
// Save the same prompt twice (same claudeSessionID + promptNumber)
|
||||
id1, err := promptStore.SaveUserPromptWithMatches(ctx, "claude-1", 1, "Test prompt", 3)
|
||||
require.NoError(t, err)
|
||||
|
||||
id2, err := promptStore.SaveUserPromptWithMatches(ctx, "claude-1", 1, "Different text", 5)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Should return the same ID (INSERT OR IGNORE)
|
||||
assert.Equal(t, id1, id2, "Duplicate prompts should return same ID")
|
||||
}
|
||||
|
||||
func TestPromptStore_SaveUserPromptWithMatches_AsyncCleanup(t *testing.T) {
|
||||
promptStore, _, cleanup := testPromptStore(t)
|
||||
defer cleanup()
|
||||
|
||||
ctx := context.Background()
|
||||
|
||||
// Track cleanup calls
|
||||
var cleanupMutex sync.Mutex
|
||||
cleanupCalled := false
|
||||
var cleanupIDs []int64
|
||||
|
||||
cleanupFunc := func(ctx context.Context, deletedIDs []int64) {
|
||||
cleanupMutex.Lock()
|
||||
defer cleanupMutex.Unlock()
|
||||
cleanupCalled = true
|
||||
cleanupIDs = deletedIDs
|
||||
}
|
||||
|
||||
promptStore.cleanupFunc = cleanupFunc
|
||||
|
||||
// Save prompts beyond the global limit (MaxPromptsGlobal = 500)
|
||||
// Insert with slower pacing to avoid database lock contention
|
||||
for i := 0; i < 505; i++ {
|
||||
_, err := promptStore.SaveUserPromptWithMatches(ctx, "claude-1", i+1, "Prompt", 1)
|
||||
require.NoError(t, err)
|
||||
if i > 500 {
|
||||
time.Sleep(5 * time.Millisecond) // Slow down after hitting limit
|
||||
}
|
||||
}
|
||||
|
||||
// Wait longer for async cleanup to complete
|
||||
time.Sleep(500 * time.Millisecond)
|
||||
|
||||
// Verify cleanup was called
|
||||
cleanupMutex.Lock()
|
||||
defer cleanupMutex.Unlock()
|
||||
assert.True(t, cleanupCalled, "Cleanup function should have been called")
|
||||
assert.NotEmpty(t, cleanupIDs, "Cleanup should have deleted some prompts")
|
||||
}
|
||||
|
||||
func TestPromptStore_CleanupOldPrompts(t *testing.T) {
|
||||
promptStore, _, cleanup := testPromptStore(t)
|
||||
defer cleanup()
|
||||
|
||||
ctx := context.Background()
|
||||
|
||||
// Save prompts beyond the limit
|
||||
// Async cleanup should fire after each insert beyond 500
|
||||
for i := 0; i < 505; i++ {
|
||||
_, err := promptStore.SaveUserPromptWithMatches(ctx, "claude-1", i+1, "Prompt", 1)
|
||||
require.NoError(t, err)
|
||||
}
|
||||
|
||||
// Wait for all async cleanups to complete
|
||||
time.Sleep(1 * time.Second)
|
||||
|
||||
// After async cleanup, we should have at most 500 prompts
|
||||
remaining, err := promptStore.GetAllPrompts(ctx)
|
||||
require.NoError(t, err)
|
||||
assert.LessOrEqual(t, len(remaining), MaxPromptsGlobal, "Should have at most %d prompts after async cleanup", MaxPromptsGlobal)
|
||||
}
|
||||
|
||||
func TestPromptStore_GetPromptsByIDs(t *testing.T) {
|
||||
promptStore, _, cleanup := testPromptStore(t)
|
||||
defer cleanup()
|
||||
|
||||
ctx := context.Background()
|
||||
|
||||
// Save multiple prompts
|
||||
var ids []int64
|
||||
for i := 1; i <= 3; i++ {
|
||||
id, err := promptStore.SaveUserPromptWithMatches(ctx, "claude-1", i, "Prompt", i)
|
||||
require.NoError(t, err)
|
||||
ids = append(ids, id)
|
||||
time.Sleep(10 * time.Millisecond) // Ensure different timestamps
|
||||
}
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
orderBy string
|
||||
expected []int64
|
||||
}{
|
||||
{
|
||||
name: "Default ordering - date desc",
|
||||
orderBy: "default",
|
||||
expected: []int64{ids[2], ids[1], ids[0]}, // Newest to oldest
|
||||
},
|
||||
{
|
||||
name: "Date ascending",
|
||||
orderBy: "date_asc",
|
||||
expected: []int64{ids[0], ids[1], ids[2]}, // Oldest to newest
|
||||
},
|
||||
{
|
||||
name: "Date descending",
|
||||
orderBy: "date_desc",
|
||||
expected: []int64{ids[2], ids[1], ids[0]}, // Newest to oldest
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
prompts, err := promptStore.GetPromptsByIDs(ctx, ids, tt.orderBy, 10)
|
||||
require.NoError(t, err)
|
||||
require.Len(t, prompts, 3)
|
||||
|
||||
// Verify ordering
|
||||
for i, prompt := range prompts {
|
||||
assert.Equal(t, tt.expected[i], prompt.ID, "Position %d should have ID %d", i, tt.expected[i])
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestPromptStore_GetPromptsByIDs_Limit(t *testing.T) {
|
||||
promptStore, _, cleanup := testPromptStore(t)
|
||||
defer cleanup()
|
||||
|
||||
ctx := context.Background()
|
||||
|
||||
// Save multiple prompts
|
||||
var ids []int64
|
||||
for i := 1; i <= 5; i++ {
|
||||
id, err := promptStore.SaveUserPromptWithMatches(ctx, "claude-1", i, "Prompt", i)
|
||||
require.NoError(t, err)
|
||||
ids = append(ids, id)
|
||||
}
|
||||
|
||||
// Get with limit
|
||||
prompts, err := promptStore.GetPromptsByIDs(ctx, ids, "default", 3)
|
||||
require.NoError(t, err)
|
||||
assert.Len(t, prompts, 3)
|
||||
}
|
||||
|
||||
func TestPromptStore_GetPromptsByIDs_EmptyInput(t *testing.T) {
|
||||
promptStore, _, cleanup := testPromptStore(t)
|
||||
defer cleanup()
|
||||
|
||||
ctx := context.Background()
|
||||
|
||||
// Get with empty IDs
|
||||
prompts, err := promptStore.GetPromptsByIDs(ctx, []int64{}, "default", 10)
|
||||
require.NoError(t, err)
|
||||
assert.Nil(t, prompts)
|
||||
}
|
||||
|
||||
func TestPromptStore_GetPromptsByIDs_WithSession(t *testing.T) {
|
||||
promptStore, store, cleanup := testPromptStore(t)
|
||||
defer cleanup()
|
||||
|
||||
ctx := context.Background()
|
||||
|
||||
// Create a session
|
||||
sessionStore := NewSessionStore(store)
|
||||
_, err := sessionStore.CreateSDKSession(ctx, "claude-1", "test-project", "")
|
||||
require.NoError(t, err)
|
||||
|
||||
// Save a prompt
|
||||
id, err := promptStore.SaveUserPromptWithMatches(ctx, "claude-1", 1, "Test prompt", 5)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Get with session join
|
||||
prompts, err := promptStore.GetPromptsByIDs(ctx, []int64{id}, "default", 10)
|
||||
require.NoError(t, err)
|
||||
require.Len(t, prompts, 1)
|
||||
|
||||
// Verify session data is populated
|
||||
assert.Equal(t, "test-project", prompts[0].Project)
|
||||
assert.NotEmpty(t, prompts[0].SDKSessionID)
|
||||
}
|
||||
|
||||
func TestPromptStore_GetAllRecentUserPrompts(t *testing.T) {
|
||||
promptStore, _, cleanup := testPromptStore(t)
|
||||
defer cleanup()
|
||||
|
||||
ctx := context.Background()
|
||||
|
||||
// Save prompts across multiple sessions with timestamps
|
||||
for i := 1; i <= 3; i++ {
|
||||
_, err := promptStore.SaveUserPromptWithMatches(ctx, "claude-1", i, "Prompt A", i)
|
||||
require.NoError(t, err)
|
||||
time.Sleep(10 * time.Millisecond) // Ensure different timestamps
|
||||
}
|
||||
|
||||
for i := 1; i <= 2; i++ {
|
||||
_, err := promptStore.SaveUserPromptWithMatches(ctx, "claude-2", i, "Prompt B", i)
|
||||
require.NoError(t, err)
|
||||
time.Sleep(10 * time.Millisecond) // Ensure different timestamps
|
||||
}
|
||||
|
||||
// Get all recent prompts
|
||||
prompts, err := promptStore.GetAllRecentUserPrompts(ctx, 10)
|
||||
require.NoError(t, err)
|
||||
assert.Len(t, prompts, 5)
|
||||
|
||||
// Verify ordering (most recent first) - last inserted should be first
|
||||
assert.Equal(t, "claude-2", prompts[0].ClaudeSessionID)
|
||||
assert.Equal(t, 2, prompts[0].PromptNumber)
|
||||
}
|
||||
|
||||
func TestPromptStore_GetAllPrompts(t *testing.T) {
|
||||
promptStore, _, cleanup := testPromptStore(t)
|
||||
defer cleanup()
|
||||
|
||||
ctx := context.Background()
|
||||
|
||||
// Save prompts
|
||||
for i := 1; i <= 5; i++ {
|
||||
_, err := promptStore.SaveUserPromptWithMatches(ctx, "claude-1", i, "Prompt", i)
|
||||
require.NoError(t, err)
|
||||
}
|
||||
|
||||
// Wait for any async cleanup to complete (longer wait for race detector)
|
||||
time.Sleep(500 * time.Millisecond)
|
||||
|
||||
// Get all prompts (for vector rebuild)
|
||||
prompts, err := promptStore.GetAllPrompts(ctx)
|
||||
require.NoError(t, err)
|
||||
assert.Len(t, prompts, 5)
|
||||
|
||||
// Verify ordering by ID
|
||||
for i := 0; i < len(prompts)-1; i++ {
|
||||
assert.Less(t, prompts[i].ID, prompts[i+1].ID)
|
||||
}
|
||||
}
|
||||
|
||||
func TestPromptStore_FindRecentPromptByText(t *testing.T) {
|
||||
promptStore, _, cleanup := testPromptStore(t)
|
||||
defer cleanup()
|
||||
|
||||
ctx := context.Background()
|
||||
|
||||
// Save a prompt
|
||||
id, err := promptStore.SaveUserPromptWithMatches(ctx, "claude-1", 1, "What is the architecture?", 3)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Find by exact text match within time window
|
||||
foundID, foundNumber, found := promptStore.FindRecentPromptByText(ctx, "claude-1", "What is the architecture?", 60)
|
||||
assert.True(t, found, "Should find the prompt")
|
||||
assert.Equal(t, id, foundID)
|
||||
assert.Equal(t, 1, foundNumber)
|
||||
|
||||
// Try to find with different text
|
||||
_, _, notFound := promptStore.FindRecentPromptByText(ctx, "claude-1", "Different text", 60)
|
||||
assert.False(t, notFound, "Should not find a different prompt")
|
||||
|
||||
// Try to find outside time window
|
||||
time.Sleep(100 * time.Millisecond)
|
||||
_, _, notFound = promptStore.FindRecentPromptByText(ctx, "claude-1", "What is the architecture?", 0)
|
||||
assert.False(t, notFound, "Should not find prompt outside time window")
|
||||
}
|
||||
|
||||
func TestPromptStore_GetRecentUserPromptsByProject(t *testing.T) {
|
||||
promptStore, store, cleanup := testPromptStore(t)
|
||||
defer cleanup()
|
||||
|
||||
ctx := context.Background()
|
||||
|
||||
// Create sessions for different projects
|
||||
sessionStore := NewSessionStore(store)
|
||||
_, err := sessionStore.CreateSDKSession(ctx, "claude-1", "project-a", "")
|
||||
require.NoError(t, err)
|
||||
_, err = sessionStore.CreateSDKSession(ctx, "claude-2", "project-b", "")
|
||||
require.NoError(t, err)
|
||||
|
||||
// Save prompts for project-a
|
||||
for i := 1; i <= 3; i++ {
|
||||
_, err := promptStore.SaveUserPromptWithMatches(ctx, "claude-1", i, "Prompt A", i)
|
||||
require.NoError(t, err)
|
||||
}
|
||||
|
||||
// Save prompts for project-b
|
||||
for i := 1; i <= 2; i++ {
|
||||
_, err := promptStore.SaveUserPromptWithMatches(ctx, "claude-2", i, "Prompt B", i)
|
||||
require.NoError(t, err)
|
||||
}
|
||||
|
||||
// Get prompts for project-a
|
||||
prompts, err := promptStore.GetRecentUserPromptsByProject(ctx, "project-a", 10)
|
||||
require.NoError(t, err)
|
||||
assert.Len(t, prompts, 3)
|
||||
|
||||
// Verify all prompts are from project-a
|
||||
for _, prompt := range prompts {
|
||||
assert.Equal(t, "project-a", prompt.Project)
|
||||
}
|
||||
}
|
||||
|
||||
func TestPromptStore_GetRecentUserPromptsByProject_Limit(t *testing.T) {
|
||||
promptStore, store, cleanup := testPromptStore(t)
|
||||
defer cleanup()
|
||||
|
||||
ctx := context.Background()
|
||||
|
||||
// Create session
|
||||
sessionStore := NewSessionStore(store)
|
||||
_, err := sessionStore.CreateSDKSession(ctx, "claude-1", "test-project", "")
|
||||
require.NoError(t, err)
|
||||
|
||||
// Save multiple prompts
|
||||
for i := 1; i <= 10; i++ {
|
||||
_, err := promptStore.SaveUserPromptWithMatches(ctx, "claude-1", i, "Prompt", i)
|
||||
require.NoError(t, err)
|
||||
}
|
||||
|
||||
// Wait for any async cleanup to complete
|
||||
time.Sleep(100 * time.Millisecond)
|
||||
|
||||
// Get with limit
|
||||
prompts, err := promptStore.GetRecentUserPromptsByProject(ctx, "test-project", 5)
|
||||
require.NoError(t, err)
|
||||
assert.Len(t, prompts, 5)
|
||||
}
|
||||
@@ -0,0 +1,383 @@
|
||||
// Package gorm provides GORM-based database operations for claude-mnemonic.
|
||||
package gorm
|
||||
|
||||
import (
|
||||
"context"
|
||||
"database/sql"
|
||||
|
||||
"gorm.io/gorm"
|
||||
"gorm.io/gorm/clause"
|
||||
|
||||
"github.com/lukaszraczylo/claude-mnemonic/pkg/models"
|
||||
)
|
||||
|
||||
// RelationStore provides relation-related database operations using GORM.
|
||||
type RelationStore struct {
|
||||
db *gorm.DB
|
||||
}
|
||||
|
||||
// NewRelationStore creates a new relation store.
|
||||
func NewRelationStore(store *Store) *RelationStore {
|
||||
return &RelationStore{
|
||||
db: store.DB,
|
||||
}
|
||||
}
|
||||
|
||||
// StoreRelation stores a new observation relation.
|
||||
// Uses INSERT OR IGNORE to handle duplicate (source_id, target_id, relation_type) combinations.
|
||||
func (s *RelationStore) StoreRelation(ctx context.Context, relation *models.ObservationRelation) (int64, error) {
|
||||
dbRelation := &ObservationRelation{
|
||||
SourceID: relation.SourceID,
|
||||
TargetID: relation.TargetID,
|
||||
RelationType: relation.RelationType,
|
||||
Confidence: relation.Confidence,
|
||||
DetectionSource: relation.DetectionSource,
|
||||
CreatedAt: relation.CreatedAt,
|
||||
CreatedAtEpoch: relation.CreatedAtEpoch,
|
||||
}
|
||||
|
||||
// Handle nullable fields
|
||||
if relation.Reason != "" {
|
||||
dbRelation.Reason = sql.NullString{String: relation.Reason, Valid: true}
|
||||
}
|
||||
|
||||
// INSERT OR IGNORE using OnConflict
|
||||
result := s.db.WithContext(ctx).
|
||||
Clauses(clause.OnConflict{
|
||||
Columns: []clause.Column{{Name: "source_id"}, {Name: "target_id"}, {Name: "relation_type"}},
|
||||
DoNothing: true,
|
||||
}).
|
||||
Create(dbRelation)
|
||||
|
||||
if result.Error != nil {
|
||||
return 0, result.Error
|
||||
}
|
||||
|
||||
// If RowsAffected is 0, the insert was ignored (duplicate)
|
||||
if result.RowsAffected == 0 {
|
||||
var existing ObservationRelation
|
||||
err := s.db.Where("source_id = ? AND target_id = ? AND relation_type = ?",
|
||||
relation.SourceID, relation.TargetID, relation.RelationType).
|
||||
First(&existing).Error
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
return existing.ID, nil
|
||||
}
|
||||
|
||||
return dbRelation.ID, nil
|
||||
}
|
||||
|
||||
// StoreRelations stores multiple relations in a single transaction.
|
||||
func (s *RelationStore) StoreRelations(ctx context.Context, relations []*models.ObservationRelation) error {
|
||||
if len(relations) == 0 {
|
||||
return nil
|
||||
}
|
||||
|
||||
return s.db.WithContext(ctx).Transaction(func(tx *gorm.DB) error {
|
||||
for _, rel := range relations {
|
||||
dbRelation := &ObservationRelation{
|
||||
SourceID: rel.SourceID,
|
||||
TargetID: rel.TargetID,
|
||||
RelationType: rel.RelationType,
|
||||
Confidence: rel.Confidence,
|
||||
DetectionSource: rel.DetectionSource,
|
||||
CreatedAt: rel.CreatedAt,
|
||||
CreatedAtEpoch: rel.CreatedAtEpoch,
|
||||
}
|
||||
|
||||
if rel.Reason != "" {
|
||||
dbRelation.Reason = sql.NullString{String: rel.Reason, Valid: true}
|
||||
}
|
||||
|
||||
result := tx.Clauses(clause.OnConflict{
|
||||
Columns: []clause.Column{{Name: "source_id"}, {Name: "target_id"}, {Name: "relation_type"}},
|
||||
DoNothing: true,
|
||||
}).Create(dbRelation)
|
||||
|
||||
if result.Error != nil {
|
||||
return result.Error
|
||||
}
|
||||
}
|
||||
return nil
|
||||
})
|
||||
}
|
||||
|
||||
// GetRelationsByObservationID retrieves all relations involving an observation (as source or target).
|
||||
func (s *RelationStore) GetRelationsByObservationID(ctx context.Context, obsID int64) ([]*models.ObservationRelation, error) {
|
||||
var relations []ObservationRelation
|
||||
|
||||
err := s.db.WithContext(ctx).
|
||||
Where("source_id = ? OR target_id = ?", obsID, obsID).
|
||||
Order("confidence DESC, created_at_epoch DESC").
|
||||
Find(&relations).Error
|
||||
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return toModelRelations(relations), nil
|
||||
}
|
||||
|
||||
// GetOutgoingRelations retrieves relations where the observation is the source.
|
||||
func (s *RelationStore) GetOutgoingRelations(ctx context.Context, obsID int64) ([]*models.ObservationRelation, error) {
|
||||
var relations []ObservationRelation
|
||||
|
||||
err := s.db.WithContext(ctx).
|
||||
Where("source_id = ?", obsID).
|
||||
Order("confidence DESC, created_at_epoch DESC").
|
||||
Find(&relations).Error
|
||||
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return toModelRelations(relations), nil
|
||||
}
|
||||
|
||||
// GetIncomingRelations retrieves relations where the observation is the target.
|
||||
func (s *RelationStore) GetIncomingRelations(ctx context.Context, obsID int64) ([]*models.ObservationRelation, error) {
|
||||
var relations []ObservationRelation
|
||||
|
||||
err := s.db.WithContext(ctx).
|
||||
Where("target_id = ?", obsID).
|
||||
Order("confidence DESC, created_at_epoch DESC").
|
||||
Find(&relations).Error
|
||||
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return toModelRelations(relations), nil
|
||||
}
|
||||
|
||||
// GetRelationsByType retrieves all relations of a specific type.
|
||||
func (s *RelationStore) GetRelationsByType(ctx context.Context, relationType models.RelationType, limit int) ([]*models.ObservationRelation, error) {
|
||||
var relations []ObservationRelation
|
||||
|
||||
err := s.db.WithContext(ctx).
|
||||
Where("relation_type = ?", relationType).
|
||||
Order("confidence DESC, created_at_epoch DESC").
|
||||
Limit(limit).
|
||||
Find(&relations).Error
|
||||
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return toModelRelations(relations), nil
|
||||
}
|
||||
|
||||
// GetRelationsWithDetails retrieves relations with observation titles for display.
|
||||
func (s *RelationStore) GetRelationsWithDetails(ctx context.Context, obsID int64) ([]*models.RelationWithDetails, error) {
|
||||
var results []struct {
|
||||
ObservationRelation
|
||||
SourceTitle sql.NullString `gorm:"column:source_title"`
|
||||
TargetTitle sql.NullString `gorm:"column:target_title"`
|
||||
SourceType string `gorm:"column:source_type"`
|
||||
TargetType string `gorm:"column:target_type"`
|
||||
}
|
||||
|
||||
err := s.db.WithContext(ctx).
|
||||
Table("observation_relations r").
|
||||
Select("r.*, "+
|
||||
"COALESCE(src.title, '') as source_title, "+
|
||||
"COALESCE(tgt.title, '') as target_title, "+
|
||||
"src.type as source_type, "+
|
||||
"tgt.type as target_type").
|
||||
Joins("JOIN observations src ON src.id = r.source_id").
|
||||
Joins("JOIN observations tgt ON tgt.id = r.target_id").
|
||||
Where("r.source_id = ? OR r.target_id = ?", obsID, obsID).
|
||||
Order("r.confidence DESC, r.created_at_epoch DESC").
|
||||
Scan(&results).Error
|
||||
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
relations := make([]*models.RelationWithDetails, len(results))
|
||||
for i, r := range results {
|
||||
relations[i] = &models.RelationWithDetails{
|
||||
Relation: toModelRelation(&r.ObservationRelation),
|
||||
SourceTitle: r.SourceTitle.String,
|
||||
TargetTitle: r.TargetTitle.String,
|
||||
SourceType: models.ObservationType(r.SourceType),
|
||||
TargetType: models.ObservationType(r.TargetType),
|
||||
}
|
||||
}
|
||||
|
||||
return relations, nil
|
||||
}
|
||||
|
||||
// GetRelationGraph retrieves a relation graph centered on an observation.
|
||||
// This returns all observations within N hops from the center.
|
||||
func (s *RelationStore) GetRelationGraph(ctx context.Context, centerID int64, maxDepth int) (*models.RelationGraph, error) {
|
||||
// Get all relations involving the center observation
|
||||
relations, err := s.GetRelationsWithDetails(ctx, centerID)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
graph := &models.RelationGraph{
|
||||
CenterID: centerID,
|
||||
Relations: relations,
|
||||
}
|
||||
|
||||
// If depth > 1, recursively get relations for connected observations
|
||||
if maxDepth > 1 {
|
||||
visited := map[int64]bool{centerID: true}
|
||||
toVisit := make([]int64, 0)
|
||||
|
||||
// Collect IDs of directly connected observations
|
||||
for _, r := range relations {
|
||||
if !visited[r.Relation.SourceID] {
|
||||
toVisit = append(toVisit, r.Relation.SourceID)
|
||||
visited[r.Relation.SourceID] = true
|
||||
}
|
||||
if !visited[r.Relation.TargetID] {
|
||||
toVisit = append(toVisit, r.Relation.TargetID)
|
||||
visited[r.Relation.TargetID] = true
|
||||
}
|
||||
}
|
||||
|
||||
// Get relations for connected observations (depth - 1)
|
||||
for depth := 1; depth < maxDepth && len(toVisit) > 0; depth++ {
|
||||
nextLevel := make([]int64, 0)
|
||||
for _, obsID := range toVisit {
|
||||
moreRelations, err := s.GetRelationsWithDetails(ctx, obsID)
|
||||
if err != nil {
|
||||
continue
|
||||
}
|
||||
for _, r := range moreRelations {
|
||||
// Avoid duplicates
|
||||
exists := false
|
||||
for _, existing := range graph.Relations {
|
||||
if existing.Relation.ID == r.Relation.ID {
|
||||
exists = true
|
||||
break
|
||||
}
|
||||
}
|
||||
if !exists {
|
||||
graph.Relations = append(graph.Relations, r)
|
||||
}
|
||||
|
||||
// Queue next level
|
||||
if !visited[r.Relation.SourceID] {
|
||||
nextLevel = append(nextLevel, r.Relation.SourceID)
|
||||
visited[r.Relation.SourceID] = true
|
||||
}
|
||||
if !visited[r.Relation.TargetID] {
|
||||
nextLevel = append(nextLevel, r.Relation.TargetID)
|
||||
visited[r.Relation.TargetID] = true
|
||||
}
|
||||
}
|
||||
}
|
||||
toVisit = nextLevel
|
||||
}
|
||||
}
|
||||
|
||||
return graph, nil
|
||||
}
|
||||
|
||||
// DeleteRelationsByObservationID deletes all relations involving an observation.
|
||||
// Called when an observation is deleted.
|
||||
func (s *RelationStore) DeleteRelationsByObservationID(ctx context.Context, obsID int64) error {
|
||||
result := s.db.WithContext(ctx).
|
||||
Where("source_id = ? OR target_id = ?", obsID, obsID).
|
||||
Delete(&ObservationRelation{})
|
||||
|
||||
return result.Error
|
||||
}
|
||||
|
||||
// GetRelationCount returns the count of relations for an observation.
|
||||
func (s *RelationStore) GetRelationCount(ctx context.Context, obsID int64) (int, error) {
|
||||
var count int64
|
||||
err := s.db.WithContext(ctx).
|
||||
Model(&ObservationRelation{}).
|
||||
Where("source_id = ? OR target_id = ?", obsID, obsID).
|
||||
Count(&count).Error
|
||||
|
||||
return int(count), err
|
||||
}
|
||||
|
||||
// GetTotalRelationCount returns the total count of all relations.
|
||||
func (s *RelationStore) GetTotalRelationCount(ctx context.Context) (int, error) {
|
||||
var count int64
|
||||
err := s.db.WithContext(ctx).
|
||||
Model(&ObservationRelation{}).
|
||||
Count(&count).Error
|
||||
|
||||
return int(count), err
|
||||
}
|
||||
|
||||
// GetHighConfidenceRelations retrieves relations with confidence above threshold.
|
||||
func (s *RelationStore) GetHighConfidenceRelations(ctx context.Context, minConfidence float64, limit int) ([]*models.ObservationRelation, error) {
|
||||
var relations []ObservationRelation
|
||||
|
||||
err := s.db.WithContext(ctx).
|
||||
Where("confidence >= ?", minConfidence).
|
||||
Order("confidence DESC, created_at_epoch DESC").
|
||||
Limit(limit).
|
||||
Find(&relations).Error
|
||||
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return toModelRelations(relations), nil
|
||||
}
|
||||
|
||||
// UpdateRelationConfidence updates the confidence of a relation.
|
||||
func (s *RelationStore) UpdateRelationConfidence(ctx context.Context, relationID int64, newConfidence float64) error {
|
||||
result := s.db.WithContext(ctx).
|
||||
Model(&ObservationRelation{}).
|
||||
Where("id = ?", relationID).
|
||||
Update("confidence", newConfidence)
|
||||
|
||||
return result.Error
|
||||
}
|
||||
|
||||
// GetRelatedObservationIDs returns IDs of observations related to the given one.
|
||||
// This is useful for expanding search results.
|
||||
// Uses CASE expression for bidirectional ID lookup (GORM doesn't support this well, so we use raw SQL).
|
||||
func (s *RelationStore) GetRelatedObservationIDs(ctx context.Context, obsID int64, minConfidence float64) ([]int64, error) {
|
||||
var ids []int64
|
||||
|
||||
err := s.db.WithContext(ctx).
|
||||
Raw("SELECT DISTINCT CASE WHEN source_id = ? THEN target_id ELSE source_id END as related_id "+
|
||||
"FROM observation_relations "+
|
||||
"WHERE (source_id = ? OR target_id = ?) AND confidence >= ?",
|
||||
obsID, obsID, obsID, minConfidence).
|
||||
Pluck("related_id", &ids).Error
|
||||
|
||||
return ids, err
|
||||
}
|
||||
|
||||
// toModelRelation converts a GORM ObservationRelation to a pkg/models ObservationRelation.
|
||||
func toModelRelation(r *ObservationRelation) *models.ObservationRelation {
|
||||
relation := &models.ObservationRelation{
|
||||
ID: r.ID,
|
||||
SourceID: r.SourceID,
|
||||
TargetID: r.TargetID,
|
||||
RelationType: r.RelationType,
|
||||
Confidence: r.Confidence,
|
||||
DetectionSource: r.DetectionSource,
|
||||
CreatedAt: r.CreatedAt,
|
||||
CreatedAtEpoch: r.CreatedAtEpoch,
|
||||
}
|
||||
|
||||
if r.Reason.Valid {
|
||||
relation.Reason = r.Reason.String
|
||||
}
|
||||
|
||||
return relation
|
||||
}
|
||||
|
||||
// toModelRelations converts a slice of GORM ObservationRelations to pkg/models ObservationRelations.
|
||||
func toModelRelations(relations []ObservationRelation) []*models.ObservationRelation {
|
||||
result := make([]*models.ObservationRelation, len(relations))
|
||||
for i, r := range relations {
|
||||
result[i] = toModelRelation(&r)
|
||||
}
|
||||
return result
|
||||
}
|
||||
@@ -0,0 +1,306 @@
|
||||
//go:build fts5
|
||||
|
||||
// Package gorm provides GORM-based database operations for claude-mnemonic.
|
||||
package gorm
|
||||
|
||||
import (
|
||||
"context"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
"gorm.io/gorm/logger"
|
||||
|
||||
"github.com/lukaszraczylo/claude-mnemonic/pkg/models"
|
||||
)
|
||||
|
||||
func testRelationStore(t *testing.T) (*RelationStore, *Store, func()) {
|
||||
t.Helper()
|
||||
|
||||
tmpDir, err := os.MkdirTemp("", "gorm_relation_test_*")
|
||||
if err != nil {
|
||||
t.Fatalf("create temp dir: %v", err)
|
||||
}
|
||||
|
||||
dbPath := filepath.Join(tmpDir, "test.db")
|
||||
cfg := Config{
|
||||
Path: dbPath,
|
||||
MaxConns: 4,
|
||||
LogLevel: logger.Silent,
|
||||
}
|
||||
|
||||
store, err := NewStore(cfg)
|
||||
if err != nil {
|
||||
os.RemoveAll(tmpDir)
|
||||
t.Fatalf("NewStore failed: %v", err)
|
||||
}
|
||||
|
||||
relationStore := NewRelationStore(store)
|
||||
|
||||
cleanup := func() {
|
||||
store.Close()
|
||||
os.RemoveAll(tmpDir)
|
||||
}
|
||||
|
||||
return relationStore, store, cleanup
|
||||
}
|
||||
|
||||
func TestRelationStore_StoreRelation(t *testing.T) {
|
||||
relationStore, _, cleanup := testRelationStore(t)
|
||||
defer cleanup()
|
||||
|
||||
ctx := context.Background()
|
||||
now := time.Now()
|
||||
|
||||
relation := &models.ObservationRelation{
|
||||
SourceID: 1,
|
||||
TargetID: 2,
|
||||
RelationType: models.RelationCauses,
|
||||
Confidence: 0.8,
|
||||
DetectionSource: models.DetectionSourceFileOverlap,
|
||||
Reason: "Test relation",
|
||||
CreatedAt: now.Format(time.RFC3339),
|
||||
CreatedAtEpoch: now.UnixMilli(),
|
||||
}
|
||||
|
||||
id, err := relationStore.StoreRelation(ctx, relation)
|
||||
require.NoError(t, err)
|
||||
assert.Greater(t, id, int64(0))
|
||||
}
|
||||
|
||||
func TestRelationStore_StoreRelation_Idempotency(t *testing.T) {
|
||||
relationStore, _, cleanup := testRelationStore(t)
|
||||
defer cleanup()
|
||||
|
||||
ctx := context.Background()
|
||||
now := time.Now()
|
||||
|
||||
relation := &models.ObservationRelation{
|
||||
SourceID: 1,
|
||||
TargetID: 2,
|
||||
RelationType: models.RelationCauses,
|
||||
Confidence: 0.8,
|
||||
DetectionSource: models.DetectionSourceFileOverlap,
|
||||
CreatedAt: now.Format(time.RFC3339),
|
||||
CreatedAtEpoch: now.UnixMilli(),
|
||||
}
|
||||
|
||||
id1, err := relationStore.StoreRelation(ctx, relation)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Store again with same source/target/type - should return same ID
|
||||
id2, err := relationStore.StoreRelation(ctx, relation)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, id1, id2)
|
||||
}
|
||||
|
||||
func TestRelationStore_StoreRelations(t *testing.T) {
|
||||
relationStore, _, cleanup := testRelationStore(t)
|
||||
defer cleanup()
|
||||
|
||||
ctx := context.Background()
|
||||
now := time.Now()
|
||||
|
||||
relations := []*models.ObservationRelation{
|
||||
{
|
||||
SourceID: 1,
|
||||
TargetID: 2,
|
||||
RelationType: models.RelationCauses,
|
||||
Confidence: 0.8,
|
||||
DetectionSource: models.DetectionSourceFileOverlap,
|
||||
CreatedAt: now.Format(time.RFC3339),
|
||||
CreatedAtEpoch: now.UnixMilli(),
|
||||
},
|
||||
{
|
||||
SourceID: 2,
|
||||
TargetID: 3,
|
||||
RelationType: models.RelationFixes,
|
||||
Confidence: 0.9,
|
||||
DetectionSource: models.DetectionSourceTemporalProximity,
|
||||
CreatedAt: now.Format(time.RFC3339),
|
||||
CreatedAtEpoch: now.UnixMilli(),
|
||||
},
|
||||
}
|
||||
|
||||
err := relationStore.StoreRelations(ctx, relations)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Verify both were stored
|
||||
count, err := relationStore.GetTotalRelationCount(ctx)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, 2, count)
|
||||
}
|
||||
|
||||
func TestRelationStore_GetRelationsByObservationID(t *testing.T) {
|
||||
relationStore, _, cleanup := testRelationStore(t)
|
||||
defer cleanup()
|
||||
|
||||
ctx := context.Background()
|
||||
now := time.Now()
|
||||
|
||||
// Create relations involving observation 2
|
||||
relations := []*models.ObservationRelation{
|
||||
{
|
||||
SourceID: 1,
|
||||
TargetID: 2,
|
||||
RelationType: models.RelationCauses,
|
||||
Confidence: 0.8,
|
||||
DetectionSource: models.DetectionSourceFileOverlap,
|
||||
CreatedAt: now.Format(time.RFC3339),
|
||||
CreatedAtEpoch: now.UnixMilli(),
|
||||
},
|
||||
{
|
||||
SourceID: 2,
|
||||
TargetID: 3,
|
||||
RelationType: models.RelationFixes,
|
||||
Confidence: 0.9,
|
||||
DetectionSource: models.DetectionSourceTemporalProximity,
|
||||
CreatedAt: now.Format(time.RFC3339),
|
||||
CreatedAtEpoch: now.UnixMilli(),
|
||||
},
|
||||
}
|
||||
|
||||
err := relationStore.StoreRelations(ctx, relations)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Get relations for observation 2 (involved in both)
|
||||
result, err := relationStore.GetRelationsByObservationID(ctx, 2)
|
||||
require.NoError(t, err)
|
||||
assert.Len(t, result, 2)
|
||||
}
|
||||
|
||||
func TestRelationStore_GetOutgoingAndIncomingRelations(t *testing.T) {
|
||||
relationStore, _, cleanup := testRelationStore(t)
|
||||
defer cleanup()
|
||||
|
||||
ctx := context.Background()
|
||||
now := time.Now()
|
||||
|
||||
relations := []*models.ObservationRelation{
|
||||
{
|
||||
SourceID: 2,
|
||||
TargetID: 1,
|
||||
RelationType: models.RelationCauses,
|
||||
Confidence: 0.8,
|
||||
DetectionSource: models.DetectionSourceFileOverlap,
|
||||
CreatedAt: now.Format(time.RFC3339),
|
||||
CreatedAtEpoch: now.UnixMilli(),
|
||||
},
|
||||
{
|
||||
SourceID: 3,
|
||||
TargetID: 2,
|
||||
RelationType: models.RelationFixes,
|
||||
Confidence: 0.9,
|
||||
DetectionSource: models.DetectionSourceTemporalProximity,
|
||||
CreatedAt: now.Format(time.RFC3339),
|
||||
CreatedAtEpoch: now.UnixMilli(),
|
||||
},
|
||||
}
|
||||
|
||||
err := relationStore.StoreRelations(ctx, relations)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Observation 2 has 1 outgoing (to 1) and 1 incoming (from 3)
|
||||
outgoing, err := relationStore.GetOutgoingRelations(ctx, 2)
|
||||
require.NoError(t, err)
|
||||
assert.Len(t, outgoing, 1)
|
||||
assert.Equal(t, int64(1), outgoing[0].TargetID)
|
||||
|
||||
incoming, err := relationStore.GetIncomingRelations(ctx, 2)
|
||||
require.NoError(t, err)
|
||||
assert.Len(t, incoming, 1)
|
||||
assert.Equal(t, int64(3), incoming[0].SourceID)
|
||||
}
|
||||
|
||||
func TestRelationStore_GetRelationCount(t *testing.T) {
|
||||
relationStore, _, cleanup := testRelationStore(t)
|
||||
defer cleanup()
|
||||
|
||||
ctx := context.Background()
|
||||
now := time.Now()
|
||||
|
||||
relations := []*models.ObservationRelation{
|
||||
{
|
||||
SourceID: 1,
|
||||
TargetID: 2,
|
||||
RelationType: models.RelationCauses,
|
||||
Confidence: 0.8,
|
||||
DetectionSource: models.DetectionSourceFileOverlap,
|
||||
CreatedAt: now.Format(time.RFC3339),
|
||||
CreatedAtEpoch: now.UnixMilli(),
|
||||
},
|
||||
{
|
||||
SourceID: 2,
|
||||
TargetID: 3,
|
||||
RelationType: models.RelationFixes,
|
||||
Confidence: 0.9,
|
||||
DetectionSource: models.DetectionSourceTemporalProximity,
|
||||
CreatedAt: now.Format(time.RFC3339),
|
||||
CreatedAtEpoch: now.UnixMilli(),
|
||||
},
|
||||
}
|
||||
|
||||
err := relationStore.StoreRelations(ctx, relations)
|
||||
require.NoError(t, err)
|
||||
|
||||
count, err := relationStore.GetRelationCount(ctx, 2)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, 2, count)
|
||||
|
||||
count, err = relationStore.GetRelationCount(ctx, 1)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, 1, count)
|
||||
}
|
||||
|
||||
func TestRelationStore_DeleteRelationsByObservationID(t *testing.T) {
|
||||
relationStore, _, cleanup := testRelationStore(t)
|
||||
defer cleanup()
|
||||
|
||||
ctx := context.Background()
|
||||
now := time.Now()
|
||||
|
||||
relations := []*models.ObservationRelation{
|
||||
{
|
||||
SourceID: 1,
|
||||
TargetID: 2,
|
||||
RelationType: models.RelationCauses,
|
||||
Confidence: 0.8,
|
||||
DetectionSource: models.DetectionSourceFileOverlap,
|
||||
CreatedAt: now.Format(time.RFC3339),
|
||||
CreatedAtEpoch: now.UnixMilli(),
|
||||
},
|
||||
{
|
||||
SourceID: 2,
|
||||
TargetID: 3,
|
||||
RelationType: models.RelationFixes,
|
||||
Confidence: 0.9,
|
||||
DetectionSource: models.DetectionSourceTemporalProximity,
|
||||
CreatedAt: now.Format(time.RFC3339),
|
||||
CreatedAtEpoch: now.UnixMilli(),
|
||||
},
|
||||
{
|
||||
SourceID: 4,
|
||||
TargetID: 5,
|
||||
RelationType: models.RelationRelatesTo,
|
||||
Confidence: 0.7,
|
||||
DetectionSource: models.DetectionSourceConceptOverlap,
|
||||
CreatedAt: now.Format(time.RFC3339),
|
||||
CreatedAtEpoch: now.UnixMilli(),
|
||||
},
|
||||
}
|
||||
|
||||
err := relationStore.StoreRelations(ctx, relations)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Delete relations involving observation 2
|
||||
err = relationStore.DeleteRelationsByObservationID(ctx, 2)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Verify only 1 relation remains (4->5)
|
||||
total, err := relationStore.GetTotalRelationCount(ctx)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, 1, total)
|
||||
}
|
||||
@@ -0,0 +1,260 @@
|
||||
// Package gorm provides GORM-based database operations for claude-mnemonic.
|
||||
package gorm
|
||||
|
||||
import (
|
||||
"context"
|
||||
"time"
|
||||
|
||||
"gorm.io/gorm"
|
||||
|
||||
"github.com/lukaszraczylo/claude-mnemonic/pkg/models"
|
||||
)
|
||||
|
||||
// UpdateObservationFeedback updates the user feedback for an observation.
|
||||
// Feedback values: -1 (thumbs down), 0 (neutral), 1 (thumbs up).
|
||||
func (s *ObservationStore) UpdateObservationFeedback(ctx context.Context, id int64, feedback int) error {
|
||||
now := time.Now().UnixMilli()
|
||||
|
||||
result := s.db.WithContext(ctx).
|
||||
Model(&Observation{}).
|
||||
Where("id = ?", id).
|
||||
Updates(map[string]interface{}{
|
||||
"user_feedback": feedback,
|
||||
"score_updated_at_epoch": now,
|
||||
})
|
||||
|
||||
return result.Error
|
||||
}
|
||||
|
||||
// IncrementRetrievalCount increments the retrieval counter for the given observation IDs.
|
||||
// This is called when observations are returned in search results.
|
||||
func (s *ObservationStore) IncrementRetrievalCount(ctx context.Context, ids []int64) error {
|
||||
if len(ids) == 0 {
|
||||
return nil
|
||||
}
|
||||
|
||||
now := time.Now().UnixMilli()
|
||||
|
||||
// Use raw SQL for increment expression
|
||||
result := s.db.WithContext(ctx).
|
||||
Exec("UPDATE observations SET retrieval_count = COALESCE(retrieval_count, 0) + 1, last_retrieved_at_epoch = ? WHERE id IN ?",
|
||||
now, ids)
|
||||
|
||||
return result.Error
|
||||
}
|
||||
|
||||
// UpdateImportanceScore updates the importance score for a single observation.
|
||||
func (s *ObservationStore) UpdateImportanceScore(ctx context.Context, id int64, score float64) error {
|
||||
now := time.Now().UnixMilli()
|
||||
|
||||
result := s.db.WithContext(ctx).
|
||||
Model(&Observation{}).
|
||||
Where("id = ?", id).
|
||||
Updates(map[string]interface{}{
|
||||
"importance_score": score,
|
||||
"score_updated_at_epoch": now,
|
||||
})
|
||||
|
||||
return result.Error
|
||||
}
|
||||
|
||||
// UpdateImportanceScores bulk updates importance scores for multiple observations.
|
||||
// This is more efficient than individual updates for batch recalculation.
|
||||
func (s *ObservationStore) UpdateImportanceScores(ctx context.Context, scores map[int64]float64) error {
|
||||
if len(scores) == 0 {
|
||||
return nil
|
||||
}
|
||||
|
||||
return s.db.WithContext(ctx).Transaction(func(tx *gorm.DB) error {
|
||||
now := time.Now().UnixMilli()
|
||||
|
||||
for id, score := range scores {
|
||||
err := tx.Model(&Observation{}).
|
||||
Where("id = ?", id).
|
||||
Updates(map[string]interface{}{
|
||||
"importance_score": score,
|
||||
"score_updated_at_epoch": now,
|
||||
}).Error
|
||||
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
})
|
||||
}
|
||||
|
||||
// GetObservationsNeedingScoreUpdate returns observations that need their importance score recalculated.
|
||||
// Returns observations where score_updated_at_epoch is NULL or older than the threshold.
|
||||
func (s *ObservationStore) GetObservationsNeedingScoreUpdate(ctx context.Context, threshold time.Duration, limit int) ([]*models.Observation, error) {
|
||||
cutoff := time.Now().Add(-threshold).UnixMilli()
|
||||
|
||||
var observations []Observation
|
||||
|
||||
err := s.db.WithContext(ctx).
|
||||
Where("score_updated_at_epoch IS NULL OR score_updated_at_epoch < ?", cutoff).
|
||||
Order("created_at_epoch DESC").
|
||||
Limit(limit).
|
||||
Find(&observations).Error
|
||||
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return toModelObservations(observations), nil
|
||||
}
|
||||
|
||||
// GetConceptWeights returns all concept weights from the database.
|
||||
func (s *ObservationStore) GetConceptWeights(ctx context.Context) (map[string]float64, error) {
|
||||
var weights []struct {
|
||||
Concept string
|
||||
Weight float64
|
||||
}
|
||||
|
||||
err := s.db.WithContext(ctx).
|
||||
Table("concept_weights").
|
||||
Select("concept, weight").
|
||||
Scan(&weights).Error
|
||||
|
||||
if err != nil {
|
||||
return models.DefaultConceptWeights, nil
|
||||
}
|
||||
|
||||
if len(weights) == 0 {
|
||||
return models.DefaultConceptWeights, nil
|
||||
}
|
||||
|
||||
result := make(map[string]float64, len(weights))
|
||||
for _, w := range weights {
|
||||
result[w.Concept] = w.Weight
|
||||
}
|
||||
|
||||
return result, nil
|
||||
}
|
||||
|
||||
// SetConceptWeights stores concept weights in the database using UPSERT.
|
||||
func (s *ObservationStore) SetConceptWeights(ctx context.Context, weights map[string]float64) error {
|
||||
if len(weights) == 0 {
|
||||
return nil
|
||||
}
|
||||
|
||||
return s.db.WithContext(ctx).Transaction(func(tx *gorm.DB) error {
|
||||
for concept, weight := range weights {
|
||||
// UPSERT using raw SQL since GORM's ON CONFLICT is complex for this case
|
||||
err := tx.Exec(`
|
||||
INSERT INTO concept_weights (concept, weight, updated_at)
|
||||
VALUES (?, ?, datetime('now'))
|
||||
ON CONFLICT(concept) DO UPDATE SET weight = excluded.weight, updated_at = excluded.updated_at
|
||||
`, concept, weight).Error
|
||||
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
})
|
||||
}
|
||||
|
||||
// UpdateConceptWeight updates a single concept weight in the database using UPSERT.
|
||||
func (s *ObservationStore) UpdateConceptWeight(ctx context.Context, concept string, weight float64) error {
|
||||
return s.db.WithContext(ctx).Exec(`
|
||||
INSERT INTO concept_weights (concept, weight, updated_at)
|
||||
VALUES (?, ?, datetime('now'))
|
||||
ON CONFLICT(concept) DO UPDATE SET weight = excluded.weight, updated_at = excluded.updated_at
|
||||
`, concept, weight).Error
|
||||
}
|
||||
|
||||
// FeedbackStats contains statistics about observation feedback and scoring.
|
||||
type FeedbackStats struct {
|
||||
Total int `json:"total"`
|
||||
Positive int `json:"positive"`
|
||||
Negative int `json:"negative"`
|
||||
Neutral int `json:"neutral"`
|
||||
AvgScore float64 `json:"avg_score"`
|
||||
AvgRetrieval float64 `json:"avg_retrieval"`
|
||||
}
|
||||
|
||||
// GetObservationFeedbackStats returns statistics about user feedback.
|
||||
func (s *ObservationStore) GetObservationFeedbackStats(ctx context.Context, project string) (*FeedbackStats, error) {
|
||||
var stats FeedbackStats
|
||||
|
||||
query := s.db.WithContext(ctx).
|
||||
Model(&Observation{}).
|
||||
Select(`
|
||||
COUNT(*) as total,
|
||||
COALESCE(SUM(CASE WHEN user_feedback = 1 THEN 1 ELSE 0 END), 0) as positive,
|
||||
COALESCE(SUM(CASE WHEN user_feedback = -1 THEN 1 ELSE 0 END), 0) as negative,
|
||||
COALESCE(SUM(CASE WHEN user_feedback = 0 THEN 1 ELSE 0 END), 0) as neutral,
|
||||
COALESCE(AVG(COALESCE(importance_score, 1.0)), 0) as avg_score,
|
||||
COALESCE(AVG(COALESCE(retrieval_count, 0)), 0) as avg_retrieval
|
||||
`)
|
||||
|
||||
if project != "" {
|
||||
query = query.Where("project = ? OR scope = 'global'", project)
|
||||
}
|
||||
|
||||
err := query.Scan(&stats).Error
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return &stats, nil
|
||||
}
|
||||
|
||||
// GetTopScoringObservations returns the highest-scoring observations.
|
||||
func (s *ObservationStore) GetTopScoringObservations(ctx context.Context, project string, limit int) ([]*models.Observation, error) {
|
||||
var observations []Observation
|
||||
|
||||
query := s.db.WithContext(ctx).
|
||||
Order("COALESCE(importance_score, 1.0) DESC, created_at_epoch DESC").
|
||||
Limit(limit)
|
||||
|
||||
if project != "" {
|
||||
query = query.Where("project = ? OR scope = 'global'", project)
|
||||
}
|
||||
|
||||
err := query.Find(&observations).Error
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return toModelObservations(observations), nil
|
||||
}
|
||||
|
||||
// GetMostRetrievedObservations returns the most frequently retrieved observations.
|
||||
func (s *ObservationStore) GetMostRetrievedObservations(ctx context.Context, project string, limit int) ([]*models.Observation, error) {
|
||||
var observations []Observation
|
||||
|
||||
query := s.db.WithContext(ctx).
|
||||
Where("retrieval_count > 0").
|
||||
Order("retrieval_count DESC, created_at_epoch DESC").
|
||||
Limit(limit)
|
||||
|
||||
if project != "" {
|
||||
query = query.Where("project = ? OR scope = 'global'", project)
|
||||
}
|
||||
|
||||
err := query.Find(&observations).Error
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return toModelObservations(observations), nil
|
||||
}
|
||||
|
||||
// ResetObservationScores resets all observation scores to their default values.
|
||||
// This is useful for testing or when changing the scoring algorithm.
|
||||
func (s *ObservationStore) ResetObservationScores(ctx context.Context) error {
|
||||
// Use Where("1 = 1") to explicitly allow bulk update of all rows
|
||||
result := s.db.WithContext(ctx).
|
||||
Model(&Observation{}).
|
||||
Where("1 = 1").
|
||||
Updates(map[string]interface{}{
|
||||
"importance_score": 1.0,
|
||||
"score_updated_at_epoch": nil,
|
||||
})
|
||||
|
||||
return result.Error
|
||||
}
|
||||
@@ -0,0 +1,355 @@
|
||||
//go:build fts5
|
||||
|
||||
// Package gorm provides GORM-based database operations for claude-mnemonic.
|
||||
package gorm
|
||||
|
||||
import (
|
||||
"context"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
|
||||
"github.com/lukaszraczylo/claude-mnemonic/pkg/models"
|
||||
)
|
||||
|
||||
func TestObservationStore_UpdateImportanceScore(t *testing.T) {
|
||||
obsStore, store, cleanup := testObservationStore(t)
|
||||
defer cleanup()
|
||||
ctx := context.Background()
|
||||
|
||||
// Create observation
|
||||
sessionStore := NewSessionStore(store)
|
||||
sessionID, _ := sessionStore.CreateSDKSession(ctx, "claude-1", "test-project", "")
|
||||
obs := &models.ParsedObservation{Type: models.ObsTypeDiscovery, Title: "Test"}
|
||||
obsID, _, _ := obsStore.StoreObservation(ctx, "claude-1", "test-project", obs, int(sessionID), 1)
|
||||
|
||||
// Update score
|
||||
err := obsStore.UpdateImportanceScore(ctx, obsID, 5.0)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Verify
|
||||
var dbObs Observation
|
||||
store.DB.First(&dbObs, obsID)
|
||||
assert.Equal(t, 5.0, dbObs.ImportanceScore)
|
||||
}
|
||||
|
||||
func TestObservationStore_IncrementRetrievalCount(t *testing.T) {
|
||||
obsStore, store, cleanup := testObservationStore(t)
|
||||
defer cleanup()
|
||||
ctx := context.Background()
|
||||
|
||||
sessionStore := NewSessionStore(store)
|
||||
sessionID, _ := sessionStore.CreateSDKSession(ctx, "claude-1", "test-project", "")
|
||||
obs := &models.ParsedObservation{Type: models.ObsTypeDiscovery, Title: "Test"}
|
||||
obsID, _, _ := obsStore.StoreObservation(ctx, "claude-1", "test-project", obs, int(sessionID), 1)
|
||||
|
||||
err := obsStore.IncrementRetrievalCount(ctx, []int64{obsID})
|
||||
require.NoError(t, err)
|
||||
|
||||
var dbObs Observation
|
||||
store.DB.First(&dbObs, obsID)
|
||||
assert.Equal(t, 1, dbObs.RetrievalCount)
|
||||
}
|
||||
|
||||
func TestObservationStore_IncrementRetrievalCount_Multiple(t *testing.T) {
|
||||
obsStore, store, cleanup := testObservationStore(t)
|
||||
defer cleanup()
|
||||
ctx := context.Background()
|
||||
|
||||
sessionStore := NewSessionStore(store)
|
||||
sessionID, _ := sessionStore.CreateSDKSession(ctx, "claude-1", "test-project", "")
|
||||
|
||||
// Create 3 observations
|
||||
ids := make([]int64, 3)
|
||||
for i := 0; i < 3; i++ {
|
||||
obs := &models.ParsedObservation{Type: models.ObsTypeDiscovery, Title: "Test"}
|
||||
obsID, _, _ := obsStore.StoreObservation(ctx, "claude-1", "test-project", obs, int(sessionID), 1)
|
||||
ids[i] = obsID
|
||||
}
|
||||
|
||||
// Increment all
|
||||
err := obsStore.IncrementRetrievalCount(ctx, ids)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Verify all were incremented
|
||||
for _, id := range ids {
|
||||
var dbObs Observation
|
||||
store.DB.First(&dbObs, id)
|
||||
assert.Equal(t, 1, dbObs.RetrievalCount)
|
||||
}
|
||||
|
||||
// Increment again
|
||||
err = obsStore.IncrementRetrievalCount(ctx, ids)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Verify all are now 2
|
||||
for _, id := range ids {
|
||||
var dbObs Observation
|
||||
store.DB.First(&dbObs, id)
|
||||
assert.Equal(t, 2, dbObs.RetrievalCount)
|
||||
}
|
||||
}
|
||||
|
||||
func TestObservationStore_UpdateImportanceScores_Bulk(t *testing.T) {
|
||||
obsStore, store, cleanup := testObservationStore(t)
|
||||
defer cleanup()
|
||||
ctx := context.Background()
|
||||
|
||||
sessionStore := NewSessionStore(store)
|
||||
sessionID, _ := sessionStore.CreateSDKSession(ctx, "claude-1", "test-project", "")
|
||||
|
||||
// Create 3 observations
|
||||
ids := make([]int64, 3)
|
||||
for i := 0; i < 3; i++ {
|
||||
obs := &models.ParsedObservation{Type: models.ObsTypeDiscovery, Title: "Test"}
|
||||
obsID, _, _ := obsStore.StoreObservation(ctx, "claude-1", "test-project", obs, int(sessionID), 1)
|
||||
ids[i] = obsID
|
||||
}
|
||||
|
||||
// Bulk update scores
|
||||
scores := map[int64]float64{
|
||||
ids[0]: 2.5,
|
||||
ids[1]: 3.7,
|
||||
ids[2]: 1.2,
|
||||
}
|
||||
|
||||
err := obsStore.UpdateImportanceScores(ctx, scores)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Verify scores
|
||||
for id, expectedScore := range scores {
|
||||
var dbObs Observation
|
||||
store.DB.First(&dbObs, id)
|
||||
assert.Equal(t, expectedScore, dbObs.ImportanceScore)
|
||||
}
|
||||
}
|
||||
|
||||
func TestObservationStore_UpdateObservationFeedback(t *testing.T) {
|
||||
obsStore, store, cleanup := testObservationStore(t)
|
||||
defer cleanup()
|
||||
ctx := context.Background()
|
||||
|
||||
sessionStore := NewSessionStore(store)
|
||||
sessionID, _ := sessionStore.CreateSDKSession(ctx, "claude-1", "test-project", "")
|
||||
obs := &models.ParsedObservation{Type: models.ObsTypeDiscovery, Title: "Test"}
|
||||
obsID, _, _ := obsStore.StoreObservation(ctx, "claude-1", "test-project", obs, int(sessionID), 1)
|
||||
|
||||
// Set thumbs up
|
||||
err := obsStore.UpdateObservationFeedback(ctx, obsID, 1)
|
||||
require.NoError(t, err)
|
||||
|
||||
var dbObs Observation
|
||||
store.DB.First(&dbObs, obsID)
|
||||
assert.Equal(t, 1, dbObs.UserFeedback)
|
||||
|
||||
// Set thumbs down
|
||||
err = obsStore.UpdateObservationFeedback(ctx, obsID, -1)
|
||||
require.NoError(t, err)
|
||||
|
||||
store.DB.First(&dbObs, obsID)
|
||||
assert.Equal(t, -1, dbObs.UserFeedback)
|
||||
}
|
||||
|
||||
func TestObservationStore_GetObservationFeedbackStats(t *testing.T) {
|
||||
obsStore, store, cleanup := testObservationStore(t)
|
||||
defer cleanup()
|
||||
ctx := context.Background()
|
||||
|
||||
sessionStore := NewSessionStore(store)
|
||||
sessionID, _ := sessionStore.CreateSDKSession(ctx, "claude-1", "test-project", "")
|
||||
|
||||
// Create observations with different feedback
|
||||
obs1 := &models.ParsedObservation{Type: models.ObsTypeDiscovery, Title: "Test1"}
|
||||
obsID1, _, _ := obsStore.StoreObservation(ctx, "claude-1", "test-project", obs1, int(sessionID), 1)
|
||||
obsStore.UpdateObservationFeedback(ctx, obsID1, 1) // thumbs up
|
||||
obsStore.UpdateImportanceScore(ctx, obsID1, 3.0)
|
||||
|
||||
obs2 := &models.ParsedObservation{Type: models.ObsTypeBugfix, Title: "Test2"}
|
||||
obsID2, _, _ := obsStore.StoreObservation(ctx, "claude-1", "test-project", obs2, int(sessionID), 2)
|
||||
obsStore.UpdateObservationFeedback(ctx, obsID2, -1) // thumbs down
|
||||
obsStore.UpdateImportanceScore(ctx, obsID2, 2.0)
|
||||
|
||||
obs3 := &models.ParsedObservation{Type: models.ObsTypeFeature, Title: "Test3"}
|
||||
obsID3, _, _ := obsStore.StoreObservation(ctx, "claude-1", "test-project", obs3, int(sessionID), 3)
|
||||
// neutral (0)
|
||||
obsStore.UpdateImportanceScore(ctx, obsID3, 1.5)
|
||||
obsStore.IncrementRetrievalCount(ctx, []int64{obsID1, obsID2, obsID3})
|
||||
|
||||
// Get stats
|
||||
stats, err := obsStore.GetObservationFeedbackStats(ctx, "test-project")
|
||||
require.NoError(t, err)
|
||||
|
||||
assert.Equal(t, 3, stats.Total)
|
||||
assert.Equal(t, 1, stats.Positive)
|
||||
assert.Equal(t, 1, stats.Negative)
|
||||
assert.Equal(t, 1, stats.Neutral)
|
||||
assert.InDelta(t, 2.166, stats.AvgScore, 0.01) // (3.0 + 2.0 + 1.5) / 3
|
||||
assert.InDelta(t, 1.0, stats.AvgRetrieval, 0.01)
|
||||
}
|
||||
|
||||
func TestObservationStore_GetTopScoringObservations(t *testing.T) {
|
||||
obsStore, store, cleanup := testObservationStore(t)
|
||||
defer cleanup()
|
||||
ctx := context.Background()
|
||||
|
||||
sessionStore := NewSessionStore(store)
|
||||
sessionID, _ := sessionStore.CreateSDKSession(ctx, "claude-1", "test-project", "")
|
||||
|
||||
// Create observations with different scores
|
||||
obs1 := &models.ParsedObservation{Type: models.ObsTypeDiscovery, Title: "High"}
|
||||
obsID1, _, _ := obsStore.StoreObservation(ctx, "claude-1", "test-project", obs1, int(sessionID), 1)
|
||||
obsStore.UpdateImportanceScore(ctx, obsID1, 5.0)
|
||||
|
||||
obs2 := &models.ParsedObservation{Type: models.ObsTypeBugfix, Title: "Low"}
|
||||
obsID2, _, _ := obsStore.StoreObservation(ctx, "claude-1", "test-project", obs2, int(sessionID), 2)
|
||||
obsStore.UpdateImportanceScore(ctx, obsID2, 1.0)
|
||||
|
||||
obs3 := &models.ParsedObservation{Type: models.ObsTypeFeature, Title: "Medium"}
|
||||
obsID3, _, _ := obsStore.StoreObservation(ctx, "claude-1", "test-project", obs3, int(sessionID), 3)
|
||||
obsStore.UpdateImportanceScore(ctx, obsID3, 3.0)
|
||||
|
||||
// Get top 2
|
||||
topObs, err := obsStore.GetTopScoringObservations(ctx, "test-project", 2)
|
||||
require.NoError(t, err)
|
||||
|
||||
require.Len(t, topObs, 2)
|
||||
assert.True(t, topObs[0].Title.Valid)
|
||||
assert.Equal(t, "High", topObs[0].Title.String)
|
||||
assert.Equal(t, 5.0, topObs[0].ImportanceScore)
|
||||
assert.True(t, topObs[1].Title.Valid)
|
||||
assert.Equal(t, "Medium", topObs[1].Title.String)
|
||||
assert.Equal(t, 3.0, topObs[1].ImportanceScore)
|
||||
}
|
||||
|
||||
func TestObservationStore_GetMostRetrievedObservations(t *testing.T) {
|
||||
obsStore, store, cleanup := testObservationStore(t)
|
||||
defer cleanup()
|
||||
ctx := context.Background()
|
||||
|
||||
sessionStore := NewSessionStore(store)
|
||||
sessionID, _ := sessionStore.CreateSDKSession(ctx, "claude-1", "test-project", "")
|
||||
|
||||
// Create observations
|
||||
obs1 := &models.ParsedObservation{Type: models.ObsTypeDiscovery, Title: "Popular"}
|
||||
obsID1, _, _ := obsStore.StoreObservation(ctx, "claude-1", "test-project", obs1, int(sessionID), 1)
|
||||
|
||||
obs2 := &models.ParsedObservation{Type: models.ObsTypeBugfix, Title: "Unpopular"}
|
||||
obsID2, _, _ := obsStore.StoreObservation(ctx, "claude-1", "test-project", obs2, int(sessionID), 2)
|
||||
|
||||
// Increment retrieval counts - each call increments by 1
|
||||
obsStore.IncrementRetrievalCount(ctx, []int64{obsID1}) // increment by 1
|
||||
obsStore.IncrementRetrievalCount(ctx, []int64{obsID1}) // increment by 1 again (total: 2)
|
||||
obsStore.IncrementRetrievalCount(ctx, []int64{obsID1}) // increment by 1 again (total: 3)
|
||||
obsStore.IncrementRetrievalCount(ctx, []int64{obsID2}) // increment by 1 (total: 1)
|
||||
|
||||
// Get most retrieved - should return obsID1 (Popular) with count 3
|
||||
topObs, err := obsStore.GetMostRetrievedObservations(ctx, "test-project", 2)
|
||||
require.NoError(t, err)
|
||||
|
||||
require.Len(t, topObs, 2)
|
||||
assert.True(t, topObs[0].Title.Valid)
|
||||
assert.Equal(t, "Popular", topObs[0].Title.String)
|
||||
assert.Equal(t, 3, topObs[0].RetrievalCount)
|
||||
assert.True(t, topObs[1].Title.Valid)
|
||||
assert.Equal(t, "Unpopular", topObs[1].Title.String)
|
||||
assert.Equal(t, 1, topObs[1].RetrievalCount)
|
||||
}
|
||||
|
||||
func TestObservationStore_SetConceptWeights(t *testing.T) {
|
||||
obsStore, _, cleanup := testObservationStore(t)
|
||||
defer cleanup()
|
||||
ctx := context.Background()
|
||||
|
||||
// Set weights
|
||||
weights := map[string]float64{
|
||||
"security": 2.0,
|
||||
"performance": 1.5,
|
||||
"best-practice": 1.8,
|
||||
}
|
||||
|
||||
err := obsStore.SetConceptWeights(ctx, weights)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Get weights back
|
||||
retrieved, err := obsStore.GetConceptWeights(ctx)
|
||||
require.NoError(t, err)
|
||||
|
||||
assert.Equal(t, 2.0, retrieved["security"])
|
||||
assert.Equal(t, 1.5, retrieved["performance"])
|
||||
assert.Equal(t, 1.8, retrieved["best-practice"])
|
||||
|
||||
// Update weights (UPSERT)
|
||||
weights["security"] = 2.5
|
||||
weights["scalability"] = 1.2
|
||||
|
||||
err = obsStore.SetConceptWeights(ctx, weights)
|
||||
require.NoError(t, err)
|
||||
|
||||
retrieved, err = obsStore.GetConceptWeights(ctx)
|
||||
require.NoError(t, err)
|
||||
|
||||
assert.Equal(t, 2.5, retrieved["security"]) // updated
|
||||
assert.Equal(t, 1.5, retrieved["performance"]) // unchanged
|
||||
assert.Equal(t, 1.2, retrieved["scalability"]) // new
|
||||
assert.Equal(t, 1.8, retrieved["best-practice"]) // unchanged
|
||||
}
|
||||
|
||||
func TestObservationStore_GetObservationsNeedingScoreUpdate(t *testing.T) {
|
||||
obsStore, store, cleanup := testObservationStore(t)
|
||||
defer cleanup()
|
||||
ctx := context.Background()
|
||||
|
||||
sessionStore := NewSessionStore(store)
|
||||
sessionID, _ := sessionStore.CreateSDKSession(ctx, "claude-1", "test-project", "")
|
||||
|
||||
// Create observation with no score update
|
||||
obs1 := &models.ParsedObservation{Type: models.ObsTypeDiscovery, Title: "Needs Update"}
|
||||
obsID1, _, _ := obsStore.StoreObservation(ctx, "claude-1", "test-project", obs1, int(sessionID), 1)
|
||||
|
||||
// Create observation with recent score update
|
||||
obs2 := &models.ParsedObservation{Type: models.ObsTypeBugfix, Title: "Recently Updated"}
|
||||
obsID2, _, _ := obsStore.StoreObservation(ctx, "claude-1", "test-project", obs2, int(sessionID), 2)
|
||||
obsStore.UpdateImportanceScore(ctx, obsID2, 2.0)
|
||||
|
||||
// Get observations needing update (within 1 hour threshold)
|
||||
needsUpdate, err := obsStore.GetObservationsNeedingScoreUpdate(ctx, 1*time.Hour, 10)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Only obs1 should need update (obs2 was just updated)
|
||||
assert.Len(t, needsUpdate, 1)
|
||||
assert.Equal(t, obsID1, needsUpdate[0].ID)
|
||||
assert.True(t, needsUpdate[0].Title.Valid)
|
||||
assert.Equal(t, "Needs Update", needsUpdate[0].Title.String)
|
||||
}
|
||||
|
||||
func TestObservationStore_ResetObservationScores(t *testing.T) {
|
||||
obsStore, store, cleanup := testObservationStore(t)
|
||||
defer cleanup()
|
||||
ctx := context.Background()
|
||||
|
||||
sessionStore := NewSessionStore(store)
|
||||
sessionID, _ := sessionStore.CreateSDKSession(ctx, "claude-1", "test-project", "")
|
||||
|
||||
// Create observations with custom scores
|
||||
obs1 := &models.ParsedObservation{Type: models.ObsTypeDiscovery, Title: "Test1"}
|
||||
obsID1, _, _ := obsStore.StoreObservation(ctx, "claude-1", "test-project", obs1, int(sessionID), 1)
|
||||
obsStore.UpdateImportanceScore(ctx, obsID1, 5.0)
|
||||
|
||||
obs2 := &models.ParsedObservation{Type: models.ObsTypeBugfix, Title: "Test2"}
|
||||
obsID2, _, _ := obsStore.StoreObservation(ctx, "claude-1", "test-project", obs2, int(sessionID), 2)
|
||||
obsStore.UpdateImportanceScore(ctx, obsID2, 3.0)
|
||||
|
||||
// Reset all scores
|
||||
err := obsStore.ResetObservationScores(ctx)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Verify all scores are 1.0
|
||||
var dbObs1, dbObs2 Observation
|
||||
store.DB.First(&dbObs1, obsID1)
|
||||
store.DB.First(&dbObs2, obsID2)
|
||||
|
||||
assert.Equal(t, 1.0, dbObs1.ImportanceScore)
|
||||
assert.Equal(t, 1.0, dbObs2.ImportanceScore)
|
||||
}
|
||||
@@ -0,0 +1,198 @@
|
||||
// Package gorm provides GORM-based database operations for claude-mnemonic.
|
||||
package gorm
|
||||
|
||||
import (
|
||||
"context"
|
||||
"database/sql"
|
||||
"time"
|
||||
|
||||
"gorm.io/gorm"
|
||||
"gorm.io/gorm/clause"
|
||||
|
||||
"github.com/lukaszraczylo/claude-mnemonic/pkg/models"
|
||||
)
|
||||
|
||||
// SessionStore provides session-related database operations using GORM.
|
||||
type SessionStore struct {
|
||||
db *gorm.DB
|
||||
}
|
||||
|
||||
// NewSessionStore creates a new session store.
|
||||
func NewSessionStore(store *Store) *SessionStore {
|
||||
return &SessionStore{db: store.DB}
|
||||
}
|
||||
|
||||
// CreateSDKSession creates a new SDK session (idempotent - returns existing ID if exists).
|
||||
// This is the KEY to how claude-mnemonic stays unified across hooks.
|
||||
func (s *SessionStore) CreateSDKSession(ctx context.Context, claudeSessionID, project, userPrompt string) (int64, error) {
|
||||
now := time.Now()
|
||||
|
||||
session := &SDKSession{
|
||||
ClaudeSessionID: claudeSessionID,
|
||||
SDKSessionID: func() sql.NullString {
|
||||
return sql.NullString{String: claudeSessionID, Valid: true}
|
||||
}(),
|
||||
Project: project,
|
||||
UserPrompt: func() sql.NullString {
|
||||
if userPrompt != "" {
|
||||
return sql.NullString{String: userPrompt, Valid: true}
|
||||
}
|
||||
return sql.NullString{Valid: false}
|
||||
}(),
|
||||
Status: "active",
|
||||
StartedAt: now.Format(time.RFC3339),
|
||||
StartedAtEpoch: now.UnixMilli(),
|
||||
}
|
||||
|
||||
// CRITICAL: INSERT OR IGNORE makes this idempotent
|
||||
// Use OnConflict with DoNothing to achieve INSERT OR IGNORE behavior
|
||||
result := s.db.WithContext(ctx).
|
||||
Clauses(clause.OnConflict{
|
||||
Columns: []clause.Column{{Name: "claude_session_id"}},
|
||||
DoNothing: true,
|
||||
}).
|
||||
Create(session)
|
||||
|
||||
if result.Error != nil {
|
||||
return 0, result.Error
|
||||
}
|
||||
|
||||
// Check if insert happened
|
||||
if result.RowsAffected == 0 {
|
||||
// Session exists - UPDATE project and user_prompt if we have non-empty values
|
||||
if project != "" {
|
||||
updates := map[string]interface{}{
|
||||
"project": project,
|
||||
}
|
||||
if userPrompt != "" {
|
||||
updates["user_prompt"] = userPrompt
|
||||
}
|
||||
s.db.WithContext(ctx).
|
||||
Model(&SDKSession{}).
|
||||
Where("claude_session_id = ?", claudeSessionID).
|
||||
Updates(updates)
|
||||
}
|
||||
|
||||
// Fetch existing session
|
||||
var existing SDKSession
|
||||
err := s.db.WithContext(ctx).
|
||||
Where("claude_session_id = ?", claudeSessionID).
|
||||
First(&existing).Error
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
return existing.ID, nil
|
||||
}
|
||||
|
||||
return session.ID, nil
|
||||
}
|
||||
|
||||
// GetSessionByID retrieves a session by its database ID.
|
||||
func (s *SessionStore) GetSessionByID(ctx context.Context, id int64) (*models.SDKSession, error) {
|
||||
var sess SDKSession
|
||||
err := s.db.WithContext(ctx).First(&sess, id).Error
|
||||
if err == gorm.ErrRecordNotFound {
|
||||
return nil, nil
|
||||
}
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return toModelSDKSession(&sess), nil
|
||||
}
|
||||
|
||||
// FindAnySDKSession finds any session by Claude session ID (any status).
|
||||
func (s *SessionStore) FindAnySDKSession(ctx context.Context, claudeSessionID string) (*models.SDKSession, error) {
|
||||
var sess SDKSession
|
||||
err := s.db.WithContext(ctx).
|
||||
Where("claude_session_id = ?", claudeSessionID).
|
||||
First(&sess).Error
|
||||
if err == gorm.ErrRecordNotFound {
|
||||
return nil, nil
|
||||
}
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return toModelSDKSession(&sess), nil
|
||||
}
|
||||
|
||||
// IncrementPromptCounter increments the prompt counter and returns the new value.
|
||||
func (s *SessionStore) IncrementPromptCounter(ctx context.Context, id int64) (int, error) {
|
||||
// Atomic increment using GORM expression
|
||||
err := s.db.WithContext(ctx).
|
||||
Model(&SDKSession{}).
|
||||
Where("id = ?", id).
|
||||
Update("prompt_counter", gorm.Expr("COALESCE(prompt_counter, 0) + 1")).Error
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
|
||||
// Fetch updated value
|
||||
var sess SDKSession
|
||||
err = s.db.WithContext(ctx).
|
||||
Select("prompt_counter").
|
||||
First(&sess, id).Error
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
|
||||
return sess.PromptCounter, nil
|
||||
}
|
||||
|
||||
// GetPromptCounter returns the current prompt counter for a session.
|
||||
func (s *SessionStore) GetPromptCounter(ctx context.Context, id int64) (int, error) {
|
||||
var sess SDKSession
|
||||
err := s.db.WithContext(ctx).
|
||||
Select("prompt_counter").
|
||||
First(&sess, id).Error
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
return sess.PromptCounter, nil
|
||||
}
|
||||
|
||||
// GetSessionsToday returns the count of sessions started today.
|
||||
func (s *SessionStore) GetSessionsToday(ctx context.Context) (int, error) {
|
||||
// Get start of today in milliseconds
|
||||
now := time.Now()
|
||||
startOfDay := time.Date(now.Year(), now.Month(), now.Day(), 0, 0, 0, 0, now.Location())
|
||||
startEpoch := startOfDay.UnixMilli()
|
||||
|
||||
var count int64
|
||||
err := s.db.WithContext(ctx).
|
||||
Model(&SDKSession{}).
|
||||
Where("started_at_epoch >= ?", startEpoch).
|
||||
Count(&count).Error
|
||||
|
||||
return int(count), err
|
||||
}
|
||||
|
||||
// GetAllProjects returns all unique project names.
|
||||
func (s *SessionStore) GetAllProjects(ctx context.Context) ([]string, error) {
|
||||
var projects []string
|
||||
err := s.db.WithContext(ctx).
|
||||
Model(&SDKSession{}).
|
||||
Distinct("project").
|
||||
Where("project IS NOT NULL AND project != ''").
|
||||
Order("project ASC").
|
||||
Pluck("project", &projects).Error
|
||||
|
||||
return projects, err
|
||||
}
|
||||
|
||||
// toModelSDKSession converts a GORM SDKSession to pkg/models.SDKSession.
|
||||
func toModelSDKSession(sess *SDKSession) *models.SDKSession {
|
||||
return &models.SDKSession{
|
||||
ID: sess.ID,
|
||||
ClaudeSessionID: sess.ClaudeSessionID,
|
||||
SDKSessionID: sess.SDKSessionID,
|
||||
Project: sess.Project,
|
||||
UserPrompt: sess.UserPrompt,
|
||||
WorkerPort: sess.WorkerPort,
|
||||
PromptCounter: int64(sess.PromptCounter),
|
||||
Status: models.SessionStatus(sess.Status),
|
||||
StartedAt: sess.StartedAt,
|
||||
StartedAtEpoch: sess.StartedAtEpoch,
|
||||
CompletedAt: sess.CompletedAt,
|
||||
CompletedAtEpoch: sess.CompletedAtEpoch,
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,259 @@
|
||||
//go:build fts5
|
||||
|
||||
// Package gorm provides GORM-based database operations for claude-mnemonic.
|
||||
package gorm
|
||||
|
||||
import (
|
||||
"context"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
"gorm.io/gorm/logger"
|
||||
|
||||
"github.com/lukaszraczylo/claude-mnemonic/pkg/models"
|
||||
)
|
||||
|
||||
// testSessionStore creates a SessionStore with a temporary database for testing.
|
||||
func testSessionStore(t *testing.T) (*SessionStore, *Store, func()) {
|
||||
t.Helper()
|
||||
|
||||
tmpDir, err := os.MkdirTemp("", "gorm_session_test_*")
|
||||
if err != nil {
|
||||
t.Fatalf("create temp dir: %v", err)
|
||||
}
|
||||
|
||||
dbPath := filepath.Join(tmpDir, "test.db")
|
||||
cfg := Config{
|
||||
Path: dbPath,
|
||||
MaxConns: 4,
|
||||
LogLevel: logger.Silent,
|
||||
}
|
||||
|
||||
store, err := NewStore(cfg)
|
||||
if err != nil {
|
||||
os.RemoveAll(tmpDir)
|
||||
t.Fatalf("NewStore failed: %v", err)
|
||||
}
|
||||
|
||||
sessionStore := NewSessionStore(store)
|
||||
|
||||
cleanup := func() {
|
||||
store.Close()
|
||||
os.RemoveAll(tmpDir)
|
||||
}
|
||||
|
||||
return sessionStore, store, cleanup
|
||||
}
|
||||
|
||||
func TestSessionStore_CreateSDKSession(t *testing.T) {
|
||||
sessionStore, _, cleanup := testSessionStore(t)
|
||||
defer cleanup()
|
||||
|
||||
ctx := context.Background()
|
||||
|
||||
// Create a new session
|
||||
id, err := sessionStore.CreateSDKSession(ctx, "claude-1", "test-project", "initial prompt")
|
||||
require.NoError(t, err)
|
||||
assert.Greater(t, id, int64(0))
|
||||
|
||||
// Retrieve and verify
|
||||
sess, err := sessionStore.GetSessionByID(ctx, id)
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, sess)
|
||||
assert.Equal(t, "claude-1", sess.ClaudeSessionID)
|
||||
assert.Equal(t, "test-project", sess.Project)
|
||||
assert.Equal(t, models.SessionStatusActive, sess.Status)
|
||||
assert.True(t, sess.UserPrompt.Valid)
|
||||
assert.Equal(t, "initial prompt", sess.UserPrompt.String)
|
||||
}
|
||||
|
||||
func TestSessionStore_CreateSDKSession_Idempotent(t *testing.T) {
|
||||
sessionStore, _, cleanup := testSessionStore(t)
|
||||
defer cleanup()
|
||||
|
||||
ctx := context.Background()
|
||||
|
||||
// Create first session
|
||||
id1, err := sessionStore.CreateSDKSession(ctx, "claude-1", "project-a", "prompt 1")
|
||||
require.NoError(t, err)
|
||||
|
||||
// Create again with same claude_session_id but different project
|
||||
id2, err := sessionStore.CreateSDKSession(ctx, "claude-1", "project-b", "prompt 2")
|
||||
require.NoError(t, err)
|
||||
|
||||
// Should return same ID (idempotent)
|
||||
assert.Equal(t, id1, id2)
|
||||
|
||||
// Should have updated project to project-b
|
||||
sess, err := sessionStore.GetSessionByID(ctx, id1)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, "project-b", sess.Project)
|
||||
assert.Equal(t, "prompt 2", sess.UserPrompt.String)
|
||||
}
|
||||
|
||||
func TestSessionStore_CreateSDKSession_EmptyPrompt(t *testing.T) {
|
||||
sessionStore, _, cleanup := testSessionStore(t)
|
||||
defer cleanup()
|
||||
|
||||
ctx := context.Background()
|
||||
|
||||
// Create session with empty prompt
|
||||
id, err := sessionStore.CreateSDKSession(ctx, "claude-2", "test-project", "")
|
||||
require.NoError(t, err)
|
||||
assert.Greater(t, id, int64(0))
|
||||
|
||||
// Verify prompt is NULL
|
||||
sess, err := sessionStore.GetSessionByID(ctx, id)
|
||||
require.NoError(t, err)
|
||||
assert.False(t, sess.UserPrompt.Valid)
|
||||
}
|
||||
|
||||
func TestSessionStore_FindAnySDKSession(t *testing.T) {
|
||||
sessionStore, _, cleanup := testSessionStore(t)
|
||||
defer cleanup()
|
||||
|
||||
ctx := context.Background()
|
||||
|
||||
// Create a session
|
||||
_, err := sessionStore.CreateSDKSession(ctx, "claude-1", "test-project", "")
|
||||
require.NoError(t, err)
|
||||
|
||||
// Find it
|
||||
sess, err := sessionStore.FindAnySDKSession(ctx, "claude-1")
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, sess)
|
||||
assert.Equal(t, "claude-1", sess.ClaudeSessionID)
|
||||
|
||||
// Try to find non-existent
|
||||
sess, err = sessionStore.FindAnySDKSession(ctx, "claude-nonexistent")
|
||||
require.NoError(t, err)
|
||||
assert.Nil(t, sess)
|
||||
}
|
||||
|
||||
func TestSessionStore_GetSessionByID_NotFound(t *testing.T) {
|
||||
sessionStore, _, cleanup := testSessionStore(t)
|
||||
defer cleanup()
|
||||
|
||||
ctx := context.Background()
|
||||
|
||||
// Try to get non-existent session
|
||||
sess, err := sessionStore.GetSessionByID(ctx, 99999)
|
||||
require.NoError(t, err)
|
||||
assert.Nil(t, sess)
|
||||
}
|
||||
|
||||
func TestSessionStore_IncrementPromptCounter(t *testing.T) {
|
||||
sessionStore, _, cleanup := testSessionStore(t)
|
||||
defer cleanup()
|
||||
|
||||
ctx := context.Background()
|
||||
|
||||
// Create a session
|
||||
id, err := sessionStore.CreateSDKSession(ctx, "claude-1", "test-project", "")
|
||||
require.NoError(t, err)
|
||||
|
||||
// Initial counter should be 0
|
||||
counter, err := sessionStore.GetPromptCounter(ctx, id)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, 0, counter)
|
||||
|
||||
// Increment
|
||||
counter, err = sessionStore.IncrementPromptCounter(ctx, id)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, 1, counter)
|
||||
|
||||
// Increment again
|
||||
counter, err = sessionStore.IncrementPromptCounter(ctx, id)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, 2, counter)
|
||||
|
||||
// Verify via GetPromptCounter
|
||||
counter, err = sessionStore.GetPromptCounter(ctx, id)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, 2, counter)
|
||||
}
|
||||
|
||||
func TestSessionStore_GetSessionsToday(t *testing.T) {
|
||||
sessionStore, _, cleanup := testSessionStore(t)
|
||||
defer cleanup()
|
||||
|
||||
ctx := context.Background()
|
||||
|
||||
// Initially should be 0
|
||||
count, err := sessionStore.GetSessionsToday(ctx)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, 0, count)
|
||||
|
||||
// Create some sessions
|
||||
_, err = sessionStore.CreateSDKSession(ctx, "claude-1", "project-1", "")
|
||||
require.NoError(t, err)
|
||||
|
||||
_, err = sessionStore.CreateSDKSession(ctx, "claude-2", "project-2", "")
|
||||
require.NoError(t, err)
|
||||
|
||||
// Should now have 2 sessions today
|
||||
count, err = sessionStore.GetSessionsToday(ctx)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, 2, count)
|
||||
}
|
||||
|
||||
func TestSessionStore_GetAllProjects(t *testing.T) {
|
||||
sessionStore, _, cleanup := testSessionStore(t)
|
||||
defer cleanup()
|
||||
|
||||
ctx := context.Background()
|
||||
|
||||
// Initially should be empty
|
||||
projects, err := sessionStore.GetAllProjects(ctx)
|
||||
require.NoError(t, err)
|
||||
assert.Empty(t, projects)
|
||||
|
||||
// Create sessions with different projects
|
||||
_, err = sessionStore.CreateSDKSession(ctx, "claude-1", "project-a", "")
|
||||
require.NoError(t, err)
|
||||
|
||||
_, err = sessionStore.CreateSDKSession(ctx, "claude-2", "project-b", "")
|
||||
require.NoError(t, err)
|
||||
|
||||
_, err = sessionStore.CreateSDKSession(ctx, "claude-3", "project-a", "") // Duplicate project
|
||||
require.NoError(t, err)
|
||||
|
||||
// Should get distinct projects in alphabetical order
|
||||
projects, err = sessionStore.GetAllProjects(ctx)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, []string{"project-a", "project-b"}, projects)
|
||||
}
|
||||
|
||||
func TestSessionStore_SessionFields(t *testing.T) {
|
||||
sessionStore, _, cleanup := testSessionStore(t)
|
||||
defer cleanup()
|
||||
|
||||
ctx := context.Background()
|
||||
|
||||
// Create a session
|
||||
id, err := sessionStore.CreateSDKSession(ctx, "claude-1", "test-project", "test prompt")
|
||||
require.NoError(t, err)
|
||||
|
||||
// Retrieve and verify all fields
|
||||
sess, err := sessionStore.GetSessionByID(ctx, id)
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, sess)
|
||||
|
||||
// Verify all fields
|
||||
assert.Equal(t, id, sess.ID)
|
||||
assert.Equal(t, "claude-1", sess.ClaudeSessionID)
|
||||
assert.True(t, sess.SDKSessionID.Valid)
|
||||
assert.Equal(t, "claude-1", sess.SDKSessionID.String) // Should be same as ClaudeSessionID
|
||||
assert.Equal(t, "test-project", sess.Project)
|
||||
assert.True(t, sess.UserPrompt.Valid)
|
||||
assert.Equal(t, "test prompt", sess.UserPrompt.String)
|
||||
assert.Equal(t, int64(0), sess.PromptCounter)
|
||||
assert.Equal(t, models.SessionStatusActive, sess.Status)
|
||||
assert.NotEmpty(t, sess.StartedAt)
|
||||
assert.Greater(t, sess.StartedAtEpoch, int64(0))
|
||||
assert.False(t, sess.CompletedAt.Valid) // Should be NULL
|
||||
assert.False(t, sess.CompletedAtEpoch.Valid) // Should be NULL
|
||||
}
|
||||
@@ -0,0 +1,8 @@
|
||||
//go:build !sqlite_omit_load_extension
|
||||
// +build !sqlite_omit_load_extension
|
||||
|
||||
// Package gorm provides GORM-based database operations for claude-mnemonic.
|
||||
package gorm
|
||||
|
||||
// This file ensures mattn/go-sqlite3 is built with FTS5 and other extensions enabled.
|
||||
// The build tag ensures extensions are not omitted.
|
||||
@@ -0,0 +1,117 @@
|
||||
// Package gorm provides GORM-based database operations for claude-mnemonic.
|
||||
package gorm
|
||||
|
||||
import (
|
||||
"database/sql"
|
||||
"fmt"
|
||||
|
||||
sqlite_vec "github.com/asg017/sqlite-vec-go-bindings/cgo"
|
||||
_ "github.com/mattn/go-sqlite3" // Import SQLite driver with FTS5 support
|
||||
"gorm.io/driver/sqlite"
|
||||
"gorm.io/gorm"
|
||||
"gorm.io/gorm/logger"
|
||||
)
|
||||
|
||||
// Store represents the GORM database connection with sqlite-vec support.
|
||||
type Store struct {
|
||||
DB *gorm.DB
|
||||
sqlDB *sql.DB // For FTS5 and sqlite-vec operations that require raw SQL
|
||||
}
|
||||
|
||||
// Config holds database configuration.
|
||||
type Config struct {
|
||||
Path string // Path to SQLite database file
|
||||
MaxConns int // Maximum number of open connections (default: 4)
|
||||
LogLevel logger.LogLevel // GORM log level (logger.Silent for production)
|
||||
}
|
||||
|
||||
// NewStore creates a new Store with WAL mode enabled and sqlite-vec registered.
|
||||
// CRITICAL: WAL mode and foreign keys are enabled via pragmas for concurrent reads.
|
||||
func NewStore(cfg Config) (*Store, error) {
|
||||
// 1. Register sqlite-vec extension (must be done before opening database)
|
||||
sqlite_vec.Auto()
|
||||
|
||||
// 2. Build connection string (foreign keys enabled in DSN)
|
||||
// Use sqlite3 driver (mattn/go-sqlite3) which has FTS5 support
|
||||
dsn := cfg.Path + "?_foreign_keys=ON"
|
||||
|
||||
// 3. Open raw database connection with mattn/go-sqlite3 (has FTS5 support)
|
||||
sqlDB, err := sql.Open("sqlite3", dsn)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("open database: %w", err)
|
||||
}
|
||||
|
||||
// 4. Wrap with GORM using existing connection
|
||||
db, err := gorm.Open(sqlite.Dialector{
|
||||
Conn: sqlDB,
|
||||
}, &gorm.Config{
|
||||
Logger: logger.Default.LogMode(cfg.LogLevel),
|
||||
// PrepareStmt enables prepared statement caching for performance
|
||||
PrepareStmt: true,
|
||||
// Disable default timestamp fields (we manage created_at manually)
|
||||
NowFunc: nil,
|
||||
})
|
||||
if err != nil {
|
||||
_ = sqlDB.Close() // Explicitly ignore close error during cleanup
|
||||
return nil, fmt.Errorf("open gorm: %w", err)
|
||||
}
|
||||
|
||||
// 5. Configure connection pool (same settings as current implementation)
|
||||
maxConns := cfg.MaxConns
|
||||
if maxConns <= 0 {
|
||||
maxConns = 4
|
||||
}
|
||||
sqlDB.SetMaxOpenConns(maxConns)
|
||||
sqlDB.SetMaxIdleConns(maxConns)
|
||||
sqlDB.SetConnMaxLifetime(0) // Never expire (SQLite connections are cheap)
|
||||
|
||||
// 6. Verify connection
|
||||
if err := sqlDB.Ping(); err != nil {
|
||||
return nil, fmt.Errorf("ping database: %w", err)
|
||||
}
|
||||
|
||||
store := &Store{
|
||||
DB: db,
|
||||
sqlDB: sqlDB,
|
||||
}
|
||||
|
||||
// 7. Run migrations FIRST (before PRAGMA commands)
|
||||
if err := runMigrations(db, sqlDB); err != nil {
|
||||
return nil, fmt.Errorf("run migrations: %w", err)
|
||||
}
|
||||
|
||||
// 8. CRITICAL: Set WAL mode and synchronous mode via raw SQL
|
||||
// Use raw sqlDB to avoid GORM transaction issues
|
||||
if _, err := sqlDB.Exec("PRAGMA journal_mode=WAL"); err != nil {
|
||||
return nil, fmt.Errorf("set WAL mode: %w", err)
|
||||
}
|
||||
if _, err := sqlDB.Exec("PRAGMA synchronous=NORMAL"); err != nil {
|
||||
return nil, fmt.Errorf("set synchronous mode: %w", err)
|
||||
}
|
||||
|
||||
return store, nil
|
||||
}
|
||||
|
||||
// Close closes the database connection.
|
||||
func (s *Store) Close() error {
|
||||
return s.sqlDB.Close()
|
||||
}
|
||||
|
||||
// Ping verifies the database connection is alive.
|
||||
func (s *Store) Ping() error {
|
||||
return s.sqlDB.Ping()
|
||||
}
|
||||
|
||||
// GetRawDB returns the underlying *sql.DB for operations GORM can't handle.
|
||||
// Use this for:
|
||||
// - FTS5 full-text search queries (MATCH operator)
|
||||
// - sqlite-vec vector operations
|
||||
// - Complex raw SQL queries
|
||||
func (s *Store) GetRawDB() *sql.DB {
|
||||
return s.sqlDB
|
||||
}
|
||||
|
||||
// GetDB returns the GORM DB instance for standard queries.
|
||||
func (s *Store) GetDB() *gorm.DB {
|
||||
return s.DB
|
||||
}
|
||||
@@ -0,0 +1,152 @@
|
||||
//go:build fts5
|
||||
|
||||
// Package gorm provides GORM-based database operations for claude-mnemonic.
|
||||
package gorm
|
||||
|
||||
import (
|
||||
"os"
|
||||
"path/filepath"
|
||||
"testing"
|
||||
|
||||
"gorm.io/gorm/logger"
|
||||
)
|
||||
|
||||
func TestNewStore(t *testing.T) {
|
||||
// Create temporary directory for test database
|
||||
tmpDir, err := os.MkdirTemp("", "gorm_test_*")
|
||||
if err != nil {
|
||||
t.Fatalf("create temp dir: %v", err)
|
||||
}
|
||||
defer os.RemoveAll(tmpDir)
|
||||
|
||||
dbPath := filepath.Join(tmpDir, "test.db")
|
||||
|
||||
// Create store with migrations
|
||||
cfg := Config{
|
||||
Path: dbPath,
|
||||
MaxConns: 4,
|
||||
LogLevel: logger.Silent,
|
||||
}
|
||||
|
||||
store, err := NewStore(cfg)
|
||||
if err != nil {
|
||||
t.Fatalf("NewStore failed: %v", err)
|
||||
}
|
||||
defer store.Close()
|
||||
|
||||
// Verify connection works
|
||||
sqlDB := store.GetRawDB()
|
||||
if err := sqlDB.Ping(); err != nil {
|
||||
t.Fatalf("ping failed: %v", err)
|
||||
}
|
||||
|
||||
// Verify WAL mode is enabled
|
||||
var journalMode string
|
||||
err = store.DB.Raw("PRAGMA journal_mode").Scan(&journalMode).Error
|
||||
if err != nil {
|
||||
t.Fatalf("query journal_mode failed: %v", err)
|
||||
}
|
||||
if journalMode != "wal" {
|
||||
t.Errorf("expected WAL mode, got %q", journalMode)
|
||||
}
|
||||
|
||||
// Verify core tables exist
|
||||
tables := []string{
|
||||
"sdk_sessions",
|
||||
"observations",
|
||||
"session_summaries",
|
||||
"user_prompts",
|
||||
"observation_conflicts",
|
||||
"observation_relations",
|
||||
"patterns",
|
||||
"concept_weights",
|
||||
}
|
||||
|
||||
for _, table := range tables {
|
||||
exists := store.DB.Migrator().HasTable(table)
|
||||
if !exists {
|
||||
t.Errorf("table %q does not exist", table)
|
||||
}
|
||||
}
|
||||
|
||||
// Verify FTS5 virtual tables exist (cannot use Migrator().HasTable for virtual tables)
|
||||
ftsTables := []string{
|
||||
"user_prompts_fts",
|
||||
"observations_fts",
|
||||
"session_summaries_fts",
|
||||
"patterns_fts",
|
||||
}
|
||||
|
||||
for _, table := range ftsTables {
|
||||
var count int
|
||||
err := store.DB.Raw("SELECT COUNT(*) FROM sqlite_master WHERE type='table' AND name=?", table).Scan(&count).Error
|
||||
if err != nil {
|
||||
t.Errorf("check FTS table %q failed: %v", table, err)
|
||||
}
|
||||
if count != 1 {
|
||||
t.Errorf("FTS table %q does not exist", table)
|
||||
}
|
||||
}
|
||||
|
||||
// Verify vectors table exists (virtual table)
|
||||
var vectorsCount int
|
||||
err = store.DB.Raw("SELECT COUNT(*) FROM sqlite_master WHERE type='table' AND name='vectors'").Scan(&vectorsCount).Error
|
||||
if err != nil {
|
||||
t.Errorf("check vectors table failed: %v", err)
|
||||
}
|
||||
if vectorsCount != 1 {
|
||||
t.Errorf("vectors table does not exist")
|
||||
}
|
||||
|
||||
// Verify concept_weights seed data exists
|
||||
var conceptCount int64
|
||||
store.DB.Model(&ConceptWeight{}).Count(&conceptCount)
|
||||
if conceptCount != 12 {
|
||||
t.Errorf("expected 12 concept weights, got %d", conceptCount)
|
||||
}
|
||||
|
||||
t.Logf("✅ Phase 1 Foundation: All migrations successful")
|
||||
t.Logf(" - Core tables: %d", len(tables))
|
||||
t.Logf(" - FTS5 tables: %d", len(ftsTables))
|
||||
t.Logf(" - Vector table: 1")
|
||||
t.Logf(" - Seed data: %d concept weights", conceptCount)
|
||||
}
|
||||
|
||||
func TestMigrationIdempotency(t *testing.T) {
|
||||
// Create temporary directory for test database
|
||||
tmpDir, err := os.MkdirTemp("", "gorm_idempotency_*")
|
||||
if err != nil {
|
||||
t.Fatalf("create temp dir: %v", err)
|
||||
}
|
||||
defer os.RemoveAll(tmpDir)
|
||||
|
||||
dbPath := filepath.Join(tmpDir, "test.db")
|
||||
cfg := Config{
|
||||
Path: dbPath,
|
||||
MaxConns: 4,
|
||||
LogLevel: logger.Silent,
|
||||
}
|
||||
|
||||
// Run migrations first time
|
||||
store1, err := NewStore(cfg)
|
||||
if err != nil {
|
||||
t.Fatalf("NewStore (first) failed: %v", err)
|
||||
}
|
||||
store1.Close()
|
||||
|
||||
// Run migrations second time (should be idempotent)
|
||||
store2, err := NewStore(cfg)
|
||||
if err != nil {
|
||||
t.Fatalf("NewStore (second) failed: %v", err)
|
||||
}
|
||||
defer store2.Close()
|
||||
|
||||
// Verify concept_weights seed data is still exactly 12 (INSERT OR IGNORE)
|
||||
var conceptCount int64
|
||||
store2.DB.Model(&ConceptWeight{}).Count(&conceptCount)
|
||||
if conceptCount != 12 {
|
||||
t.Errorf("expected 12 concept weights after second migration, got %d", conceptCount)
|
||||
}
|
||||
|
||||
t.Logf("✅ Migrations are idempotent")
|
||||
}
|
||||
@@ -0,0 +1,171 @@
|
||||
// Package gorm provides GORM-based database operations for claude-mnemonic.
|
||||
package gorm
|
||||
|
||||
import (
|
||||
"context"
|
||||
"database/sql"
|
||||
"time"
|
||||
|
||||
"gorm.io/gorm"
|
||||
|
||||
"github.com/lukaszraczylo/claude-mnemonic/pkg/models"
|
||||
)
|
||||
|
||||
// SummaryStore provides summary-related database operations using GORM.
|
||||
type SummaryStore struct {
|
||||
db *gorm.DB
|
||||
}
|
||||
|
||||
// NewSummaryStore creates a new summary store.
|
||||
func NewSummaryStore(store *Store) *SummaryStore {
|
||||
return &SummaryStore{db: store.DB}
|
||||
}
|
||||
|
||||
// StoreSummary stores a new session summary.
|
||||
func (s *SummaryStore) StoreSummary(ctx context.Context, sdkSessionID, project string, summary *models.ParsedSummary, promptNumber int, discoveryTokens int64) (int64, int64, error) {
|
||||
now := time.Now()
|
||||
nowEpoch := now.UnixMilli()
|
||||
|
||||
// Ensure session exists (auto-create if missing)
|
||||
if err := EnsureSessionExists(ctx, s.db, sdkSessionID, project); err != nil {
|
||||
return 0, 0, err
|
||||
}
|
||||
|
||||
dbSummary := &SessionSummary{
|
||||
SDKSessionID: sdkSessionID,
|
||||
Project: project,
|
||||
Request: nullString(summary.Request),
|
||||
Investigated: nullString(summary.Investigated),
|
||||
Learned: nullString(summary.Learned),
|
||||
Completed: nullString(summary.Completed),
|
||||
NextSteps: nullString(summary.NextSteps),
|
||||
Notes: nullString(summary.Notes),
|
||||
PromptNumber: func() sql.NullInt64 {
|
||||
if promptNumber > 0 {
|
||||
return sql.NullInt64{Int64: int64(promptNumber), Valid: true}
|
||||
}
|
||||
return sql.NullInt64{Valid: false}
|
||||
}(),
|
||||
DiscoveryTokens: discoveryTokens,
|
||||
CreatedAt: now.Format(time.RFC3339),
|
||||
CreatedAtEpoch: nowEpoch,
|
||||
}
|
||||
|
||||
err := s.db.WithContext(ctx).Create(dbSummary).Error
|
||||
if err != nil {
|
||||
return 0, 0, err
|
||||
}
|
||||
|
||||
return dbSummary.ID, nowEpoch, nil
|
||||
}
|
||||
|
||||
// GetSummariesByIDs retrieves summaries by a list of IDs.
|
||||
func (s *SummaryStore) GetSummariesByIDs(ctx context.Context, ids []int64, orderBy string, limit int) ([]*models.SessionSummary, error) {
|
||||
if len(ids) == 0 {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
var dbSummaries []SessionSummary
|
||||
query := s.db.WithContext(ctx).Where("id IN ?", ids)
|
||||
|
||||
// Apply ordering
|
||||
switch orderBy {
|
||||
case "date_asc":
|
||||
query = query.Order("created_at_epoch ASC")
|
||||
case "date_desc", "default", "":
|
||||
query = query.Order("created_at_epoch DESC")
|
||||
}
|
||||
|
||||
// Apply limit
|
||||
if limit > 0 {
|
||||
query = query.Limit(limit)
|
||||
}
|
||||
|
||||
err := query.Find(&dbSummaries).Error
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return toModelSessionSummaries(dbSummaries), nil
|
||||
}
|
||||
|
||||
// GetRecentSummaries retrieves recent summaries for a project.
|
||||
func (s *SummaryStore) GetRecentSummaries(ctx context.Context, project string, limit int) ([]*models.SessionSummary, error) {
|
||||
var dbSummaries []SessionSummary
|
||||
err := s.db.WithContext(ctx).
|
||||
Where("project = ?", project).
|
||||
Order("created_at_epoch DESC").
|
||||
Limit(limit).
|
||||
Find(&dbSummaries).Error
|
||||
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return toModelSessionSummaries(dbSummaries), nil
|
||||
}
|
||||
|
||||
// GetAllRecentSummaries retrieves recent summaries across all projects.
|
||||
func (s *SummaryStore) GetAllRecentSummaries(ctx context.Context, limit int) ([]*models.SessionSummary, error) {
|
||||
var dbSummaries []SessionSummary
|
||||
err := s.db.WithContext(ctx).
|
||||
Order("created_at_epoch DESC").
|
||||
Limit(limit).
|
||||
Find(&dbSummaries).Error
|
||||
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return toModelSessionSummaries(dbSummaries), nil
|
||||
}
|
||||
|
||||
// GetAllSummaries retrieves all summaries (for vector rebuild).
|
||||
func (s *SummaryStore) GetAllSummaries(ctx context.Context) ([]*models.SessionSummary, error) {
|
||||
var dbSummaries []SessionSummary
|
||||
err := s.db.WithContext(ctx).
|
||||
Order("id").
|
||||
Find(&dbSummaries).Error
|
||||
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return toModelSessionSummaries(dbSummaries), nil
|
||||
}
|
||||
|
||||
// toModelSessionSummary converts a GORM SessionSummary to pkg/models.SessionSummary.
|
||||
func toModelSessionSummary(s *SessionSummary) *models.SessionSummary {
|
||||
return &models.SessionSummary{
|
||||
ID: s.ID,
|
||||
SDKSessionID: s.SDKSessionID,
|
||||
Project: s.Project,
|
||||
Request: s.Request,
|
||||
Investigated: s.Investigated,
|
||||
Learned: s.Learned,
|
||||
Completed: s.Completed,
|
||||
NextSteps: s.NextSteps,
|
||||
Notes: s.Notes,
|
||||
PromptNumber: s.PromptNumber,
|
||||
DiscoveryTokens: s.DiscoveryTokens,
|
||||
CreatedAt: s.CreatedAt,
|
||||
CreatedAtEpoch: s.CreatedAtEpoch,
|
||||
}
|
||||
}
|
||||
|
||||
// toModelSessionSummaries converts a slice of GORM SessionSummary to pkg/models.SessionSummary.
|
||||
func toModelSessionSummaries(summaries []SessionSummary) []*models.SessionSummary {
|
||||
result := make([]*models.SessionSummary, len(summaries))
|
||||
for i := range summaries {
|
||||
result[i] = toModelSessionSummary(&summaries[i])
|
||||
}
|
||||
return result
|
||||
}
|
||||
|
||||
// nullString converts a string to sql.NullString.
|
||||
func nullString(s string) sql.NullString {
|
||||
if s == "" {
|
||||
return sql.NullString{Valid: false}
|
||||
}
|
||||
return sql.NullString{String: s, Valid: true}
|
||||
}
|
||||
@@ -0,0 +1,278 @@
|
||||
//go:build fts5
|
||||
|
||||
// Package gorm provides GORM-based database operations for claude-mnemonic.
|
||||
package gorm
|
||||
|
||||
import (
|
||||
"context"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
"gorm.io/gorm/logger"
|
||||
|
||||
"github.com/lukaszraczylo/claude-mnemonic/pkg/models"
|
||||
)
|
||||
|
||||
// testSummaryStore creates a SummaryStore with a temporary database for testing.
|
||||
func testSummaryStore(t *testing.T) (*SummaryStore, *Store, func()) {
|
||||
t.Helper()
|
||||
|
||||
tmpDir, err := os.MkdirTemp("", "gorm_summary_test_*")
|
||||
if err != nil {
|
||||
t.Fatalf("create temp dir: %v", err)
|
||||
}
|
||||
|
||||
dbPath := filepath.Join(tmpDir, "test.db")
|
||||
cfg := Config{
|
||||
Path: dbPath,
|
||||
MaxConns: 4,
|
||||
LogLevel: logger.Silent,
|
||||
}
|
||||
|
||||
store, err := NewStore(cfg)
|
||||
if err != nil {
|
||||
os.RemoveAll(tmpDir)
|
||||
t.Fatalf("NewStore failed: %v", err)
|
||||
}
|
||||
|
||||
summaryStore := NewSummaryStore(store)
|
||||
|
||||
cleanup := func() {
|
||||
store.Close()
|
||||
os.RemoveAll(tmpDir)
|
||||
}
|
||||
|
||||
return summaryStore, store, cleanup
|
||||
}
|
||||
|
||||
func TestSummaryStore_StoreSummary(t *testing.T) {
|
||||
summaryStore, store, cleanup := testSummaryStore(t)
|
||||
defer cleanup()
|
||||
|
||||
ctx := context.Background()
|
||||
|
||||
// Create a session first
|
||||
sessionStore := NewSessionStore(store)
|
||||
_, err := sessionStore.CreateSDKSession(ctx, "claude-1", "test-project", "")
|
||||
require.NoError(t, err)
|
||||
|
||||
// Store a summary
|
||||
summary := &models.ParsedSummary{
|
||||
Request: "Build a feature",
|
||||
Investigated: "Examined the codebase",
|
||||
Learned: "Discovered patterns",
|
||||
Completed: "Implemented solution",
|
||||
NextSteps: "Write tests",
|
||||
Notes: "Additional notes",
|
||||
}
|
||||
|
||||
id, epoch, err := summaryStore.StoreSummary(ctx, "claude-1", "test-project", summary, 1, 100)
|
||||
require.NoError(t, err)
|
||||
assert.Greater(t, id, int64(0))
|
||||
assert.Greater(t, epoch, int64(0))
|
||||
}
|
||||
|
||||
func TestSummaryStore_StoreSummary_AutoCreateSession(t *testing.T) {
|
||||
summaryStore, _, cleanup := testSummaryStore(t)
|
||||
defer cleanup()
|
||||
|
||||
ctx := context.Background()
|
||||
|
||||
// Store summary without pre-creating session
|
||||
summary := &models.ParsedSummary{
|
||||
Request: "Test auto-create",
|
||||
}
|
||||
|
||||
id, _, err := summaryStore.StoreSummary(ctx, "claude-auto", "auto-project", summary, 1, 50)
|
||||
require.NoError(t, err)
|
||||
assert.Greater(t, id, int64(0))
|
||||
}
|
||||
|
||||
func TestSummaryStore_GetRecentSummaries(t *testing.T) {
|
||||
summaryStore, _, cleanup := testSummaryStore(t)
|
||||
defer cleanup()
|
||||
|
||||
ctx := context.Background()
|
||||
|
||||
// Store multiple summaries
|
||||
for i := 1; i <= 5; i++ {
|
||||
summary := &models.ParsedSummary{
|
||||
Request: "Request " + string(rune('0'+i)),
|
||||
}
|
||||
_, _, err := summaryStore.StoreSummary(ctx, "claude-1", "project-a", summary, i, 10)
|
||||
require.NoError(t, err)
|
||||
}
|
||||
|
||||
// Store summary for different project
|
||||
summary := &models.ParsedSummary{Request: "Other project"}
|
||||
_, _, err := summaryStore.StoreSummary(ctx, "claude-2", "project-b", summary, 1, 10)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Get recent summaries for project-a
|
||||
summaries, err := summaryStore.GetRecentSummaries(ctx, "project-a", 10)
|
||||
require.NoError(t, err)
|
||||
assert.Len(t, summaries, 5)
|
||||
|
||||
// Verify ordering (most recent first)
|
||||
assert.Equal(t, "project-a", summaries[0].Project)
|
||||
}
|
||||
|
||||
func TestSummaryStore_GetAllRecentSummaries(t *testing.T) {
|
||||
summaryStore, _, cleanup := testSummaryStore(t)
|
||||
defer cleanup()
|
||||
|
||||
ctx := context.Background()
|
||||
|
||||
// Store summaries across projects
|
||||
_, _, err := summaryStore.StoreSummary(ctx, "claude-1", "project-a", &models.ParsedSummary{Request: "A1"}, 1, 10)
|
||||
require.NoError(t, err)
|
||||
|
||||
_, _, err = summaryStore.StoreSummary(ctx, "claude-2", "project-b", &models.ParsedSummary{Request: "B1"}, 1, 10)
|
||||
require.NoError(t, err)
|
||||
|
||||
_, _, err = summaryStore.StoreSummary(ctx, "claude-3", "project-c", &models.ParsedSummary{Request: "C1"}, 1, 10)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Get all recent summaries
|
||||
summaries, err := summaryStore.GetAllRecentSummaries(ctx, 10)
|
||||
require.NoError(t, err)
|
||||
assert.Len(t, summaries, 3)
|
||||
}
|
||||
|
||||
func TestSummaryStore_GetSummariesByIDs(t *testing.T) {
|
||||
summaryStore, _, cleanup := testSummaryStore(t)
|
||||
defer cleanup()
|
||||
|
||||
ctx := context.Background()
|
||||
|
||||
// Store multiple summaries
|
||||
var ids []int64
|
||||
for i := 1; i <= 3; i++ {
|
||||
id, _, err := summaryStore.StoreSummary(ctx, "claude-1", "project-a", &models.ParsedSummary{Request: "Test"}, i, 10)
|
||||
require.NoError(t, err)
|
||||
ids = append(ids, id)
|
||||
}
|
||||
|
||||
// Get by IDs
|
||||
summaries, err := summaryStore.GetSummariesByIDs(ctx, ids, "date_desc", 10)
|
||||
require.NoError(t, err)
|
||||
assert.Len(t, summaries, 3)
|
||||
|
||||
// Get with limit
|
||||
summaries, err = summaryStore.GetSummariesByIDs(ctx, ids, "date_desc", 2)
|
||||
require.NoError(t, err)
|
||||
assert.Len(t, summaries, 2)
|
||||
}
|
||||
|
||||
func TestSummaryStore_GetSummariesByIDs_EmptyInput(t *testing.T) {
|
||||
summaryStore, _, cleanup := testSummaryStore(t)
|
||||
defer cleanup()
|
||||
|
||||
ctx := context.Background()
|
||||
|
||||
// Get with empty IDs
|
||||
summaries, err := summaryStore.GetSummariesByIDs(ctx, []int64{}, "date_desc", 10)
|
||||
require.NoError(t, err)
|
||||
assert.Nil(t, summaries)
|
||||
}
|
||||
|
||||
func TestSummaryStore_SummaryFields(t *testing.T) {
|
||||
summaryStore, _, cleanup := testSummaryStore(t)
|
||||
defer cleanup()
|
||||
|
||||
ctx := context.Background()
|
||||
|
||||
// Store a summary with all fields
|
||||
summary := &models.ParsedSummary{
|
||||
Request: "Full request",
|
||||
Investigated: "Full investigation",
|
||||
Learned: "Full learning",
|
||||
Completed: "Full completion",
|
||||
NextSteps: "Full next steps",
|
||||
Notes: "Full notes",
|
||||
}
|
||||
|
||||
id, epoch, err := summaryStore.StoreSummary(ctx, "claude-1", "test-project", summary, 5, 200)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Retrieve and verify all fields
|
||||
summaries, err := summaryStore.GetSummariesByIDs(ctx, []int64{id}, "date_desc", 1)
|
||||
require.NoError(t, err)
|
||||
require.Len(t, summaries, 1)
|
||||
|
||||
s := summaries[0]
|
||||
assert.Equal(t, id, s.ID)
|
||||
assert.Equal(t, "claude-1", s.SDKSessionID)
|
||||
assert.Equal(t, "test-project", s.Project)
|
||||
assert.True(t, s.Request.Valid)
|
||||
assert.Equal(t, "Full request", s.Request.String)
|
||||
assert.True(t, s.Investigated.Valid)
|
||||
assert.Equal(t, "Full investigation", s.Investigated.String)
|
||||
assert.True(t, s.Learned.Valid)
|
||||
assert.Equal(t, "Full learning", s.Learned.String)
|
||||
assert.True(t, s.Completed.Valid)
|
||||
assert.Equal(t, "Full completion", s.Completed.String)
|
||||
assert.True(t, s.NextSteps.Valid)
|
||||
assert.Equal(t, "Full next steps", s.NextSteps.String)
|
||||
assert.True(t, s.Notes.Valid)
|
||||
assert.Equal(t, "Full notes", s.Notes.String)
|
||||
assert.True(t, s.PromptNumber.Valid)
|
||||
assert.Equal(t, int64(5), s.PromptNumber.Int64)
|
||||
assert.Equal(t, int64(200), s.DiscoveryTokens)
|
||||
assert.NotEmpty(t, s.CreatedAt)
|
||||
assert.Equal(t, epoch, s.CreatedAtEpoch)
|
||||
}
|
||||
|
||||
func TestSummaryStore_EmptySummary(t *testing.T) {
|
||||
summaryStore, _, cleanup := testSummaryStore(t)
|
||||
defer cleanup()
|
||||
|
||||
ctx := context.Background()
|
||||
|
||||
// Store a summary with empty fields
|
||||
summary := &models.ParsedSummary{}
|
||||
|
||||
id, _, err := summaryStore.StoreSummary(ctx, "claude-1", "test-project", summary, 0, 0)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Retrieve and verify NULL fields
|
||||
summaries, err := summaryStore.GetSummariesByIDs(ctx, []int64{id}, "date_desc", 1)
|
||||
require.NoError(t, err)
|
||||
require.Len(t, summaries, 1)
|
||||
|
||||
s := summaries[0]
|
||||
assert.False(t, s.Request.Valid)
|
||||
assert.False(t, s.Investigated.Valid)
|
||||
assert.False(t, s.Learned.Valid)
|
||||
assert.False(t, s.Completed.Valid)
|
||||
assert.False(t, s.NextSteps.Valid)
|
||||
assert.False(t, s.Notes.Valid)
|
||||
assert.False(t, s.PromptNumber.Valid)
|
||||
assert.Equal(t, int64(0), s.DiscoveryTokens)
|
||||
}
|
||||
|
||||
func TestSummaryStore_GetAllSummaries(t *testing.T) {
|
||||
summaryStore, _, cleanup := testSummaryStore(t)
|
||||
defer cleanup()
|
||||
|
||||
ctx := context.Background()
|
||||
|
||||
// Store multiple summaries
|
||||
for i := 1; i <= 5; i++ {
|
||||
_, _, err := summaryStore.StoreSummary(ctx, "claude-1", "project-a", &models.ParsedSummary{Request: "Test"}, i, 10)
|
||||
require.NoError(t, err)
|
||||
}
|
||||
|
||||
// Get all summaries
|
||||
summaries, err := summaryStore.GetAllSummaries(ctx)
|
||||
require.NoError(t, err)
|
||||
assert.Len(t, summaries, 5)
|
||||
|
||||
// Verify ordering by ID
|
||||
for i := 0; i < len(summaries)-1; i++ {
|
||||
assert.Less(t, summaries[i].ID, summaries[i+1].ID)
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,71 @@
|
||||
// Package db defines database interfaces for the claude-mnemonic stores.
|
||||
package db
|
||||
|
||||
import (
|
||||
"context"
|
||||
|
||||
"github.com/lukaszraczylo/claude-mnemonic/pkg/models"
|
||||
)
|
||||
|
||||
// ObservationReader defines read operations for observations.
|
||||
type ObservationReader interface {
|
||||
GetObservationByID(ctx context.Context, id int64) (*models.Observation, error)
|
||||
GetObservationsByIDs(ctx context.Context, ids []int64, orderBy string, limit int) ([]*models.Observation, error)
|
||||
GetRecentObservations(ctx context.Context, project string, limit int) ([]*models.Observation, error)
|
||||
GetActiveObservations(ctx context.Context, project string, limit int) ([]*models.Observation, error)
|
||||
GetAllRecentObservations(ctx context.Context, limit int) ([]*models.Observation, error)
|
||||
GetAllObservations(ctx context.Context) ([]*models.Observation, error)
|
||||
SearchObservationsFTS(ctx context.Context, query, project string, limit int) ([]*models.Observation, error)
|
||||
GetObservationCount(ctx context.Context, project string) (int, error)
|
||||
}
|
||||
|
||||
// ObservationWriter defines write operations for observations.
|
||||
type ObservationWriter interface {
|
||||
StoreObservation(ctx context.Context, sdkSessionID, project string, obs *models.ParsedObservation, promptNumber int, discoveryTokens int64) (int64, int64, error)
|
||||
DeleteObservations(ctx context.Context, ids []int64) (int64, error)
|
||||
}
|
||||
|
||||
// ObservationStore combines read and write operations for observations.
|
||||
type ObservationStore interface {
|
||||
ObservationReader
|
||||
ObservationWriter
|
||||
}
|
||||
|
||||
// SummaryReader defines read operations for summaries.
|
||||
type SummaryReader interface {
|
||||
GetSummariesByIDs(ctx context.Context, ids []int64, orderBy string, limit int) ([]*models.SessionSummary, error)
|
||||
GetRecentSummaries(ctx context.Context, project string, limit int) ([]*models.SessionSummary, error)
|
||||
GetAllRecentSummaries(ctx context.Context, limit int) ([]*models.SessionSummary, error)
|
||||
GetAllSummaries(ctx context.Context) ([]*models.SessionSummary, error)
|
||||
}
|
||||
|
||||
// SummaryWriter defines write operations for summaries.
|
||||
type SummaryWriter interface {
|
||||
StoreSummary(ctx context.Context, sdkSessionID, project string, summary *models.ParsedSummary, promptNumber int, discoveryTokens int64) (int64, int64, error)
|
||||
}
|
||||
|
||||
// SummaryStore combines read and write operations for summaries.
|
||||
type SummaryStore interface {
|
||||
SummaryReader
|
||||
SummaryWriter
|
||||
}
|
||||
|
||||
// PromptReader defines read operations for prompts.
|
||||
type PromptReader interface {
|
||||
GetPromptsByIDs(ctx context.Context, ids []int64, orderBy string, limit int) ([]*models.UserPromptWithSession, error)
|
||||
GetAllRecentUserPrompts(ctx context.Context, limit int) ([]*models.UserPromptWithSession, error)
|
||||
GetAllPrompts(ctx context.Context) ([]*models.UserPromptWithSession, error)
|
||||
GetRecentUserPromptsByProject(ctx context.Context, project string, limit int) ([]*models.UserPromptWithSession, error)
|
||||
FindRecentPromptByText(ctx context.Context, claudeSessionID, promptText string, withinSeconds int) (int64, int, bool)
|
||||
}
|
||||
|
||||
// PromptWriter defines write operations for prompts.
|
||||
type PromptWriter interface {
|
||||
SaveUserPromptWithMatches(ctx context.Context, claudeSessionID string, promptNumber int, promptText string, matchedObservations int) (int64, error)
|
||||
}
|
||||
|
||||
// PromptStore combines read and write operations for prompts.
|
||||
type PromptStore interface {
|
||||
PromptReader
|
||||
PromptWriter
|
||||
}
|
||||
@@ -1,276 +0,0 @@
|
||||
// Package sqlite provides SQLite database operations for claude-mnemonic.
|
||||
package sqlite
|
||||
|
||||
import (
|
||||
"context"
|
||||
"time"
|
||||
|
||||
"github.com/lukaszraczylo/claude-mnemonic/pkg/models"
|
||||
)
|
||||
|
||||
// SupersededRetentionDays is the number of days to keep superseded observations before deletion.
|
||||
const SupersededRetentionDays = 3
|
||||
|
||||
// ConflictStore provides conflict-related database operations.
|
||||
type ConflictStore struct {
|
||||
store *Store
|
||||
}
|
||||
|
||||
// NewConflictStore creates a new conflict store.
|
||||
func NewConflictStore(store *Store) *ConflictStore {
|
||||
return &ConflictStore{store: store}
|
||||
}
|
||||
|
||||
// StoreConflict stores a new observation conflict.
|
||||
func (s *ConflictStore) StoreConflict(ctx context.Context, conflict *models.ObservationConflict) (int64, error) {
|
||||
const query = `
|
||||
INSERT INTO observation_conflicts
|
||||
(newer_obs_id, older_obs_id, conflict_type, resolution, reason, detected_at, detected_at_epoch, resolved, resolved_at)
|
||||
VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?)
|
||||
`
|
||||
|
||||
result, err := s.store.ExecContext(ctx, query,
|
||||
conflict.NewerObsID, conflict.OlderObsID,
|
||||
string(conflict.ConflictType), string(conflict.Resolution),
|
||||
conflict.Reason, conflict.DetectedAt, conflict.DetectedAtEpoch,
|
||||
conflict.Resolved, conflict.ResolvedAt,
|
||||
)
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
|
||||
return result.LastInsertId()
|
||||
}
|
||||
|
||||
// MarkObservationSuperseded marks an observation as superseded.
|
||||
func (s *ConflictStore) MarkObservationSuperseded(ctx context.Context, obsID int64) error {
|
||||
const query = `UPDATE observations SET is_superseded = 1 WHERE id = ?`
|
||||
_, err := s.store.ExecContext(ctx, query, obsID)
|
||||
return err
|
||||
}
|
||||
|
||||
// MarkObservationsSuperseded marks multiple observations as superseded.
|
||||
func (s *ConflictStore) MarkObservationsSuperseded(ctx context.Context, obsIDs []int64) error {
|
||||
if len(obsIDs) == 0 {
|
||||
return nil
|
||||
}
|
||||
|
||||
query := `UPDATE observations SET is_superseded = 1 WHERE id IN (?` + repeatPlaceholders(len(obsIDs)-1) + `)` // #nosec G202 -- uses parameterized placeholders
|
||||
args := int64SliceToInterface(obsIDs)
|
||||
_, err := s.store.db.ExecContext(ctx, query, args...)
|
||||
return err
|
||||
}
|
||||
|
||||
// GetConflictsByObservationID retrieves all conflicts involving an observation.
|
||||
func (s *ConflictStore) GetConflictsByObservationID(ctx context.Context, obsID int64) ([]*models.ObservationConflict, error) {
|
||||
const query = `
|
||||
SELECT id, newer_obs_id, older_obs_id, conflict_type, resolution, reason,
|
||||
detected_at, detected_at_epoch, resolved, resolved_at
|
||||
FROM observation_conflicts
|
||||
WHERE newer_obs_id = ? OR older_obs_id = ?
|
||||
ORDER BY detected_at_epoch DESC
|
||||
`
|
||||
|
||||
rows, err := s.store.QueryContext(ctx, query, obsID, obsID)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer rows.Close()
|
||||
|
||||
return s.scanConflictRows(rows)
|
||||
}
|
||||
|
||||
// GetUnresolvedConflicts retrieves all unresolved conflicts.
|
||||
func (s *ConflictStore) GetUnresolvedConflicts(ctx context.Context, limit int) ([]*models.ObservationConflict, error) {
|
||||
const query = `
|
||||
SELECT id, newer_obs_id, older_obs_id, conflict_type, resolution, reason,
|
||||
detected_at, detected_at_epoch, resolved, resolved_at
|
||||
FROM observation_conflicts
|
||||
WHERE resolved = 0
|
||||
ORDER BY detected_at_epoch DESC
|
||||
LIMIT ?
|
||||
`
|
||||
|
||||
rows, err := s.store.QueryContext(ctx, query, limit)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer rows.Close()
|
||||
|
||||
return s.scanConflictRows(rows)
|
||||
}
|
||||
|
||||
// GetSupersededObservationIDs returns IDs of all observations that have been superseded.
|
||||
func (s *ConflictStore) GetSupersededObservationIDs(ctx context.Context, project string) ([]int64, error) {
|
||||
const query = `
|
||||
SELECT DISTINCT older_obs_id
|
||||
FROM observation_conflicts oc
|
||||
JOIN observations o ON o.id = oc.older_obs_id
|
||||
WHERE oc.resolution = 'prefer_newer'
|
||||
AND (o.project = ? OR o.scope = 'global')
|
||||
`
|
||||
|
||||
rows, err := s.store.QueryContext(ctx, query, project)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer rows.Close()
|
||||
|
||||
var ids []int64
|
||||
for rows.Next() {
|
||||
var id int64
|
||||
if err := rows.Scan(&id); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
ids = append(ids, id)
|
||||
}
|
||||
return ids, rows.Err()
|
||||
}
|
||||
|
||||
// ResolveConflict marks a conflict as resolved.
|
||||
func (s *ConflictStore) ResolveConflict(ctx context.Context, conflictID int64, resolution models.ConflictResolution) error {
|
||||
now := time.Now().Format(time.RFC3339)
|
||||
const query = `
|
||||
UPDATE observation_conflicts
|
||||
SET resolved = 1, resolved_at = ?, resolution = ?
|
||||
WHERE id = ?
|
||||
`
|
||||
_, err := s.store.ExecContext(ctx, query, now, string(resolution), conflictID)
|
||||
return err
|
||||
}
|
||||
|
||||
// DeleteConflictsByObservationID deletes all conflicts involving an observation.
|
||||
// Called when an observation is deleted.
|
||||
func (s *ConflictStore) DeleteConflictsByObservationID(ctx context.Context, obsID int64) error {
|
||||
const query = `DELETE FROM observation_conflicts WHERE newer_obs_id = ? OR older_obs_id = ?`
|
||||
_, err := s.store.ExecContext(ctx, query, obsID, obsID)
|
||||
return err
|
||||
}
|
||||
|
||||
// ConflictWithDetails contains a conflict with its observation details.
|
||||
type ConflictWithDetails struct {
|
||||
Conflict *models.ObservationConflict
|
||||
NewerObsTitle string
|
||||
OlderObsTitle string
|
||||
}
|
||||
|
||||
// CleanupSupersededObservations deletes observations that have been superseded for longer than
|
||||
// SupersededRetentionDays. Returns the IDs of deleted observations for downstream cleanup (e.g., vector DB).
|
||||
func (s *ConflictStore) CleanupSupersededObservations(ctx context.Context, project string) ([]int64, error) {
|
||||
// Calculate cutoff time (3 days ago in milliseconds)
|
||||
cutoffEpoch := time.Now().AddDate(0, 0, -SupersededRetentionDays).UnixMilli()
|
||||
|
||||
// First, find the IDs that will be deleted
|
||||
// We delete observations that:
|
||||
// 1. Are marked as superseded
|
||||
// 2. Have a conflict record where they are the older observation
|
||||
// 3. The conflict was detected more than 3 days ago
|
||||
const selectQuery = `
|
||||
SELECT DISTINCT o.id FROM observations o
|
||||
JOIN observation_conflicts oc ON o.id = oc.older_obs_id
|
||||
WHERE o.is_superseded = 1
|
||||
AND o.project = ?
|
||||
AND oc.detected_at_epoch < ?
|
||||
`
|
||||
|
||||
rows, err := s.store.QueryContext(ctx, selectQuery, project, cutoffEpoch)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer rows.Close()
|
||||
|
||||
var toDelete []int64
|
||||
for rows.Next() {
|
||||
var id int64
|
||||
if err := rows.Scan(&id); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
toDelete = append(toDelete, id)
|
||||
}
|
||||
if err := rows.Err(); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if len(toDelete) == 0 {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
// Delete the conflict records first (due to foreign key constraints)
|
||||
for _, obsID := range toDelete {
|
||||
if err := s.DeleteConflictsByObservationID(ctx, obsID); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
|
||||
// Delete the observations
|
||||
deleteQuery := `DELETE FROM observations WHERE id IN (?` + repeatPlaceholders(len(toDelete)-1) + `)` // #nosec G202 -- uses parameterized placeholders
|
||||
args := int64SliceToInterface(toDelete)
|
||||
_, err = s.store.db.ExecContext(ctx, deleteQuery, args...)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return toDelete, nil
|
||||
}
|
||||
|
||||
// GetConflictsWithDetails retrieves all conflicts with observation titles for display.
|
||||
func (s *ConflictStore) GetConflictsWithDetails(ctx context.Context, project string, limit int) ([]*ConflictWithDetails, error) {
|
||||
const query = `
|
||||
SELECT oc.id, oc.newer_obs_id, oc.older_obs_id, oc.conflict_type, oc.resolution, oc.reason,
|
||||
oc.detected_at, oc.detected_at_epoch, oc.resolved, oc.resolved_at,
|
||||
COALESCE(newer.title, '') as newer_title,
|
||||
COALESCE(older.title, '') as older_title
|
||||
FROM observation_conflicts oc
|
||||
JOIN observations newer ON newer.id = oc.newer_obs_id
|
||||
JOIN observations older ON older.id = oc.older_obs_id
|
||||
WHERE newer.project = ? OR older.project = ?
|
||||
ORDER BY oc.detected_at_epoch DESC
|
||||
LIMIT ?
|
||||
`
|
||||
|
||||
rows, err := s.store.QueryContext(ctx, query, project, project, limit)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer rows.Close()
|
||||
|
||||
var results []*ConflictWithDetails
|
||||
for rows.Next() {
|
||||
var c models.ObservationConflict
|
||||
var cwd ConflictWithDetails
|
||||
if err := rows.Scan(
|
||||
&c.ID, &c.NewerObsID, &c.OlderObsID,
|
||||
&c.ConflictType, &c.Resolution, &c.Reason,
|
||||
&c.DetectedAt, &c.DetectedAtEpoch,
|
||||
&c.Resolved, &c.ResolvedAt,
|
||||
&cwd.NewerObsTitle, &cwd.OlderObsTitle,
|
||||
); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
cwd.Conflict = &c
|
||||
results = append(results, &cwd)
|
||||
}
|
||||
return results, rows.Err()
|
||||
}
|
||||
|
||||
// scanConflictRows scans multiple conflicts from rows.
|
||||
func (s *ConflictStore) scanConflictRows(rows interface {
|
||||
Next() bool
|
||||
Scan(...interface{}) error
|
||||
Err() error
|
||||
}) ([]*models.ObservationConflict, error) {
|
||||
var conflicts []*models.ObservationConflict
|
||||
for rows.Next() {
|
||||
var c models.ObservationConflict
|
||||
if err := rows.Scan(
|
||||
&c.ID, &c.NewerObsID, &c.OlderObsID,
|
||||
&c.ConflictType, &c.Resolution, &c.Reason,
|
||||
&c.DetectedAt, &c.DetectedAtEpoch,
|
||||
&c.Resolved, &c.ResolvedAt,
|
||||
); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
conflicts = append(conflicts, &c)
|
||||
}
|
||||
return conflicts, rows.Err()
|
||||
}
|
||||
@@ -1,160 +0,0 @@
|
||||
// Package sqlite provides SQLite database operations for claude-mnemonic.
|
||||
package sqlite
|
||||
|
||||
import (
|
||||
"context"
|
||||
"database/sql"
|
||||
"net/http"
|
||||
"strconv"
|
||||
"time"
|
||||
|
||||
"github.com/lukaszraczylo/claude-mnemonic/pkg/models"
|
||||
)
|
||||
|
||||
// EnsureSessionExists creates a session if it doesn't exist.
|
||||
// This is shared between ObservationStore and SummaryStore to avoid duplication.
|
||||
func EnsureSessionExists(ctx context.Context, store *Store, sdkSessionID, project string) error {
|
||||
const checkQuery = `SELECT id FROM sdk_sessions WHERE sdk_session_id = ?`
|
||||
var id int64
|
||||
err := store.QueryRowContext(ctx, checkQuery, sdkSessionID).Scan(&id)
|
||||
if err == nil {
|
||||
return nil // Session exists
|
||||
}
|
||||
if err != sql.ErrNoRows {
|
||||
return err
|
||||
}
|
||||
|
||||
// Auto-create session
|
||||
now := time.Now()
|
||||
const insertQuery = `
|
||||
INSERT INTO sdk_sessions
|
||||
(claude_session_id, sdk_session_id, project, started_at, started_at_epoch, status)
|
||||
VALUES (?, ?, ?, ?, ?, 'active')
|
||||
`
|
||||
_, err = store.ExecContext(ctx, insertQuery,
|
||||
sdkSessionID, sdkSessionID, project,
|
||||
now.Format(time.RFC3339), now.UnixMilli(),
|
||||
)
|
||||
return err
|
||||
}
|
||||
|
||||
// nullString converts a string to sql.NullString.
|
||||
func nullString(s string) sql.NullString {
|
||||
return sql.NullString{String: s, Valid: s != ""}
|
||||
}
|
||||
|
||||
// nullInt converts an int to sql.NullInt64.
|
||||
func nullInt(i int) sql.NullInt64 {
|
||||
return sql.NullInt64{Int64: int64(i), Valid: i > 0}
|
||||
}
|
||||
|
||||
// repeatPlaceholders generates n comma-prefixed placeholders for SQL IN clauses.
|
||||
// e.g., repeatPlaceholders(2) returns ", ?, ?"
|
||||
func repeatPlaceholders(n int) string {
|
||||
if n <= 0 {
|
||||
return ""
|
||||
}
|
||||
result := ""
|
||||
for i := 0; i < n; i++ {
|
||||
result += ", ?"
|
||||
}
|
||||
return result
|
||||
}
|
||||
|
||||
// int64SliceToInterface converts []int64 to []interface{} for SQL queries.
|
||||
func int64SliceToInterface(ids []int64) []interface{} {
|
||||
args := make([]interface{}, len(ids))
|
||||
for i, id := range ids {
|
||||
args[i] = id
|
||||
}
|
||||
return args
|
||||
}
|
||||
|
||||
// ParseLimitParam parses a limit query parameter with a default value.
|
||||
func ParseLimitParam(r *http.Request, defaultLimit int) int {
|
||||
if l := r.URL.Query().Get("limit"); l != "" {
|
||||
if parsed, err := strconv.Atoi(l); err == nil && parsed > 0 {
|
||||
return parsed
|
||||
}
|
||||
}
|
||||
return defaultLimit
|
||||
}
|
||||
|
||||
// scanSummary scans a single summary from a row scanner.
|
||||
func scanSummary(scanner interface{ Scan(...interface{}) error }) (*models.SessionSummary, error) {
|
||||
var summary models.SessionSummary
|
||||
if err := scanner.Scan(
|
||||
&summary.ID, &summary.SDKSessionID, &summary.Project,
|
||||
&summary.Request, &summary.Investigated, &summary.Learned, &summary.Completed,
|
||||
&summary.NextSteps, &summary.Notes, &summary.PromptNumber, &summary.DiscoveryTokens,
|
||||
&summary.CreatedAt, &summary.CreatedAtEpoch,
|
||||
); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return &summary, nil
|
||||
}
|
||||
|
||||
// scanSummaryRows scans multiple summaries from rows.
|
||||
func scanSummaryRows(rows *sql.Rows) ([]*models.SessionSummary, error) {
|
||||
var summaries []*models.SessionSummary
|
||||
for rows.Next() {
|
||||
summary, err := scanSummary(rows)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
summaries = append(summaries, summary)
|
||||
}
|
||||
return summaries, rows.Err()
|
||||
}
|
||||
|
||||
// scanPromptWithSession scans a single prompt with session info from a row scanner.
|
||||
func scanPromptWithSession(scanner interface{ Scan(...interface{}) error }) (*models.UserPromptWithSession, error) {
|
||||
var prompt models.UserPromptWithSession
|
||||
if err := scanner.Scan(
|
||||
&prompt.ID, &prompt.ClaudeSessionID, &prompt.PromptNumber, &prompt.PromptText,
|
||||
&prompt.MatchedObservations, &prompt.CreatedAt, &prompt.CreatedAtEpoch,
|
||||
&prompt.Project, &prompt.SDKSessionID,
|
||||
); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return &prompt, nil
|
||||
}
|
||||
|
||||
// scanPromptWithSessionRows scans multiple prompts with session info from rows.
|
||||
func scanPromptWithSessionRows(rows *sql.Rows) ([]*models.UserPromptWithSession, error) {
|
||||
var prompts []*models.UserPromptWithSession
|
||||
for rows.Next() {
|
||||
prompt, err := scanPromptWithSession(rows)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
prompts = append(prompts, prompt)
|
||||
}
|
||||
return prompts, rows.Err()
|
||||
}
|
||||
|
||||
// BuildGetByIDsQuery builds a query for fetching records by IDs with optional ordering and limit.
|
||||
// Returns the query string and args slice.
|
||||
func BuildGetByIDsQuery(baseQuery string, ids []int64, orderBy string, limit int) (string, []interface{}) {
|
||||
// Build query with placeholders
|
||||
// #nosec G202 -- query uses parameterized placeholders, not user input
|
||||
query := baseQuery + ` WHERE id IN (?` + repeatPlaceholders(len(ids)-1) + `)
|
||||
ORDER BY created_at_epoch `
|
||||
|
||||
if orderBy == "date_asc" {
|
||||
query += "ASC"
|
||||
} else {
|
||||
query += "DESC"
|
||||
}
|
||||
|
||||
if limit > 0 {
|
||||
query += " LIMIT ?"
|
||||
}
|
||||
|
||||
args := int64SliceToInterface(ids)
|
||||
if limit > 0 {
|
||||
args = append(args, limit)
|
||||
}
|
||||
|
||||
return query, args
|
||||
}
|
||||
@@ -1,254 +0,0 @@
|
||||
package sqlite
|
||||
|
||||
import (
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func TestNullString(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
input string
|
||||
expected string
|
||||
valid bool
|
||||
}{
|
||||
{"empty_string", "", "", false},
|
||||
{"non_empty_string", "hello", "hello", true},
|
||||
{"whitespace", " ", " ", true},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
result := nullString(tt.input)
|
||||
assert.Equal(t, tt.expected, result.String)
|
||||
assert.Equal(t, tt.valid, result.Valid)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestNullInt(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
input int
|
||||
expected int64
|
||||
valid bool
|
||||
}{
|
||||
{"zero", 0, 0, false},
|
||||
{"positive", 42, 42, true},
|
||||
{"negative", -1, -1, false},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
result := nullInt(tt.input)
|
||||
assert.Equal(t, tt.expected, result.Int64)
|
||||
assert.Equal(t, tt.valid, result.Valid)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestRepeatPlaceholders(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
n int
|
||||
expected string
|
||||
}{
|
||||
{"zero", 0, ""},
|
||||
{"negative", -1, ""},
|
||||
{"one", 1, ", ?"},
|
||||
{"two", 2, ", ?, ?"},
|
||||
{"three", 3, ", ?, ?, ?"},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
result := repeatPlaceholders(tt.n)
|
||||
assert.Equal(t, tt.expected, result)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestInt64SliceToInterface(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
input []int64
|
||||
expected []interface{}
|
||||
}{
|
||||
{"empty", []int64{}, []interface{}{}},
|
||||
{"single", []int64{42}, []interface{}{int64(42)}},
|
||||
{"multiple", []int64{1, 2, 3}, []interface{}{int64(1), int64(2), int64(3)}},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
result := int64SliceToInterface(tt.input)
|
||||
assert.Equal(t, tt.expected, result)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestParseLimitParam(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
query string
|
||||
defaultLimit int
|
||||
expected int
|
||||
}{
|
||||
{"no_param_uses_default", "", 10, 10},
|
||||
{"valid_limit", "limit=20", 10, 20},
|
||||
{"invalid_limit_uses_default", "limit=abc", 10, 10},
|
||||
{"zero_limit_uses_default", "limit=0", 10, 10},
|
||||
{"negative_limit_uses_default", "limit=-5", 10, 10},
|
||||
{"large_limit", "limit=1000", 10, 1000},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
req := httptest.NewRequest("GET", "/?"+tt.query, nil)
|
||||
result := ParseLimitParam(req, tt.defaultLimit)
|
||||
assert.Equal(t, tt.expected, result)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestScanSummary(t *testing.T) {
|
||||
db, _, cleanup := testDB(t)
|
||||
defer cleanup()
|
||||
createBaseTables(t, db)
|
||||
seedSession(t, db, "claude-123", "sdk-123", "test-project")
|
||||
|
||||
// Insert a test summary
|
||||
_, err := db.Exec(`
|
||||
INSERT INTO session_summaries (sdk_session_id, project, request, investigated, learned, completed, next_steps, notes, prompt_number, discovery_tokens, created_at, created_at_epoch)
|
||||
VALUES ('sdk-123', 'test-project', 'test request', 'test investigated', 'test learned', 'test completed', 'test next steps', 'test notes', 1, 100, '2025-01-01T00:00:00Z', 1704067200000)
|
||||
`)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Query and scan
|
||||
row := db.QueryRow(`
|
||||
SELECT id, sdk_session_id, project, request, investigated, learned, completed, next_steps, notes, prompt_number, discovery_tokens, created_at, created_at_epoch
|
||||
FROM session_summaries WHERE sdk_session_id = ?
|
||||
`, "sdk-123")
|
||||
|
||||
summary, err := scanSummary(row)
|
||||
require.NoError(t, err)
|
||||
assert.NotNil(t, summary)
|
||||
assert.Equal(t, "sdk-123", summary.SDKSessionID)
|
||||
assert.Equal(t, "test-project", summary.Project)
|
||||
assert.Equal(t, "test request", summary.Request.String)
|
||||
assert.Equal(t, "test investigated", summary.Investigated.String)
|
||||
assert.Equal(t, "test learned", summary.Learned.String)
|
||||
assert.Equal(t, "test completed", summary.Completed.String)
|
||||
assert.Equal(t, "test next steps", summary.NextSteps.String)
|
||||
assert.Equal(t, "test notes", summary.Notes.String)
|
||||
}
|
||||
|
||||
func TestScanSummaryRows(t *testing.T) {
|
||||
db, _, cleanup := testDB(t)
|
||||
defer cleanup()
|
||||
createBaseTables(t, db)
|
||||
seedSession(t, db, "claude-123", "sdk-123", "test-project")
|
||||
|
||||
// Insert multiple summaries
|
||||
_, err := db.Exec(`
|
||||
INSERT INTO session_summaries (sdk_session_id, project, request, investigated, learned, completed, next_steps, notes, prompt_number, discovery_tokens, created_at, created_at_epoch)
|
||||
VALUES
|
||||
('sdk-123', 'test-project', 'request 1', '', '', '', '', '', 1, 0, '2025-01-01T00:00:00Z', 1704067200000),
|
||||
('sdk-123', 'test-project', 'request 2', '', '', '', '', '', 2, 0, '2025-01-02T00:00:00Z', 1704153600000)
|
||||
`)
|
||||
require.NoError(t, err)
|
||||
|
||||
rows, err := db.Query(`
|
||||
SELECT id, sdk_session_id, project, request, investigated, learned, completed, next_steps, notes, prompt_number, discovery_tokens, created_at, created_at_epoch
|
||||
FROM session_summaries WHERE sdk_session_id = ? ORDER BY id
|
||||
`, "sdk-123")
|
||||
require.NoError(t, err)
|
||||
defer rows.Close()
|
||||
|
||||
summaries, err := scanSummaryRows(rows)
|
||||
require.NoError(t, err)
|
||||
assert.Len(t, summaries, 2)
|
||||
assert.Equal(t, "request 1", summaries[0].Request.String)
|
||||
assert.Equal(t, "request 2", summaries[1].Request.String)
|
||||
}
|
||||
|
||||
func TestScanPromptWithSession(t *testing.T) {
|
||||
db, _, cleanup := testDB(t)
|
||||
defer cleanup()
|
||||
createBaseTables(t, db)
|
||||
seedSession(t, db, "claude-123", "sdk-123", "test-project")
|
||||
|
||||
// Insert a test prompt
|
||||
_, err := db.Exec(`
|
||||
INSERT INTO user_prompts (claude_session_id, prompt_number, prompt_text, matched_observations, created_at, created_at_epoch)
|
||||
VALUES ('claude-123', 1, 'test prompt', 5, '2025-01-01T00:00:00Z', 1704067200000)
|
||||
`)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Query with session join
|
||||
row := db.QueryRow(`
|
||||
SELECT p.id, p.claude_session_id, p.prompt_number, p.prompt_text, p.matched_observations, p.created_at, p.created_at_epoch, s.project, s.sdk_session_id
|
||||
FROM user_prompts p
|
||||
JOIN sdk_sessions s ON p.claude_session_id = s.claude_session_id
|
||||
WHERE p.claude_session_id = ?
|
||||
`, "claude-123")
|
||||
|
||||
prompt, err := scanPromptWithSession(row)
|
||||
require.NoError(t, err)
|
||||
assert.NotNil(t, prompt)
|
||||
assert.Equal(t, "claude-123", prompt.ClaudeSessionID)
|
||||
assert.Equal(t, 1, prompt.PromptNumber)
|
||||
assert.Equal(t, "test prompt", prompt.PromptText)
|
||||
assert.Equal(t, 5, prompt.MatchedObservations)
|
||||
assert.Equal(t, "test-project", prompt.Project)
|
||||
assert.Equal(t, "sdk-123", prompt.SDKSessionID)
|
||||
}
|
||||
|
||||
func TestScanPromptWithSessionRows(t *testing.T) {
|
||||
db, _, cleanup := testDB(t)
|
||||
defer cleanup()
|
||||
createBaseTables(t, db)
|
||||
seedSession(t, db, "claude-123", "sdk-123", "test-project")
|
||||
|
||||
// Insert multiple prompts
|
||||
_, err := db.Exec(`
|
||||
INSERT INTO user_prompts (claude_session_id, prompt_number, prompt_text, matched_observations, created_at, created_at_epoch)
|
||||
VALUES
|
||||
('claude-123', 1, 'prompt one', 3, '2025-01-01T00:00:00Z', 1704067200000),
|
||||
('claude-123', 2, 'prompt two', 5, '2025-01-02T00:00:00Z', 1704153600000)
|
||||
`)
|
||||
require.NoError(t, err)
|
||||
|
||||
rows, err := db.Query(`
|
||||
SELECT p.id, p.claude_session_id, p.prompt_number, p.prompt_text, p.matched_observations, p.created_at, p.created_at_epoch, s.project, s.sdk_session_id
|
||||
FROM user_prompts p
|
||||
JOIN sdk_sessions s ON p.claude_session_id = s.claude_session_id
|
||||
WHERE p.claude_session_id = ? ORDER BY p.id
|
||||
`, "claude-123")
|
||||
require.NoError(t, err)
|
||||
defer rows.Close()
|
||||
|
||||
prompts, err := scanPromptWithSessionRows(rows)
|
||||
require.NoError(t, err)
|
||||
assert.Len(t, prompts, 2)
|
||||
assert.Equal(t, "prompt one", prompts[0].PromptText)
|
||||
assert.Equal(t, "prompt two", prompts[1].PromptText)
|
||||
}
|
||||
|
||||
func TestParseLimitParam_HTTPRequest(t *testing.T) {
|
||||
// Test with an actual HTTP request
|
||||
handler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
limit := ParseLimitParam(r, 25)
|
||||
if limit != 50 {
|
||||
t.Errorf("Expected limit 50, got %d", limit)
|
||||
}
|
||||
})
|
||||
|
||||
req := httptest.NewRequest("GET", "http://example.com/api?limit=50", nil)
|
||||
w := httptest.NewRecorder()
|
||||
handler.ServeHTTP(w, req)
|
||||
}
|
||||
@@ -1,583 +0,0 @@
|
||||
// Package sqlite provides SQLite database operations for claude-mnemonic.
|
||||
package sqlite
|
||||
|
||||
import (
|
||||
"database/sql"
|
||||
"fmt"
|
||||
"time"
|
||||
)
|
||||
|
||||
// Migration represents a database schema migration.
|
||||
type Migration struct {
|
||||
Version int
|
||||
Name string
|
||||
SQL string
|
||||
}
|
||||
|
||||
// Migrations is the list of all database migrations in order.
|
||||
var Migrations = []Migration{
|
||||
{
|
||||
Version: 4,
|
||||
Name: "sdk_agent_architecture",
|
||||
SQL: `
|
||||
-- SDK Sessions (main session tracking)
|
||||
CREATE TABLE IF NOT EXISTS sdk_sessions (
|
||||
id INTEGER PRIMARY KEY AUTOINCREMENT,
|
||||
claude_session_id TEXT UNIQUE NOT NULL,
|
||||
sdk_session_id TEXT UNIQUE,
|
||||
project TEXT NOT NULL,
|
||||
user_prompt TEXT,
|
||||
started_at TEXT NOT NULL,
|
||||
started_at_epoch INTEGER NOT NULL,
|
||||
completed_at TEXT,
|
||||
completed_at_epoch INTEGER,
|
||||
status TEXT CHECK(status IN ('active', 'completed', 'failed')) NOT NULL DEFAULT 'active'
|
||||
);
|
||||
|
||||
CREATE INDEX IF NOT EXISTS idx_sdk_sessions_claude_id ON sdk_sessions(claude_session_id);
|
||||
CREATE INDEX IF NOT EXISTS idx_sdk_sessions_sdk_id ON sdk_sessions(sdk_session_id);
|
||||
CREATE INDEX IF NOT EXISTS idx_sdk_sessions_project ON sdk_sessions(project);
|
||||
CREATE INDEX IF NOT EXISTS idx_sdk_sessions_status ON sdk_sessions(status);
|
||||
CREATE INDEX IF NOT EXISTS idx_sdk_sessions_started ON sdk_sessions(started_at_epoch DESC);
|
||||
|
||||
-- Observations (extracted learnings)
|
||||
CREATE TABLE IF NOT EXISTS observations (
|
||||
id INTEGER PRIMARY KEY AUTOINCREMENT,
|
||||
sdk_session_id TEXT NOT NULL,
|
||||
project TEXT NOT NULL,
|
||||
text TEXT,
|
||||
type TEXT NOT NULL CHECK(type IN ('decision', 'bugfix', 'feature', 'refactor', 'discovery', 'change')),
|
||||
created_at TEXT NOT NULL,
|
||||
created_at_epoch INTEGER NOT NULL,
|
||||
FOREIGN KEY(sdk_session_id) REFERENCES sdk_sessions(sdk_session_id) ON DELETE CASCADE
|
||||
);
|
||||
|
||||
CREATE INDEX IF NOT EXISTS idx_observations_sdk_session ON observations(sdk_session_id);
|
||||
CREATE INDEX IF NOT EXISTS idx_observations_project ON observations(project);
|
||||
CREATE INDEX IF NOT EXISTS idx_observations_type ON observations(type);
|
||||
CREATE INDEX IF NOT EXISTS idx_observations_created ON observations(created_at_epoch DESC);
|
||||
|
||||
-- Session Summaries
|
||||
CREATE TABLE IF NOT EXISTS session_summaries (
|
||||
id INTEGER PRIMARY KEY AUTOINCREMENT,
|
||||
sdk_session_id TEXT NOT NULL,
|
||||
project TEXT NOT NULL,
|
||||
request TEXT,
|
||||
investigated TEXT,
|
||||
learned TEXT,
|
||||
completed TEXT,
|
||||
next_steps TEXT,
|
||||
files_read TEXT,
|
||||
files_edited TEXT,
|
||||
notes TEXT,
|
||||
created_at TEXT NOT NULL,
|
||||
created_at_epoch INTEGER NOT NULL,
|
||||
FOREIGN KEY(sdk_session_id) REFERENCES sdk_sessions(sdk_session_id) ON DELETE CASCADE
|
||||
);
|
||||
|
||||
CREATE INDEX IF NOT EXISTS idx_session_summaries_sdk_session ON session_summaries(sdk_session_id);
|
||||
CREATE INDEX IF NOT EXISTS idx_session_summaries_project ON session_summaries(project);
|
||||
CREATE INDEX IF NOT EXISTS idx_session_summaries_created ON session_summaries(created_at_epoch DESC);
|
||||
`,
|
||||
},
|
||||
{
|
||||
Version: 5,
|
||||
Name: "worker_port_column",
|
||||
SQL: `ALTER TABLE sdk_sessions ADD COLUMN worker_port INTEGER;`,
|
||||
},
|
||||
{
|
||||
Version: 6,
|
||||
Name: "prompt_tracking_columns",
|
||||
SQL: `
|
||||
ALTER TABLE sdk_sessions ADD COLUMN prompt_counter INTEGER DEFAULT 0;
|
||||
ALTER TABLE observations ADD COLUMN prompt_number INTEGER;
|
||||
ALTER TABLE session_summaries ADD COLUMN prompt_number INTEGER;
|
||||
`,
|
||||
},
|
||||
{
|
||||
Version: 8,
|
||||
Name: "observation_hierarchical_fields",
|
||||
SQL: `
|
||||
ALTER TABLE observations ADD COLUMN title TEXT;
|
||||
ALTER TABLE observations ADD COLUMN subtitle TEXT;
|
||||
ALTER TABLE observations ADD COLUMN facts TEXT;
|
||||
ALTER TABLE observations ADD COLUMN narrative TEXT;
|
||||
ALTER TABLE observations ADD COLUMN concepts TEXT;
|
||||
ALTER TABLE observations ADD COLUMN files_read TEXT;
|
||||
ALTER TABLE observations ADD COLUMN files_modified TEXT;
|
||||
`,
|
||||
},
|
||||
{
|
||||
Version: 10,
|
||||
Name: "user_prompts_table",
|
||||
SQL: `
|
||||
-- User prompts table
|
||||
CREATE TABLE IF NOT EXISTS user_prompts (
|
||||
id INTEGER PRIMARY KEY AUTOINCREMENT,
|
||||
claude_session_id TEXT NOT NULL,
|
||||
prompt_number INTEGER NOT NULL,
|
||||
prompt_text TEXT NOT NULL,
|
||||
created_at TEXT NOT NULL,
|
||||
created_at_epoch INTEGER NOT NULL,
|
||||
FOREIGN KEY(claude_session_id) REFERENCES sdk_sessions(claude_session_id) ON DELETE CASCADE
|
||||
);
|
||||
|
||||
CREATE INDEX IF NOT EXISTS idx_user_prompts_claude_session ON user_prompts(claude_session_id);
|
||||
CREATE INDEX IF NOT EXISTS idx_user_prompts_created ON user_prompts(created_at_epoch DESC);
|
||||
CREATE INDEX IF NOT EXISTS idx_user_prompts_prompt_number ON user_prompts(prompt_number);
|
||||
CREATE INDEX IF NOT EXISTS idx_user_prompts_lookup ON user_prompts(claude_session_id, prompt_number);
|
||||
|
||||
-- FTS5 virtual table for user prompts
|
||||
CREATE VIRTUAL TABLE IF NOT EXISTS user_prompts_fts USING fts5(
|
||||
prompt_text,
|
||||
content='user_prompts',
|
||||
content_rowid='id'
|
||||
);
|
||||
|
||||
-- Triggers for FTS5 sync
|
||||
CREATE TRIGGER IF NOT EXISTS user_prompts_ai AFTER INSERT ON user_prompts BEGIN
|
||||
INSERT INTO user_prompts_fts(rowid, prompt_text)
|
||||
VALUES (new.id, new.prompt_text);
|
||||
END;
|
||||
|
||||
CREATE TRIGGER IF NOT EXISTS user_prompts_ad AFTER DELETE ON user_prompts BEGIN
|
||||
INSERT INTO user_prompts_fts(user_prompts_fts, rowid, prompt_text)
|
||||
VALUES('delete', old.id, old.prompt_text);
|
||||
END;
|
||||
|
||||
CREATE TRIGGER IF NOT EXISTS user_prompts_au AFTER UPDATE ON user_prompts BEGIN
|
||||
INSERT INTO user_prompts_fts(user_prompts_fts, rowid, prompt_text)
|
||||
VALUES('delete', old.id, old.prompt_text);
|
||||
INSERT INTO user_prompts_fts(rowid, prompt_text)
|
||||
VALUES (new.id, new.prompt_text);
|
||||
END;
|
||||
`,
|
||||
},
|
||||
{
|
||||
Version: 11,
|
||||
Name: "discovery_tokens_column",
|
||||
SQL: `
|
||||
ALTER TABLE observations ADD COLUMN discovery_tokens INTEGER DEFAULT 0;
|
||||
ALTER TABLE session_summaries ADD COLUMN discovery_tokens INTEGER DEFAULT 0;
|
||||
`,
|
||||
},
|
||||
{
|
||||
Version: 12,
|
||||
Name: "observations_fts",
|
||||
SQL: `
|
||||
-- FTS5 virtual table for observations
|
||||
CREATE VIRTUAL TABLE IF NOT EXISTS observations_fts USING fts5(
|
||||
title, subtitle, narrative,
|
||||
content='observations',
|
||||
content_rowid='id'
|
||||
);
|
||||
|
||||
-- Triggers for FTS5 sync
|
||||
CREATE TRIGGER IF NOT EXISTS observations_ai AFTER INSERT ON observations BEGIN
|
||||
INSERT INTO observations_fts(rowid, title, subtitle, narrative)
|
||||
VALUES (new.id, new.title, new.subtitle, new.narrative);
|
||||
END;
|
||||
|
||||
CREATE TRIGGER IF NOT EXISTS observations_ad AFTER DELETE ON observations BEGIN
|
||||
INSERT INTO observations_fts(observations_fts, rowid, title, subtitle, narrative)
|
||||
VALUES('delete', old.id, old.title, old.subtitle, old.narrative);
|
||||
END;
|
||||
|
||||
CREATE TRIGGER IF NOT EXISTS observations_au AFTER UPDATE ON observations BEGIN
|
||||
INSERT INTO observations_fts(observations_fts, rowid, title, subtitle, narrative)
|
||||
VALUES('delete', old.id, old.title, old.subtitle, old.narrative);
|
||||
INSERT INTO observations_fts(rowid, title, subtitle, narrative)
|
||||
VALUES (new.id, new.title, new.subtitle, new.narrative);
|
||||
END;
|
||||
`,
|
||||
},
|
||||
{
|
||||
Version: 13,
|
||||
Name: "session_summaries_fts",
|
||||
SQL: `
|
||||
-- FTS5 virtual table for session summaries
|
||||
CREATE VIRTUAL TABLE IF NOT EXISTS session_summaries_fts USING fts5(
|
||||
request, investigated, learned, completed, next_steps, notes,
|
||||
content='session_summaries',
|
||||
content_rowid='id'
|
||||
);
|
||||
|
||||
-- Triggers for FTS5 sync
|
||||
CREATE TRIGGER IF NOT EXISTS session_summaries_ai AFTER INSERT ON session_summaries BEGIN
|
||||
INSERT INTO session_summaries_fts(rowid, request, investigated, learned, completed, next_steps, notes)
|
||||
VALUES (new.id, new.request, new.investigated, new.learned, new.completed, new.next_steps, new.notes);
|
||||
END;
|
||||
|
||||
CREATE TRIGGER IF NOT EXISTS session_summaries_ad AFTER DELETE ON session_summaries BEGIN
|
||||
INSERT INTO session_summaries_fts(session_summaries_fts, rowid, request, investigated, learned, completed, next_steps, notes)
|
||||
VALUES('delete', old.id, old.request, old.investigated, old.learned, old.completed, old.next_steps, old.notes);
|
||||
END;
|
||||
|
||||
CREATE TRIGGER IF NOT EXISTS session_summaries_au AFTER UPDATE ON session_summaries BEGIN
|
||||
INSERT INTO session_summaries_fts(session_summaries_fts, rowid, request, investigated, learned, completed, next_steps, notes)
|
||||
VALUES('delete', old.id, old.request, old.investigated, old.learned, old.completed, old.next_steps, old.notes);
|
||||
INSERT INTO session_summaries_fts(rowid, request, investigated, learned, completed, next_steps, notes)
|
||||
VALUES (new.id, new.request, new.investigated, new.learned, new.completed, new.next_steps, new.notes);
|
||||
END;
|
||||
`,
|
||||
},
|
||||
{
|
||||
Version: 14,
|
||||
Name: "observation_scope_column",
|
||||
SQL: `
|
||||
-- Add scope column for project isolation
|
||||
-- 'project' = only visible within same project (default)
|
||||
-- 'global' = visible across all projects (best practices, patterns)
|
||||
ALTER TABLE observations ADD COLUMN scope TEXT DEFAULT 'project' CHECK(scope IN ('project', 'global'));
|
||||
|
||||
-- Index for efficient scope-based queries
|
||||
CREATE INDEX IF NOT EXISTS idx_observations_scope ON observations(scope);
|
||||
CREATE INDEX IF NOT EXISTS idx_observations_project_scope ON observations(project, scope);
|
||||
`,
|
||||
},
|
||||
{
|
||||
Version: 15,
|
||||
Name: "observation_file_mtimes",
|
||||
SQL: `
|
||||
-- Store file modification times at observation creation
|
||||
-- JSON object: {"path": mtime_epoch_ms, ...}
|
||||
-- Used to detect staleness when files change
|
||||
ALTER TABLE observations ADD COLUMN file_mtimes TEXT;
|
||||
`,
|
||||
},
|
||||
{
|
||||
Version: 16,
|
||||
Name: "prompt_matched_observations",
|
||||
SQL: `
|
||||
-- Track how many observations were found relevant for each prompt
|
||||
-- Displayed in dashboard timeline
|
||||
ALTER TABLE user_prompts ADD COLUMN matched_observations INTEGER DEFAULT 0;
|
||||
`,
|
||||
},
|
||||
{
|
||||
Version: 17,
|
||||
Name: "sqlite_vec_vectors",
|
||||
SQL: `
|
||||
-- Vector embeddings table using sqlite-vec
|
||||
-- Each document (narrative, fact, summary field, prompt) gets one vector
|
||||
-- Uses all-MiniLM-L6-v2 embeddings (384 dimensions)
|
||||
CREATE VIRTUAL TABLE IF NOT EXISTS vectors USING vec0(
|
||||
doc_id TEXT PRIMARY KEY,
|
||||
embedding float[384],
|
||||
sqlite_id INTEGER,
|
||||
doc_type TEXT,
|
||||
field_type TEXT,
|
||||
project TEXT,
|
||||
scope TEXT
|
||||
);
|
||||
`,
|
||||
},
|
||||
{
|
||||
Version: 18,
|
||||
Name: "user_prompts_unique_constraint",
|
||||
SQL: `
|
||||
-- Add unique constraint to prevent duplicate prompts
|
||||
-- This fixes a bug where the user-prompt hook could fire multiple times
|
||||
-- creating duplicate prompt records with incrementing numbers
|
||||
CREATE UNIQUE INDEX IF NOT EXISTS idx_user_prompts_session_number_unique
|
||||
ON user_prompts(claude_session_id, prompt_number);
|
||||
`,
|
||||
},
|
||||
{
|
||||
Version: 19,
|
||||
Name: "vectors_with_model_version",
|
||||
SQL: `
|
||||
-- Drop old vectors table (virtual tables cannot be altered)
|
||||
DROP TABLE IF EXISTS vectors;
|
||||
|
||||
-- Recreate vectors table with model_version column
|
||||
-- Uses bge-small-en-v1.5 embeddings (384 dimensions)
|
||||
CREATE VIRTUAL TABLE IF NOT EXISTS vectors USING vec0(
|
||||
doc_id TEXT PRIMARY KEY,
|
||||
embedding float[384],
|
||||
sqlite_id INTEGER,
|
||||
doc_type TEXT,
|
||||
field_type TEXT,
|
||||
project TEXT,
|
||||
scope TEXT,
|
||||
model_version TEXT
|
||||
);
|
||||
`,
|
||||
},
|
||||
{
|
||||
Version: 20,
|
||||
Name: "importance_scoring",
|
||||
SQL: `
|
||||
-- Importance scoring system for observations
|
||||
-- Implements multi-factor scoring: type weight, recency decay, user feedback, concept weights, retrieval boost
|
||||
|
||||
-- Cached importance score (recalculated periodically)
|
||||
ALTER TABLE observations ADD COLUMN importance_score REAL DEFAULT 1.0;
|
||||
|
||||
-- User feedback: -1 = thumbs down, 0 = neutral, 1 = thumbs up
|
||||
ALTER TABLE observations ADD COLUMN user_feedback INTEGER DEFAULT 0;
|
||||
|
||||
-- Retrieval tracking: how many times this observation was returned in searches
|
||||
ALTER TABLE observations ADD COLUMN retrieval_count INTEGER DEFAULT 0;
|
||||
|
||||
-- Last time this observation was retrieved (for analytics)
|
||||
ALTER TABLE observations ADD COLUMN last_retrieved_at_epoch INTEGER;
|
||||
|
||||
-- Timestamp of last score recalculation
|
||||
ALTER TABLE observations ADD COLUMN score_updated_at_epoch INTEGER;
|
||||
|
||||
-- Index for importance-based sorting (primary ordering strategy)
|
||||
CREATE INDEX IF NOT EXISTS idx_observations_importance
|
||||
ON observations(importance_score DESC, created_at_epoch DESC);
|
||||
|
||||
-- Index for finding observations needing score recalculation
|
||||
CREATE INDEX IF NOT EXISTS idx_observations_score_updated
|
||||
ON observations(score_updated_at_epoch);
|
||||
|
||||
-- Configurable concept weights table
|
||||
-- Allows runtime tuning of how much each concept contributes to importance
|
||||
CREATE TABLE IF NOT EXISTS concept_weights (
|
||||
concept TEXT PRIMARY KEY,
|
||||
weight REAL NOT NULL DEFAULT 0.1,
|
||||
updated_at TEXT NOT NULL
|
||||
);
|
||||
|
||||
-- Seed default concept weights (security highest, tooling lowest)
|
||||
INSERT OR IGNORE INTO concept_weights (concept, weight, updated_at) VALUES
|
||||
('security', 0.30, datetime('now')),
|
||||
('gotcha', 0.25, datetime('now')),
|
||||
('best-practice', 0.20, datetime('now')),
|
||||
('anti-pattern', 0.20, datetime('now')),
|
||||
('architecture', 0.15, datetime('now')),
|
||||
('performance', 0.15, datetime('now')),
|
||||
('error-handling', 0.15, datetime('now')),
|
||||
('pattern', 0.10, datetime('now')),
|
||||
('testing', 0.10, datetime('now')),
|
||||
('debugging', 0.10, datetime('now')),
|
||||
('workflow', 0.05, datetime('now')),
|
||||
('tooling', 0.05, datetime('now'));
|
||||
`,
|
||||
},
|
||||
{
|
||||
Version: 21,
|
||||
Name: "observation_conflicts",
|
||||
SQL: `
|
||||
-- Observation conflicts table for tracking contradictions and superseded observations
|
||||
-- Implements Issue #5: Contradiction & Obsolescence Detection
|
||||
CREATE TABLE IF NOT EXISTS observation_conflicts (
|
||||
id INTEGER PRIMARY KEY AUTOINCREMENT,
|
||||
newer_obs_id INTEGER NOT NULL,
|
||||
older_obs_id INTEGER NOT NULL,
|
||||
conflict_type TEXT NOT NULL CHECK(conflict_type IN ('superseded', 'contradicts', 'outdated_pattern')),
|
||||
resolution TEXT NOT NULL CHECK(resolution IN ('prefer_newer', 'prefer_older', 'manual')),
|
||||
reason TEXT,
|
||||
detected_at TEXT NOT NULL,
|
||||
detected_at_epoch INTEGER NOT NULL,
|
||||
resolved INTEGER DEFAULT 0,
|
||||
resolved_at TEXT,
|
||||
FOREIGN KEY(newer_obs_id) REFERENCES observations(id) ON DELETE CASCADE,
|
||||
FOREIGN KEY(older_obs_id) REFERENCES observations(id) ON DELETE CASCADE
|
||||
);
|
||||
|
||||
-- Index for looking up conflicts by observation ID
|
||||
CREATE INDEX IF NOT EXISTS idx_conflicts_newer ON observation_conflicts(newer_obs_id);
|
||||
CREATE INDEX IF NOT EXISTS idx_conflicts_older ON observation_conflicts(older_obs_id);
|
||||
CREATE INDEX IF NOT EXISTS idx_conflicts_unresolved ON observation_conflicts(resolved, detected_at_epoch DESC);
|
||||
|
||||
-- Add is_superseded column to observations for quick filtering
|
||||
-- Set to 1 when this observation has been superseded by a newer one
|
||||
ALTER TABLE observations ADD COLUMN is_superseded INTEGER DEFAULT 0;
|
||||
|
||||
-- Index for filtering out superseded observations in queries
|
||||
CREATE INDEX IF NOT EXISTS idx_observations_superseded ON observations(is_superseded, importance_score DESC);
|
||||
`,
|
||||
},
|
||||
{
|
||||
Version: 22,
|
||||
Name: "patterns_table",
|
||||
SQL: `
|
||||
-- Pattern Recognition Engine (Issue #7)
|
||||
-- Tracks recurring patterns detected across observations
|
||||
-- Enables Claude to reference historical insights: "I've encountered this pattern 12 times."
|
||||
CREATE TABLE IF NOT EXISTS patterns (
|
||||
id INTEGER PRIMARY KEY AUTOINCREMENT,
|
||||
name TEXT NOT NULL,
|
||||
type TEXT NOT NULL CHECK(type IN ('bug', 'refactor', 'architecture', 'anti-pattern', 'best-practice')),
|
||||
description TEXT,
|
||||
signature TEXT, -- JSON array of keywords/concepts for detection
|
||||
recommendation TEXT, -- What works for this pattern
|
||||
frequency INTEGER DEFAULT 1, -- How many times encountered
|
||||
projects TEXT, -- JSON array of projects where seen
|
||||
observation_ids TEXT, -- JSON array of source observation IDs
|
||||
status TEXT DEFAULT 'active' CHECK(status IN ('active', 'deprecated', 'merged')),
|
||||
merged_into_id INTEGER, -- If status is 'merged', which pattern it merged into
|
||||
confidence REAL DEFAULT 0.5, -- Detection confidence (0.0-1.0)
|
||||
last_seen_at TEXT NOT NULL,
|
||||
last_seen_at_epoch INTEGER NOT NULL,
|
||||
created_at TEXT NOT NULL,
|
||||
created_at_epoch INTEGER NOT NULL,
|
||||
FOREIGN KEY(merged_into_id) REFERENCES patterns(id) ON DELETE SET NULL
|
||||
);
|
||||
|
||||
-- Indexes for efficient pattern queries
|
||||
CREATE INDEX IF NOT EXISTS idx_patterns_type ON patterns(type);
|
||||
CREATE INDEX IF NOT EXISTS idx_patterns_status ON patterns(status);
|
||||
CREATE INDEX IF NOT EXISTS idx_patterns_frequency ON patterns(frequency DESC);
|
||||
CREATE INDEX IF NOT EXISTS idx_patterns_confidence ON patterns(confidence DESC);
|
||||
CREATE INDEX IF NOT EXISTS idx_patterns_last_seen ON patterns(last_seen_at_epoch DESC);
|
||||
|
||||
-- FTS5 virtual table for pattern search
|
||||
CREATE VIRTUAL TABLE IF NOT EXISTS patterns_fts USING fts5(
|
||||
name, description, recommendation,
|
||||
content='patterns',
|
||||
content_rowid='id'
|
||||
);
|
||||
|
||||
-- Triggers for FTS5 sync
|
||||
CREATE TRIGGER IF NOT EXISTS patterns_ai AFTER INSERT ON patterns BEGIN
|
||||
INSERT INTO patterns_fts(rowid, name, description, recommendation)
|
||||
VALUES (new.id, new.name, new.description, new.recommendation);
|
||||
END;
|
||||
|
||||
CREATE TRIGGER IF NOT EXISTS patterns_ad AFTER DELETE ON patterns BEGIN
|
||||
INSERT INTO patterns_fts(patterns_fts, rowid, name, description, recommendation)
|
||||
VALUES('delete', old.id, old.name, old.description, old.recommendation);
|
||||
END;
|
||||
|
||||
CREATE TRIGGER IF NOT EXISTS patterns_au AFTER UPDATE ON patterns BEGIN
|
||||
INSERT INTO patterns_fts(patterns_fts, rowid, name, description, recommendation)
|
||||
VALUES('delete', old.id, old.name, old.description, old.recommendation);
|
||||
INSERT INTO patterns_fts(rowid, name, description, recommendation)
|
||||
VALUES (new.id, new.name, new.description, new.recommendation);
|
||||
END;
|
||||
`,
|
||||
},
|
||||
{
|
||||
Version: 23,
|
||||
Name: "observation_relations",
|
||||
SQL: `
|
||||
-- Knowledge Graph: Observation Relations (Issue #4)
|
||||
-- Tracks explicit relationships between observations for knowledge graph navigation.
|
||||
-- Enables queries like "What caused this bug?" or "What depends on this decision?"
|
||||
CREATE TABLE IF NOT EXISTS observation_relations (
|
||||
id INTEGER PRIMARY KEY AUTOINCREMENT,
|
||||
source_id INTEGER NOT NULL,
|
||||
target_id INTEGER NOT NULL,
|
||||
relation_type TEXT NOT NULL CHECK(relation_type IN ('causes', 'fixes', 'supersedes', 'depends_on', 'relates_to', 'evolves_from')),
|
||||
confidence REAL NOT NULL DEFAULT 0.5,
|
||||
detection_source TEXT NOT NULL CHECK(detection_source IN ('file_overlap', 'embedding_similarity', 'temporal_proximity', 'narrative_mention', 'concept_overlap', 'type_progression')),
|
||||
reason TEXT,
|
||||
created_at TEXT NOT NULL,
|
||||
created_at_epoch INTEGER NOT NULL,
|
||||
FOREIGN KEY(source_id) REFERENCES observations(id) ON DELETE CASCADE,
|
||||
FOREIGN KEY(target_id) REFERENCES observations(id) ON DELETE CASCADE,
|
||||
UNIQUE(source_id, target_id, relation_type)
|
||||
);
|
||||
|
||||
-- Index for finding relations by source observation
|
||||
CREATE INDEX IF NOT EXISTS idx_relations_source ON observation_relations(source_id);
|
||||
|
||||
-- Index for finding relations by target observation
|
||||
CREATE INDEX IF NOT EXISTS idx_relations_target ON observation_relations(target_id);
|
||||
|
||||
-- Index for relation type queries
|
||||
CREATE INDEX IF NOT EXISTS idx_relations_type ON observation_relations(relation_type);
|
||||
|
||||
-- Index for confidence-based filtering
|
||||
CREATE INDEX IF NOT EXISTS idx_relations_confidence ON observation_relations(confidence DESC);
|
||||
|
||||
-- Index for finding all relations involving an observation (either direction)
|
||||
CREATE INDEX IF NOT EXISTS idx_relations_both ON observation_relations(source_id, target_id);
|
||||
`,
|
||||
},
|
||||
}
|
||||
|
||||
// MigrationManager handles database schema migrations.
|
||||
type MigrationManager struct {
|
||||
db *sql.DB
|
||||
}
|
||||
|
||||
// NewMigrationManager creates a new migration manager.
|
||||
func NewMigrationManager(db *sql.DB) *MigrationManager {
|
||||
return &MigrationManager{db: db}
|
||||
}
|
||||
|
||||
// EnsureSchemaVersionsTable creates the schema_versions table if it doesn't exist.
|
||||
func (m *MigrationManager) EnsureSchemaVersionsTable() error {
|
||||
_, err := m.db.Exec(`
|
||||
CREATE TABLE IF NOT EXISTS schema_versions (
|
||||
id INTEGER PRIMARY KEY,
|
||||
version INTEGER UNIQUE NOT NULL,
|
||||
applied_at TEXT NOT NULL
|
||||
)
|
||||
`)
|
||||
return err
|
||||
}
|
||||
|
||||
// GetAppliedVersions returns all applied migration versions.
|
||||
func (m *MigrationManager) GetAppliedVersions() (map[int]bool, error) {
|
||||
rows, err := m.db.Query("SELECT version FROM schema_versions ORDER BY version")
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer rows.Close()
|
||||
|
||||
versions := make(map[int]bool)
|
||||
for rows.Next() {
|
||||
var version int
|
||||
if err := rows.Scan(&version); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
versions[version] = true
|
||||
}
|
||||
return versions, rows.Err()
|
||||
}
|
||||
|
||||
// ApplyMigration applies a single migration.
|
||||
func (m *MigrationManager) ApplyMigration(migration Migration) error {
|
||||
tx, err := m.db.Begin()
|
||||
if err != nil {
|
||||
return fmt.Errorf("begin transaction: %w", err)
|
||||
}
|
||||
defer tx.Rollback()
|
||||
|
||||
// Execute migration SQL
|
||||
if _, err := tx.Exec(migration.SQL); err != nil {
|
||||
return fmt.Errorf("execute migration %d (%s): %w", migration.Version, migration.Name, err)
|
||||
}
|
||||
|
||||
// Record migration
|
||||
_, err = tx.Exec(
|
||||
"INSERT INTO schema_versions (version, applied_at) VALUES (?, ?)",
|
||||
migration.Version, time.Now().Format(time.RFC3339),
|
||||
)
|
||||
if err != nil {
|
||||
return fmt.Errorf("record migration %d: %w", migration.Version, err)
|
||||
}
|
||||
|
||||
return tx.Commit()
|
||||
}
|
||||
|
||||
// RunMigrations applies all pending migrations.
|
||||
func (m *MigrationManager) RunMigrations() error {
|
||||
if err := m.EnsureSchemaVersionsTable(); err != nil {
|
||||
return fmt.Errorf("ensure schema_versions table: %w", err)
|
||||
}
|
||||
|
||||
applied, err := m.GetAppliedVersions()
|
||||
if err != nil {
|
||||
return fmt.Errorf("get applied versions: %w", err)
|
||||
}
|
||||
|
||||
for _, migration := range Migrations {
|
||||
if applied[migration.Version] {
|
||||
continue
|
||||
}
|
||||
|
||||
if err := m.ApplyMigration(migration); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
@@ -1,196 +0,0 @@
|
||||
package sqlite
|
||||
|
||||
import (
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func TestNewMigrationManager(t *testing.T) {
|
||||
db, _, cleanup := testDB(t)
|
||||
defer cleanup()
|
||||
|
||||
manager := NewMigrationManager(db)
|
||||
require.NotNil(t, manager)
|
||||
assert.Equal(t, db, manager.db)
|
||||
}
|
||||
|
||||
func TestMigrationManager_EnsureSchemaVersionsTable(t *testing.T) {
|
||||
db, _, cleanup := testDB(t)
|
||||
defer cleanup()
|
||||
|
||||
manager := NewMigrationManager(db)
|
||||
|
||||
// Should create table without error
|
||||
err := manager.EnsureSchemaVersionsTable()
|
||||
require.NoError(t, err)
|
||||
|
||||
// Table should exist
|
||||
var count int
|
||||
err = db.QueryRow("SELECT COUNT(*) FROM schema_versions").Scan(&count)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, 0, count) // Empty table
|
||||
|
||||
// Calling again should not error (IF NOT EXISTS)
|
||||
err = manager.EnsureSchemaVersionsTable()
|
||||
require.NoError(t, err)
|
||||
}
|
||||
|
||||
func TestMigrationManager_GetAppliedVersions_Empty(t *testing.T) {
|
||||
db, _, cleanup := testDB(t)
|
||||
defer cleanup()
|
||||
|
||||
manager := NewMigrationManager(db)
|
||||
err := manager.EnsureSchemaVersionsTable()
|
||||
require.NoError(t, err)
|
||||
|
||||
versions, err := manager.GetAppliedVersions()
|
||||
require.NoError(t, err)
|
||||
assert.Empty(t, versions)
|
||||
}
|
||||
|
||||
func TestMigrationManager_GetAppliedVersions_WithVersions(t *testing.T) {
|
||||
db, _, cleanup := testDB(t)
|
||||
defer cleanup()
|
||||
|
||||
manager := NewMigrationManager(db)
|
||||
err := manager.EnsureSchemaVersionsTable()
|
||||
require.NoError(t, err)
|
||||
|
||||
// Insert some versions
|
||||
_, err = db.Exec("INSERT INTO schema_versions (version, applied_at) VALUES (1, '2025-01-01T00:00:00Z')")
|
||||
require.NoError(t, err)
|
||||
_, err = db.Exec("INSERT INTO schema_versions (version, applied_at) VALUES (2, '2025-01-02T00:00:00Z')")
|
||||
require.NoError(t, err)
|
||||
|
||||
versions, err := manager.GetAppliedVersions()
|
||||
require.NoError(t, err)
|
||||
assert.Len(t, versions, 2)
|
||||
assert.True(t, versions[1])
|
||||
assert.True(t, versions[2])
|
||||
assert.False(t, versions[3]) // Not applied
|
||||
}
|
||||
|
||||
func TestMigrationManager_ApplyMigration(t *testing.T) {
|
||||
db, _, cleanup := testDB(t)
|
||||
defer cleanup()
|
||||
|
||||
manager := NewMigrationManager(db)
|
||||
err := manager.EnsureSchemaVersionsTable()
|
||||
require.NoError(t, err)
|
||||
|
||||
// Apply a simple migration
|
||||
migration := Migration{
|
||||
Version: 100,
|
||||
Name: "test_migration",
|
||||
SQL: "CREATE TABLE test_table (id INTEGER PRIMARY KEY, name TEXT)",
|
||||
}
|
||||
|
||||
err = manager.ApplyMigration(migration)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Verify table was created
|
||||
var count int
|
||||
err = db.QueryRow("SELECT COUNT(*) FROM sqlite_master WHERE type='table' AND name='test_table'").Scan(&count)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, 1, count)
|
||||
|
||||
// Verify migration was recorded
|
||||
var version int
|
||||
err = db.QueryRow("SELECT version FROM schema_versions WHERE version = 100").Scan(&version)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, 100, version)
|
||||
}
|
||||
|
||||
func TestMigrationManager_ApplyMigration_InvalidSQL(t *testing.T) {
|
||||
db, _, cleanup := testDB(t)
|
||||
defer cleanup()
|
||||
|
||||
manager := NewMigrationManager(db)
|
||||
err := manager.EnsureSchemaVersionsTable()
|
||||
require.NoError(t, err)
|
||||
|
||||
// Try to apply invalid migration
|
||||
migration := Migration{
|
||||
Version: 100,
|
||||
Name: "invalid_migration",
|
||||
SQL: "INVALID SQL SYNTAX",
|
||||
}
|
||||
|
||||
err = manager.ApplyMigration(migration)
|
||||
assert.Error(t, err)
|
||||
assert.Contains(t, err.Error(), "execute migration 100")
|
||||
}
|
||||
|
||||
func TestMigrationManager_RunMigrations_SingleMigration(t *testing.T) {
|
||||
db, _, cleanup := testDB(t)
|
||||
defer cleanup()
|
||||
|
||||
// Create a test migration manager with a subset of migrations
|
||||
manager := NewMigrationManager(db)
|
||||
|
||||
// First ensure schema versions table exists
|
||||
err := manager.EnsureSchemaVersionsTable()
|
||||
require.NoError(t, err)
|
||||
|
||||
// Apply first migration manually
|
||||
err = manager.ApplyMigration(Migrations[0])
|
||||
require.NoError(t, err)
|
||||
|
||||
// Verify the first migration version was recorded
|
||||
versions, err := manager.GetAppliedVersions()
|
||||
require.NoError(t, err)
|
||||
assert.True(t, versions[Migrations[0].Version])
|
||||
}
|
||||
|
||||
func TestMigrationManager_RunMigrations_SkipsApplied(t *testing.T) {
|
||||
db, _, cleanup := testDB(t)
|
||||
defer cleanup()
|
||||
|
||||
manager := NewMigrationManager(db)
|
||||
err := manager.EnsureSchemaVersionsTable()
|
||||
require.NoError(t, err)
|
||||
|
||||
// Mark some migrations as already applied
|
||||
_, err = db.Exec("INSERT INTO schema_versions (version, applied_at) VALUES (4, '2025-01-01T00:00:00Z')")
|
||||
require.NoError(t, err)
|
||||
|
||||
// Get applied versions
|
||||
versions, err := manager.GetAppliedVersions()
|
||||
require.NoError(t, err)
|
||||
assert.True(t, versions[4])
|
||||
}
|
||||
|
||||
func TestMigration_Struct(t *testing.T) {
|
||||
m := Migration{
|
||||
Version: 1,
|
||||
Name: "test",
|
||||
SQL: "SELECT 1",
|
||||
}
|
||||
|
||||
assert.Equal(t, 1, m.Version)
|
||||
assert.Equal(t, "test", m.Name)
|
||||
assert.Equal(t, "SELECT 1", m.SQL)
|
||||
}
|
||||
|
||||
func TestMigrations_List(t *testing.T) {
|
||||
// Verify migrations are ordered correctly
|
||||
assert.NotEmpty(t, Migrations)
|
||||
|
||||
// Verify all migrations have required fields
|
||||
for i, m := range Migrations {
|
||||
assert.Greater(t, m.Version, 0, "Migration %d has invalid version", i)
|
||||
assert.NotEmpty(t, m.Name, "Migration %d has empty name", i)
|
||||
assert.NotEmpty(t, m.SQL, "Migration %d has empty SQL", i)
|
||||
}
|
||||
|
||||
// Verify key migrations exist
|
||||
versionSet := make(map[int]bool)
|
||||
for _, m := range Migrations {
|
||||
versionSet[m.Version] = true
|
||||
}
|
||||
|
||||
assert.True(t, versionSet[4], "Should have sdk_agent_architecture migration")
|
||||
assert.True(t, versionSet[17], "Should have sqlite_vec_vectors migration")
|
||||
}
|
||||
@@ -1,657 +0,0 @@
|
||||
// Package sqlite provides SQLite database operations for claude-mnemonic.
|
||||
package sqlite
|
||||
|
||||
import (
|
||||
"context"
|
||||
"database/sql"
|
||||
"encoding/json"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/lukaszraczylo/claude-mnemonic/pkg/models"
|
||||
)
|
||||
|
||||
// observationColumns is the standard list of columns to select for observations.
|
||||
// This ensures consistency across all observation queries and includes importance scoring fields.
|
||||
const observationColumns = `id, sdk_session_id, project, COALESCE(scope, 'project') as scope, type,
|
||||
title, subtitle, facts, narrative, concepts, files_read, files_modified, file_mtimes,
|
||||
prompt_number, discovery_tokens, created_at, created_at_epoch,
|
||||
COALESCE(importance_score, 1.0) as importance_score,
|
||||
COALESCE(user_feedback, 0) as user_feedback,
|
||||
COALESCE(retrieval_count, 0) as retrieval_count,
|
||||
last_retrieved_at_epoch, score_updated_at_epoch,
|
||||
COALESCE(is_superseded, 0) as is_superseded`
|
||||
|
||||
// CleanupFunc is a callback for when observations are cleaned up.
|
||||
// Receives the IDs of deleted observations for downstream cleanup (e.g., vector DB).
|
||||
type CleanupFunc func(ctx context.Context, deletedIDs []int64)
|
||||
|
||||
// ObservationStore provides observation-related database operations.
|
||||
type ObservationStore struct {
|
||||
store *Store
|
||||
cleanupFunc CleanupFunc
|
||||
conflictStore *ConflictStore
|
||||
relationStore *RelationStore
|
||||
}
|
||||
|
||||
// NewObservationStore creates a new observation store.
|
||||
func NewObservationStore(store *Store) *ObservationStore {
|
||||
return &ObservationStore{store: store}
|
||||
}
|
||||
|
||||
// SetCleanupFunc sets the callback for when observations are deleted during cleanup.
|
||||
func (s *ObservationStore) SetCleanupFunc(fn CleanupFunc) {
|
||||
s.cleanupFunc = fn
|
||||
}
|
||||
|
||||
// SetConflictStore sets the conflict store for conflict detection.
|
||||
func (s *ObservationStore) SetConflictStore(conflictStore *ConflictStore) {
|
||||
s.conflictStore = conflictStore
|
||||
}
|
||||
|
||||
// SetRelationStore sets the relation store for relationship detection.
|
||||
func (s *ObservationStore) SetRelationStore(relationStore *RelationStore) {
|
||||
s.relationStore = relationStore
|
||||
}
|
||||
|
||||
// StoreObservation stores a new observation.
|
||||
func (s *ObservationStore) StoreObservation(ctx context.Context, sdkSessionID, project string, obs *models.ParsedObservation, promptNumber int, discoveryTokens int64) (int64, int64, error) {
|
||||
now := time.Now()
|
||||
nowEpoch := now.UnixMilli()
|
||||
|
||||
// Ensure session exists (auto-create if missing)
|
||||
if err := s.ensureSessionExists(ctx, sdkSessionID, project); err != nil {
|
||||
return 0, 0, err
|
||||
}
|
||||
|
||||
// Determine scope: use parsed scope if set, otherwise auto-determine from concepts
|
||||
scope := obs.Scope
|
||||
if scope == "" {
|
||||
scope = models.DetermineScope(obs.Concepts)
|
||||
}
|
||||
|
||||
factsJSON, _ := json.Marshal(obs.Facts)
|
||||
conceptsJSON, _ := json.Marshal(obs.Concepts)
|
||||
filesReadJSON, _ := json.Marshal(obs.FilesRead)
|
||||
filesModifiedJSON, _ := json.Marshal(obs.FilesModified)
|
||||
fileMtimesJSON, _ := json.Marshal(obs.FileMtimes)
|
||||
|
||||
const query = `
|
||||
INSERT INTO observations
|
||||
(sdk_session_id, project, scope, type, title, subtitle, facts, narrative, concepts,
|
||||
files_read, files_modified, file_mtimes, prompt_number, discovery_tokens, created_at, created_at_epoch)
|
||||
VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)
|
||||
`
|
||||
|
||||
result, err := s.store.ExecContext(ctx, query,
|
||||
sdkSessionID, project, string(scope), string(obs.Type),
|
||||
nullString(obs.Title), nullString(obs.Subtitle),
|
||||
string(factsJSON), nullString(obs.Narrative), string(conceptsJSON),
|
||||
string(filesReadJSON), string(filesModifiedJSON), string(fileMtimesJSON),
|
||||
nullInt(promptNumber), discoveryTokens,
|
||||
now.Format(time.RFC3339), nowEpoch,
|
||||
)
|
||||
if err != nil {
|
||||
return 0, 0, err
|
||||
}
|
||||
|
||||
id, _ := result.LastInsertId()
|
||||
|
||||
// Cleanup old observations beyond the limit for this project (async to not block handler)
|
||||
if project != "" {
|
||||
go func(proj string) {
|
||||
cleanupCtx, cancel := context.WithTimeout(context.Background(), 10*time.Second)
|
||||
defer cancel()
|
||||
deletedIDs, _ := s.CleanupOldObservations(cleanupCtx, proj)
|
||||
if len(deletedIDs) > 0 && s.cleanupFunc != nil {
|
||||
s.cleanupFunc(cleanupCtx, deletedIDs)
|
||||
}
|
||||
}(project)
|
||||
}
|
||||
|
||||
// Detect conflicts with existing observations (async to not block handler)
|
||||
if s.conflictStore != nil && project != "" {
|
||||
go func(newObsID int64, proj string, parsedObs *models.ParsedObservation) {
|
||||
conflictCtx, cancel := context.WithTimeout(context.Background(), 10*time.Second)
|
||||
defer cancel()
|
||||
s.detectAndStoreConflicts(conflictCtx, newObsID, proj, parsedObs)
|
||||
}(id, project, obs)
|
||||
}
|
||||
|
||||
// Detect relationships with existing observations (async to not block handler)
|
||||
if s.relationStore != nil && project != "" {
|
||||
go func(newObsID int64, proj string) {
|
||||
relationCtx, cancel := context.WithTimeout(context.Background(), 10*time.Second)
|
||||
defer cancel()
|
||||
s.detectAndStoreRelations(relationCtx, newObsID, proj)
|
||||
}(id, project)
|
||||
}
|
||||
|
||||
return id, nowEpoch, nil
|
||||
}
|
||||
|
||||
// detectAndStoreConflicts detects conflicts between a new observation and existing ones.
|
||||
func (s *ObservationStore) detectAndStoreConflicts(ctx context.Context, newObsID int64, project string, parsedObs *models.ParsedObservation) {
|
||||
// Fetch the newly stored observation
|
||||
newObs, err := s.GetObservationByID(ctx, newObsID)
|
||||
if err != nil || newObs == nil {
|
||||
return
|
||||
}
|
||||
|
||||
// Fetch recent observations from the same project to check for conflicts
|
||||
existing, err := s.GetRecentObservations(ctx, project, 50)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
|
||||
// Detect conflicts
|
||||
conflicts := models.DetectConflictsWithExisting(newObs, existing)
|
||||
|
||||
// Store conflicts and mark superseded observations
|
||||
var supersededIDs []int64
|
||||
for _, result := range conflicts {
|
||||
for _, olderID := range result.OlderObsIDs {
|
||||
conflict := models.NewObservationConflict(
|
||||
newObsID, olderID,
|
||||
result.Type, result.Resolution, result.Reason,
|
||||
)
|
||||
if _, err := s.conflictStore.StoreConflict(ctx, conflict); err != nil {
|
||||
continue
|
||||
}
|
||||
|
||||
// If resolution is to prefer newer, mark older as superseded
|
||||
if result.Resolution == models.ResolutionPreferNewer {
|
||||
supersededIDs = append(supersededIDs, olderID)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Mark superseded observations
|
||||
if len(supersededIDs) > 0 {
|
||||
_ = s.conflictStore.MarkObservationsSuperseded(ctx, supersededIDs)
|
||||
}
|
||||
|
||||
// Cleanup old superseded observations (older than 3 days)
|
||||
deletedIDs, _ := s.conflictStore.CleanupSupersededObservations(ctx, project)
|
||||
if len(deletedIDs) > 0 && s.cleanupFunc != nil {
|
||||
s.cleanupFunc(ctx, deletedIDs)
|
||||
}
|
||||
}
|
||||
|
||||
// MinRelationConfidence is the minimum confidence threshold for storing relations.
|
||||
const MinRelationConfidence = 0.4
|
||||
|
||||
// detectAndStoreRelations detects relationships between a new observation and existing ones.
|
||||
func (s *ObservationStore) detectAndStoreRelations(ctx context.Context, newObsID int64, project string) {
|
||||
// Fetch the newly stored observation
|
||||
newObs, err := s.GetObservationByID(ctx, newObsID)
|
||||
if err != nil || newObs == nil {
|
||||
return
|
||||
}
|
||||
|
||||
// Fetch recent observations from the same project to check for relations
|
||||
existing, err := s.GetRecentObservations(ctx, project, 50)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
|
||||
// Detect relationships using the models package detection logic
|
||||
results := models.DetectRelationsWithExisting(newObs, existing, MinRelationConfidence)
|
||||
if len(results) == 0 {
|
||||
return
|
||||
}
|
||||
|
||||
// Convert detection results to relation objects
|
||||
relations := make([]*models.ObservationRelation, len(results))
|
||||
for i, r := range results {
|
||||
relations[i] = models.NewObservationRelation(
|
||||
r.SourceID, r.TargetID,
|
||||
r.RelationType, r.Confidence,
|
||||
r.DetectionSource, r.Reason,
|
||||
)
|
||||
}
|
||||
|
||||
// Store all relations
|
||||
_ = s.relationStore.StoreRelations(ctx, relations)
|
||||
}
|
||||
|
||||
// ensureSessionExists creates a session if it doesn't exist.
|
||||
func (s *ObservationStore) ensureSessionExists(ctx context.Context, sdkSessionID, project string) error {
|
||||
return EnsureSessionExists(ctx, s.store, sdkSessionID, project)
|
||||
}
|
||||
|
||||
// GetObservationByID retrieves an observation by ID.
|
||||
func (s *ObservationStore) GetObservationByID(ctx context.Context, id int64) (*models.Observation, error) {
|
||||
query := `SELECT ` + observationColumns + ` FROM observations WHERE id = ?`
|
||||
|
||||
obs, err := scanObservation(s.store.QueryRowContext(ctx, query, id))
|
||||
if err == sql.ErrNoRows {
|
||||
return nil, nil
|
||||
}
|
||||
return obs, err
|
||||
}
|
||||
|
||||
// GetObservationsByIDs retrieves observations by a list of IDs.
|
||||
// Results are ordered by importance_score DESC by default, with created_at_epoch as secondary sort.
|
||||
func (s *ObservationStore) GetObservationsByIDs(ctx context.Context, ids []int64, orderBy string, limit int) ([]*models.Observation, error) {
|
||||
if len(ids) == 0 {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
// Build query with placeholders
|
||||
// #nosec G202 -- query uses parameterized placeholders, not user input
|
||||
query := `SELECT ` + observationColumns + `
|
||||
FROM observations
|
||||
WHERE id IN (?` + repeatPlaceholders(len(ids)-1) + `)
|
||||
ORDER BY `
|
||||
|
||||
// Default to importance-based ordering
|
||||
switch orderBy {
|
||||
case "date_asc":
|
||||
query += "created_at_epoch ASC"
|
||||
case "date_desc":
|
||||
query += "created_at_epoch DESC"
|
||||
case "importance":
|
||||
query += "importance_score DESC, created_at_epoch DESC"
|
||||
default:
|
||||
// Default: importance first, then recency
|
||||
query += "COALESCE(importance_score, 1.0) DESC, created_at_epoch DESC"
|
||||
}
|
||||
|
||||
if limit > 0 {
|
||||
query += " LIMIT ?"
|
||||
}
|
||||
|
||||
// Convert []int64 to []interface{}
|
||||
args := int64SliceToInterface(ids)
|
||||
if limit > 0 {
|
||||
args = append(args, limit)
|
||||
}
|
||||
|
||||
rows, err := s.store.db.QueryContext(ctx, query, args...)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer rows.Close()
|
||||
|
||||
return scanObservationRows(rows)
|
||||
}
|
||||
|
||||
// GetRecentObservations retrieves recent observations for a project.
|
||||
// This includes project-scoped observations for the specified project AND global observations.
|
||||
// Results are ordered by importance_score DESC, then created_at_epoch DESC.
|
||||
func (s *ObservationStore) GetRecentObservations(ctx context.Context, project string, limit int) ([]*models.Observation, error) {
|
||||
query := `SELECT ` + observationColumns + `
|
||||
FROM observations
|
||||
WHERE (project = ? AND (scope IS NULL OR scope = 'project'))
|
||||
OR scope = 'global'
|
||||
ORDER BY COALESCE(importance_score, 1.0) DESC, created_at_epoch DESC
|
||||
LIMIT ?
|
||||
`
|
||||
|
||||
rows, err := s.store.QueryContext(ctx, query, project, limit)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer rows.Close()
|
||||
|
||||
return scanObservationRows(rows)
|
||||
}
|
||||
|
||||
// GetActiveObservations retrieves recent non-superseded observations for a project.
|
||||
// This excludes observations that have been marked as superseded by newer ones.
|
||||
// Use this for context injection where you want to avoid outdated/contradicted advice.
|
||||
// Results are ordered by importance_score DESC, then created_at_epoch DESC.
|
||||
func (s *ObservationStore) GetActiveObservations(ctx context.Context, project string, limit int) ([]*models.Observation, error) {
|
||||
query := `SELECT ` + observationColumns + `
|
||||
FROM observations
|
||||
WHERE ((project = ? AND (scope IS NULL OR scope = 'project'))
|
||||
OR scope = 'global')
|
||||
AND COALESCE(is_superseded, 0) = 0
|
||||
ORDER BY COALESCE(importance_score, 1.0) DESC, created_at_epoch DESC
|
||||
LIMIT ?
|
||||
`
|
||||
|
||||
rows, err := s.store.QueryContext(ctx, query, project, limit)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer rows.Close()
|
||||
|
||||
return scanObservationRows(rows)
|
||||
}
|
||||
|
||||
// GetSupersededObservations retrieves observations that have been superseded by newer ones.
|
||||
// Use this for verification/debugging to see which observations were marked as outdated.
|
||||
// Results are ordered by created_at_epoch DESC.
|
||||
func (s *ObservationStore) GetSupersededObservations(ctx context.Context, project string, limit int) ([]*models.Observation, error) {
|
||||
query := `SELECT ` + observationColumns + `
|
||||
FROM observations
|
||||
WHERE project = ?
|
||||
AND COALESCE(is_superseded, 0) = 1
|
||||
ORDER BY created_at_epoch DESC
|
||||
LIMIT ?
|
||||
`
|
||||
|
||||
rows, err := s.store.QueryContext(ctx, query, project, limit)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer rows.Close()
|
||||
|
||||
return scanObservationRows(rows)
|
||||
}
|
||||
|
||||
// GetObservationsByProjectStrict retrieves observations strictly for a specific project.
|
||||
// Unlike GetRecentObservations, this does NOT include global observations from other projects.
|
||||
// Use this for dashboard filtering where the user expects to see only that project's data.
|
||||
// Results are ordered by importance_score DESC, then created_at_epoch DESC.
|
||||
func (s *ObservationStore) GetObservationsByProjectStrict(ctx context.Context, project string, limit int) ([]*models.Observation, error) {
|
||||
query := `SELECT ` + observationColumns + `
|
||||
FROM observations
|
||||
WHERE project = ?
|
||||
ORDER BY COALESCE(importance_score, 1.0) DESC, created_at_epoch DESC
|
||||
LIMIT ?
|
||||
`
|
||||
|
||||
rows, err := s.store.QueryContext(ctx, query, project, limit)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer rows.Close()
|
||||
|
||||
return scanObservationRows(rows)
|
||||
}
|
||||
|
||||
// GetObservationCount returns the count of observations for a project (including global).
|
||||
func (s *ObservationStore) GetObservationCount(ctx context.Context, project string) (int, error) {
|
||||
const query = `
|
||||
SELECT COUNT(*) FROM observations
|
||||
WHERE project = ? OR scope = 'global'
|
||||
`
|
||||
var count int
|
||||
err := s.store.QueryRowContext(ctx, query, project).Scan(&count)
|
||||
return count, err
|
||||
}
|
||||
|
||||
// GetAllRecentObservations retrieves recent observations across all projects.
|
||||
// Results are ordered by importance_score DESC, then created_at_epoch DESC.
|
||||
func (s *ObservationStore) GetAllRecentObservations(ctx context.Context, limit int) ([]*models.Observation, error) {
|
||||
query := `SELECT ` + observationColumns + `
|
||||
FROM observations
|
||||
ORDER BY COALESCE(importance_score, 1.0) DESC, created_at_epoch DESC
|
||||
LIMIT ?
|
||||
`
|
||||
|
||||
rows, err := s.store.QueryContext(ctx, query, limit)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer rows.Close()
|
||||
|
||||
return scanObservationRows(rows)
|
||||
}
|
||||
|
||||
// GetAllObservations retrieves all observations (for vector rebuild).
|
||||
func (s *ObservationStore) GetAllObservations(ctx context.Context) ([]*models.Observation, error) {
|
||||
query := `SELECT ` + observationColumns + `
|
||||
FROM observations
|
||||
ORDER BY id
|
||||
`
|
||||
|
||||
rows, err := s.store.QueryContext(ctx, query)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer rows.Close()
|
||||
|
||||
return scanObservationRows(rows)
|
||||
}
|
||||
|
||||
// SearchObservationsFTS performs full-text search on observations.
|
||||
// Results are ordered by FTS rank (relevance), then by importance_score.
|
||||
func (s *ObservationStore) SearchObservationsFTS(ctx context.Context, query, project string, limit int) ([]*models.Observation, error) {
|
||||
if limit <= 0 {
|
||||
limit = 10
|
||||
}
|
||||
|
||||
// Extract keywords from the query (words > 3 chars, not common)
|
||||
keywords := extractKeywords(query)
|
||||
if len(keywords) == 0 {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
// Build FTS5 query: keyword1 OR keyword2 OR keyword3
|
||||
ftsTerms := strings.Join(keywords, " OR ")
|
||||
|
||||
// Use FTS5 to search title, subtitle, and narrative
|
||||
// Include importance scoring columns and order by rank then importance
|
||||
ftsQuery := `
|
||||
SELECT o.id, o.sdk_session_id, o.project, COALESCE(o.scope, 'project') as scope, o.type,
|
||||
o.title, o.subtitle, o.facts, o.narrative, o.concepts, o.files_read, o.files_modified,
|
||||
o.file_mtimes, o.prompt_number, o.discovery_tokens, o.created_at, o.created_at_epoch,
|
||||
COALESCE(o.importance_score, 1.0) as importance_score,
|
||||
COALESCE(o.user_feedback, 0) as user_feedback,
|
||||
COALESCE(o.retrieval_count, 0) as retrieval_count,
|
||||
o.last_retrieved_at_epoch, o.score_updated_at_epoch,
|
||||
COALESCE(o.is_superseded, 0) as is_superseded
|
||||
FROM observations o
|
||||
JOIN observations_fts fts ON o.id = fts.rowid
|
||||
WHERE observations_fts MATCH ?
|
||||
AND (o.project = ? OR o.scope = 'global')
|
||||
ORDER BY rank, COALESCE(o.importance_score, 1.0) DESC
|
||||
LIMIT ?
|
||||
`
|
||||
|
||||
rows, err := s.store.QueryContext(ctx, ftsQuery, ftsTerms, project, limit)
|
||||
if err != nil {
|
||||
// FTS failed, try LIKE fallback
|
||||
return s.searchObservationsLike(ctx, keywords, project, limit)
|
||||
}
|
||||
defer rows.Close()
|
||||
|
||||
observations, err := scanObservationRows(rows)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// If FTS returned nothing, try LIKE search
|
||||
if len(observations) == 0 {
|
||||
return s.searchObservationsLike(ctx, keywords, project, limit)
|
||||
}
|
||||
|
||||
return observations, nil
|
||||
}
|
||||
|
||||
// searchObservationsLike performs fallback LIKE search on observations.
|
||||
// Results are ordered by importance_score DESC, then created_at_epoch DESC.
|
||||
func (s *ObservationStore) searchObservationsLike(ctx context.Context, keywords []string, project string, limit int) ([]*models.Observation, error) {
|
||||
if len(keywords) == 0 {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
// Build LIKE conditions for each keyword
|
||||
var conditions []string
|
||||
var args []interface{}
|
||||
|
||||
for _, kw := range keywords {
|
||||
pattern := "%" + kw + "%"
|
||||
conditions = append(conditions, "(title LIKE ? OR subtitle LIKE ? OR narrative LIKE ?)")
|
||||
args = append(args, pattern, pattern, pattern)
|
||||
}
|
||||
|
||||
// #nosec G202 -- query uses parameterized placeholders, not user input
|
||||
query := `SELECT ` + observationColumns + `
|
||||
FROM observations
|
||||
WHERE (` + strings.Join(conditions, " OR ") + `)
|
||||
AND (project = ? OR scope = 'global')
|
||||
ORDER BY COALESCE(importance_score, 1.0) DESC, created_at_epoch DESC
|
||||
LIMIT ?
|
||||
`
|
||||
args = append(args, project, limit)
|
||||
|
||||
rows, err := s.store.db.QueryContext(ctx, query, args...)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer rows.Close()
|
||||
|
||||
return scanObservationRows(rows)
|
||||
}
|
||||
|
||||
// extractKeywords extracts significant words from a query.
|
||||
func extractKeywords(query string) []string {
|
||||
// Common words to skip
|
||||
stopWords := map[string]bool{
|
||||
"what": true, "is": true, "the": true, "a": true, "an": true,
|
||||
"how": true, "does": true, "do": true, "can": true, "could": true,
|
||||
"would": true, "should": true, "where": true, "when": true, "why": true,
|
||||
"which": true, "who": true, "this": true, "that": true, "these": true,
|
||||
"those": true, "it": true, "its": true, "for": true, "from": true,
|
||||
"with": true, "about": true, "into": true, "through": true, "during": true,
|
||||
"before": true, "after": true, "above": true, "below": true, "to": true,
|
||||
"of": true, "in": true, "on": true, "at": true, "by": true, "and": true,
|
||||
"or": true, "but": true, "if": true, "then": true, "else": true,
|
||||
"function": true, "method": true, "class": true, "file": true,
|
||||
"code": true, "work": true, "works": true, "working": true,
|
||||
"please": true, "help": true, "me": true, "my": true, "i": true,
|
||||
"tell": true, "show": true, "explain": true, "describe": true,
|
||||
}
|
||||
|
||||
// Split and filter
|
||||
words := strings.FieldsFunc(strings.ToLower(query), func(r rune) bool {
|
||||
return !((r >= 'a' && r <= 'z') || (r >= 'A' && r <= 'Z') || (r >= '0' && r <= '9') || r == '_')
|
||||
})
|
||||
|
||||
var keywords []string
|
||||
seen := make(map[string]bool)
|
||||
|
||||
for _, word := range words {
|
||||
// Skip short words, stop words, and duplicates
|
||||
if len(word) < 4 || stopWords[word] || seen[word] {
|
||||
continue
|
||||
}
|
||||
seen[word] = true
|
||||
keywords = append(keywords, word)
|
||||
}
|
||||
|
||||
return keywords
|
||||
}
|
||||
|
||||
// DeleteObservations deletes multiple observations by ID.
|
||||
func (s *ObservationStore) DeleteObservations(ctx context.Context, ids []int64) (int64, error) {
|
||||
if len(ids) == 0 {
|
||||
return 0, nil
|
||||
}
|
||||
|
||||
query := `DELETE FROM observations WHERE id IN (?` + repeatPlaceholders(len(ids)-1) + `)` // #nosec G202 -- uses parameterized placeholders
|
||||
|
||||
args := make([]interface{}, len(ids))
|
||||
for i, id := range ids {
|
||||
args[i] = id
|
||||
}
|
||||
|
||||
result, err := s.store.db.ExecContext(ctx, query, args...)
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
return result.RowsAffected()
|
||||
}
|
||||
|
||||
// MaxObservationsPerProject is the hard limit of observations per project.
|
||||
const MaxObservationsPerProject = 100
|
||||
|
||||
// CleanupOldObservations deletes observations beyond the limit for a project.
|
||||
// Keeps the most recent MaxObservationsPerProject observations per project.
|
||||
// Returns the IDs of deleted observations for downstream cleanup (e.g., vector DB).
|
||||
func (s *ObservationStore) CleanupOldObservations(ctx context.Context, project string) ([]int64, error) {
|
||||
// First, find IDs that will be deleted
|
||||
const selectQuery = `
|
||||
SELECT id FROM observations
|
||||
WHERE project = ? AND id NOT IN (
|
||||
SELECT id FROM observations
|
||||
WHERE project = ?
|
||||
ORDER BY created_at_epoch DESC
|
||||
LIMIT ?
|
||||
)
|
||||
`
|
||||
|
||||
rows, err := s.store.QueryContext(ctx, selectQuery, project, project, MaxObservationsPerProject)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer rows.Close()
|
||||
|
||||
var toDelete []int64
|
||||
for rows.Next() {
|
||||
var id int64
|
||||
if err := rows.Scan(&id); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
toDelete = append(toDelete, id)
|
||||
}
|
||||
if err := rows.Err(); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if len(toDelete) == 0 {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
// Delete the observations
|
||||
const deleteQuery = `
|
||||
DELETE FROM observations
|
||||
WHERE project = ? AND id NOT IN (
|
||||
SELECT id FROM observations
|
||||
WHERE project = ?
|
||||
ORDER BY created_at_epoch DESC
|
||||
LIMIT ?
|
||||
)
|
||||
`
|
||||
|
||||
_, err = s.store.ExecContext(ctx, deleteQuery, project, project, MaxObservationsPerProject)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return toDelete, nil
|
||||
}
|
||||
|
||||
// Helper functions
|
||||
|
||||
// scanObservation scans a single observation from a row scanner.
|
||||
// This reduces code duplication across all observation query methods.
|
||||
func scanObservation(scanner interface{ Scan(...interface{}) error }) (*models.Observation, error) {
|
||||
var obs models.Observation
|
||||
if err := scanner.Scan(
|
||||
&obs.ID, &obs.SDKSessionID, &obs.Project, &obs.Scope, &obs.Type,
|
||||
&obs.Title, &obs.Subtitle, &obs.Facts, &obs.Narrative,
|
||||
&obs.Concepts, &obs.FilesRead, &obs.FilesModified, &obs.FileMtimes,
|
||||
&obs.PromptNumber, &obs.DiscoveryTokens,
|
||||
&obs.CreatedAt, &obs.CreatedAtEpoch,
|
||||
// Importance scoring fields
|
||||
&obs.ImportanceScore, &obs.UserFeedback, &obs.RetrievalCount,
|
||||
&obs.LastRetrievedAt, &obs.ScoreUpdatedAt,
|
||||
// Conflict detection fields
|
||||
&obs.IsSuperseded,
|
||||
); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return &obs, nil
|
||||
}
|
||||
|
||||
// scanObservationRows scans multiple observations from rows.
|
||||
// Caller must close rows after calling this function.
|
||||
func scanObservationRows(rows *sql.Rows) ([]*models.Observation, error) {
|
||||
var observations []*models.Observation
|
||||
for rows.Next() {
|
||||
obs, err := scanObservation(rows)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
observations = append(observations, obs)
|
||||
}
|
||||
return observations, rows.Err()
|
||||
}
|
||||
|
||||
// Note: nullString, nullInt, and repeatPlaceholders are in helpers.go
|
||||
@@ -1,947 +0,0 @@
|
||||
// Package sqlite provides SQLite database operations for claude-mnemonic.
|
||||
package sqlite
|
||||
|
||||
import (
|
||||
"context"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/lukaszraczylo/claude-mnemonic/pkg/models"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
"github.com/stretchr/testify/suite"
|
||||
)
|
||||
|
||||
// testObservationStoreBasic creates an ObservationStore with base tables (no FTS5).
|
||||
func testObservationStoreBasic(t *testing.T) (*ObservationStore, *Store, func()) {
|
||||
t.Helper()
|
||||
|
||||
db, _, cleanup := testDB(t)
|
||||
createBaseTables(t, db)
|
||||
|
||||
store := newStoreFromDB(db)
|
||||
obsStore := NewObservationStore(store)
|
||||
|
||||
return obsStore, store, cleanup
|
||||
}
|
||||
|
||||
// testObservationStore creates an ObservationStore with a test database including FTS5.
|
||||
func testObservationStore(t *testing.T) (*ObservationStore, *Store, func()) {
|
||||
t.Helper()
|
||||
|
||||
db, _, cleanup := testDB(t)
|
||||
createAllTables(t, db)
|
||||
|
||||
store := newStoreFromDB(db)
|
||||
obsStore := NewObservationStore(store)
|
||||
|
||||
return obsStore, store, cleanup
|
||||
}
|
||||
|
||||
// ObservationStoreSuite is a test suite for ObservationStore operations.
|
||||
type ObservationStoreSuite struct {
|
||||
suite.Suite
|
||||
obsStore *ObservationStore
|
||||
store *Store
|
||||
cleanup func()
|
||||
}
|
||||
|
||||
func (s *ObservationStoreSuite) SetupTest() {
|
||||
s.obsStore, s.store, s.cleanup = testObservationStoreBasic(s.T())
|
||||
}
|
||||
|
||||
func (s *ObservationStoreSuite) TearDownTest() {
|
||||
if s.cleanup != nil {
|
||||
s.cleanup()
|
||||
}
|
||||
}
|
||||
|
||||
func TestObservationStoreSuite(t *testing.T) {
|
||||
suite.Run(t, new(ObservationStoreSuite))
|
||||
}
|
||||
|
||||
// TestStoreObservation_TableDriven tests observation storage with various scenarios.
|
||||
func (s *ObservationStoreSuite) TestStoreObservation_TableDriven() {
|
||||
ctx := context.Background()
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
sdkSessionID string
|
||||
project string
|
||||
obs *models.ParsedObservation
|
||||
promptNum int
|
||||
tokens int64
|
||||
wantErr bool
|
||||
}{
|
||||
{
|
||||
name: "basic discovery observation",
|
||||
sdkSessionID: "session-basic",
|
||||
project: "project-a",
|
||||
obs: &models.ParsedObservation{
|
||||
Type: models.ObsTypeDiscovery,
|
||||
Title: "Test Title",
|
||||
Subtitle: "Test Subtitle",
|
||||
Narrative: "Test narrative content",
|
||||
Facts: []string{"Fact 1", "Fact 2"},
|
||||
Concepts: []string{"testing", "golang"},
|
||||
},
|
||||
promptNum: 1,
|
||||
tokens: 100,
|
||||
wantErr: false,
|
||||
},
|
||||
{
|
||||
name: "bugfix observation",
|
||||
sdkSessionID: "session-bugfix",
|
||||
project: "project-b",
|
||||
obs: &models.ParsedObservation{
|
||||
Type: models.ObsTypeBugfix,
|
||||
Title: "Fixed null pointer",
|
||||
Narrative: "Fixed null pointer exception in handler",
|
||||
FilesModified: []string{"handler.go"},
|
||||
},
|
||||
promptNum: 2,
|
||||
tokens: 50,
|
||||
wantErr: false,
|
||||
},
|
||||
{
|
||||
name: "global scope observation",
|
||||
sdkSessionID: "session-global",
|
||||
project: "project-c",
|
||||
obs: &models.ParsedObservation{
|
||||
Type: models.ObsTypeDiscovery,
|
||||
Title: "Security best practice",
|
||||
Narrative: "Always validate user input",
|
||||
Concepts: []string{"security", "best-practice"},
|
||||
},
|
||||
promptNum: 1,
|
||||
tokens: 75,
|
||||
wantErr: false,
|
||||
},
|
||||
{
|
||||
name: "observation with files",
|
||||
sdkSessionID: "session-files",
|
||||
project: "project-d",
|
||||
obs: &models.ParsedObservation{
|
||||
Type: models.ObsTypeFeature,
|
||||
Title: "Added authentication",
|
||||
Narrative: "Implemented JWT authentication",
|
||||
FilesRead: []string{"config.go", "auth.go"},
|
||||
FilesModified: []string{"handler.go", "middleware.go"},
|
||||
FileMtimes: map[string]int64{"handler.go": 1234567890, "middleware.go": 1234567891},
|
||||
},
|
||||
promptNum: 3,
|
||||
tokens: 200,
|
||||
wantErr: false,
|
||||
},
|
||||
{
|
||||
name: "minimal observation",
|
||||
sdkSessionID: "session-minimal",
|
||||
project: "project-e",
|
||||
obs: &models.ParsedObservation{
|
||||
Type: models.ObsTypeChange,
|
||||
},
|
||||
promptNum: 0,
|
||||
tokens: 0,
|
||||
wantErr: false,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
s.Run(tt.name, func() {
|
||||
id, epoch, err := s.obsStore.StoreObservation(ctx, tt.sdkSessionID, tt.project, tt.obs, tt.promptNum, tt.tokens)
|
||||
if tt.wantErr {
|
||||
s.Error(err)
|
||||
return
|
||||
}
|
||||
|
||||
s.NoError(err)
|
||||
s.Greater(id, int64(0))
|
||||
s.Greater(epoch, int64(0))
|
||||
|
||||
// Retrieve and verify
|
||||
retrieved, err := s.obsStore.GetObservationByID(ctx, id)
|
||||
s.NoError(err)
|
||||
s.NotNil(retrieved)
|
||||
s.Equal(id, retrieved.ID)
|
||||
s.Equal(tt.project, retrieved.Project)
|
||||
s.Equal(tt.obs.Type, retrieved.Type)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// TestGetObservationByID_NotFound tests retrieval of non-existent observation.
|
||||
func (s *ObservationStoreSuite) TestGetObservationByID_NotFound() {
|
||||
ctx := context.Background()
|
||||
|
||||
obs, err := s.obsStore.GetObservationByID(ctx, 99999)
|
||||
s.NoError(err)
|
||||
s.Nil(obs)
|
||||
}
|
||||
|
||||
// TestGetRecentObservations_TableDriven tests recent observations retrieval.
|
||||
func (s *ObservationStoreSuite) TestGetRecentObservations_TableDriven() {
|
||||
ctx := context.Background()
|
||||
|
||||
// Create 15 observations
|
||||
for i := 0; i < 15; i++ {
|
||||
obs := &models.ParsedObservation{
|
||||
Type: models.ObsTypeDiscovery,
|
||||
Title: "Observation " + string(rune('A'+i)),
|
||||
}
|
||||
_, _, err := s.obsStore.StoreObservation(ctx, "session-"+string(rune('0'+i)), "project-a", obs, i, 10)
|
||||
s.NoError(err)
|
||||
time.Sleep(time.Millisecond) // Ensure different timestamps
|
||||
}
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
project string
|
||||
limit int
|
||||
wantCount int
|
||||
}{
|
||||
{
|
||||
name: "limit 5",
|
||||
project: "project-a",
|
||||
limit: 5,
|
||||
wantCount: 5,
|
||||
},
|
||||
{
|
||||
name: "limit 10",
|
||||
project: "project-a",
|
||||
limit: 10,
|
||||
wantCount: 10,
|
||||
},
|
||||
{
|
||||
name: "limit higher than count",
|
||||
project: "project-a",
|
||||
limit: 50,
|
||||
wantCount: 15,
|
||||
},
|
||||
{
|
||||
name: "different project (no results)",
|
||||
project: "project-b",
|
||||
limit: 10,
|
||||
wantCount: 0,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
s.Run(tt.name, func() {
|
||||
observations, err := s.obsStore.GetRecentObservations(ctx, tt.project, tt.limit)
|
||||
s.NoError(err)
|
||||
s.Len(observations, tt.wantCount)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// TestDeleteObservations_TableDriven tests observation deletion.
|
||||
func (s *ObservationStoreSuite) TestDeleteObservations_TableDriven() {
|
||||
ctx := context.Background()
|
||||
|
||||
// Create observations
|
||||
var ids []int64
|
||||
for i := 0; i < 5; i++ {
|
||||
obs := &models.ParsedObservation{
|
||||
Type: models.ObsTypeDiscovery,
|
||||
Title: "To delete " + string(rune('A'+i)),
|
||||
}
|
||||
id, _, err := s.obsStore.StoreObservation(ctx, "session-del", "project-del", obs, i, 10)
|
||||
s.NoError(err)
|
||||
ids = append(ids, id)
|
||||
}
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
toDelete []int64
|
||||
wantDeleted int64
|
||||
wantRemain int
|
||||
}{
|
||||
{
|
||||
name: "delete none",
|
||||
toDelete: []int64{},
|
||||
wantDeleted: 0,
|
||||
wantRemain: 5,
|
||||
},
|
||||
{
|
||||
name: "delete one",
|
||||
toDelete: ids[0:1],
|
||||
wantDeleted: 1,
|
||||
wantRemain: 4,
|
||||
},
|
||||
{
|
||||
name: "delete remaining",
|
||||
toDelete: ids[1:],
|
||||
wantDeleted: 4,
|
||||
wantRemain: 0,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
s.Run(tt.name, func() {
|
||||
deleted, err := s.obsStore.DeleteObservations(ctx, tt.toDelete)
|
||||
s.NoError(err)
|
||||
s.Equal(tt.wantDeleted, deleted)
|
||||
|
||||
remaining, err := s.obsStore.GetAllRecentObservations(ctx, 100)
|
||||
s.NoError(err)
|
||||
s.Len(remaining, tt.wantRemain)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// TestGetObservationsByIDs tests retrieval by multiple IDs.
|
||||
func (s *ObservationStoreSuite) TestGetObservationsByIDs() {
|
||||
ctx := context.Background()
|
||||
|
||||
// Create observations
|
||||
var ids []int64
|
||||
for i := 0; i < 5; i++ {
|
||||
obs := &models.ParsedObservation{
|
||||
Type: models.ObsTypeDiscovery,
|
||||
Title: "By ID " + string(rune('A'+i)),
|
||||
}
|
||||
id, _, err := s.obsStore.StoreObservation(ctx, "session-byid", "project-byid", obs, i, 10)
|
||||
s.NoError(err)
|
||||
ids = append(ids, id)
|
||||
time.Sleep(time.Millisecond)
|
||||
}
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
queryIDs []int64
|
||||
orderBy string
|
||||
limit int
|
||||
wantCount int
|
||||
}{
|
||||
{
|
||||
name: "empty IDs",
|
||||
queryIDs: []int64{},
|
||||
orderBy: "date_desc",
|
||||
limit: 10,
|
||||
wantCount: 0,
|
||||
},
|
||||
{
|
||||
name: "single ID",
|
||||
queryIDs: ids[0:1],
|
||||
orderBy: "date_desc",
|
||||
limit: 10,
|
||||
wantCount: 1,
|
||||
},
|
||||
{
|
||||
name: "all IDs",
|
||||
queryIDs: ids,
|
||||
orderBy: "date_desc",
|
||||
limit: 10,
|
||||
wantCount: 5,
|
||||
},
|
||||
{
|
||||
name: "with limit less than IDs",
|
||||
queryIDs: ids,
|
||||
orderBy: "date_desc",
|
||||
limit: 3,
|
||||
wantCount: 3,
|
||||
},
|
||||
{
|
||||
name: "ascending order",
|
||||
queryIDs: ids,
|
||||
orderBy: "date_asc",
|
||||
limit: 10,
|
||||
wantCount: 5,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
s.Run(tt.name, func() {
|
||||
observations, err := s.obsStore.GetObservationsByIDs(ctx, tt.queryIDs, tt.orderBy, tt.limit)
|
||||
if tt.wantCount == 0 {
|
||||
s.NoError(err)
|
||||
s.Nil(observations)
|
||||
} else {
|
||||
s.NoError(err)
|
||||
s.Len(observations, tt.wantCount)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// TestGlobalScope tests global vs project scope.
|
||||
func (s *ObservationStoreSuite) TestGlobalScope() {
|
||||
ctx := context.Background()
|
||||
|
||||
// Create project-scoped observation
|
||||
projectObs := &models.ParsedObservation{
|
||||
Type: models.ObsTypeDiscovery,
|
||||
Title: "Project specific",
|
||||
Concepts: []string{"project-specific"},
|
||||
}
|
||||
_, _, err := s.obsStore.StoreObservation(ctx, "session-scope", "project-a", projectObs, 1, 10)
|
||||
s.NoError(err)
|
||||
|
||||
// Create global-scoped observation (security concept triggers global)
|
||||
globalObs := &models.ParsedObservation{
|
||||
Type: models.ObsTypeDiscovery,
|
||||
Title: "Global security",
|
||||
Concepts: []string{"security"},
|
||||
}
|
||||
_, _, err = s.obsStore.StoreObservation(ctx, "session-scope", "project-a", globalObs, 2, 10)
|
||||
s.NoError(err)
|
||||
|
||||
// Project-a should see both
|
||||
resultsA, err := s.obsStore.GetRecentObservations(ctx, "project-a", 10)
|
||||
s.NoError(err)
|
||||
s.Len(resultsA, 2)
|
||||
|
||||
// Project-b should only see global
|
||||
resultsB, err := s.obsStore.GetRecentObservations(ctx, "project-b", 10)
|
||||
s.NoError(err)
|
||||
s.Len(resultsB, 1)
|
||||
s.Equal("Global security", resultsB[0].Title.String)
|
||||
s.Equal(models.ScopeGlobal, resultsB[0].Scope)
|
||||
}
|
||||
|
||||
// TestSetCleanupFunc tests the cleanup function callback.
|
||||
func (s *ObservationStoreSuite) TestSetCleanupFunc() {
|
||||
ctx := context.Background()
|
||||
|
||||
var calledWith []int64
|
||||
s.obsStore.SetCleanupFunc(func(ctx context.Context, deletedIDs []int64) {
|
||||
calledWith = deletedIDs
|
||||
})
|
||||
|
||||
// Store an observation
|
||||
obs := &models.ParsedObservation{
|
||||
Type: models.ObsTypeDiscovery,
|
||||
Title: "Test cleanup",
|
||||
}
|
||||
_, _, err := s.obsStore.StoreObservation(ctx, "session-cleanup", "project-cleanup", obs, 1, 10)
|
||||
s.NoError(err)
|
||||
|
||||
// Cleanup should not have been called since nothing was deleted
|
||||
s.Empty(calledWith)
|
||||
}
|
||||
|
||||
// TestGetObservationCount tests observation counting.
|
||||
func (s *ObservationStoreSuite) TestGetObservationCount() {
|
||||
ctx := context.Background()
|
||||
|
||||
// Create observations for project-a
|
||||
for i := 0; i < 5; i++ {
|
||||
obs := &models.ParsedObservation{
|
||||
Type: models.ObsTypeDiscovery,
|
||||
}
|
||||
_, _, err := s.obsStore.StoreObservation(ctx, "session-count", "project-a", obs, i, 10)
|
||||
s.NoError(err)
|
||||
}
|
||||
|
||||
// Create global observation
|
||||
globalObs := &models.ParsedObservation{
|
||||
Type: models.ObsTypeDiscovery,
|
||||
Concepts: []string{"security"},
|
||||
}
|
||||
_, _, err := s.obsStore.StoreObservation(ctx, "session-count", "project-a", globalObs, 6, 10)
|
||||
s.NoError(err)
|
||||
|
||||
// Project-a should count 6 (5 project + 1 global)
|
||||
count, err := s.obsStore.GetObservationCount(ctx, "project-a")
|
||||
s.NoError(err)
|
||||
s.Equal(6, count)
|
||||
|
||||
// Project-b should count 1 (only global)
|
||||
count, err = s.obsStore.GetObservationCount(ctx, "project-b")
|
||||
s.NoError(err)
|
||||
s.Equal(1, count)
|
||||
}
|
||||
|
||||
func TestObservationStore_StoreAndRetrieve(t *testing.T) {
|
||||
obsStore, _, cleanup := testObservationStore(t)
|
||||
defer cleanup()
|
||||
|
||||
ctx := context.Background()
|
||||
|
||||
obs := &models.ParsedObservation{
|
||||
Type: models.ObsTypeDiscovery,
|
||||
Title: "Test Observation",
|
||||
Subtitle: "A subtitle",
|
||||
Narrative: "This is a test observation about testing",
|
||||
Facts: []string{"Fact 1", "Fact 2"},
|
||||
Concepts: []string{"testing", "golang"},
|
||||
FilesRead: []string{"test.go"},
|
||||
FilesModified: []string{},
|
||||
}
|
||||
|
||||
id, epoch, err := obsStore.StoreObservation(ctx, "session-1", "project-a", obs, 1, 100)
|
||||
require.NoError(t, err)
|
||||
assert.Greater(t, id, int64(0))
|
||||
assert.Greater(t, epoch, int64(0))
|
||||
|
||||
// Retrieve by ID
|
||||
retrieved, err := obsStore.GetObservationByID(ctx, id)
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, retrieved)
|
||||
|
||||
assert.Equal(t, id, retrieved.ID)
|
||||
assert.Equal(t, "session-1", retrieved.SDKSessionID)
|
||||
assert.Equal(t, "project-a", retrieved.Project)
|
||||
assert.Equal(t, models.ObsTypeDiscovery, retrieved.Type)
|
||||
assert.Equal(t, "Test Observation", retrieved.Title.String)
|
||||
assert.Equal(t, "A subtitle", retrieved.Subtitle.String)
|
||||
assert.Equal(t, "This is a test observation about testing", retrieved.Narrative.String)
|
||||
assert.Equal(t, []string{"Fact 1", "Fact 2"}, []string(retrieved.Facts))
|
||||
assert.Equal(t, []string{"testing", "golang"}, []string(retrieved.Concepts))
|
||||
}
|
||||
|
||||
func TestObservationStore_GetRecentObservations(t *testing.T) {
|
||||
obsStore, _, cleanup := testObservationStore(t)
|
||||
defer cleanup()
|
||||
|
||||
ctx := context.Background()
|
||||
|
||||
// Create multiple observations
|
||||
for i := 0; i < 10; i++ {
|
||||
obs := &models.ParsedObservation{
|
||||
Type: models.ObsTypeDiscovery,
|
||||
Title: "Observation " + string(rune('A'+i)),
|
||||
Narrative: "Content " + string(rune('A'+i)),
|
||||
Concepts: []string{"test"},
|
||||
}
|
||||
_, _, err := obsStore.StoreObservation(ctx, "session-1", "project-a", obs, i+1, 100)
|
||||
require.NoError(t, err)
|
||||
time.Sleep(time.Millisecond) // Ensure different timestamps
|
||||
}
|
||||
|
||||
// Get recent with limit 5
|
||||
recent, err := obsStore.GetRecentObservations(ctx, "project-a", 5)
|
||||
require.NoError(t, err)
|
||||
assert.Len(t, recent, 5)
|
||||
|
||||
// Get recent with limit 20 (more than exists)
|
||||
recent, err = obsStore.GetRecentObservations(ctx, "project-a", 20)
|
||||
require.NoError(t, err)
|
||||
assert.Len(t, recent, 10)
|
||||
}
|
||||
|
||||
func TestObservationStore_SearchObservationsFTS(t *testing.T) {
|
||||
obsStore, _, cleanup := testObservationStore(t)
|
||||
defer cleanup()
|
||||
|
||||
// FTS5 tables are created by testObservationStore via testutil.CreateAllTables
|
||||
ctx := context.Background()
|
||||
|
||||
// Create observations with different content
|
||||
observations := []struct {
|
||||
title string
|
||||
narrative string
|
||||
}{
|
||||
{"Authentication implementation", "JWT based authentication flow"},
|
||||
{"Database setup", "PostgreSQL configuration and migrations"},
|
||||
{"Caching layer", "Redis caching implementation"},
|
||||
{"User authentication fix", "Fixed authentication bug in login"},
|
||||
{"API endpoints", "REST API implementation details"},
|
||||
}
|
||||
|
||||
for _, o := range observations {
|
||||
obs := &models.ParsedObservation{
|
||||
Type: models.ObsTypeDiscovery,
|
||||
Title: o.title,
|
||||
Narrative: o.narrative,
|
||||
}
|
||||
_, _, err := obsStore.StoreObservation(ctx, "session-1", "project-a", obs, 1, 100)
|
||||
require.NoError(t, err)
|
||||
time.Sleep(time.Millisecond)
|
||||
}
|
||||
|
||||
// Search for authentication - should find 2 observations
|
||||
results, err := obsStore.SearchObservationsFTS(ctx, "authentication", "project-a", 50)
|
||||
require.NoError(t, err)
|
||||
assert.GreaterOrEqual(t, len(results), 2, "should find at least 2 authentication-related observations")
|
||||
|
||||
// Search for database - should find 1 observation
|
||||
results, err = obsStore.SearchObservationsFTS(ctx, "database PostgreSQL", "project-a", 50)
|
||||
require.NoError(t, err)
|
||||
assert.GreaterOrEqual(t, len(results), 1, "should find at least 1 database-related observation")
|
||||
}
|
||||
|
||||
func TestObservationStore_SearchObservationsFTS_LimitRespected(t *testing.T) {
|
||||
obsStore, _, cleanup := testObservationStore(t)
|
||||
defer cleanup()
|
||||
|
||||
// FTS5 tables are created by testObservationStore via testutil.CreateAllTables
|
||||
ctx := context.Background()
|
||||
|
||||
// Create 20 observations with similar content
|
||||
for i := 0; i < 20; i++ {
|
||||
obs := &models.ParsedObservation{
|
||||
Type: models.ObsTypeDiscovery,
|
||||
Title: "Testing observation " + string(rune('A'+i)),
|
||||
Narrative: "This is about testing and quality assurance " + string(rune('A'+i)),
|
||||
}
|
||||
_, _, err := obsStore.StoreObservation(ctx, "session-1", "project-a", obs, 1, 100)
|
||||
require.NoError(t, err)
|
||||
time.Sleep(time.Millisecond)
|
||||
}
|
||||
|
||||
// Search with limit 5
|
||||
results, err := obsStore.SearchObservationsFTS(ctx, "testing quality", "project-a", 5)
|
||||
require.NoError(t, err)
|
||||
assert.LessOrEqual(t, len(results), 5, "should respect limit of 5")
|
||||
|
||||
// Search with limit 15
|
||||
results, err = obsStore.SearchObservationsFTS(ctx, "testing quality", "project-a", 15)
|
||||
require.NoError(t, err)
|
||||
assert.LessOrEqual(t, len(results), 15, "should respect limit of 15")
|
||||
|
||||
// Search with limit 50 (our new default)
|
||||
results, err = obsStore.SearchObservationsFTS(ctx, "testing quality", "project-a", 50)
|
||||
require.NoError(t, err)
|
||||
assert.LessOrEqual(t, len(results), 50, "should respect limit of 50")
|
||||
assert.Equal(t, 20, len(results), "should return all 20 matching observations")
|
||||
}
|
||||
|
||||
func TestObservationStore_GlobalScope(t *testing.T) {
|
||||
obsStore, _, cleanup := testObservationStore(t)
|
||||
defer cleanup()
|
||||
|
||||
ctx := context.Background()
|
||||
|
||||
// Create a project-scoped observation
|
||||
projectObs := &models.ParsedObservation{
|
||||
Type: models.ObsTypeDiscovery,
|
||||
Title: "Project specific code",
|
||||
Narrative: "This is specific to project-a",
|
||||
Concepts: []string{"project-specific"},
|
||||
}
|
||||
_, _, err := obsStore.StoreObservation(ctx, "session-1", "project-a", projectObs, 1, 100)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Create a global-scoped observation (has a globalizable concept)
|
||||
globalObs := &models.ParsedObservation{
|
||||
Type: models.ObsTypeDiscovery,
|
||||
Title: "Security best practice",
|
||||
Narrative: "Always validate user input",
|
||||
Concepts: []string{"security", "best-practice"}, // "security" is in GlobalizableConcepts
|
||||
}
|
||||
_, _, err = obsStore.StoreObservation(ctx, "session-1", "project-a", globalObs, 1, 100)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Get recent for project-a - should see both
|
||||
results, err := obsStore.GetRecentObservations(ctx, "project-a", 10)
|
||||
require.NoError(t, err)
|
||||
assert.Len(t, results, 2)
|
||||
|
||||
// Get recent for project-b - should only see global observation
|
||||
results, err = obsStore.GetRecentObservations(ctx, "project-b", 10)
|
||||
require.NoError(t, err)
|
||||
assert.Len(t, results, 1)
|
||||
assert.Equal(t, "Security best practice", results[0].Title.String)
|
||||
assert.Equal(t, models.ScopeGlobal, results[0].Scope)
|
||||
}
|
||||
|
||||
func TestObservationStore_DeleteObservations(t *testing.T) {
|
||||
obsStore, _, cleanup := testObservationStore(t)
|
||||
defer cleanup()
|
||||
|
||||
ctx := context.Background()
|
||||
|
||||
// Create observations
|
||||
var ids []int64
|
||||
for i := 0; i < 5; i++ {
|
||||
obs := &models.ParsedObservation{
|
||||
Type: models.ObsTypeDiscovery,
|
||||
Title: "Observation " + string(rune('A'+i)),
|
||||
}
|
||||
id, _, err := obsStore.StoreObservation(ctx, "session-1", "project-a", obs, 1, 100)
|
||||
require.NoError(t, err)
|
||||
ids = append(ids, id)
|
||||
}
|
||||
|
||||
// Verify all exist
|
||||
all, err := obsStore.GetRecentObservations(ctx, "project-a", 10)
|
||||
require.NoError(t, err)
|
||||
assert.Len(t, all, 5)
|
||||
|
||||
// Delete first 3
|
||||
deleted, err := obsStore.DeleteObservations(ctx, ids[:3])
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, int64(3), deleted)
|
||||
|
||||
// Verify only 2 remain
|
||||
remaining, err := obsStore.GetRecentObservations(ctx, "project-a", 10)
|
||||
require.NoError(t, err)
|
||||
assert.Len(t, remaining, 2)
|
||||
}
|
||||
|
||||
func TestObservationStore_GetObservationCount(t *testing.T) {
|
||||
obsStore, _, cleanup := testObservationStore(t)
|
||||
defer cleanup()
|
||||
|
||||
ctx := context.Background()
|
||||
|
||||
// Create observations for different projects
|
||||
for i := 0; i < 5; i++ {
|
||||
obs := &models.ParsedObservation{
|
||||
Type: models.ObsTypeDiscovery,
|
||||
Title: "Project A observation " + string(rune('0'+i)),
|
||||
}
|
||||
_, _, err := obsStore.StoreObservation(ctx, "session-1", "project-a", obs, 1, 100)
|
||||
require.NoError(t, err)
|
||||
}
|
||||
|
||||
for i := 0; i < 3; i++ {
|
||||
obs := &models.ParsedObservation{
|
||||
Type: models.ObsTypeDiscovery,
|
||||
Title: "Project B observation " + string(rune('0'+i)),
|
||||
}
|
||||
_, _, err := obsStore.StoreObservation(ctx, "session-1", "project-b", obs, 1, 100)
|
||||
require.NoError(t, err)
|
||||
}
|
||||
|
||||
// Create a global observation
|
||||
globalObs := &models.ParsedObservation{
|
||||
Type: models.ObsTypeDiscovery,
|
||||
Title: "Global observation",
|
||||
Concepts: []string{"best-practice"}, // Makes it global
|
||||
}
|
||||
_, _, err := obsStore.StoreObservation(ctx, "session-1", "project-a", globalObs, 1, 100)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Count for project-a includes its own + global
|
||||
count, err := obsStore.GetObservationCount(ctx, "project-a")
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, 6, count) // 5 project-a + 1 global
|
||||
|
||||
// Count for project-b includes its own + global
|
||||
count, err = obsStore.GetObservationCount(ctx, "project-b")
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, 4, count) // 3 project-b + 1 global
|
||||
}
|
||||
|
||||
func TestObservationStore_CleanupOldObservations(t *testing.T) {
|
||||
obsStore, _, cleanup := testObservationStore(t)
|
||||
defer cleanup()
|
||||
|
||||
ctx := context.Background()
|
||||
|
||||
// Create more observations than the limit (MaxObservationsPerProject = 100)
|
||||
// We'll create a smaller number and verify the logic works
|
||||
for i := 0; i < 10; i++ {
|
||||
obs := &models.ParsedObservation{
|
||||
Type: models.ObsTypeDiscovery,
|
||||
Title: "Observation " + string(rune('A'+i)),
|
||||
}
|
||||
_, _, err := obsStore.StoreObservation(ctx, "session-1", "project-a", obs, i+1, 100)
|
||||
require.NoError(t, err)
|
||||
time.Sleep(time.Millisecond)
|
||||
}
|
||||
|
||||
// Cleanup should return empty since we're under the limit
|
||||
deletedIDs, err := obsStore.CleanupOldObservations(ctx, "project-a")
|
||||
require.NoError(t, err)
|
||||
assert.Empty(t, deletedIDs)
|
||||
|
||||
// All 10 should still exist
|
||||
count, err := obsStore.GetObservationCount(ctx, "project-a")
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, 10, count)
|
||||
}
|
||||
|
||||
func TestObservationStore_SetCleanupFunc(t *testing.T) {
|
||||
obsStore, _, cleanup := testObservationStore(t)
|
||||
defer cleanup()
|
||||
|
||||
ctx := context.Background()
|
||||
|
||||
// Track cleanup calls
|
||||
var cleanupCalledWith []int64
|
||||
obsStore.SetCleanupFunc(func(ctx context.Context, deletedIDs []int64) {
|
||||
cleanupCalledWith = deletedIDs
|
||||
})
|
||||
|
||||
// Store an observation (should trigger cleanup, but won't delete anything under limit)
|
||||
obs := &models.ParsedObservation{
|
||||
Type: models.ObsTypeDiscovery,
|
||||
Title: "Test observation",
|
||||
}
|
||||
_, _, err := obsStore.StoreObservation(ctx, "session-1", "project-a", obs, 1, 100)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Cleanup func should not have been called since nothing was deleted
|
||||
assert.Empty(t, cleanupCalledWith)
|
||||
}
|
||||
|
||||
func TestExtractKeywords(t *testing.T) {
|
||||
tests := []struct {
|
||||
query string
|
||||
expected []string
|
||||
}{
|
||||
{
|
||||
query: "What is the authentication flow?",
|
||||
expected: []string{"authentication", "flow"},
|
||||
},
|
||||
{
|
||||
query: "How does the database connection work?",
|
||||
expected: []string{"database", "connection"},
|
||||
},
|
||||
{
|
||||
query: "JWT token validation",
|
||||
expected: []string{"token", "validation"},
|
||||
},
|
||||
{
|
||||
query: "the a an is are", // All stop words
|
||||
expected: nil,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.query, func(t *testing.T) {
|
||||
keywords := extractKeywords(tt.query)
|
||||
for _, exp := range tt.expected {
|
||||
assert.Contains(t, keywords, exp, "should contain keyword: "+exp)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestObservationStore_GetObservationsByProjectStrict(t *testing.T) {
|
||||
obsStore, _, cleanup := testObservationStore(t)
|
||||
defer cleanup()
|
||||
|
||||
ctx := context.Background()
|
||||
|
||||
// Create project-scoped observation for project-a
|
||||
projectObs := &models.ParsedObservation{
|
||||
Type: models.ObsTypeDiscovery,
|
||||
Title: "Project A specific",
|
||||
Narrative: "Only for project-a",
|
||||
Concepts: []string{"local-concept"},
|
||||
}
|
||||
_, _, err := obsStore.StoreObservation(ctx, "session-1", "project-a", projectObs, 1, 100)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Create global observation from project-a
|
||||
globalObs := &models.ParsedObservation{
|
||||
Type: models.ObsTypeDiscovery,
|
||||
Title: "Global security practice",
|
||||
Narrative: "Best practice for all",
|
||||
Concepts: []string{"security", "best-practice"},
|
||||
}
|
||||
_, _, err = obsStore.StoreObservation(ctx, "session-1", "project-a", globalObs, 2, 100)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Create observation for project-b
|
||||
projectBObs := &models.ParsedObservation{
|
||||
Type: models.ObsTypeDiscovery,
|
||||
Title: "Project B specific",
|
||||
Narrative: "Only for project-b",
|
||||
}
|
||||
_, _, err = obsStore.StoreObservation(ctx, "session-1", "project-b", projectBObs, 1, 100)
|
||||
require.NoError(t, err)
|
||||
|
||||
// GetObservationsByProjectStrict for project-a should only return project-a observations
|
||||
// This is different from GetRecentObservations which includes globals from other projects
|
||||
results, err := obsStore.GetObservationsByProjectStrict(ctx, "project-a", 10)
|
||||
require.NoError(t, err)
|
||||
assert.Len(t, results, 2) // Only observations created in project-a
|
||||
|
||||
// Verify both are from project-a
|
||||
for _, obs := range results {
|
||||
assert.Equal(t, "project-a", obs.Project)
|
||||
}
|
||||
|
||||
// GetObservationsByProjectStrict for project-b should only return project-b observations
|
||||
results, err = obsStore.GetObservationsByProjectStrict(ctx, "project-b", 10)
|
||||
require.NoError(t, err)
|
||||
assert.Len(t, results, 1)
|
||||
assert.Equal(t, "Project B specific", results[0].Title.String)
|
||||
}
|
||||
|
||||
func TestObservationStore_SearchObservationsFTS_EmptyQuery(t *testing.T) {
|
||||
obsStore, _, cleanup := testObservationStore(t)
|
||||
defer cleanup()
|
||||
|
||||
ctx := context.Background()
|
||||
|
||||
// Create an observation
|
||||
obs := &models.ParsedObservation{
|
||||
Type: models.ObsTypeDiscovery,
|
||||
Title: "Test observation",
|
||||
Narrative: "Some content here",
|
||||
}
|
||||
_, _, err := obsStore.StoreObservation(ctx, "session-1", "project-a", obs, 1, 100)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Search with only stop words (should return nil)
|
||||
results, err := obsStore.SearchObservationsFTS(ctx, "the a an is are", "project-a", 10)
|
||||
require.NoError(t, err)
|
||||
assert.Nil(t, results)
|
||||
|
||||
// Search with empty query
|
||||
results, err = obsStore.SearchObservationsFTS(ctx, "", "project-a", 10)
|
||||
require.NoError(t, err)
|
||||
assert.Nil(t, results)
|
||||
}
|
||||
|
||||
func TestObservationStore_SearchObservationsFTS_DefaultLimit(t *testing.T) {
|
||||
obsStore, _, cleanup := testObservationStore(t)
|
||||
defer cleanup()
|
||||
|
||||
ctx := context.Background()
|
||||
|
||||
// Create observations
|
||||
for i := 0; i < 15; i++ {
|
||||
obs := &models.ParsedObservation{
|
||||
Type: models.ObsTypeDiscovery,
|
||||
Title: "Authentication test " + string(rune('A'+i)),
|
||||
Narrative: "Auth related content",
|
||||
}
|
||||
_, _, err := obsStore.StoreObservation(ctx, "session-1", "project-a", obs, i+1, 100)
|
||||
require.NoError(t, err)
|
||||
time.Sleep(time.Millisecond)
|
||||
}
|
||||
|
||||
// Search with limit 0 (should default to 10)
|
||||
results, err := obsStore.SearchObservationsFTS(ctx, "authentication", "project-a", 0)
|
||||
require.NoError(t, err)
|
||||
assert.LessOrEqual(t, len(results), 10)
|
||||
|
||||
// Search with negative limit (should default to 10)
|
||||
results, err = obsStore.SearchObservationsFTS(ctx, "authentication", "project-a", -5)
|
||||
require.NoError(t, err)
|
||||
assert.LessOrEqual(t, len(results), 10)
|
||||
}
|
||||
|
||||
func TestObservationStore_GetAllRecentObservations(t *testing.T) {
|
||||
obsStore, _, cleanup := testObservationStore(t)
|
||||
defer cleanup()
|
||||
|
||||
ctx := context.Background()
|
||||
|
||||
// Create observations across different projects
|
||||
projects := []string{"project-a", "project-b", "project-c"}
|
||||
for _, proj := range projects {
|
||||
for i := 0; i < 3; i++ {
|
||||
obs := &models.ParsedObservation{
|
||||
Type: models.ObsTypeDiscovery,
|
||||
Title: proj + " observation " + string(rune('A'+i)),
|
||||
Narrative: "Content for " + proj,
|
||||
}
|
||||
_, _, err := obsStore.StoreObservation(ctx, "session-1", proj, obs, i+1, 100)
|
||||
require.NoError(t, err)
|
||||
time.Sleep(time.Millisecond)
|
||||
}
|
||||
}
|
||||
|
||||
// Get all recent observations
|
||||
results, err := obsStore.GetAllRecentObservations(ctx, 100)
|
||||
require.NoError(t, err)
|
||||
assert.Len(t, results, 9) // 3 projects * 3 observations
|
||||
|
||||
// Verify they are in descending order by epoch
|
||||
for i := 1; i < len(results); i++ {
|
||||
assert.GreaterOrEqual(t, results[i-1].CreatedAtEpoch, results[i].CreatedAtEpoch)
|
||||
}
|
||||
|
||||
// Test with limit
|
||||
results, err = obsStore.GetAllRecentObservations(ctx, 5)
|
||||
require.NoError(t, err)
|
||||
assert.Len(t, results, 5)
|
||||
}
|
||||
@@ -1,507 +0,0 @@
|
||||
package sqlite
|
||||
|
||||
import (
|
||||
"context"
|
||||
"database/sql"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/lukaszraczylo/claude-mnemonic/pkg/models"
|
||||
)
|
||||
|
||||
// setupPatternTestStore creates a test store with patterns table.
|
||||
func setupPatternTestStore(t *testing.T) *Store {
|
||||
t.Helper()
|
||||
db, _, cleanup := testDB(t)
|
||||
t.Cleanup(cleanup)
|
||||
createBaseTables(t, db)
|
||||
return newStoreFromDB(db)
|
||||
}
|
||||
|
||||
func TestPatternStore_StoreAndGet(t *testing.T) {
|
||||
store := setupPatternTestStore(t)
|
||||
|
||||
patternStore := NewPatternStore(store)
|
||||
ctx := context.Background()
|
||||
|
||||
// Create test pattern
|
||||
pattern := &models.Pattern{
|
||||
Name: "Test Pattern",
|
||||
Type: models.PatternTypeBug,
|
||||
Description: sql.NullString{String: "A test pattern", Valid: true},
|
||||
Signature: []string{"nil", "error"},
|
||||
Recommendation: sql.NullString{String: "Always check for nil", Valid: true},
|
||||
Frequency: 1,
|
||||
Projects: []string{"project1"},
|
||||
ObservationIDs: []int64{1, 2},
|
||||
Status: models.PatternStatusActive,
|
||||
Confidence: 0.5,
|
||||
LastSeenAt: time.Now().Format(time.RFC3339),
|
||||
LastSeenEpoch: time.Now().UnixMilli(),
|
||||
CreatedAt: time.Now().Format(time.RFC3339),
|
||||
CreatedAtEpoch: time.Now().UnixMilli(),
|
||||
}
|
||||
|
||||
// Store pattern
|
||||
id, err := patternStore.StorePattern(ctx, pattern)
|
||||
if err != nil {
|
||||
t.Fatalf("StorePattern() error = %v", err)
|
||||
}
|
||||
if id <= 0 {
|
||||
t.Errorf("Expected positive ID, got %d", id)
|
||||
}
|
||||
|
||||
// Get pattern by ID
|
||||
retrieved, err := patternStore.GetPatternByID(ctx, id)
|
||||
if err != nil {
|
||||
t.Fatalf("GetPatternByID() error = %v", err)
|
||||
}
|
||||
|
||||
if retrieved.Name != pattern.Name {
|
||||
t.Errorf("Expected name %s, got %s", pattern.Name, retrieved.Name)
|
||||
}
|
||||
if retrieved.Type != pattern.Type {
|
||||
t.Errorf("Expected type %s, got %s", pattern.Type, retrieved.Type)
|
||||
}
|
||||
if len(retrieved.Signature) != len(pattern.Signature) {
|
||||
t.Errorf("Expected %d signature elements, got %d",
|
||||
len(pattern.Signature), len(retrieved.Signature))
|
||||
}
|
||||
}
|
||||
|
||||
func TestPatternStore_GetByName(t *testing.T) {
|
||||
store := setupPatternTestStore(t)
|
||||
|
||||
patternStore := NewPatternStore(store)
|
||||
ctx := context.Background()
|
||||
|
||||
pattern := createTestPattern("Unique Name Pattern")
|
||||
_, err := patternStore.StorePattern(ctx, pattern)
|
||||
if err != nil {
|
||||
t.Fatalf("StorePattern() error = %v", err)
|
||||
}
|
||||
|
||||
// Get by name
|
||||
retrieved, err := patternStore.GetPatternByName(ctx, "Unique Name Pattern")
|
||||
if err != nil {
|
||||
t.Fatalf("GetPatternByName() error = %v", err)
|
||||
}
|
||||
if retrieved == nil {
|
||||
t.Fatal("Expected pattern, got nil")
|
||||
}
|
||||
if retrieved.Name != "Unique Name Pattern" {
|
||||
t.Errorf("Expected name 'Unique Name Pattern', got '%s'", retrieved.Name)
|
||||
}
|
||||
|
||||
// Get non-existent pattern
|
||||
nonExistent, err := patternStore.GetPatternByName(ctx, "Non Existent")
|
||||
if err != nil {
|
||||
t.Fatalf("GetPatternByName() error = %v", err)
|
||||
}
|
||||
if nonExistent != nil {
|
||||
t.Errorf("Expected nil for non-existent pattern, got %v", nonExistent)
|
||||
}
|
||||
}
|
||||
|
||||
func TestPatternStore_GetActivePatterns(t *testing.T) {
|
||||
store := setupPatternTestStore(t)
|
||||
|
||||
patternStore := NewPatternStore(store)
|
||||
ctx := context.Background()
|
||||
|
||||
// Create multiple patterns with different statuses
|
||||
active1 := createTestPattern("Active 1")
|
||||
active1.Frequency = 5
|
||||
active2 := createTestPattern("Active 2")
|
||||
active2.Frequency = 3
|
||||
deprecated := createTestPattern("Deprecated")
|
||||
deprecated.Status = models.PatternStatusDeprecated
|
||||
|
||||
patternStore.StorePattern(ctx, active1)
|
||||
patternStore.StorePattern(ctx, active2)
|
||||
patternStore.StorePattern(ctx, deprecated)
|
||||
|
||||
// Get active patterns
|
||||
patterns, err := patternStore.GetActivePatterns(ctx, 10)
|
||||
if err != nil {
|
||||
t.Fatalf("GetActivePatterns() error = %v", err)
|
||||
}
|
||||
|
||||
if len(patterns) != 2 {
|
||||
t.Errorf("Expected 2 active patterns, got %d", len(patterns))
|
||||
}
|
||||
|
||||
// Check order (should be by frequency descending)
|
||||
if len(patterns) >= 2 {
|
||||
if patterns[0].Frequency < patterns[1].Frequency {
|
||||
t.Errorf("Patterns not ordered by frequency descending")
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestPatternStore_GetPatternsByType(t *testing.T) {
|
||||
store := setupPatternTestStore(t)
|
||||
|
||||
patternStore := NewPatternStore(store)
|
||||
ctx := context.Background()
|
||||
|
||||
// Create patterns of different types
|
||||
bugPattern := createTestPattern("Bug Pattern")
|
||||
bugPattern.Type = models.PatternTypeBug
|
||||
|
||||
refactorPattern := createTestPattern("Refactor Pattern")
|
||||
refactorPattern.Type = models.PatternTypeRefactor
|
||||
|
||||
patternStore.StorePattern(ctx, bugPattern)
|
||||
patternStore.StorePattern(ctx, refactorPattern)
|
||||
|
||||
// Get by type
|
||||
bugs, err := patternStore.GetPatternsByType(ctx, models.PatternTypeBug, 10)
|
||||
if err != nil {
|
||||
t.Fatalf("GetPatternsByType() error = %v", err)
|
||||
}
|
||||
if len(bugs) != 1 {
|
||||
t.Errorf("Expected 1 bug pattern, got %d", len(bugs))
|
||||
}
|
||||
|
||||
refactors, err := patternStore.GetPatternsByType(ctx, models.PatternTypeRefactor, 10)
|
||||
if err != nil {
|
||||
t.Fatalf("GetPatternsByType() error = %v", err)
|
||||
}
|
||||
if len(refactors) != 1 {
|
||||
t.Errorf("Expected 1 refactor pattern, got %d", len(refactors))
|
||||
}
|
||||
}
|
||||
|
||||
func TestPatternStore_GetPatternsByProject(t *testing.T) {
|
||||
store := setupPatternTestStore(t)
|
||||
|
||||
patternStore := NewPatternStore(store)
|
||||
ctx := context.Background()
|
||||
|
||||
// Create patterns with different projects
|
||||
pattern1 := createTestPattern("Pattern 1")
|
||||
pattern1.Projects = []string{"project-a", "project-b"}
|
||||
|
||||
pattern2 := createTestPattern("Pattern 2")
|
||||
pattern2.Projects = []string{"project-b", "project-c"}
|
||||
|
||||
patternStore.StorePattern(ctx, pattern1)
|
||||
patternStore.StorePattern(ctx, pattern2)
|
||||
|
||||
// Get by project
|
||||
projectA, err := patternStore.GetPatternsByProject(ctx, "project-a", 10)
|
||||
if err != nil {
|
||||
t.Fatalf("GetPatternsByProject() error = %v", err)
|
||||
}
|
||||
if len(projectA) != 1 {
|
||||
t.Errorf("Expected 1 pattern for project-a, got %d", len(projectA))
|
||||
}
|
||||
|
||||
projectB, err := patternStore.GetPatternsByProject(ctx, "project-b", 10)
|
||||
if err != nil {
|
||||
t.Fatalf("GetPatternsByProject() error = %v", err)
|
||||
}
|
||||
if len(projectB) != 2 {
|
||||
t.Errorf("Expected 2 patterns for project-b, got %d", len(projectB))
|
||||
}
|
||||
}
|
||||
|
||||
func TestPatternStore_UpdatePattern(t *testing.T) {
|
||||
store := setupPatternTestStore(t)
|
||||
|
||||
patternStore := NewPatternStore(store)
|
||||
ctx := context.Background()
|
||||
|
||||
// Create and store pattern
|
||||
pattern := createTestPattern("Original Name")
|
||||
id, _ := patternStore.StorePattern(ctx, pattern)
|
||||
|
||||
// Update pattern
|
||||
pattern.ID = id
|
||||
pattern.Name = "Updated Name"
|
||||
pattern.Frequency = 10
|
||||
pattern.Confidence = 0.9
|
||||
|
||||
err := patternStore.UpdatePattern(ctx, pattern)
|
||||
if err != nil {
|
||||
t.Fatalf("UpdatePattern() error = %v", err)
|
||||
}
|
||||
|
||||
// Verify update
|
||||
updated, _ := patternStore.GetPatternByID(ctx, id)
|
||||
if updated.Name != "Updated Name" {
|
||||
t.Errorf("Expected name 'Updated Name', got '%s'", updated.Name)
|
||||
}
|
||||
if updated.Frequency != 10 {
|
||||
t.Errorf("Expected frequency 10, got %d", updated.Frequency)
|
||||
}
|
||||
if updated.Confidence != 0.9 {
|
||||
t.Errorf("Expected confidence 0.9, got %f", updated.Confidence)
|
||||
}
|
||||
}
|
||||
|
||||
func TestPatternStore_DeletePattern(t *testing.T) {
|
||||
store := setupPatternTestStore(t)
|
||||
|
||||
patternStore := NewPatternStore(store)
|
||||
ctx := context.Background()
|
||||
|
||||
// Create and store pattern
|
||||
pattern := createTestPattern("To Delete")
|
||||
id, _ := patternStore.StorePattern(ctx, pattern)
|
||||
|
||||
// Delete pattern
|
||||
err := patternStore.DeletePattern(ctx, id)
|
||||
if err != nil {
|
||||
t.Fatalf("DeletePattern() error = %v", err)
|
||||
}
|
||||
|
||||
// Verify deletion
|
||||
deleted, err := patternStore.GetPatternByID(ctx, id)
|
||||
if err != sql.ErrNoRows {
|
||||
t.Errorf("Expected ErrNoRows, got %v", err)
|
||||
}
|
||||
if deleted != nil {
|
||||
t.Errorf("Expected nil for deleted pattern")
|
||||
}
|
||||
}
|
||||
|
||||
func TestPatternStore_MarkPatternDeprecated(t *testing.T) {
|
||||
store := setupPatternTestStore(t)
|
||||
|
||||
patternStore := NewPatternStore(store)
|
||||
ctx := context.Background()
|
||||
|
||||
// Create and store pattern
|
||||
pattern := createTestPattern("To Deprecate")
|
||||
id, _ := patternStore.StorePattern(ctx, pattern)
|
||||
|
||||
// Mark as deprecated
|
||||
err := patternStore.MarkPatternDeprecated(ctx, id)
|
||||
if err != nil {
|
||||
t.Fatalf("MarkPatternDeprecated() error = %v", err)
|
||||
}
|
||||
|
||||
// Verify status
|
||||
deprecated, _ := patternStore.GetPatternByID(ctx, id)
|
||||
if deprecated.Status != models.PatternStatusDeprecated {
|
||||
t.Errorf("Expected status 'deprecated', got '%s'", deprecated.Status)
|
||||
}
|
||||
}
|
||||
|
||||
func TestPatternStore_MergePatterns(t *testing.T) {
|
||||
store := setupPatternTestStore(t)
|
||||
|
||||
patternStore := NewPatternStore(store)
|
||||
ctx := context.Background()
|
||||
|
||||
// Create source and target patterns
|
||||
source := createTestPattern("Source Pattern")
|
||||
source.Frequency = 3
|
||||
source.Projects = []string{"proj1", "proj2"}
|
||||
source.ObservationIDs = []int64{1, 2, 3}
|
||||
|
||||
target := createTestPattern("Target Pattern")
|
||||
target.Frequency = 2
|
||||
target.Projects = []string{"proj2", "proj3"}
|
||||
target.ObservationIDs = []int64{4, 5}
|
||||
|
||||
sourceID, _ := patternStore.StorePattern(ctx, source)
|
||||
targetID, _ := patternStore.StorePattern(ctx, target)
|
||||
|
||||
// Merge
|
||||
err := patternStore.MergePatterns(ctx, sourceID, targetID)
|
||||
if err != nil {
|
||||
t.Fatalf("MergePatterns() error = %v", err)
|
||||
}
|
||||
|
||||
// Verify source is marked as merged
|
||||
mergedSource, _ := patternStore.GetPatternByID(ctx, sourceID)
|
||||
if mergedSource.Status != models.PatternStatusMerged {
|
||||
t.Errorf("Expected source status 'merged', got '%s'", mergedSource.Status)
|
||||
}
|
||||
if !mergedSource.MergedIntoID.Valid || mergedSource.MergedIntoID.Int64 != targetID {
|
||||
t.Errorf("Expected source merged_into_id to be %d", targetID)
|
||||
}
|
||||
|
||||
// Verify target has combined data
|
||||
mergedTarget, _ := patternStore.GetPatternByID(ctx, targetID)
|
||||
expectedFrequency := 5 // 3 + 2
|
||||
if mergedTarget.Frequency != expectedFrequency {
|
||||
t.Errorf("Expected merged frequency %d, got %d", expectedFrequency, mergedTarget.Frequency)
|
||||
}
|
||||
// Should have 3 unique projects: proj1, proj2, proj3
|
||||
if len(mergedTarget.Projects) != 3 {
|
||||
t.Errorf("Expected 3 projects after merge, got %d", len(mergedTarget.Projects))
|
||||
}
|
||||
// Should have 5 observation IDs
|
||||
if len(mergedTarget.ObservationIDs) != 5 {
|
||||
t.Errorf("Expected 5 observation IDs after merge, got %d", len(mergedTarget.ObservationIDs))
|
||||
}
|
||||
}
|
||||
|
||||
func TestPatternStore_FindMatchingPatterns(t *testing.T) {
|
||||
store := setupPatternTestStore(t)
|
||||
|
||||
patternStore := NewPatternStore(store)
|
||||
ctx := context.Background()
|
||||
|
||||
// Create patterns with known signatures
|
||||
pattern1 := createTestPattern("Pattern 1")
|
||||
pattern1.Signature = []string{"nil", "error", "handling"}
|
||||
|
||||
pattern2 := createTestPattern("Pattern 2")
|
||||
pattern2.Signature = []string{"nil", "pointer", "check"}
|
||||
|
||||
pattern3 := createTestPattern("Pattern 3")
|
||||
pattern3.Signature = []string{"refactor", "extract", "method"}
|
||||
|
||||
patternStore.StorePattern(ctx, pattern1)
|
||||
patternStore.StorePattern(ctx, pattern2)
|
||||
patternStore.StorePattern(ctx, pattern3)
|
||||
|
||||
// Find patterns matching "nil" related signature
|
||||
matches, err := patternStore.FindMatchingPatterns(ctx, []string{"nil", "error"}, 0.3)
|
||||
if err != nil {
|
||||
t.Fatalf("FindMatchingPatterns() error = %v", err)
|
||||
}
|
||||
|
||||
if len(matches) < 1 {
|
||||
t.Errorf("Expected at least 1 match, got %d", len(matches))
|
||||
}
|
||||
|
||||
// Verify no match for unrelated signature
|
||||
noMatches, err := patternStore.FindMatchingPatterns(ctx, []string{"completely", "different"}, 0.5)
|
||||
if err != nil {
|
||||
t.Fatalf("FindMatchingPatterns() error = %v", err)
|
||||
}
|
||||
if len(noMatches) != 0 {
|
||||
t.Errorf("Expected 0 matches for unrelated signature, got %d", len(noMatches))
|
||||
}
|
||||
}
|
||||
|
||||
func TestPatternStore_IncrementPatternFrequency(t *testing.T) {
|
||||
store := setupPatternTestStore(t)
|
||||
|
||||
patternStore := NewPatternStore(store)
|
||||
ctx := context.Background()
|
||||
|
||||
// Create pattern
|
||||
pattern := createTestPattern("Frequency Test")
|
||||
pattern.Frequency = 1
|
||||
pattern.Projects = []string{"proj1"}
|
||||
pattern.ObservationIDs = []int64{1}
|
||||
|
||||
id, _ := patternStore.StorePattern(ctx, pattern)
|
||||
|
||||
// Increment frequency
|
||||
err := patternStore.IncrementPatternFrequency(ctx, id, "proj2", 2)
|
||||
if err != nil {
|
||||
t.Fatalf("IncrementPatternFrequency() error = %v", err)
|
||||
}
|
||||
|
||||
// Verify
|
||||
updated, _ := patternStore.GetPatternByID(ctx, id)
|
||||
if updated.Frequency != 2 {
|
||||
t.Errorf("Expected frequency 2, got %d", updated.Frequency)
|
||||
}
|
||||
if len(updated.Projects) != 2 {
|
||||
t.Errorf("Expected 2 projects, got %d", len(updated.Projects))
|
||||
}
|
||||
if len(updated.ObservationIDs) != 2 {
|
||||
t.Errorf("Expected 2 observation IDs, got %d", len(updated.ObservationIDs))
|
||||
}
|
||||
}
|
||||
|
||||
func TestPatternStore_GetPatternStats(t *testing.T) {
|
||||
store := setupPatternTestStore(t)
|
||||
|
||||
patternStore := NewPatternStore(store)
|
||||
ctx := context.Background()
|
||||
|
||||
// Create patterns with different types and statuses
|
||||
bug := createTestPattern("Bug")
|
||||
bug.Type = models.PatternTypeBug
|
||||
bug.Frequency = 5
|
||||
|
||||
refactor := createTestPattern("Refactor")
|
||||
refactor.Type = models.PatternTypeRefactor
|
||||
refactor.Frequency = 3
|
||||
|
||||
deprecated := createTestPattern("Deprecated")
|
||||
deprecated.Type = models.PatternTypeArchitecture
|
||||
deprecated.Status = models.PatternStatusDeprecated
|
||||
|
||||
patternStore.StorePattern(ctx, bug)
|
||||
patternStore.StorePattern(ctx, refactor)
|
||||
patternStore.StorePattern(ctx, deprecated)
|
||||
|
||||
// Get stats
|
||||
stats, err := patternStore.GetPatternStats(ctx)
|
||||
if err != nil {
|
||||
t.Fatalf("GetPatternStats() error = %v", err)
|
||||
}
|
||||
|
||||
if stats.Total != 3 {
|
||||
t.Errorf("Expected total 3, got %d", stats.Total)
|
||||
}
|
||||
if stats.Active != 2 {
|
||||
t.Errorf("Expected 2 active, got %d", stats.Active)
|
||||
}
|
||||
if stats.Deprecated != 1 {
|
||||
t.Errorf("Expected 1 deprecated, got %d", stats.Deprecated)
|
||||
}
|
||||
if stats.Bugs != 1 {
|
||||
t.Errorf("Expected 1 bug, got %d", stats.Bugs)
|
||||
}
|
||||
if stats.Refactors != 1 {
|
||||
t.Errorf("Expected 1 refactor, got %d", stats.Refactors)
|
||||
}
|
||||
if stats.TotalOccurrences != 9 { // 5 + 3 + 1
|
||||
t.Errorf("Expected 9 total occurrences, got %d", stats.TotalOccurrences)
|
||||
}
|
||||
}
|
||||
|
||||
func TestPatternStore_CleanupCallback(t *testing.T) {
|
||||
store := setupPatternTestStore(t)
|
||||
|
||||
patternStore := NewPatternStore(store)
|
||||
ctx := context.Background()
|
||||
|
||||
var deletedIDs []int64
|
||||
patternStore.SetCleanupFunc(func(ctx context.Context, ids []int64) {
|
||||
deletedIDs = ids
|
||||
})
|
||||
|
||||
// Create and delete pattern
|
||||
pattern := createTestPattern("Cleanup Test")
|
||||
id, _ := patternStore.StorePattern(ctx, pattern)
|
||||
|
||||
patternStore.DeletePattern(ctx, id)
|
||||
|
||||
if len(deletedIDs) != 1 || deletedIDs[0] != id {
|
||||
t.Errorf("Expected cleanup callback with ID %d, got %v", id, deletedIDs)
|
||||
}
|
||||
}
|
||||
|
||||
// Helper function to create a test pattern
|
||||
func createTestPattern(name string) *models.Pattern {
|
||||
now := time.Now()
|
||||
return &models.Pattern{
|
||||
Name: name,
|
||||
Type: models.PatternTypeBug,
|
||||
Description: sql.NullString{String: "Test description", Valid: true},
|
||||
Signature: []string{"test", "pattern"},
|
||||
Recommendation: sql.NullString{String: "Test recommendation", Valid: true},
|
||||
Frequency: 1,
|
||||
Projects: []string{"test-project"},
|
||||
ObservationIDs: []int64{1},
|
||||
Status: models.PatternStatusActive,
|
||||
Confidence: 0.5,
|
||||
LastSeenAt: now.Format(time.RFC3339),
|
||||
LastSeenEpoch: now.UnixMilli(),
|
||||
CreatedAt: now.Format(time.RFC3339),
|
||||
CreatedAtEpoch: now.UnixMilli(),
|
||||
}
|
||||
}
|
||||
@@ -1,271 +0,0 @@
|
||||
// Package sqlite provides SQLite database operations for claude-mnemonic.
|
||||
package sqlite
|
||||
|
||||
import (
|
||||
"context"
|
||||
"time"
|
||||
|
||||
"github.com/lukaszraczylo/claude-mnemonic/pkg/models"
|
||||
)
|
||||
|
||||
// PromptCleanupFunc is a callback for when prompts are cleaned up.
|
||||
// Receives the IDs of deleted prompts for downstream cleanup (e.g., vector DB).
|
||||
type PromptCleanupFunc func(ctx context.Context, deletedIDs []int64)
|
||||
|
||||
// MaxPromptsGlobal is the hard limit of prompts across all projects.
|
||||
const MaxPromptsGlobal = 500
|
||||
|
||||
// PromptStore provides user prompt-related database operations.
|
||||
type PromptStore struct {
|
||||
store *Store
|
||||
cleanupFunc PromptCleanupFunc
|
||||
}
|
||||
|
||||
// NewPromptStore creates a new prompt store.
|
||||
func NewPromptStore(store *Store) *PromptStore {
|
||||
return &PromptStore{store: store}
|
||||
}
|
||||
|
||||
// SetCleanupFunc sets the callback for when prompts are deleted during cleanup.
|
||||
func (s *PromptStore) SetCleanupFunc(fn PromptCleanupFunc) {
|
||||
s.cleanupFunc = fn
|
||||
}
|
||||
|
||||
// SaveUserPromptWithMatches saves a user prompt with matched observation count.
|
||||
// Uses INSERT OR IGNORE to be idempotent - duplicate (session, prompt_number) pairs are silently ignored.
|
||||
// This prevents duplicate prompts when the user-prompt hook fires multiple times.
|
||||
func (s *PromptStore) SaveUserPromptWithMatches(ctx context.Context, claudeSessionID string, promptNumber int, promptText string, matchedObservations int) (int64, error) {
|
||||
now := time.Now()
|
||||
|
||||
// Use INSERT OR IGNORE for idempotency - if (claude_session_id, prompt_number) already exists,
|
||||
// the insert is silently ignored. This handles concurrent/duplicate hook invocations.
|
||||
const query = `
|
||||
INSERT OR IGNORE INTO user_prompts
|
||||
(claude_session_id, prompt_number, prompt_text, matched_observations, created_at, created_at_epoch)
|
||||
VALUES (?, ?, ?, ?, ?, ?)
|
||||
`
|
||||
|
||||
result, err := s.store.ExecContext(ctx, query,
|
||||
claudeSessionID, promptNumber, promptText, matchedObservations,
|
||||
now.Format(time.RFC3339), now.UnixMilli(),
|
||||
)
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
|
||||
id, _ := result.LastInsertId()
|
||||
|
||||
// If id is 0, the insert was ignored (duplicate) - fetch the existing ID
|
||||
if id == 0 {
|
||||
const selectQuery = `SELECT id FROM user_prompts WHERE claude_session_id = ? AND prompt_number = ?`
|
||||
row := s.store.QueryRowContext(ctx, selectQuery, claudeSessionID, promptNumber)
|
||||
if err := row.Scan(&id); err != nil {
|
||||
return 0, err
|
||||
}
|
||||
// Return existing ID without triggering cleanup (already handled when first inserted)
|
||||
return id, nil
|
||||
}
|
||||
|
||||
// Cleanup old prompts beyond the global limit (async to not block handler)
|
||||
go func() {
|
||||
cleanupCtx, cancel := context.WithTimeout(context.Background(), 10*time.Second)
|
||||
defer cancel()
|
||||
deletedIDs, _ := s.CleanupOldPrompts(cleanupCtx)
|
||||
if len(deletedIDs) > 0 && s.cleanupFunc != nil {
|
||||
s.cleanupFunc(cleanupCtx, deletedIDs)
|
||||
}
|
||||
}()
|
||||
|
||||
return id, nil
|
||||
}
|
||||
|
||||
// CleanupOldPrompts deletes prompts beyond the global limit.
|
||||
// Keeps the most recent MaxPromptsGlobal prompts.
|
||||
// Returns the IDs of deleted prompts for downstream cleanup (e.g., vector DB).
|
||||
func (s *PromptStore) CleanupOldPrompts(ctx context.Context) ([]int64, error) {
|
||||
// First, find IDs that will be deleted
|
||||
const selectQuery = `
|
||||
SELECT id FROM user_prompts
|
||||
WHERE id NOT IN (
|
||||
SELECT id FROM user_prompts
|
||||
ORDER BY created_at_epoch DESC
|
||||
LIMIT ?
|
||||
)
|
||||
`
|
||||
|
||||
rows, err := s.store.QueryContext(ctx, selectQuery, MaxPromptsGlobal)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer rows.Close()
|
||||
|
||||
var toDelete []int64
|
||||
for rows.Next() {
|
||||
var id int64
|
||||
if err := rows.Scan(&id); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
toDelete = append(toDelete, id)
|
||||
}
|
||||
if err := rows.Err(); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if len(toDelete) == 0 {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
// Delete the prompts
|
||||
const deleteQuery = `
|
||||
DELETE FROM user_prompts
|
||||
WHERE id NOT IN (
|
||||
SELECT id FROM user_prompts
|
||||
ORDER BY created_at_epoch DESC
|
||||
LIMIT ?
|
||||
)
|
||||
`
|
||||
|
||||
_, err = s.store.ExecContext(ctx, deleteQuery, MaxPromptsGlobal)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return toDelete, nil
|
||||
}
|
||||
|
||||
// GetPromptsByIDs retrieves user prompts by a list of IDs.
|
||||
func (s *PromptStore) GetPromptsByIDs(ctx context.Context, ids []int64, orderBy string, limit int) ([]*models.UserPromptWithSession, error) {
|
||||
if len(ids) == 0 {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
// Build query with placeholders
|
||||
// #nosec G202 -- query uses parameterized placeholders, not user input
|
||||
query := `
|
||||
SELECT up.id, up.claude_session_id, up.prompt_number, up.prompt_text,
|
||||
COALESCE(up.matched_observations, 0) as matched_observations,
|
||||
up.created_at, up.created_at_epoch,
|
||||
COALESCE(s.project, '') as project,
|
||||
COALESCE(s.sdk_session_id, '') as sdk_session_id
|
||||
FROM user_prompts up
|
||||
LEFT JOIN sdk_sessions s ON up.claude_session_id = s.claude_session_id
|
||||
WHERE up.id IN (?` + repeatPlaceholders(len(ids)-1) + `)
|
||||
ORDER BY up.created_at_epoch `
|
||||
|
||||
if orderBy == "date_asc" {
|
||||
query += "ASC"
|
||||
} else {
|
||||
query += "DESC"
|
||||
}
|
||||
|
||||
if limit > 0 {
|
||||
query += " LIMIT ?"
|
||||
}
|
||||
|
||||
args := int64SliceToInterface(ids)
|
||||
if limit > 0 {
|
||||
args = append(args, limit)
|
||||
}
|
||||
|
||||
rows, err := s.store.db.QueryContext(ctx, query, args...)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer rows.Close()
|
||||
|
||||
return scanPromptWithSessionRows(rows)
|
||||
}
|
||||
|
||||
// GetAllRecentUserPrompts retrieves recent user prompts across all sessions.
|
||||
func (s *PromptStore) GetAllRecentUserPrompts(ctx context.Context, limit int) ([]*models.UserPromptWithSession, error) {
|
||||
const query = `
|
||||
SELECT up.id, up.claude_session_id, up.prompt_number, up.prompt_text,
|
||||
COALESCE(up.matched_observations, 0) as matched_observations,
|
||||
up.created_at, up.created_at_epoch,
|
||||
COALESCE(s.project, '') as project,
|
||||
COALESCE(s.sdk_session_id, '') as sdk_session_id
|
||||
FROM user_prompts up
|
||||
LEFT JOIN sdk_sessions s ON up.claude_session_id = s.claude_session_id
|
||||
ORDER BY up.created_at_epoch DESC
|
||||
LIMIT ?
|
||||
`
|
||||
|
||||
rows, err := s.store.QueryContext(ctx, query, limit)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer rows.Close()
|
||||
|
||||
return scanPromptWithSessionRows(rows)
|
||||
}
|
||||
|
||||
// GetAllPrompts retrieves all user prompts (for vector rebuild).
|
||||
func (s *PromptStore) GetAllPrompts(ctx context.Context) ([]*models.UserPromptWithSession, error) {
|
||||
const query = `
|
||||
SELECT up.id, up.claude_session_id, up.prompt_number, up.prompt_text,
|
||||
COALESCE(up.matched_observations, 0) as matched_observations,
|
||||
up.created_at, up.created_at_epoch,
|
||||
COALESCE(s.project, '') as project,
|
||||
COALESCE(s.sdk_session_id, '') as sdk_session_id
|
||||
FROM user_prompts up
|
||||
LEFT JOIN sdk_sessions s ON up.claude_session_id = s.claude_session_id
|
||||
ORDER BY up.id
|
||||
`
|
||||
|
||||
rows, err := s.store.QueryContext(ctx, query)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer rows.Close()
|
||||
|
||||
return scanPromptWithSessionRows(rows)
|
||||
}
|
||||
|
||||
// FindRecentPromptByText finds a prompt with the same text for a session within the last few seconds.
|
||||
// This is used to detect duplicate hook invocations.
|
||||
// Returns (promptID, promptNumber, found).
|
||||
func (s *PromptStore) FindRecentPromptByText(ctx context.Context, claudeSessionID, promptText string, withinSeconds int) (int64, int, bool) {
|
||||
// Look for an existing prompt with the same text within the time window
|
||||
// This catches duplicate hook invocations that happen in quick succession
|
||||
const query = `
|
||||
SELECT id, prompt_number FROM user_prompts
|
||||
WHERE claude_session_id = ? AND prompt_text = ?
|
||||
AND created_at_epoch > ?
|
||||
ORDER BY created_at_epoch DESC
|
||||
LIMIT 1
|
||||
`
|
||||
|
||||
cutoff := time.Now().Add(-time.Duration(withinSeconds) * time.Second).UnixMilli()
|
||||
|
||||
var id int64
|
||||
var promptNumber int
|
||||
err := s.store.QueryRowContext(ctx, query, claudeSessionID, promptText, cutoff).Scan(&id, &promptNumber)
|
||||
if err != nil {
|
||||
return 0, 0, false
|
||||
}
|
||||
return id, promptNumber, true
|
||||
}
|
||||
|
||||
// GetRecentUserPromptsByProject retrieves recent user prompts for a specific project.
|
||||
func (s *PromptStore) GetRecentUserPromptsByProject(ctx context.Context, project string, limit int) ([]*models.UserPromptWithSession, error) {
|
||||
const query = `
|
||||
SELECT up.id, up.claude_session_id, up.prompt_number, up.prompt_text,
|
||||
COALESCE(up.matched_observations, 0) as matched_observations,
|
||||
up.created_at, up.created_at_epoch,
|
||||
COALESCE(s.project, '') as project,
|
||||
COALESCE(s.sdk_session_id, '') as sdk_session_id
|
||||
FROM user_prompts up
|
||||
LEFT JOIN sdk_sessions s ON up.claude_session_id = s.claude_session_id
|
||||
WHERE s.project = ?
|
||||
ORDER BY up.created_at_epoch DESC
|
||||
LIMIT ?
|
||||
`
|
||||
|
||||
rows, err := s.store.QueryContext(ctx, query, project, limit)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer rows.Close()
|
||||
|
||||
return scanPromptWithSessionRows(rows)
|
||||
}
|
||||
@@ -1,289 +0,0 @@
|
||||
package sqlite
|
||||
|
||||
import (
|
||||
"context"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func testPromptStore(t *testing.T) (*PromptStore, *Store, func()) {
|
||||
t.Helper()
|
||||
|
||||
db, _, cleanup := testDB(t)
|
||||
createAllTables(t, db)
|
||||
|
||||
store := newStoreFromDB(db)
|
||||
promptStore := NewPromptStore(store)
|
||||
|
||||
return promptStore, store, cleanup
|
||||
}
|
||||
|
||||
func TestPromptStore_SaveUserPromptWithMatches(t *testing.T) {
|
||||
promptStore, store, cleanup := testPromptStore(t)
|
||||
defer cleanup()
|
||||
|
||||
ctx := context.Background()
|
||||
|
||||
// Create a session first
|
||||
seedSession(t, storeDB(store), "claude-1", "sdk-1", "test-project")
|
||||
|
||||
// Save a prompt
|
||||
id, err := promptStore.SaveUserPromptWithMatches(ctx, "claude-1", 1, "Help me fix this bug", 5)
|
||||
require.NoError(t, err)
|
||||
assert.Greater(t, id, int64(0))
|
||||
|
||||
// Verify it was saved
|
||||
var count int
|
||||
err = storeDB(store).QueryRow("SELECT COUNT(*) FROM user_prompts WHERE id = ?", id).Scan(&count)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, 1, count)
|
||||
}
|
||||
|
||||
func TestPromptStore_GetAllRecentUserPrompts(t *testing.T) {
|
||||
promptStore, store, cleanup := testPromptStore(t)
|
||||
defer cleanup()
|
||||
|
||||
ctx := context.Background()
|
||||
|
||||
// Create a session
|
||||
seedSession(t, storeDB(store), "claude-1", "sdk-1", "test-project")
|
||||
|
||||
// Save multiple prompts
|
||||
for i := 1; i <= 5; i++ {
|
||||
_, err := promptStore.SaveUserPromptWithMatches(ctx, "claude-1", i, "Prompt "+string(rune('A'+i-1)), i)
|
||||
require.NoError(t, err)
|
||||
time.Sleep(time.Millisecond) // Ensure different timestamps
|
||||
}
|
||||
|
||||
// Get recent prompts
|
||||
prompts, err := promptStore.GetAllRecentUserPrompts(ctx, 3)
|
||||
require.NoError(t, err)
|
||||
assert.Len(t, prompts, 3)
|
||||
|
||||
// Should be in descending order (most recent first)
|
||||
assert.Equal(t, 5, prompts[0].PromptNumber)
|
||||
}
|
||||
|
||||
func TestPromptStore_GetRecentUserPromptsByProject(t *testing.T) {
|
||||
promptStore, store, cleanup := testPromptStore(t)
|
||||
defer cleanup()
|
||||
|
||||
ctx := context.Background()
|
||||
|
||||
// Create sessions for different projects
|
||||
seedSession(t, storeDB(store), "claude-1", "sdk-1", "project-a")
|
||||
seedSession(t, storeDB(store), "claude-2", "sdk-2", "project-b")
|
||||
|
||||
// Save prompts for both projects
|
||||
for i := 1; i <= 3; i++ {
|
||||
_, err := promptStore.SaveUserPromptWithMatches(ctx, "claude-1", i, "Project A prompt", 0)
|
||||
require.NoError(t, err)
|
||||
}
|
||||
for i := 1; i <= 2; i++ {
|
||||
_, err := promptStore.SaveUserPromptWithMatches(ctx, "claude-2", i, "Project B prompt", 0)
|
||||
require.NoError(t, err)
|
||||
}
|
||||
|
||||
// Get prompts for project-a
|
||||
prompts, err := promptStore.GetRecentUserPromptsByProject(ctx, "project-a", 10)
|
||||
require.NoError(t, err)
|
||||
assert.Len(t, prompts, 3)
|
||||
|
||||
// Get prompts for project-b
|
||||
prompts, err = promptStore.GetRecentUserPromptsByProject(ctx, "project-b", 10)
|
||||
require.NoError(t, err)
|
||||
assert.Len(t, prompts, 2)
|
||||
}
|
||||
|
||||
func TestPromptStore_CleanupOldPrompts(t *testing.T) {
|
||||
promptStore, store, cleanup := testPromptStore(t)
|
||||
defer cleanup()
|
||||
|
||||
ctx := context.Background()
|
||||
|
||||
// Create a session
|
||||
seedSession(t, storeDB(store), "claude-1", "sdk-1", "test-project")
|
||||
|
||||
// Save more prompts than the limit
|
||||
// Note: MaxPromptsGlobal is 500, but we'll test with a smaller number
|
||||
// by directly calling CleanupOldPrompts
|
||||
for i := 1; i <= 10; i++ {
|
||||
_, err := storeDB(store).Exec(`
|
||||
INSERT INTO user_prompts (claude_session_id, prompt_number, prompt_text, created_at, created_at_epoch)
|
||||
VALUES (?, ?, ?, datetime('now'), ?)
|
||||
`, "claude-1", i, "Prompt "+string(rune('A'+i-1)), time.Now().UnixMilli()+int64(i))
|
||||
require.NoError(t, err)
|
||||
}
|
||||
|
||||
// Verify we have 10 prompts
|
||||
var count int
|
||||
err := storeDB(store).QueryRow("SELECT COUNT(*) FROM user_prompts").Scan(&count)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, 10, count)
|
||||
|
||||
// Cleanup should return empty since we're under the limit
|
||||
deletedIDs, err := promptStore.CleanupOldPrompts(ctx)
|
||||
require.NoError(t, err)
|
||||
assert.Empty(t, deletedIDs)
|
||||
}
|
||||
|
||||
func TestPromptStore_SetCleanupFunc(t *testing.T) {
|
||||
promptStore, store, cleanup := testPromptStore(t)
|
||||
defer cleanup()
|
||||
|
||||
ctx := context.Background()
|
||||
|
||||
// Track cleanup calls
|
||||
var cleanupCalledWith []int64
|
||||
promptStore.SetCleanupFunc(func(ctx context.Context, deletedIDs []int64) {
|
||||
cleanupCalledWith = deletedIDs
|
||||
})
|
||||
|
||||
// Create a session
|
||||
seedSession(t, storeDB(store), "claude-1", "sdk-1", "test-project")
|
||||
|
||||
// Save a prompt (should trigger cleanup, but won't delete anything under limit)
|
||||
_, err := promptStore.SaveUserPromptWithMatches(ctx, "claude-1", 1, "Test prompt", 0)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Cleanup func should not have been called since nothing was deleted
|
||||
assert.Empty(t, cleanupCalledWith)
|
||||
}
|
||||
|
||||
func TestPromptStore_GetPromptsByIDs(t *testing.T) {
|
||||
promptStore, store, cleanup := testPromptStore(t)
|
||||
defer cleanup()
|
||||
|
||||
ctx := context.Background()
|
||||
|
||||
// Create a session
|
||||
seedSession(t, storeDB(store), "claude-1", "sdk-1", "test-project")
|
||||
|
||||
// Save some prompts and collect their IDs
|
||||
var ids []int64
|
||||
for i := 1; i <= 5; i++ {
|
||||
id, err := promptStore.SaveUserPromptWithMatches(ctx, "claude-1", i, "Prompt "+string(rune('A'+i-1)), 0)
|
||||
require.NoError(t, err)
|
||||
ids = append(ids, id)
|
||||
time.Sleep(time.Millisecond)
|
||||
}
|
||||
|
||||
// Get specific prompts by ID
|
||||
prompts, err := promptStore.GetPromptsByIDs(ctx, ids[:3], "date_desc", 10)
|
||||
require.NoError(t, err)
|
||||
assert.Len(t, prompts, 3)
|
||||
|
||||
// Test with ascending order
|
||||
prompts, err = promptStore.GetPromptsByIDs(ctx, ids, "date_asc", 2)
|
||||
require.NoError(t, err)
|
||||
assert.Len(t, prompts, 2)
|
||||
assert.Equal(t, 1, prompts[0].PromptNumber)
|
||||
}
|
||||
|
||||
func TestPromptStore_GetPromptsByIDs_EmptyInput(t *testing.T) {
|
||||
promptStore, _, cleanup := testPromptStore(t)
|
||||
defer cleanup()
|
||||
|
||||
ctx := context.Background()
|
||||
|
||||
// Empty IDs should return nil
|
||||
prompts, err := promptStore.GetPromptsByIDs(ctx, []int64{}, "date_desc", 10)
|
||||
require.NoError(t, err)
|
||||
assert.Nil(t, prompts)
|
||||
}
|
||||
|
||||
func TestPromptStore_FindRecentPromptByText(t *testing.T) {
|
||||
promptStore, store, cleanup := testPromptStore(t)
|
||||
defer cleanup()
|
||||
|
||||
ctx := context.Background()
|
||||
|
||||
// Create a session
|
||||
seedSession(t, storeDB(store), "claude-1", "sdk-1", "test-project")
|
||||
|
||||
// Save a prompt
|
||||
_, err := promptStore.SaveUserPromptWithMatches(ctx, "claude-1", 1, "Help me fix this bug in the code", 0)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Find the prompt by text - returns (id, promptNumber, found)
|
||||
id, promptNum, found := promptStore.FindRecentPromptByText(ctx, "claude-1", "Help me fix this bug in the code", 60)
|
||||
assert.True(t, found, "should find the exact prompt text")
|
||||
assert.Greater(t, id, int64(0))
|
||||
assert.Equal(t, 1, promptNum)
|
||||
|
||||
// Try to find non-existent prompt
|
||||
_, _, found = promptStore.FindRecentPromptByText(ctx, "claude-1", "This prompt does not exist", 60)
|
||||
assert.False(t, found, "should not find non-existent prompt")
|
||||
|
||||
// Try with different session
|
||||
_, _, found = promptStore.FindRecentPromptByText(ctx, "claude-2", "Help me fix this bug in the code", 60)
|
||||
assert.False(t, found, "should not find prompt for different session")
|
||||
}
|
||||
|
||||
func TestPromptStore_FindRecentPromptByText_WindowSeconds(t *testing.T) {
|
||||
promptStore, store, cleanup := testPromptStore(t)
|
||||
defer cleanup()
|
||||
|
||||
ctx := context.Background()
|
||||
|
||||
// Create a session
|
||||
seedSession(t, storeDB(store), "claude-1", "sdk-1", "test-project")
|
||||
|
||||
// Save a prompt with an old timestamp
|
||||
oldEpoch := time.Now().Add(-2 * time.Hour).UnixMilli()
|
||||
_, err := storeDB(store).Exec(`
|
||||
INSERT INTO user_prompts (claude_session_id, prompt_number, prompt_text, created_at, created_at_epoch)
|
||||
VALUES (?, ?, ?, datetime('now'), ?)
|
||||
`, "claude-1", 1, "Old prompt text", oldEpoch)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Search within last hour - should not find old prompt
|
||||
_, _, found := promptStore.FindRecentPromptByText(ctx, "claude-1", "Old prompt text", 3600)
|
||||
assert.False(t, found, "should not find prompt outside window")
|
||||
|
||||
// Search within last 3 hours - should find old prompt
|
||||
_, _, found = promptStore.FindRecentPromptByText(ctx, "claude-1", "Old prompt text", 3*3600)
|
||||
assert.True(t, found, "should find prompt within extended window")
|
||||
}
|
||||
|
||||
func TestPromptStore_SaveMultiplePrompts(t *testing.T) {
|
||||
promptStore, store, cleanup := testPromptStore(t)
|
||||
defer cleanup()
|
||||
|
||||
ctx := context.Background()
|
||||
|
||||
// Create sessions
|
||||
seedSession(t, storeDB(store), "claude-1", "sdk-1", "project-x")
|
||||
seedSession(t, storeDB(store), "claude-2", "sdk-2", "project-y")
|
||||
|
||||
tests := []struct {
|
||||
claudeSessionID string
|
||||
promptNum int
|
||||
text string
|
||||
matches int
|
||||
}{
|
||||
{"claude-1", 1, "First prompt", 5},
|
||||
{"claude-1", 2, "Second prompt", 3},
|
||||
{"claude-2", 1, "Third prompt", 0},
|
||||
{"claude-1", 3, "Fourth prompt", 10},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
id, err := promptStore.SaveUserPromptWithMatches(ctx, tt.claudeSessionID, tt.promptNum, tt.text, tt.matches)
|
||||
require.NoError(t, err)
|
||||
assert.Greater(t, id, int64(0))
|
||||
}
|
||||
|
||||
// Verify counts
|
||||
var count int
|
||||
err := storeDB(store).QueryRow("SELECT COUNT(*) FROM user_prompts WHERE claude_session_id = 'claude-1'").Scan(&count)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, 3, count)
|
||||
|
||||
err = storeDB(store).QueryRow("SELECT COUNT(*) FROM user_prompts WHERE claude_session_id = 'claude-2'").Scan(&count)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, 1, count)
|
||||
}
|
||||
@@ -1,377 +0,0 @@
|
||||
// Package sqlite provides SQLite database operations for claude-mnemonic.
|
||||
package sqlite
|
||||
|
||||
import (
|
||||
"context"
|
||||
"database/sql"
|
||||
|
||||
"github.com/lukaszraczylo/claude-mnemonic/pkg/models"
|
||||
)
|
||||
|
||||
// RelationStore provides relation-related database operations.
|
||||
type RelationStore struct {
|
||||
store *Store
|
||||
}
|
||||
|
||||
// NewRelationStore creates a new relation store.
|
||||
func NewRelationStore(store *Store) *RelationStore {
|
||||
return &RelationStore{store: store}
|
||||
}
|
||||
|
||||
// StoreRelation stores a new observation relation.
|
||||
// Uses INSERT OR IGNORE to handle duplicate (source_id, target_id, relation_type) combinations.
|
||||
func (s *RelationStore) StoreRelation(ctx context.Context, relation *models.ObservationRelation) (int64, error) {
|
||||
const query = `
|
||||
INSERT OR IGNORE INTO observation_relations
|
||||
(source_id, target_id, relation_type, confidence, detection_source, reason, created_at, created_at_epoch)
|
||||
VALUES (?, ?, ?, ?, ?, ?, ?, ?)
|
||||
`
|
||||
|
||||
result, err := s.store.ExecContext(ctx, query,
|
||||
relation.SourceID, relation.TargetID,
|
||||
string(relation.RelationType), relation.Confidence,
|
||||
string(relation.DetectionSource), relation.Reason,
|
||||
relation.CreatedAt, relation.CreatedAtEpoch,
|
||||
)
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
|
||||
return result.LastInsertId()
|
||||
}
|
||||
|
||||
// StoreRelations stores multiple relations in a single transaction.
|
||||
func (s *RelationStore) StoreRelations(ctx context.Context, relations []*models.ObservationRelation) error {
|
||||
if len(relations) == 0 {
|
||||
return nil
|
||||
}
|
||||
|
||||
tx, err := s.store.db.BeginTx(ctx, nil)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
defer func() {
|
||||
if err != nil {
|
||||
_ = tx.Rollback()
|
||||
}
|
||||
}()
|
||||
|
||||
const query = `
|
||||
INSERT OR IGNORE INTO observation_relations
|
||||
(source_id, target_id, relation_type, confidence, detection_source, reason, created_at, created_at_epoch)
|
||||
VALUES (?, ?, ?, ?, ?, ?, ?, ?)
|
||||
`
|
||||
|
||||
stmt, err := tx.PrepareContext(ctx, query)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
defer stmt.Close()
|
||||
|
||||
for _, rel := range relations {
|
||||
_, err = stmt.ExecContext(ctx,
|
||||
rel.SourceID, rel.TargetID,
|
||||
string(rel.RelationType), rel.Confidence,
|
||||
string(rel.DetectionSource), rel.Reason,
|
||||
rel.CreatedAt, rel.CreatedAtEpoch,
|
||||
)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
return tx.Commit()
|
||||
}
|
||||
|
||||
// GetRelationsByObservationID retrieves all relations involving an observation (as source or target).
|
||||
func (s *RelationStore) GetRelationsByObservationID(ctx context.Context, obsID int64) ([]*models.ObservationRelation, error) {
|
||||
const query = `
|
||||
SELECT id, source_id, target_id, relation_type, confidence, detection_source, reason,
|
||||
created_at, created_at_epoch
|
||||
FROM observation_relations
|
||||
WHERE source_id = ? OR target_id = ?
|
||||
ORDER BY confidence DESC, created_at_epoch DESC
|
||||
`
|
||||
|
||||
rows, err := s.store.QueryContext(ctx, query, obsID, obsID)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer rows.Close()
|
||||
|
||||
return s.scanRelationRows(rows)
|
||||
}
|
||||
|
||||
// GetOutgoingRelations retrieves relations where the observation is the source.
|
||||
func (s *RelationStore) GetOutgoingRelations(ctx context.Context, obsID int64) ([]*models.ObservationRelation, error) {
|
||||
const query = `
|
||||
SELECT id, source_id, target_id, relation_type, confidence, detection_source, reason,
|
||||
created_at, created_at_epoch
|
||||
FROM observation_relations
|
||||
WHERE source_id = ?
|
||||
ORDER BY confidence DESC, created_at_epoch DESC
|
||||
`
|
||||
|
||||
rows, err := s.store.QueryContext(ctx, query, obsID)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer rows.Close()
|
||||
|
||||
return s.scanRelationRows(rows)
|
||||
}
|
||||
|
||||
// GetIncomingRelations retrieves relations where the observation is the target.
|
||||
func (s *RelationStore) GetIncomingRelations(ctx context.Context, obsID int64) ([]*models.ObservationRelation, error) {
|
||||
const query = `
|
||||
SELECT id, source_id, target_id, relation_type, confidence, detection_source, reason,
|
||||
created_at, created_at_epoch
|
||||
FROM observation_relations
|
||||
WHERE target_id = ?
|
||||
ORDER BY confidence DESC, created_at_epoch DESC
|
||||
`
|
||||
|
||||
rows, err := s.store.QueryContext(ctx, query, obsID)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer rows.Close()
|
||||
|
||||
return s.scanRelationRows(rows)
|
||||
}
|
||||
|
||||
// GetRelationsByType retrieves all relations of a specific type.
|
||||
func (s *RelationStore) GetRelationsByType(ctx context.Context, relationType models.RelationType, limit int) ([]*models.ObservationRelation, error) {
|
||||
const query = `
|
||||
SELECT id, source_id, target_id, relation_type, confidence, detection_source, reason,
|
||||
created_at, created_at_epoch
|
||||
FROM observation_relations
|
||||
WHERE relation_type = ?
|
||||
ORDER BY confidence DESC, created_at_epoch DESC
|
||||
LIMIT ?
|
||||
`
|
||||
|
||||
rows, err := s.store.QueryContext(ctx, query, string(relationType), limit)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer rows.Close()
|
||||
|
||||
return s.scanRelationRows(rows)
|
||||
}
|
||||
|
||||
// GetRelationsWithDetails retrieves relations with observation titles for display.
|
||||
func (s *RelationStore) GetRelationsWithDetails(ctx context.Context, obsID int64) ([]*models.RelationWithDetails, error) {
|
||||
const query = `
|
||||
SELECT r.id, r.source_id, r.target_id, r.relation_type, r.confidence, r.detection_source, r.reason,
|
||||
r.created_at, r.created_at_epoch,
|
||||
COALESCE(src.title, '') as source_title,
|
||||
COALESCE(tgt.title, '') as target_title,
|
||||
src.type as source_type,
|
||||
tgt.type as target_type
|
||||
FROM observation_relations r
|
||||
JOIN observations src ON src.id = r.source_id
|
||||
JOIN observations tgt ON tgt.id = r.target_id
|
||||
WHERE r.source_id = ? OR r.target_id = ?
|
||||
ORDER BY r.confidence DESC, r.created_at_epoch DESC
|
||||
`
|
||||
|
||||
rows, err := s.store.QueryContext(ctx, query, obsID, obsID)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer rows.Close()
|
||||
|
||||
var results []*models.RelationWithDetails
|
||||
for rows.Next() {
|
||||
var r models.ObservationRelation
|
||||
var rwd models.RelationWithDetails
|
||||
var reason sql.NullString
|
||||
if err := rows.Scan(
|
||||
&r.ID, &r.SourceID, &r.TargetID,
|
||||
&r.RelationType, &r.Confidence, &r.DetectionSource, &reason,
|
||||
&r.CreatedAt, &r.CreatedAtEpoch,
|
||||
&rwd.SourceTitle, &rwd.TargetTitle,
|
||||
&rwd.SourceType, &rwd.TargetType,
|
||||
); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if reason.Valid {
|
||||
r.Reason = reason.String
|
||||
}
|
||||
rwd.Relation = &r
|
||||
results = append(results, &rwd)
|
||||
}
|
||||
return results, rows.Err()
|
||||
}
|
||||
|
||||
// GetRelationGraph retrieves a relation graph centered on an observation.
|
||||
// This returns all observations within N hops from the center.
|
||||
func (s *RelationStore) GetRelationGraph(ctx context.Context, centerID int64, maxDepth int) (*models.RelationGraph, error) {
|
||||
// Get all relations involving the center observation
|
||||
relations, err := s.GetRelationsWithDetails(ctx, centerID)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
graph := &models.RelationGraph{
|
||||
CenterID: centerID,
|
||||
Relations: relations,
|
||||
}
|
||||
|
||||
// If depth > 1, recursively get relations for connected observations
|
||||
if maxDepth > 1 {
|
||||
visited := map[int64]bool{centerID: true}
|
||||
toVisit := make([]int64, 0)
|
||||
|
||||
// Collect IDs of directly connected observations
|
||||
for _, r := range relations {
|
||||
if !visited[r.Relation.SourceID] {
|
||||
toVisit = append(toVisit, r.Relation.SourceID)
|
||||
visited[r.Relation.SourceID] = true
|
||||
}
|
||||
if !visited[r.Relation.TargetID] {
|
||||
toVisit = append(toVisit, r.Relation.TargetID)
|
||||
visited[r.Relation.TargetID] = true
|
||||
}
|
||||
}
|
||||
|
||||
// Get relations for connected observations (depth - 1)
|
||||
for depth := 1; depth < maxDepth && len(toVisit) > 0; depth++ {
|
||||
nextLevel := make([]int64, 0)
|
||||
for _, obsID := range toVisit {
|
||||
moreRelations, err := s.GetRelationsWithDetails(ctx, obsID)
|
||||
if err != nil {
|
||||
continue
|
||||
}
|
||||
for _, r := range moreRelations {
|
||||
// Avoid duplicates
|
||||
exists := false
|
||||
for _, existing := range graph.Relations {
|
||||
if existing.Relation.ID == r.Relation.ID {
|
||||
exists = true
|
||||
break
|
||||
}
|
||||
}
|
||||
if !exists {
|
||||
graph.Relations = append(graph.Relations, r)
|
||||
}
|
||||
|
||||
// Queue next level
|
||||
if !visited[r.Relation.SourceID] {
|
||||
nextLevel = append(nextLevel, r.Relation.SourceID)
|
||||
visited[r.Relation.SourceID] = true
|
||||
}
|
||||
if !visited[r.Relation.TargetID] {
|
||||
nextLevel = append(nextLevel, r.Relation.TargetID)
|
||||
visited[r.Relation.TargetID] = true
|
||||
}
|
||||
}
|
||||
}
|
||||
toVisit = nextLevel
|
||||
}
|
||||
}
|
||||
|
||||
return graph, nil
|
||||
}
|
||||
|
||||
// DeleteRelationsByObservationID deletes all relations involving an observation.
|
||||
// Called when an observation is deleted.
|
||||
func (s *RelationStore) DeleteRelationsByObservationID(ctx context.Context, obsID int64) error {
|
||||
const query = `DELETE FROM observation_relations WHERE source_id = ? OR target_id = ?`
|
||||
_, err := s.store.ExecContext(ctx, query, obsID, obsID)
|
||||
return err
|
||||
}
|
||||
|
||||
// GetRelationCount returns the count of relations for an observation.
|
||||
func (s *RelationStore) GetRelationCount(ctx context.Context, obsID int64) (int, error) {
|
||||
const query = `
|
||||
SELECT COUNT(*) FROM observation_relations
|
||||
WHERE source_id = ? OR target_id = ?
|
||||
`
|
||||
var count int
|
||||
err := s.store.QueryRowContext(ctx, query, obsID, obsID).Scan(&count)
|
||||
return count, err
|
||||
}
|
||||
|
||||
// GetTotalRelationCount returns the total count of all relations.
|
||||
func (s *RelationStore) GetTotalRelationCount(ctx context.Context) (int, error) {
|
||||
const query = `SELECT COUNT(*) FROM observation_relations`
|
||||
var count int
|
||||
err := s.store.QueryRowContext(ctx, query).Scan(&count)
|
||||
return count, err
|
||||
}
|
||||
|
||||
// GetHighConfidenceRelations retrieves relations with confidence above threshold.
|
||||
func (s *RelationStore) GetHighConfidenceRelations(ctx context.Context, minConfidence float64, limit int) ([]*models.ObservationRelation, error) {
|
||||
const query = `
|
||||
SELECT id, source_id, target_id, relation_type, confidence, detection_source, reason,
|
||||
created_at, created_at_epoch
|
||||
FROM observation_relations
|
||||
WHERE confidence >= ?
|
||||
ORDER BY confidence DESC, created_at_epoch DESC
|
||||
LIMIT ?
|
||||
`
|
||||
|
||||
rows, err := s.store.QueryContext(ctx, query, minConfidence, limit)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer rows.Close()
|
||||
|
||||
return s.scanRelationRows(rows)
|
||||
}
|
||||
|
||||
// UpdateRelationConfidence updates the confidence of a relation.
|
||||
func (s *RelationStore) UpdateRelationConfidence(ctx context.Context, relationID int64, newConfidence float64) error {
|
||||
const query = `UPDATE observation_relations SET confidence = ? WHERE id = ?`
|
||||
_, err := s.store.ExecContext(ctx, query, newConfidence, relationID)
|
||||
return err
|
||||
}
|
||||
|
||||
// GetRelatedObservationIDs returns IDs of observations related to the given one.
|
||||
// This is useful for expanding search results.
|
||||
func (s *RelationStore) GetRelatedObservationIDs(ctx context.Context, obsID int64, minConfidence float64) ([]int64, error) {
|
||||
const query = `
|
||||
SELECT DISTINCT CASE WHEN source_id = ? THEN target_id ELSE source_id END as related_id
|
||||
FROM observation_relations
|
||||
WHERE (source_id = ? OR target_id = ?) AND confidence >= ?
|
||||
`
|
||||
|
||||
rows, err := s.store.QueryContext(ctx, query, obsID, obsID, obsID, minConfidence)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer rows.Close()
|
||||
|
||||
var ids []int64
|
||||
for rows.Next() {
|
||||
var id int64
|
||||
if err := rows.Scan(&id); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
ids = append(ids, id)
|
||||
}
|
||||
return ids, rows.Err()
|
||||
}
|
||||
|
||||
// scanRelationRows scans multiple relations from rows.
|
||||
func (s *RelationStore) scanRelationRows(rows *sql.Rows) ([]*models.ObservationRelation, error) {
|
||||
var relations []*models.ObservationRelation
|
||||
for rows.Next() {
|
||||
var r models.ObservationRelation
|
||||
var reason sql.NullString
|
||||
if err := rows.Scan(
|
||||
&r.ID, &r.SourceID, &r.TargetID,
|
||||
&r.RelationType, &r.Confidence, &r.DetectionSource, &reason,
|
||||
&r.CreatedAt, &r.CreatedAtEpoch,
|
||||
); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if reason.Valid {
|
||||
r.Reason = reason.String
|
||||
}
|
||||
relations = append(relations, &r)
|
||||
}
|
||||
return relations, rows.Err()
|
||||
}
|
||||
@@ -1,324 +0,0 @@
|
||||
// Package sqlite provides SQLite database operations for claude-mnemonic.
|
||||
package sqlite
|
||||
|
||||
import (
|
||||
"context"
|
||||
"database/sql"
|
||||
"time"
|
||||
|
||||
"github.com/lukaszraczylo/claude-mnemonic/pkg/models"
|
||||
)
|
||||
|
||||
// UpdateObservationFeedback updates the user feedback for an observation.
|
||||
// Feedback values: -1 (thumbs down), 0 (neutral), 1 (thumbs up).
|
||||
func (s *ObservationStore) UpdateObservationFeedback(ctx context.Context, id int64, feedback int) error {
|
||||
const query = `
|
||||
UPDATE observations
|
||||
SET user_feedback = ?, score_updated_at_epoch = ?
|
||||
WHERE id = ?
|
||||
`
|
||||
_, err := s.store.ExecContext(ctx, query, feedback, time.Now().UnixMilli(), id)
|
||||
return err
|
||||
}
|
||||
|
||||
// IncrementRetrievalCount increments the retrieval counter for the given observation IDs.
|
||||
// This is called when observations are returned in search results.
|
||||
func (s *ObservationStore) IncrementRetrievalCount(ctx context.Context, ids []int64) error {
|
||||
if len(ids) == 0 {
|
||||
return nil
|
||||
}
|
||||
|
||||
now := time.Now().UnixMilli()
|
||||
|
||||
// Build query with placeholders
|
||||
// #nosec G202 -- query uses parameterized placeholders, not user input
|
||||
query := `
|
||||
UPDATE observations
|
||||
SET retrieval_count = COALESCE(retrieval_count, 0) + 1,
|
||||
last_retrieved_at_epoch = ?
|
||||
WHERE id IN (?` + repeatPlaceholders(len(ids)-1) + `)
|
||||
`
|
||||
|
||||
args := make([]interface{}, 0, len(ids)+1)
|
||||
args = append(args, now)
|
||||
for _, id := range ids {
|
||||
args = append(args, id)
|
||||
}
|
||||
|
||||
_, err := s.store.db.ExecContext(ctx, query, args...)
|
||||
return err
|
||||
}
|
||||
|
||||
// UpdateImportanceScore updates the importance score for a single observation.
|
||||
func (s *ObservationStore) UpdateImportanceScore(ctx context.Context, id int64, score float64) error {
|
||||
const query = `
|
||||
UPDATE observations
|
||||
SET importance_score = ?, score_updated_at_epoch = ?
|
||||
WHERE id = ?
|
||||
`
|
||||
_, err := s.store.ExecContext(ctx, query, score, time.Now().UnixMilli(), id)
|
||||
return err
|
||||
}
|
||||
|
||||
// UpdateImportanceScores bulk updates importance scores for multiple observations.
|
||||
// This is more efficient than individual updates for batch recalculation.
|
||||
func (s *ObservationStore) UpdateImportanceScores(ctx context.Context, scores map[int64]float64) error {
|
||||
if len(scores) == 0 {
|
||||
return nil
|
||||
}
|
||||
|
||||
tx, err := s.store.db.BeginTx(ctx, nil)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
defer tx.Rollback()
|
||||
|
||||
now := time.Now().UnixMilli()
|
||||
stmt, err := tx.PrepareContext(ctx, `
|
||||
UPDATE observations
|
||||
SET importance_score = ?, score_updated_at_epoch = ?
|
||||
WHERE id = ?
|
||||
`)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
defer stmt.Close()
|
||||
|
||||
for id, score := range scores {
|
||||
if _, err := stmt.ExecContext(ctx, score, now, id); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
return tx.Commit()
|
||||
}
|
||||
|
||||
// GetObservationsNeedingScoreUpdate returns observations that need their importance score recalculated.
|
||||
// Returns observations where score_updated_at_epoch is NULL or older than the threshold.
|
||||
func (s *ObservationStore) GetObservationsNeedingScoreUpdate(ctx context.Context, threshold time.Duration, limit int) ([]*models.Observation, error) {
|
||||
cutoff := time.Now().Add(-threshold).UnixMilli()
|
||||
|
||||
query := `SELECT ` + observationColumns + `
|
||||
FROM observations
|
||||
WHERE score_updated_at_epoch IS NULL OR score_updated_at_epoch < ?
|
||||
ORDER BY created_at_epoch DESC
|
||||
LIMIT ?
|
||||
`
|
||||
|
||||
rows, err := s.store.QueryContext(ctx, query, cutoff, limit)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer rows.Close()
|
||||
|
||||
return scanObservationRows(rows)
|
||||
}
|
||||
|
||||
// GetConceptWeights returns all concept weights from the database.
|
||||
func (s *ObservationStore) GetConceptWeights(ctx context.Context) (map[string]float64, error) {
|
||||
const query = `SELECT concept, weight FROM concept_weights`
|
||||
|
||||
rows, err := s.store.QueryContext(ctx, query)
|
||||
if err != nil {
|
||||
// Table might not exist in older databases
|
||||
if err == sql.ErrNoRows {
|
||||
return models.DefaultConceptWeights, nil
|
||||
}
|
||||
return nil, err
|
||||
}
|
||||
defer rows.Close()
|
||||
|
||||
weights := make(map[string]float64)
|
||||
for rows.Next() {
|
||||
var concept string
|
||||
var weight float64
|
||||
if err := rows.Scan(&concept, &weight); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
weights[concept] = weight
|
||||
}
|
||||
|
||||
if err := rows.Err(); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// If no weights found, use defaults
|
||||
if len(weights) == 0 {
|
||||
return models.DefaultConceptWeights, nil
|
||||
}
|
||||
|
||||
return weights, nil
|
||||
}
|
||||
|
||||
// UpdateConceptWeight updates a single concept weight.
|
||||
func (s *ObservationStore) UpdateConceptWeight(ctx context.Context, concept string, weight float64) error {
|
||||
const query = `
|
||||
INSERT INTO concept_weights (concept, weight, updated_at)
|
||||
VALUES (?, ?, datetime('now'))
|
||||
ON CONFLICT(concept) DO UPDATE SET weight = excluded.weight, updated_at = excluded.updated_at
|
||||
`
|
||||
_, err := s.store.ExecContext(ctx, query, concept, weight)
|
||||
return err
|
||||
}
|
||||
|
||||
// UpdateConceptWeights bulk updates multiple concept weights.
|
||||
func (s *ObservationStore) UpdateConceptWeights(ctx context.Context, weights map[string]float64) error {
|
||||
if len(weights) == 0 {
|
||||
return nil
|
||||
}
|
||||
|
||||
tx, err := s.store.db.BeginTx(ctx, nil)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
defer tx.Rollback()
|
||||
|
||||
stmt, err := tx.PrepareContext(ctx, `
|
||||
INSERT INTO concept_weights (concept, weight, updated_at)
|
||||
VALUES (?, ?, datetime('now'))
|
||||
ON CONFLICT(concept) DO UPDATE SET weight = excluded.weight, updated_at = excluded.updated_at
|
||||
`)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
defer stmt.Close()
|
||||
|
||||
for concept, weight := range weights {
|
||||
if _, err := stmt.ExecContext(ctx, concept, weight); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
return tx.Commit()
|
||||
}
|
||||
|
||||
// GetObservationFeedbackStats returns statistics about user feedback.
|
||||
func (s *ObservationStore) GetObservationFeedbackStats(ctx context.Context, project string) (*FeedbackStats, error) {
|
||||
var query string
|
||||
var args []interface{}
|
||||
|
||||
if project == "" {
|
||||
query = `
|
||||
SELECT
|
||||
COUNT(*) as total,
|
||||
COALESCE(SUM(CASE WHEN user_feedback = 1 THEN 1 ELSE 0 END), 0) as positive,
|
||||
COALESCE(SUM(CASE WHEN user_feedback = -1 THEN 1 ELSE 0 END), 0) as negative,
|
||||
COALESCE(SUM(CASE WHEN user_feedback = 0 THEN 1 ELSE 0 END), 0) as neutral,
|
||||
COALESCE(AVG(COALESCE(importance_score, 1.0)), 0) as avg_score,
|
||||
COALESCE(AVG(COALESCE(retrieval_count, 0)), 0) as avg_retrieval
|
||||
FROM observations
|
||||
`
|
||||
} else {
|
||||
query = `
|
||||
SELECT
|
||||
COUNT(*) as total,
|
||||
COALESCE(SUM(CASE WHEN user_feedback = 1 THEN 1 ELSE 0 END), 0) as positive,
|
||||
COALESCE(SUM(CASE WHEN user_feedback = -1 THEN 1 ELSE 0 END), 0) as negative,
|
||||
COALESCE(SUM(CASE WHEN user_feedback = 0 THEN 1 ELSE 0 END), 0) as neutral,
|
||||
COALESCE(AVG(COALESCE(importance_score, 1.0)), 0) as avg_score,
|
||||
COALESCE(AVG(COALESCE(retrieval_count, 0)), 0) as avg_retrieval
|
||||
FROM observations
|
||||
WHERE project = ? OR scope = 'global'
|
||||
`
|
||||
args = append(args, project)
|
||||
}
|
||||
|
||||
var stats FeedbackStats
|
||||
err := s.store.QueryRowContext(ctx, query, args...).Scan(
|
||||
&stats.Total,
|
||||
&stats.Positive,
|
||||
&stats.Negative,
|
||||
&stats.Neutral,
|
||||
&stats.AvgScore,
|
||||
&stats.AvgRetrieval,
|
||||
)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return &stats, nil
|
||||
}
|
||||
|
||||
// FeedbackStats contains statistics about observation feedback and scoring.
|
||||
type FeedbackStats struct {
|
||||
Total int `json:"total"`
|
||||
Positive int `json:"positive"`
|
||||
Negative int `json:"negative"`
|
||||
Neutral int `json:"neutral"`
|
||||
AvgScore float64 `json:"avg_score"`
|
||||
AvgRetrieval float64 `json:"avg_retrieval"`
|
||||
}
|
||||
|
||||
// GetTopScoringObservations returns the highest-scoring observations.
|
||||
func (s *ObservationStore) GetTopScoringObservations(ctx context.Context, project string, limit int) ([]*models.Observation, error) {
|
||||
var query string
|
||||
var args []interface{}
|
||||
|
||||
if project == "" {
|
||||
query = `SELECT ` + observationColumns + `
|
||||
FROM observations
|
||||
ORDER BY COALESCE(importance_score, 1.0) DESC, created_at_epoch DESC
|
||||
LIMIT ?
|
||||
`
|
||||
args = append(args, limit)
|
||||
} else {
|
||||
query = `SELECT ` + observationColumns + `
|
||||
FROM observations
|
||||
WHERE project = ? OR scope = 'global'
|
||||
ORDER BY COALESCE(importance_score, 1.0) DESC, created_at_epoch DESC
|
||||
LIMIT ?
|
||||
`
|
||||
args = append(args, project, limit)
|
||||
}
|
||||
|
||||
rows, err := s.store.QueryContext(ctx, query, args...)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer rows.Close()
|
||||
|
||||
return scanObservationRows(rows)
|
||||
}
|
||||
|
||||
// GetMostRetrievedObservations returns the most frequently retrieved observations.
|
||||
func (s *ObservationStore) GetMostRetrievedObservations(ctx context.Context, project string, limit int) ([]*models.Observation, error) {
|
||||
var query string
|
||||
var args []interface{}
|
||||
|
||||
if project == "" {
|
||||
query = `SELECT ` + observationColumns + `
|
||||
FROM observations
|
||||
WHERE retrieval_count > 0
|
||||
ORDER BY retrieval_count DESC, created_at_epoch DESC
|
||||
LIMIT ?
|
||||
`
|
||||
args = append(args, limit)
|
||||
} else {
|
||||
query = `SELECT ` + observationColumns + `
|
||||
FROM observations
|
||||
WHERE (project = ? OR scope = 'global') AND retrieval_count > 0
|
||||
ORDER BY retrieval_count DESC, created_at_epoch DESC
|
||||
LIMIT ?
|
||||
`
|
||||
args = append(args, project, limit)
|
||||
}
|
||||
|
||||
rows, err := s.store.QueryContext(ctx, query, args...)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer rows.Close()
|
||||
|
||||
return scanObservationRows(rows)
|
||||
}
|
||||
|
||||
// ResetObservationScores resets all observation scores to their default values.
|
||||
// This is useful for testing or when changing the scoring algorithm.
|
||||
func (s *ObservationStore) ResetObservationScores(ctx context.Context) error {
|
||||
const query = `
|
||||
UPDATE observations
|
||||
SET importance_score = 1.0, score_updated_at_epoch = NULL
|
||||
`
|
||||
_, err := s.store.ExecContext(ctx, query)
|
||||
return err
|
||||
}
|
||||
@@ -1,698 +0,0 @@
|
||||
// Package sqlite provides SQLite database operations for claude-mnemonic.
|
||||
package sqlite
|
||||
|
||||
import (
|
||||
"context"
|
||||
"database/sql"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/lukaszraczylo/claude-mnemonic/pkg/models"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
"github.com/stretchr/testify/suite"
|
||||
)
|
||||
|
||||
// testScoringObservationStore creates an ObservationStore with scoring columns for testing.
|
||||
func testScoringObservationStore(t *testing.T) (*ObservationStore, *Store, func()) {
|
||||
t.Helper()
|
||||
|
||||
db, _, cleanup := testDB(t)
|
||||
createBaseTables(t, db)
|
||||
createConceptWeightsTable(t, db)
|
||||
|
||||
// Add importance index if not exists (columns already in createBaseTables)
|
||||
if _, err := db.Exec(`CREATE INDEX IF NOT EXISTS idx_observations_importance ON observations(importance_score DESC, created_at_epoch DESC)`); err != nil {
|
||||
t.Fatalf("create importance index: %v", err)
|
||||
}
|
||||
|
||||
store := newStoreFromDB(db)
|
||||
obsStore := NewObservationStore(store)
|
||||
|
||||
return obsStore, store, cleanup
|
||||
}
|
||||
|
||||
// createConceptWeightsTable creates the concept_weights table for testing.
|
||||
func createConceptWeightsTable(t *testing.T, db *sql.DB) {
|
||||
t.Helper()
|
||||
|
||||
_, err := db.Exec(`
|
||||
CREATE TABLE IF NOT EXISTS concept_weights (
|
||||
concept TEXT PRIMARY KEY,
|
||||
weight REAL NOT NULL DEFAULT 0.1,
|
||||
updated_at TEXT NOT NULL
|
||||
)
|
||||
`)
|
||||
if err != nil {
|
||||
t.Fatalf("create concept_weights: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
// ScoringStoreSuite is a test suite for scoring-related database operations.
|
||||
type ScoringStoreSuite struct {
|
||||
suite.Suite
|
||||
obsStore *ObservationStore
|
||||
store *Store
|
||||
cleanup func()
|
||||
ctx context.Context
|
||||
}
|
||||
|
||||
func (s *ScoringStoreSuite) SetupTest() {
|
||||
s.obsStore, s.store, s.cleanup = testScoringObservationStore(s.T())
|
||||
s.ctx = context.Background()
|
||||
}
|
||||
|
||||
func (s *ScoringStoreSuite) TearDownTest() {
|
||||
if s.cleanup != nil {
|
||||
s.cleanup()
|
||||
}
|
||||
}
|
||||
|
||||
func TestScoringStoreSuite(t *testing.T) {
|
||||
suite.Run(t, new(ScoringStoreSuite))
|
||||
}
|
||||
|
||||
// =============================================================================
|
||||
// FEEDBACK TESTS
|
||||
// =============================================================================
|
||||
|
||||
func (s *ScoringStoreSuite) TestUpdateObservationFeedback_Positive() {
|
||||
// Create observation
|
||||
obs := &models.ParsedObservation{
|
||||
Type: models.ObsTypeBugfix,
|
||||
Title: "Test feedback",
|
||||
}
|
||||
id, _, err := s.obsStore.StoreObservation(s.ctx, "session-1", "project-a", obs, 1, 100)
|
||||
s.NoError(err)
|
||||
|
||||
// Update feedback to positive
|
||||
err = s.obsStore.UpdateObservationFeedback(s.ctx, id, 1)
|
||||
s.NoError(err)
|
||||
|
||||
// Verify
|
||||
retrieved, err := s.obsStore.GetObservationByID(s.ctx, id)
|
||||
s.NoError(err)
|
||||
s.Equal(1, retrieved.UserFeedback)
|
||||
s.True(retrieved.ScoreUpdatedAt.Valid)
|
||||
}
|
||||
|
||||
func (s *ScoringStoreSuite) TestUpdateObservationFeedback_Negative() {
|
||||
obs := &models.ParsedObservation{
|
||||
Type: models.ObsTypeDiscovery,
|
||||
}
|
||||
id, _, err := s.obsStore.StoreObservation(s.ctx, "session-1", "project-a", obs, 1, 100)
|
||||
s.NoError(err)
|
||||
|
||||
err = s.obsStore.UpdateObservationFeedback(s.ctx, id, -1)
|
||||
s.NoError(err)
|
||||
|
||||
retrieved, err := s.obsStore.GetObservationByID(s.ctx, id)
|
||||
s.NoError(err)
|
||||
s.Equal(-1, retrieved.UserFeedback)
|
||||
}
|
||||
|
||||
func (s *ScoringStoreSuite) TestUpdateObservationFeedback_Neutral() {
|
||||
obs := &models.ParsedObservation{
|
||||
Type: models.ObsTypeChange,
|
||||
}
|
||||
id, _, err := s.obsStore.StoreObservation(s.ctx, "session-1", "project-a", obs, 1, 100)
|
||||
s.NoError(err)
|
||||
|
||||
// First set to positive
|
||||
err = s.obsStore.UpdateObservationFeedback(s.ctx, id, 1)
|
||||
s.NoError(err)
|
||||
|
||||
// Then reset to neutral
|
||||
err = s.obsStore.UpdateObservationFeedback(s.ctx, id, 0)
|
||||
s.NoError(err)
|
||||
|
||||
retrieved, err := s.obsStore.GetObservationByID(s.ctx, id)
|
||||
s.NoError(err)
|
||||
s.Equal(0, retrieved.UserFeedback)
|
||||
}
|
||||
|
||||
func (s *ScoringStoreSuite) TestUpdateObservationFeedback_NonExistent() {
|
||||
// Updating non-existent observation should not fail (just no rows affected)
|
||||
err := s.obsStore.UpdateObservationFeedback(s.ctx, 99999, 1)
|
||||
s.NoError(err)
|
||||
}
|
||||
|
||||
// =============================================================================
|
||||
// RETRIEVAL COUNT TESTS
|
||||
// =============================================================================
|
||||
|
||||
func (s *ScoringStoreSuite) TestIncrementRetrievalCount_Single() {
|
||||
obs := &models.ParsedObservation{
|
||||
Type: models.ObsTypeBugfix,
|
||||
}
|
||||
id, _, err := s.obsStore.StoreObservation(s.ctx, "session-1", "project-a", obs, 1, 100)
|
||||
s.NoError(err)
|
||||
|
||||
err = s.obsStore.IncrementRetrievalCount(s.ctx, []int64{id})
|
||||
s.NoError(err)
|
||||
|
||||
retrieved, err := s.obsStore.GetObservationByID(s.ctx, id)
|
||||
s.NoError(err)
|
||||
s.Equal(1, retrieved.RetrievalCount)
|
||||
s.True(retrieved.LastRetrievedAt.Valid)
|
||||
}
|
||||
|
||||
func (s *ScoringStoreSuite) TestIncrementRetrievalCount_Multiple() {
|
||||
var ids []int64
|
||||
for i := 0; i < 3; i++ {
|
||||
obs := &models.ParsedObservation{
|
||||
Type: models.ObsTypeDiscovery,
|
||||
}
|
||||
id, _, err := s.obsStore.StoreObservation(s.ctx, "session-1", "project-a", obs, i, 100)
|
||||
s.NoError(err)
|
||||
ids = append(ids, id)
|
||||
}
|
||||
|
||||
err := s.obsStore.IncrementRetrievalCount(s.ctx, ids)
|
||||
s.NoError(err)
|
||||
|
||||
for _, id := range ids {
|
||||
retrieved, err := s.obsStore.GetObservationByID(s.ctx, id)
|
||||
s.NoError(err)
|
||||
s.Equal(1, retrieved.RetrievalCount)
|
||||
}
|
||||
}
|
||||
|
||||
func (s *ScoringStoreSuite) TestIncrementRetrievalCount_Cumulative() {
|
||||
obs := &models.ParsedObservation{
|
||||
Type: models.ObsTypeBugfix,
|
||||
}
|
||||
id, _, err := s.obsStore.StoreObservation(s.ctx, "session-1", "project-a", obs, 1, 100)
|
||||
s.NoError(err)
|
||||
|
||||
// Increment multiple times
|
||||
for i := 0; i < 5; i++ {
|
||||
err = s.obsStore.IncrementRetrievalCount(s.ctx, []int64{id})
|
||||
s.NoError(err)
|
||||
}
|
||||
|
||||
retrieved, err := s.obsStore.GetObservationByID(s.ctx, id)
|
||||
s.NoError(err)
|
||||
s.Equal(5, retrieved.RetrievalCount)
|
||||
}
|
||||
|
||||
func (s *ScoringStoreSuite) TestIncrementRetrievalCount_Empty() {
|
||||
err := s.obsStore.IncrementRetrievalCount(s.ctx, []int64{})
|
||||
s.NoError(err)
|
||||
}
|
||||
|
||||
// =============================================================================
|
||||
// IMPORTANCE SCORE TESTS
|
||||
// =============================================================================
|
||||
|
||||
func (s *ScoringStoreSuite) TestUpdateImportanceScore_Single() {
|
||||
obs := &models.ParsedObservation{
|
||||
Type: models.ObsTypeBugfix,
|
||||
}
|
||||
id, _, err := s.obsStore.StoreObservation(s.ctx, "session-1", "project-a", obs, 1, 100)
|
||||
s.NoError(err)
|
||||
|
||||
err = s.obsStore.UpdateImportanceScore(s.ctx, id, 1.5)
|
||||
s.NoError(err)
|
||||
|
||||
retrieved, err := s.obsStore.GetObservationByID(s.ctx, id)
|
||||
s.NoError(err)
|
||||
s.InDelta(1.5, retrieved.ImportanceScore, 0.001)
|
||||
}
|
||||
|
||||
func (s *ScoringStoreSuite) TestUpdateImportanceScores_Batch() {
|
||||
var ids []int64
|
||||
for i := 0; i < 5; i++ {
|
||||
obs := &models.ParsedObservation{
|
||||
Type: models.ObsTypeDiscovery,
|
||||
}
|
||||
id, _, err := s.obsStore.StoreObservation(s.ctx, "session-1", "project-a", obs, i, 100)
|
||||
s.NoError(err)
|
||||
ids = append(ids, id)
|
||||
}
|
||||
|
||||
scores := map[int64]float64{
|
||||
ids[0]: 1.5,
|
||||
ids[1]: 0.8,
|
||||
ids[2]: 1.2,
|
||||
ids[3]: 0.5,
|
||||
ids[4]: 2.0,
|
||||
}
|
||||
|
||||
err := s.obsStore.UpdateImportanceScores(s.ctx, scores)
|
||||
s.NoError(err)
|
||||
|
||||
for id, expectedScore := range scores {
|
||||
retrieved, err := s.obsStore.GetObservationByID(s.ctx, id)
|
||||
s.NoError(err)
|
||||
s.InDelta(expectedScore, retrieved.ImportanceScore, 0.001)
|
||||
}
|
||||
}
|
||||
|
||||
func (s *ScoringStoreSuite) TestUpdateImportanceScores_Empty() {
|
||||
err := s.obsStore.UpdateImportanceScores(s.ctx, map[int64]float64{})
|
||||
s.NoError(err)
|
||||
}
|
||||
|
||||
// =============================================================================
|
||||
// OBSERVATIONS NEEDING SCORE UPDATE TESTS
|
||||
// =============================================================================
|
||||
|
||||
func (s *ScoringStoreSuite) TestGetObservationsNeedingScoreUpdate_NeverUpdated() {
|
||||
// Observations without score_updated_at_epoch should need update
|
||||
obs := &models.ParsedObservation{
|
||||
Type: models.ObsTypeBugfix,
|
||||
}
|
||||
id, _, err := s.obsStore.StoreObservation(s.ctx, "session-1", "project-a", obs, 1, 100)
|
||||
s.NoError(err)
|
||||
|
||||
observations, err := s.obsStore.GetObservationsNeedingScoreUpdate(s.ctx, 6*time.Hour, 100)
|
||||
s.NoError(err)
|
||||
s.Len(observations, 1)
|
||||
s.Equal(id, observations[0].ID)
|
||||
}
|
||||
|
||||
func (s *ScoringStoreSuite) TestGetObservationsNeedingScoreUpdate_RecentlyUpdated() {
|
||||
obs := &models.ParsedObservation{
|
||||
Type: models.ObsTypeBugfix,
|
||||
}
|
||||
id, _, err := s.obsStore.StoreObservation(s.ctx, "session-1", "project-a", obs, 1, 100)
|
||||
s.NoError(err)
|
||||
|
||||
// Update score (this sets score_updated_at_epoch)
|
||||
err = s.obsStore.UpdateImportanceScore(s.ctx, id, 1.5)
|
||||
s.NoError(err)
|
||||
|
||||
// Should not need update (just updated)
|
||||
observations, err := s.obsStore.GetObservationsNeedingScoreUpdate(s.ctx, 6*time.Hour, 100)
|
||||
s.NoError(err)
|
||||
s.Empty(observations)
|
||||
}
|
||||
|
||||
func (s *ScoringStoreSuite) TestGetObservationsNeedingScoreUpdate_Limit() {
|
||||
// Create 10 observations
|
||||
for i := 0; i < 10; i++ {
|
||||
obs := &models.ParsedObservation{
|
||||
Type: models.ObsTypeDiscovery,
|
||||
}
|
||||
_, _, err := s.obsStore.StoreObservation(s.ctx, "session-1", "project-a", obs, i, 100)
|
||||
s.NoError(err)
|
||||
}
|
||||
|
||||
// Request only 5
|
||||
observations, err := s.obsStore.GetObservationsNeedingScoreUpdate(s.ctx, 6*time.Hour, 5)
|
||||
s.NoError(err)
|
||||
s.Len(observations, 5)
|
||||
}
|
||||
|
||||
// =============================================================================
|
||||
// CONCEPT WEIGHTS TESTS
|
||||
// =============================================================================
|
||||
|
||||
func (s *ScoringStoreSuite) TestGetConceptWeights_Empty() {
|
||||
weights, err := s.obsStore.GetConceptWeights(s.ctx)
|
||||
s.NoError(err)
|
||||
s.Equal(models.DefaultConceptWeights, weights)
|
||||
}
|
||||
|
||||
func (s *ScoringStoreSuite) TestUpdateConceptWeight_NewConcept() {
|
||||
err := s.obsStore.UpdateConceptWeight(s.ctx, "new-concept", 0.42)
|
||||
s.NoError(err)
|
||||
|
||||
weights, err := s.obsStore.GetConceptWeights(s.ctx)
|
||||
s.NoError(err)
|
||||
s.Equal(0.42, weights["new-concept"])
|
||||
}
|
||||
|
||||
func (s *ScoringStoreSuite) TestUpdateConceptWeight_UpdateExisting() {
|
||||
// Insert first
|
||||
err := s.obsStore.UpdateConceptWeight(s.ctx, "test-concept", 0.1)
|
||||
s.NoError(err)
|
||||
|
||||
// Update
|
||||
err = s.obsStore.UpdateConceptWeight(s.ctx, "test-concept", 0.9)
|
||||
s.NoError(err)
|
||||
|
||||
weights, err := s.obsStore.GetConceptWeights(s.ctx)
|
||||
s.NoError(err)
|
||||
s.Equal(0.9, weights["test-concept"])
|
||||
}
|
||||
|
||||
func (s *ScoringStoreSuite) TestUpdateConceptWeights_Batch() {
|
||||
weightsToSet := map[string]float64{
|
||||
"security": 0.5,
|
||||
"performance": 0.3,
|
||||
"testing": 0.2,
|
||||
}
|
||||
|
||||
err := s.obsStore.UpdateConceptWeights(s.ctx, weightsToSet)
|
||||
s.NoError(err)
|
||||
|
||||
retrieved, err := s.obsStore.GetConceptWeights(s.ctx)
|
||||
s.NoError(err)
|
||||
|
||||
for concept, expected := range weightsToSet {
|
||||
s.Equal(expected, retrieved[concept])
|
||||
}
|
||||
}
|
||||
|
||||
func (s *ScoringStoreSuite) TestUpdateConceptWeights_Empty() {
|
||||
err := s.obsStore.UpdateConceptWeights(s.ctx, map[string]float64{})
|
||||
s.NoError(err)
|
||||
}
|
||||
|
||||
// =============================================================================
|
||||
// FEEDBACK STATS TESTS
|
||||
// =============================================================================
|
||||
|
||||
func (s *ScoringStoreSuite) TestGetObservationFeedbackStats_Empty() {
|
||||
stats, err := s.obsStore.GetObservationFeedbackStats(s.ctx, "")
|
||||
s.NoError(err)
|
||||
s.Equal(0, stats.Total)
|
||||
s.Equal(0, stats.Positive)
|
||||
s.Equal(0, stats.Negative)
|
||||
s.Equal(0, stats.Neutral)
|
||||
}
|
||||
|
||||
func (s *ScoringStoreSuite) TestGetObservationFeedbackStats_WithData() {
|
||||
// Create observations with different feedback
|
||||
feedbacks := []int{1, 1, 1, -1, -1, 0, 0, 0, 0, 0}
|
||||
for i, fb := range feedbacks {
|
||||
obs := &models.ParsedObservation{
|
||||
Type: models.ObsTypeDiscovery,
|
||||
}
|
||||
id, _, err := s.obsStore.StoreObservation(s.ctx, "session-1", "project-a", obs, i, 100)
|
||||
s.NoError(err)
|
||||
if fb != 0 {
|
||||
err = s.obsStore.UpdateObservationFeedback(s.ctx, id, fb)
|
||||
s.NoError(err)
|
||||
}
|
||||
}
|
||||
|
||||
stats, err := s.obsStore.GetObservationFeedbackStats(s.ctx, "")
|
||||
s.NoError(err)
|
||||
s.Equal(10, stats.Total)
|
||||
s.Equal(3, stats.Positive)
|
||||
s.Equal(2, stats.Negative)
|
||||
s.Equal(5, stats.Neutral)
|
||||
}
|
||||
|
||||
func (s *ScoringStoreSuite) TestGetObservationFeedbackStats_ByProject() {
|
||||
// Project A observations
|
||||
for i := 0; i < 5; i++ {
|
||||
obs := &models.ParsedObservation{
|
||||
Type: models.ObsTypeBugfix,
|
||||
}
|
||||
id, _, err := s.obsStore.StoreObservation(s.ctx, "session-1", "project-a", obs, i, 100)
|
||||
s.NoError(err)
|
||||
_ = s.obsStore.UpdateObservationFeedback(s.ctx, id, 1)
|
||||
}
|
||||
|
||||
// Project B observations
|
||||
for i := 0; i < 3; i++ {
|
||||
obs := &models.ParsedObservation{
|
||||
Type: models.ObsTypeFeature,
|
||||
}
|
||||
id, _, err := s.obsStore.StoreObservation(s.ctx, "session-1", "project-b", obs, i, 100)
|
||||
s.NoError(err)
|
||||
_ = s.obsStore.UpdateObservationFeedback(s.ctx, id, -1)
|
||||
}
|
||||
|
||||
// Check project A stats
|
||||
statsA, err := s.obsStore.GetObservationFeedbackStats(s.ctx, "project-a")
|
||||
s.NoError(err)
|
||||
s.Equal(5, statsA.Total)
|
||||
s.Equal(5, statsA.Positive)
|
||||
|
||||
// Check project B stats
|
||||
statsB, err := s.obsStore.GetObservationFeedbackStats(s.ctx, "project-b")
|
||||
s.NoError(err)
|
||||
s.Equal(3, statsB.Total)
|
||||
s.Equal(3, statsB.Negative)
|
||||
}
|
||||
|
||||
// =============================================================================
|
||||
// TOP SCORING OBSERVATIONS TESTS
|
||||
// =============================================================================
|
||||
|
||||
func (s *ScoringStoreSuite) TestGetTopScoringObservations() {
|
||||
// Create observations with different scores
|
||||
for i := 0; i < 5; i++ {
|
||||
obs := &models.ParsedObservation{
|
||||
Type: models.ObsTypeDiscovery,
|
||||
}
|
||||
id, _, err := s.obsStore.StoreObservation(s.ctx, "session-1", "project-a", obs, i, 100)
|
||||
s.NoError(err)
|
||||
// Set different scores
|
||||
err = s.obsStore.UpdateImportanceScore(s.ctx, id, float64(i+1)*0.5)
|
||||
s.NoError(err)
|
||||
}
|
||||
|
||||
// Get top 3
|
||||
top, err := s.obsStore.GetTopScoringObservations(s.ctx, "", 3)
|
||||
s.NoError(err)
|
||||
s.Len(top, 3)
|
||||
|
||||
// Verify ordered by score descending
|
||||
s.GreaterOrEqual(top[0].ImportanceScore, top[1].ImportanceScore)
|
||||
s.GreaterOrEqual(top[1].ImportanceScore, top[2].ImportanceScore)
|
||||
}
|
||||
|
||||
func (s *ScoringStoreSuite) TestGetTopScoringObservations_ByProject() {
|
||||
// Project A with high scores
|
||||
for i := 0; i < 3; i++ {
|
||||
obs := &models.ParsedObservation{
|
||||
Type: models.ObsTypeBugfix,
|
||||
}
|
||||
id, _, err := s.obsStore.StoreObservation(s.ctx, "session-1", "project-a", obs, i, 100)
|
||||
s.NoError(err)
|
||||
_ = s.obsStore.UpdateImportanceScore(s.ctx, id, 2.0)
|
||||
}
|
||||
|
||||
// Project B with low scores
|
||||
for i := 0; i < 3; i++ {
|
||||
obs := &models.ParsedObservation{
|
||||
Type: models.ObsTypeChange,
|
||||
}
|
||||
id, _, err := s.obsStore.StoreObservation(s.ctx, "session-1", "project-b", obs, i, 100)
|
||||
s.NoError(err)
|
||||
_ = s.obsStore.UpdateImportanceScore(s.ctx, id, 0.5)
|
||||
}
|
||||
|
||||
// Get top for project A
|
||||
topA, err := s.obsStore.GetTopScoringObservations(s.ctx, "project-a", 10)
|
||||
s.NoError(err)
|
||||
s.Len(topA, 3)
|
||||
for _, obs := range topA {
|
||||
s.Equal("project-a", obs.Project)
|
||||
}
|
||||
}
|
||||
|
||||
// =============================================================================
|
||||
// MOST RETRIEVED OBSERVATIONS TESTS
|
||||
// =============================================================================
|
||||
|
||||
func (s *ScoringStoreSuite) TestGetMostRetrievedObservations() {
|
||||
// Create observations with different retrieval counts
|
||||
var ids []int64
|
||||
for i := 0; i < 5; i++ {
|
||||
obs := &models.ParsedObservation{
|
||||
Type: models.ObsTypeDiscovery,
|
||||
}
|
||||
id, _, err := s.obsStore.StoreObservation(s.ctx, "session-1", "project-a", obs, i, 100)
|
||||
s.NoError(err)
|
||||
ids = append(ids, id)
|
||||
}
|
||||
|
||||
// Set different retrieval counts
|
||||
for i := 0; i < 10; i++ {
|
||||
_ = s.obsStore.IncrementRetrievalCount(s.ctx, []int64{ids[0]}) // 10 retrievals
|
||||
}
|
||||
for i := 0; i < 5; i++ {
|
||||
_ = s.obsStore.IncrementRetrievalCount(s.ctx, []int64{ids[1]}) // 5 retrievals
|
||||
}
|
||||
for i := 0; i < 3; i++ {
|
||||
_ = s.obsStore.IncrementRetrievalCount(s.ctx, []int64{ids[2]}) // 3 retrievals
|
||||
}
|
||||
// ids[3] and ids[4] have 0 retrievals
|
||||
|
||||
// Get top 3
|
||||
most, err := s.obsStore.GetMostRetrievedObservations(s.ctx, "", 3)
|
||||
s.NoError(err)
|
||||
s.Len(most, 3)
|
||||
|
||||
// Verify ordered by retrieval count descending
|
||||
s.Equal(10, most[0].RetrievalCount)
|
||||
s.Equal(5, most[1].RetrievalCount)
|
||||
s.Equal(3, most[2].RetrievalCount)
|
||||
}
|
||||
|
||||
func (s *ScoringStoreSuite) TestGetMostRetrievedObservations_NoRetrievals() {
|
||||
// Create observations without any retrievals
|
||||
obs := &models.ParsedObservation{
|
||||
Type: models.ObsTypeDiscovery,
|
||||
}
|
||||
_, _, err := s.obsStore.StoreObservation(s.ctx, "session-1", "project-a", obs, 1, 100)
|
||||
s.NoError(err)
|
||||
|
||||
most, err := s.obsStore.GetMostRetrievedObservations(s.ctx, "", 10)
|
||||
s.NoError(err)
|
||||
s.Empty(most) // No observations with retrieval_count > 0
|
||||
}
|
||||
|
||||
// =============================================================================
|
||||
// RESET OBSERVATION SCORES TESTS
|
||||
// =============================================================================
|
||||
|
||||
func (s *ScoringStoreSuite) TestResetObservationScores() {
|
||||
// Create observations with various scores
|
||||
for i := 0; i < 5; i++ {
|
||||
obs := &models.ParsedObservation{
|
||||
Type: models.ObsTypeDiscovery,
|
||||
}
|
||||
id, _, err := s.obsStore.StoreObservation(s.ctx, "session-1", "project-a", obs, i, 100)
|
||||
s.NoError(err)
|
||||
_ = s.obsStore.UpdateImportanceScore(s.ctx, id, float64(i+1))
|
||||
}
|
||||
|
||||
// Reset all scores
|
||||
err := s.obsStore.ResetObservationScores(s.ctx)
|
||||
s.NoError(err)
|
||||
|
||||
// Verify all scores are reset to 1.0
|
||||
observations, err := s.obsStore.GetAllRecentObservations(s.ctx, 100)
|
||||
s.NoError(err)
|
||||
for _, obs := range observations {
|
||||
s.InDelta(1.0, obs.ImportanceScore, 0.001)
|
||||
s.False(obs.ScoreUpdatedAt.Valid)
|
||||
}
|
||||
}
|
||||
|
||||
// =============================================================================
|
||||
// EDGE CASES
|
||||
// =============================================================================
|
||||
|
||||
func (s *ScoringStoreSuite) TestScoring_ZeroScore() {
|
||||
obs := &models.ParsedObservation{
|
||||
Type: models.ObsTypeBugfix,
|
||||
}
|
||||
id, _, err := s.obsStore.StoreObservation(s.ctx, "session-1", "project-a", obs, 1, 100)
|
||||
s.NoError(err)
|
||||
|
||||
// Set score to 0
|
||||
err = s.obsStore.UpdateImportanceScore(s.ctx, id, 0.0)
|
||||
s.NoError(err)
|
||||
|
||||
retrieved, err := s.obsStore.GetObservationByID(s.ctx, id)
|
||||
s.NoError(err)
|
||||
s.InDelta(0.0, retrieved.ImportanceScore, 0.001)
|
||||
}
|
||||
|
||||
func (s *ScoringStoreSuite) TestScoring_NegativeScore() {
|
||||
obs := &models.ParsedObservation{
|
||||
Type: models.ObsTypeBugfix,
|
||||
}
|
||||
id, _, err := s.obsStore.StoreObservation(s.ctx, "session-1", "project-a", obs, 1, 100)
|
||||
s.NoError(err)
|
||||
|
||||
// Set negative score (calculator shouldn't produce this, but test DB handling)
|
||||
err = s.obsStore.UpdateImportanceScore(s.ctx, id, -0.5)
|
||||
s.NoError(err)
|
||||
|
||||
retrieved, err := s.obsStore.GetObservationByID(s.ctx, id)
|
||||
s.NoError(err)
|
||||
s.InDelta(-0.5, retrieved.ImportanceScore, 0.001)
|
||||
}
|
||||
|
||||
func (s *ScoringStoreSuite) TestScoring_LargeScore() {
|
||||
obs := &models.ParsedObservation{
|
||||
Type: models.ObsTypeBugfix,
|
||||
}
|
||||
id, _, err := s.obsStore.StoreObservation(s.ctx, "session-1", "project-a", obs, 1, 100)
|
||||
s.NoError(err)
|
||||
|
||||
// Set very large score
|
||||
err = s.obsStore.UpdateImportanceScore(s.ctx, id, 999.999)
|
||||
s.NoError(err)
|
||||
|
||||
retrieved, err := s.obsStore.GetObservationByID(s.ctx, id)
|
||||
s.NoError(err)
|
||||
s.InDelta(999.999, retrieved.ImportanceScore, 0.001)
|
||||
}
|
||||
|
||||
func (s *ScoringStoreSuite) TestConceptWeight_ZeroWeight() {
|
||||
err := s.obsStore.UpdateConceptWeight(s.ctx, "zero-concept", 0.0)
|
||||
s.NoError(err)
|
||||
|
||||
weights, err := s.obsStore.GetConceptWeights(s.ctx)
|
||||
s.NoError(err)
|
||||
s.Equal(0.0, weights["zero-concept"])
|
||||
}
|
||||
|
||||
func (s *ScoringStoreSuite) TestConceptWeight_ExactBoundary() {
|
||||
err := s.obsStore.UpdateConceptWeight(s.ctx, "max-concept", 1.0)
|
||||
s.NoError(err)
|
||||
|
||||
weights, err := s.obsStore.GetConceptWeights(s.ctx)
|
||||
s.NoError(err)
|
||||
s.Equal(1.0, weights["max-concept"])
|
||||
}
|
||||
|
||||
// =============================================================================
|
||||
// STANDALONE TESTS
|
||||
// =============================================================================
|
||||
|
||||
func TestFeedbackStats_Structure(t *testing.T) {
|
||||
stats := FeedbackStats{
|
||||
Total: 100,
|
||||
Positive: 30,
|
||||
Negative: 10,
|
||||
Neutral: 60,
|
||||
AvgScore: 1.5,
|
||||
AvgRetrieval: 5.0,
|
||||
}
|
||||
|
||||
assert.Equal(t, 100, stats.Total)
|
||||
assert.Equal(t, 30, stats.Positive)
|
||||
assert.Equal(t, 10, stats.Negative)
|
||||
assert.Equal(t, 60, stats.Neutral)
|
||||
assert.Equal(t, 1.5, stats.AvgScore)
|
||||
assert.Equal(t, 5.0, stats.AvgRetrieval)
|
||||
}
|
||||
|
||||
func TestScoringStore_Integration(t *testing.T) {
|
||||
obsStore, _, cleanup := testScoringObservationStore(t)
|
||||
defer cleanup()
|
||||
|
||||
ctx := context.Background()
|
||||
|
||||
// Full integration test: store, feedback, retrieval, score update
|
||||
obs := &models.ParsedObservation{
|
||||
Type: models.ObsTypeBugfix,
|
||||
Title: "Integration test observation",
|
||||
Concepts: []string{"security"},
|
||||
}
|
||||
id, _, err := obsStore.StoreObservation(ctx, "session-int", "project-int", obs, 1, 100)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Add feedback
|
||||
err = obsStore.UpdateObservationFeedback(ctx, id, 1)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Increment retrieval
|
||||
err = obsStore.IncrementRetrievalCount(ctx, []int64{id})
|
||||
require.NoError(t, err)
|
||||
|
||||
// Update score
|
||||
err = obsStore.UpdateImportanceScore(ctx, id, 1.75)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Verify final state
|
||||
retrieved, err := obsStore.GetObservationByID(ctx, id)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, 1, retrieved.UserFeedback)
|
||||
assert.Equal(t, 1, retrieved.RetrievalCount)
|
||||
assert.InDelta(t, 1.75, retrieved.ImportanceScore, 0.001)
|
||||
assert.True(t, retrieved.ScoreUpdatedAt.Valid)
|
||||
assert.True(t, retrieved.LastRetrievedAt.Valid)
|
||||
}
|
||||
@@ -1,184 +0,0 @@
|
||||
// Package sqlite provides SQLite database operations for claude-mnemonic.
|
||||
package sqlite
|
||||
|
||||
import (
|
||||
"context"
|
||||
"database/sql"
|
||||
"time"
|
||||
|
||||
"github.com/lukaszraczylo/claude-mnemonic/pkg/models"
|
||||
)
|
||||
|
||||
// SessionStore provides session-related database operations.
|
||||
type SessionStore struct {
|
||||
store *Store
|
||||
}
|
||||
|
||||
// NewSessionStore creates a new session store.
|
||||
func NewSessionStore(store *Store) *SessionStore {
|
||||
return &SessionStore{store: store}
|
||||
}
|
||||
|
||||
// CreateSDKSession creates a new SDK session (idempotent - returns existing ID if exists).
|
||||
// This is the KEY to how claude-mnemonic stays unified across hooks.
|
||||
func (s *SessionStore) CreateSDKSession(ctx context.Context, claudeSessionID, project, userPrompt string) (int64, error) {
|
||||
now := time.Now()
|
||||
|
||||
// CRITICAL: INSERT OR IGNORE makes this idempotent
|
||||
const query = `
|
||||
INSERT OR IGNORE INTO sdk_sessions
|
||||
(claude_session_id, sdk_session_id, project, user_prompt, started_at, started_at_epoch, status)
|
||||
VALUES (?, ?, ?, ?, ?, ?, 'active')
|
||||
`
|
||||
|
||||
result, err := s.store.ExecContext(ctx, query,
|
||||
claudeSessionID, claudeSessionID, project, userPrompt,
|
||||
now.Format(time.RFC3339), now.UnixMilli(),
|
||||
)
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
|
||||
// Check if insert happened
|
||||
rowsAffected, _ := result.RowsAffected()
|
||||
if rowsAffected == 0 {
|
||||
// Session exists - UPDATE project and user_prompt if we have non-empty values
|
||||
if project != "" {
|
||||
const updateQuery = `
|
||||
UPDATE sdk_sessions
|
||||
SET project = ?, user_prompt = ?
|
||||
WHERE claude_session_id = ?
|
||||
`
|
||||
_, _ = s.store.ExecContext(ctx, updateQuery, project, userPrompt, claudeSessionID)
|
||||
}
|
||||
|
||||
// Fetch existing ID
|
||||
var id int64
|
||||
const selectQuery = `SELECT id FROM sdk_sessions WHERE claude_session_id = ? LIMIT 1`
|
||||
err := s.store.QueryRowContext(ctx, selectQuery, claudeSessionID).Scan(&id)
|
||||
return id, err
|
||||
}
|
||||
|
||||
id, _ := result.LastInsertId()
|
||||
return id, nil
|
||||
}
|
||||
|
||||
// GetSessionByID retrieves a session by its database ID.
|
||||
func (s *SessionStore) GetSessionByID(ctx context.Context, id int64) (*models.SDKSession, error) {
|
||||
const query = `
|
||||
SELECT id, claude_session_id, sdk_session_id, project, user_prompt,
|
||||
worker_port, prompt_counter, status, started_at, started_at_epoch,
|
||||
completed_at, completed_at_epoch
|
||||
FROM sdk_sessions
|
||||
WHERE id = ?
|
||||
LIMIT 1
|
||||
`
|
||||
|
||||
var sess models.SDKSession
|
||||
err := s.store.QueryRowContext(ctx, query, id).Scan(
|
||||
&sess.ID, &sess.ClaudeSessionID, &sess.SDKSessionID, &sess.Project, &sess.UserPrompt,
|
||||
&sess.WorkerPort, &sess.PromptCounter, &sess.Status, &sess.StartedAt, &sess.StartedAtEpoch,
|
||||
&sess.CompletedAt, &sess.CompletedAtEpoch,
|
||||
)
|
||||
if err == sql.ErrNoRows {
|
||||
return nil, nil
|
||||
}
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return &sess, nil
|
||||
}
|
||||
|
||||
// FindAnySDKSession finds any session by Claude session ID (any status).
|
||||
func (s *SessionStore) FindAnySDKSession(ctx context.Context, claudeSessionID string) (*models.SDKSession, error) {
|
||||
const query = `
|
||||
SELECT id, claude_session_id, sdk_session_id, project, user_prompt,
|
||||
worker_port, prompt_counter, status, started_at, started_at_epoch,
|
||||
completed_at, completed_at_epoch
|
||||
FROM sdk_sessions
|
||||
WHERE claude_session_id = ?
|
||||
LIMIT 1
|
||||
`
|
||||
|
||||
var sess models.SDKSession
|
||||
err := s.store.QueryRowContext(ctx, query, claudeSessionID).Scan(
|
||||
&sess.ID, &sess.ClaudeSessionID, &sess.SDKSessionID, &sess.Project, &sess.UserPrompt,
|
||||
&sess.WorkerPort, &sess.PromptCounter, &sess.Status, &sess.StartedAt, &sess.StartedAtEpoch,
|
||||
&sess.CompletedAt, &sess.CompletedAtEpoch,
|
||||
)
|
||||
if err == sql.ErrNoRows {
|
||||
return nil, nil
|
||||
}
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return &sess, nil
|
||||
}
|
||||
|
||||
// IncrementPromptCounter increments the prompt counter and returns the new value.
|
||||
func (s *SessionStore) IncrementPromptCounter(ctx context.Context, id int64) (int, error) {
|
||||
const updateQuery = `
|
||||
UPDATE sdk_sessions
|
||||
SET prompt_counter = COALESCE(prompt_counter, 0) + 1
|
||||
WHERE id = ?
|
||||
`
|
||||
if _, err := s.store.ExecContext(ctx, updateQuery, id); err != nil {
|
||||
return 0, err
|
||||
}
|
||||
|
||||
const selectQuery = `SELECT prompt_counter FROM sdk_sessions WHERE id = ?`
|
||||
var counter int
|
||||
err := s.store.QueryRowContext(ctx, selectQuery, id).Scan(&counter)
|
||||
return counter, err
|
||||
}
|
||||
|
||||
// GetPromptCounter returns the current prompt counter for a session.
|
||||
func (s *SessionStore) GetPromptCounter(ctx context.Context, id int64) (int, error) {
|
||||
const query = `SELECT COALESCE(prompt_counter, 0) FROM sdk_sessions WHERE id = ?`
|
||||
var counter int
|
||||
err := s.store.QueryRowContext(ctx, query, id).Scan(&counter)
|
||||
return counter, err
|
||||
}
|
||||
|
||||
// GetSessionsToday returns the count of sessions started today.
|
||||
func (s *SessionStore) GetSessionsToday(ctx context.Context) (int, error) {
|
||||
// Get start of today in milliseconds
|
||||
now := time.Now()
|
||||
startOfDay := time.Date(now.Year(), now.Month(), now.Day(), 0, 0, 0, 0, now.Location())
|
||||
startEpoch := startOfDay.UnixMilli()
|
||||
|
||||
const query = `SELECT COUNT(*) FROM sdk_sessions WHERE started_at_epoch >= ?`
|
||||
|
||||
var count int
|
||||
err := s.store.QueryRowContext(ctx, query, startEpoch).Scan(&count)
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
return count, nil
|
||||
}
|
||||
|
||||
// GetAllProjects returns all unique project names.
|
||||
func (s *SessionStore) GetAllProjects(ctx context.Context) ([]string, error) {
|
||||
const query = `
|
||||
SELECT DISTINCT project
|
||||
FROM sdk_sessions
|
||||
WHERE project IS NOT NULL AND project != ''
|
||||
ORDER BY project ASC
|
||||
`
|
||||
|
||||
rows, err := s.store.QueryContext(ctx, query)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer rows.Close()
|
||||
|
||||
var projects []string
|
||||
for rows.Next() {
|
||||
var project string
|
||||
if err := rows.Scan(&project); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
projects = append(projects, project)
|
||||
}
|
||||
return projects, rows.Err()
|
||||
}
|
||||
@@ -1,449 +0,0 @@
|
||||
package sqlite
|
||||
|
||||
import (
|
||||
"context"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/lukaszraczylo/claude-mnemonic/pkg/models"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
"github.com/stretchr/testify/suite"
|
||||
)
|
||||
|
||||
func testSessionStore(t *testing.T) (*SessionStore, *Store, func()) {
|
||||
t.Helper()
|
||||
|
||||
db, _, cleanup := testDB(t)
|
||||
createBaseTables(t, db) // Use base tables without FTS5 for session tests
|
||||
|
||||
store := newStoreFromDB(db)
|
||||
sessionStore := NewSessionStore(store)
|
||||
|
||||
return sessionStore, store, cleanup
|
||||
}
|
||||
|
||||
// SessionStoreSuite is a test suite for SessionStore operations.
|
||||
type SessionStoreSuite struct {
|
||||
suite.Suite
|
||||
sessionStore *SessionStore
|
||||
store *Store
|
||||
cleanup func()
|
||||
}
|
||||
|
||||
func (s *SessionStoreSuite) SetupTest() {
|
||||
s.sessionStore, s.store, s.cleanup = testSessionStore(s.T())
|
||||
}
|
||||
|
||||
func (s *SessionStoreSuite) TearDownTest() {
|
||||
if s.cleanup != nil {
|
||||
s.cleanup()
|
||||
}
|
||||
}
|
||||
|
||||
func TestSessionStoreSuite(t *testing.T) {
|
||||
suite.Run(t, new(SessionStoreSuite))
|
||||
}
|
||||
|
||||
// TestCreateSDKSession_TableDriven tests session creation with various scenarios.
|
||||
func (s *SessionStoreSuite) TestCreateSDKSession_TableDriven() {
|
||||
ctx := context.Background()
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
claudeSessionID string
|
||||
project string
|
||||
userPrompt string
|
||||
wantErr bool
|
||||
wantID bool
|
||||
}{
|
||||
{
|
||||
name: "basic session creation",
|
||||
claudeSessionID: "claude-basic",
|
||||
project: "project-a",
|
||||
userPrompt: "hello world",
|
||||
wantErr: false,
|
||||
wantID: true,
|
||||
},
|
||||
{
|
||||
name: "empty user prompt",
|
||||
claudeSessionID: "claude-noprompt",
|
||||
project: "project-b",
|
||||
userPrompt: "",
|
||||
wantErr: false,
|
||||
wantID: true,
|
||||
},
|
||||
{
|
||||
name: "long project name",
|
||||
claudeSessionID: "claude-longproj",
|
||||
project: "/Users/test/Documents/very/long/path/to/some/project/directory",
|
||||
userPrompt: "test",
|
||||
wantErr: false,
|
||||
wantID: true,
|
||||
},
|
||||
{
|
||||
name: "unicode project name",
|
||||
claudeSessionID: "claude-unicode",
|
||||
project: "项目名称-プロジェクト",
|
||||
userPrompt: "测试 テスト",
|
||||
wantErr: false,
|
||||
wantID: true,
|
||||
},
|
||||
{
|
||||
name: "special characters in prompt",
|
||||
claudeSessionID: "claude-special",
|
||||
project: "project-special",
|
||||
userPrompt: "Fix the bug in file.go:123 with \"quotes\" and 'apostrophes'",
|
||||
wantErr: false,
|
||||
wantID: true,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
s.Run(tt.name, func() {
|
||||
id, err := s.sessionStore.CreateSDKSession(ctx, tt.claudeSessionID, tt.project, tt.userPrompt)
|
||||
if tt.wantErr {
|
||||
s.Error(err)
|
||||
} else {
|
||||
s.NoError(err)
|
||||
if tt.wantID {
|
||||
s.Greater(id, int64(0))
|
||||
}
|
||||
|
||||
// Verify created session
|
||||
sess, err := s.sessionStore.GetSessionByID(ctx, id)
|
||||
s.NoError(err)
|
||||
s.NotNil(sess)
|
||||
s.Equal(tt.claudeSessionID, sess.ClaudeSessionID)
|
||||
s.Equal(tt.project, sess.Project)
|
||||
s.Equal(models.SessionStatusActive, sess.Status)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// TestIdempotentSession tests that session creation is idempotent.
|
||||
func (s *SessionStoreSuite) TestIdempotentSession() {
|
||||
ctx := context.Background()
|
||||
|
||||
// Create initial session
|
||||
id1, err := s.sessionStore.CreateSDKSession(ctx, "claude-idem", "project-1", "prompt-1")
|
||||
s.NoError(err)
|
||||
s.Greater(id1, int64(0))
|
||||
|
||||
// Create with same claude_session_id - should return same ID
|
||||
id2, err := s.sessionStore.CreateSDKSession(ctx, "claude-idem", "project-2", "prompt-2")
|
||||
s.NoError(err)
|
||||
s.Equal(id1, id2)
|
||||
|
||||
// Verify project was updated
|
||||
sess, err := s.sessionStore.GetSessionByID(ctx, id1)
|
||||
s.NoError(err)
|
||||
s.Equal("project-2", sess.Project)
|
||||
}
|
||||
|
||||
// TestPromptCounterOperations tests prompt counter increment and retrieval.
|
||||
func (s *SessionStoreSuite) TestPromptCounterOperations() {
|
||||
ctx := context.Background()
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
increments int
|
||||
expectedCount int
|
||||
}{
|
||||
{
|
||||
name: "no increments",
|
||||
increments: 0,
|
||||
expectedCount: 0,
|
||||
},
|
||||
{
|
||||
name: "single increment",
|
||||
increments: 1,
|
||||
expectedCount: 1,
|
||||
},
|
||||
{
|
||||
name: "multiple increments",
|
||||
increments: 5,
|
||||
expectedCount: 5,
|
||||
},
|
||||
{
|
||||
name: "many increments",
|
||||
increments: 100,
|
||||
expectedCount: 100,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
s.Run(tt.name, func() {
|
||||
// Create fresh session for each test
|
||||
id, err := s.sessionStore.CreateSDKSession(ctx, "claude-counter-"+tt.name, "project", "")
|
||||
s.NoError(err)
|
||||
|
||||
// Increment specified number of times
|
||||
var lastCount int
|
||||
for i := 0; i < tt.increments; i++ {
|
||||
lastCount, err = s.sessionStore.IncrementPromptCounter(ctx, id)
|
||||
s.NoError(err)
|
||||
}
|
||||
|
||||
// Get final count
|
||||
finalCount, err := s.sessionStore.GetPromptCounter(ctx, id)
|
||||
s.NoError(err)
|
||||
s.Equal(tt.expectedCount, finalCount)
|
||||
|
||||
if tt.increments > 0 {
|
||||
s.Equal(tt.expectedCount, lastCount)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// TestFindAnySDKSession tests session lookup scenarios.
|
||||
func (s *SessionStoreSuite) TestFindAnySDKSession_Scenarios() {
|
||||
ctx := context.Background()
|
||||
|
||||
// Create test sessions
|
||||
_, err := s.sessionStore.CreateSDKSession(ctx, "session-find-1", "project-a", "")
|
||||
s.NoError(err)
|
||||
_, err = s.sessionStore.CreateSDKSession(ctx, "session-find-2", "project-b", "")
|
||||
s.NoError(err)
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
claudeSessionID string
|
||||
wantFound bool
|
||||
wantProject string
|
||||
}{
|
||||
{
|
||||
name: "find existing session 1",
|
||||
claudeSessionID: "session-find-1",
|
||||
wantFound: true,
|
||||
wantProject: "project-a",
|
||||
},
|
||||
{
|
||||
name: "find existing session 2",
|
||||
claudeSessionID: "session-find-2",
|
||||
wantFound: true,
|
||||
wantProject: "project-b",
|
||||
},
|
||||
{
|
||||
name: "find non-existent session",
|
||||
claudeSessionID: "session-nonexistent",
|
||||
wantFound: false,
|
||||
},
|
||||
{
|
||||
name: "find with empty string",
|
||||
claudeSessionID: "",
|
||||
wantFound: false,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
s.Run(tt.name, func() {
|
||||
sess, err := s.sessionStore.FindAnySDKSession(ctx, tt.claudeSessionID)
|
||||
s.NoError(err) // FindAnySDKSession returns nil,nil for not found
|
||||
|
||||
if tt.wantFound {
|
||||
s.NotNil(sess)
|
||||
s.Equal(tt.wantProject, sess.Project)
|
||||
} else {
|
||||
s.Nil(sess)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestSessionStore_CreateSDKSession(t *testing.T) {
|
||||
sessionStore, _, cleanup := testSessionStore(t)
|
||||
defer cleanup()
|
||||
|
||||
ctx := context.Background()
|
||||
|
||||
// Create a new session
|
||||
id, err := sessionStore.CreateSDKSession(ctx, "claude-1", "test-project", "initial prompt")
|
||||
require.NoError(t, err)
|
||||
assert.Greater(t, id, int64(0))
|
||||
|
||||
// Retrieve and verify
|
||||
sess, err := sessionStore.GetSessionByID(ctx, id)
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, sess)
|
||||
assert.Equal(t, "claude-1", sess.ClaudeSessionID)
|
||||
assert.Equal(t, "test-project", sess.Project)
|
||||
assert.Equal(t, models.SessionStatusActive, sess.Status)
|
||||
}
|
||||
|
||||
func TestSessionStore_CreateSDKSession_Idempotent(t *testing.T) {
|
||||
sessionStore, _, cleanup := testSessionStore(t)
|
||||
defer cleanup()
|
||||
|
||||
ctx := context.Background()
|
||||
|
||||
// Create first session
|
||||
id1, err := sessionStore.CreateSDKSession(ctx, "claude-1", "project-a", "prompt 1")
|
||||
require.NoError(t, err)
|
||||
|
||||
// Create again with same claude_session_id but different project
|
||||
id2, err := sessionStore.CreateSDKSession(ctx, "claude-1", "project-b", "prompt 2")
|
||||
require.NoError(t, err)
|
||||
|
||||
// Should return same ID (idempotent)
|
||||
assert.Equal(t, id1, id2)
|
||||
|
||||
// Should have updated project to project-b
|
||||
sess, err := sessionStore.GetSessionByID(ctx, id1)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, "project-b", sess.Project)
|
||||
}
|
||||
|
||||
func TestSessionStore_FindAnySDKSession(t *testing.T) {
|
||||
sessionStore, _, cleanup := testSessionStore(t)
|
||||
defer cleanup()
|
||||
|
||||
ctx := context.Background()
|
||||
|
||||
// Create a session
|
||||
_, err := sessionStore.CreateSDKSession(ctx, "claude-1", "test-project", "")
|
||||
require.NoError(t, err)
|
||||
|
||||
// Find it
|
||||
sess, err := sessionStore.FindAnySDKSession(ctx, "claude-1")
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, sess)
|
||||
assert.Equal(t, "claude-1", sess.ClaudeSessionID)
|
||||
|
||||
// Try to find non-existent
|
||||
sess, err = sessionStore.FindAnySDKSession(ctx, "claude-nonexistent")
|
||||
require.NoError(t, err)
|
||||
assert.Nil(t, sess)
|
||||
}
|
||||
|
||||
func TestSessionStore_IncrementPromptCounter(t *testing.T) {
|
||||
sessionStore, _, cleanup := testSessionStore(t)
|
||||
defer cleanup()
|
||||
|
||||
ctx := context.Background()
|
||||
|
||||
// Create a session
|
||||
id, err := sessionStore.CreateSDKSession(ctx, "claude-1", "test-project", "")
|
||||
require.NoError(t, err)
|
||||
|
||||
// Initial counter should be 0
|
||||
counter, err := sessionStore.GetPromptCounter(ctx, id)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, 0, counter)
|
||||
|
||||
// Increment
|
||||
counter, err = sessionStore.IncrementPromptCounter(ctx, id)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, 1, counter)
|
||||
|
||||
// Increment again
|
||||
counter, err = sessionStore.IncrementPromptCounter(ctx, id)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, 2, counter)
|
||||
|
||||
// Verify via GetPromptCounter
|
||||
counter, err = sessionStore.GetPromptCounter(ctx, id)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, 2, counter)
|
||||
}
|
||||
|
||||
func TestSessionStore_GetSessionsToday(t *testing.T) {
|
||||
sessionStore, _, cleanup := testSessionStore(t)
|
||||
defer cleanup()
|
||||
|
||||
ctx := context.Background()
|
||||
|
||||
// Initially no sessions today
|
||||
count, err := sessionStore.GetSessionsToday(ctx)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, 0, count)
|
||||
|
||||
// Create some sessions
|
||||
_, err = sessionStore.CreateSDKSession(ctx, "claude-1", "project-a", "")
|
||||
require.NoError(t, err)
|
||||
_, err = sessionStore.CreateSDKSession(ctx, "claude-2", "project-b", "")
|
||||
require.NoError(t, err)
|
||||
_, err = sessionStore.CreateSDKSession(ctx, "claude-3", "project-c", "")
|
||||
require.NoError(t, err)
|
||||
|
||||
// Should have 3 sessions today
|
||||
count, err = sessionStore.GetSessionsToday(ctx)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, 3, count)
|
||||
}
|
||||
|
||||
func TestSessionStore_GetAllProjects(t *testing.T) {
|
||||
sessionStore, _, cleanup := testSessionStore(t)
|
||||
defer cleanup()
|
||||
|
||||
ctx := context.Background()
|
||||
|
||||
// Create sessions for different projects
|
||||
_, err := sessionStore.CreateSDKSession(ctx, "claude-1", "alpha-project", "")
|
||||
require.NoError(t, err)
|
||||
_, err = sessionStore.CreateSDKSession(ctx, "claude-2", "beta-project", "")
|
||||
require.NoError(t, err)
|
||||
_, err = sessionStore.CreateSDKSession(ctx, "claude-3", "alpha-project", "") // duplicate
|
||||
require.NoError(t, err)
|
||||
_, err = sessionStore.CreateSDKSession(ctx, "claude-4", "gamma-project", "")
|
||||
require.NoError(t, err)
|
||||
|
||||
// Get all projects
|
||||
projects, err := sessionStore.GetAllProjects(ctx)
|
||||
require.NoError(t, err)
|
||||
assert.Len(t, projects, 3)
|
||||
assert.Contains(t, projects, "alpha-project")
|
||||
assert.Contains(t, projects, "beta-project")
|
||||
assert.Contains(t, projects, "gamma-project")
|
||||
|
||||
// Should be sorted alphabetically
|
||||
assert.Equal(t, "alpha-project", projects[0])
|
||||
}
|
||||
|
||||
func TestSessionStore_GetSessionByID_NotFound(t *testing.T) {
|
||||
sessionStore, _, cleanup := testSessionStore(t)
|
||||
defer cleanup()
|
||||
|
||||
ctx := context.Background()
|
||||
|
||||
// Non-existent ID should return nil, nil (not an error)
|
||||
sess, err := sessionStore.GetSessionByID(ctx, 999)
|
||||
require.NoError(t, err)
|
||||
assert.Nil(t, sess)
|
||||
}
|
||||
|
||||
func TestSessionStore_SessionFields(t *testing.T) {
|
||||
sessionStore, store, cleanup := testSessionStore(t)
|
||||
defer cleanup()
|
||||
|
||||
ctx := context.Background()
|
||||
|
||||
// Create a session with full details
|
||||
id, err := sessionStore.CreateSDKSession(ctx, "claude-full", "full-project", "full user prompt")
|
||||
require.NoError(t, err)
|
||||
|
||||
// Manually update additional fields for testing
|
||||
now := time.Now()
|
||||
_, err = storeDB(store).Exec(`
|
||||
UPDATE sdk_sessions
|
||||
SET worker_port = ?, completed_at = ?, completed_at_epoch = ?, status = 'completed'
|
||||
WHERE id = ?
|
||||
`, 37777, now.Format(time.RFC3339), now.UnixMilli(), id)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Retrieve and verify all fields
|
||||
sess, err := sessionStore.GetSessionByID(ctx, id)
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, sess)
|
||||
|
||||
assert.Equal(t, id, sess.ID)
|
||||
assert.Equal(t, "claude-full", sess.ClaudeSessionID)
|
||||
assert.Equal(t, "full-project", sess.Project)
|
||||
assert.Equal(t, models.SessionStatusCompleted, sess.Status)
|
||||
assert.True(t, sess.WorkerPort.Valid)
|
||||
assert.Equal(t, int64(37777), sess.WorkerPort.Int64)
|
||||
assert.True(t, sess.CompletedAt.Valid)
|
||||
assert.True(t, sess.CompletedAtEpoch.Valid)
|
||||
}
|
||||
@@ -1,149 +0,0 @@
|
||||
// Package sqlite provides SQLite database operations for claude-mnemonic.
|
||||
package sqlite
|
||||
|
||||
import (
|
||||
"context"
|
||||
"database/sql"
|
||||
"fmt"
|
||||
"sync"
|
||||
|
||||
sqlite_vec "github.com/asg017/sqlite-vec-go-bindings/cgo"
|
||||
_ "github.com/mattn/go-sqlite3"
|
||||
)
|
||||
|
||||
// Store provides database operations with connection pooling and prepared statements.
|
||||
type Store struct {
|
||||
db *sql.DB
|
||||
stmtCache map[string]*sql.Stmt
|
||||
stmtMu sync.RWMutex
|
||||
}
|
||||
|
||||
// StoreConfig holds configuration for the database store.
|
||||
type StoreConfig struct {
|
||||
Path string
|
||||
MaxConns int
|
||||
WALMode bool
|
||||
}
|
||||
|
||||
// NewStore creates a new database store with the given configuration.
|
||||
func NewStore(cfg StoreConfig) (*Store, error) {
|
||||
// Register sqlite-vec extension for vector operations
|
||||
sqlite_vec.Auto()
|
||||
|
||||
// Build connection string with pragmas
|
||||
connStr := cfg.Path + "?_journal_mode=WAL&_synchronous=NORMAL&_foreign_keys=ON"
|
||||
|
||||
db, err := sql.Open("sqlite3", connStr)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("open database: %w", err)
|
||||
}
|
||||
|
||||
// Configure connection pool
|
||||
maxConns := cfg.MaxConns
|
||||
if maxConns <= 0 {
|
||||
maxConns = 4
|
||||
}
|
||||
db.SetMaxOpenConns(maxConns)
|
||||
db.SetMaxIdleConns(maxConns)
|
||||
db.SetConnMaxLifetime(0) // Never expire - SQLite connections are cheap
|
||||
|
||||
// Verify connection
|
||||
if err := db.Ping(); err != nil {
|
||||
_ = db.Close()
|
||||
return nil, fmt.Errorf("ping database: %w", err)
|
||||
}
|
||||
|
||||
store := &Store{
|
||||
db: db,
|
||||
stmtCache: make(map[string]*sql.Stmt),
|
||||
}
|
||||
|
||||
// Run migrations
|
||||
mgr := NewMigrationManager(db)
|
||||
if err := mgr.RunMigrations(); err != nil {
|
||||
_ = db.Close()
|
||||
return nil, fmt.Errorf("run migrations: %w", err)
|
||||
}
|
||||
|
||||
return store, nil
|
||||
}
|
||||
|
||||
// Close closes the database connection and all cached statements.
|
||||
func (s *Store) Close() error {
|
||||
s.stmtMu.Lock()
|
||||
defer s.stmtMu.Unlock()
|
||||
|
||||
for _, stmt := range s.stmtCache {
|
||||
_ = stmt.Close()
|
||||
}
|
||||
s.stmtCache = nil
|
||||
|
||||
return s.db.Close()
|
||||
}
|
||||
|
||||
// GetStmt returns a cached prepared statement, creating it if necessary.
|
||||
func (s *Store) GetStmt(query string) (*sql.Stmt, error) {
|
||||
s.stmtMu.RLock()
|
||||
stmt, ok := s.stmtCache[query]
|
||||
s.stmtMu.RUnlock()
|
||||
if ok {
|
||||
return stmt, nil
|
||||
}
|
||||
|
||||
s.stmtMu.Lock()
|
||||
defer s.stmtMu.Unlock()
|
||||
|
||||
// Double-check after acquiring write lock
|
||||
if stmt, ok := s.stmtCache[query]; ok {
|
||||
return stmt, nil
|
||||
}
|
||||
|
||||
stmt, err := s.db.Prepare(query)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
s.stmtCache[query] = stmt
|
||||
return stmt, nil
|
||||
}
|
||||
|
||||
// ExecContext executes a query that doesn't return rows.
|
||||
func (s *Store) ExecContext(ctx context.Context, query string, args ...interface{}) (sql.Result, error) {
|
||||
stmt, err := s.GetStmt(query)
|
||||
if err != nil {
|
||||
// Fall back to direct execution
|
||||
return s.db.ExecContext(ctx, query, args...)
|
||||
}
|
||||
return stmt.ExecContext(ctx, args...)
|
||||
}
|
||||
|
||||
// QueryContext executes a query that returns rows.
|
||||
func (s *Store) QueryContext(ctx context.Context, query string, args ...interface{}) (*sql.Rows, error) {
|
||||
stmt, err := s.GetStmt(query)
|
||||
if err != nil {
|
||||
// Fall back to direct execution
|
||||
return s.db.QueryContext(ctx, query, args...)
|
||||
}
|
||||
return stmt.QueryContext(ctx, args...)
|
||||
}
|
||||
|
||||
// QueryRowContext executes a query that returns a single row.
|
||||
func (s *Store) QueryRowContext(ctx context.Context, query string, args ...interface{}) *sql.Row {
|
||||
stmt, err := s.GetStmt(query)
|
||||
if err != nil {
|
||||
// Fall back to direct execution
|
||||
return s.db.QueryRowContext(ctx, query, args...)
|
||||
}
|
||||
return stmt.QueryRowContext(ctx, args...)
|
||||
}
|
||||
|
||||
// Ping checks if the database connection is alive.
|
||||
func (s *Store) Ping() error {
|
||||
return s.db.Ping()
|
||||
}
|
||||
|
||||
// DB returns the underlying database connection for direct access.
|
||||
// Use this sparingly - prefer the store methods for most operations.
|
||||
func (s *Store) DB() *sql.DB {
|
||||
return s.db
|
||||
}
|
||||
@@ -1,529 +0,0 @@
|
||||
// Package sqlite provides SQLite database operations for claude-mnemonic.
|
||||
package sqlite
|
||||
|
||||
import (
|
||||
"context"
|
||||
"database/sql"
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
"github.com/stretchr/testify/suite"
|
||||
)
|
||||
|
||||
// StoreSuite is a test suite for Store operations.
|
||||
type StoreSuite struct {
|
||||
suite.Suite
|
||||
db *sql.DB
|
||||
store *Store
|
||||
cleanup func()
|
||||
}
|
||||
|
||||
// SetupTest creates a fresh database before each test.
|
||||
func (s *StoreSuite) SetupTest() {
|
||||
s.db, _, s.cleanup = testDB(s.T())
|
||||
createBaseTables(s.T(), s.db)
|
||||
s.store = newStoreFromDB(s.db)
|
||||
}
|
||||
|
||||
// TearDownTest cleans up after each test.
|
||||
func (s *StoreSuite) TearDownTest() {
|
||||
if s.cleanup != nil {
|
||||
s.cleanup()
|
||||
}
|
||||
}
|
||||
|
||||
func TestStoreSuite(t *testing.T) {
|
||||
suite.Run(t, new(StoreSuite))
|
||||
}
|
||||
|
||||
// TestGetStmt tests prepared statement caching.
|
||||
func (s *StoreSuite) TestGetStmt() {
|
||||
tests := []struct {
|
||||
name string
|
||||
query string
|
||||
wantErr bool
|
||||
}{
|
||||
{
|
||||
name: "valid simple query",
|
||||
query: "SELECT 1",
|
||||
wantErr: false,
|
||||
},
|
||||
{
|
||||
name: "valid query with parameter",
|
||||
query: "SELECT * FROM sdk_sessions WHERE id = ?",
|
||||
wantErr: false,
|
||||
},
|
||||
{
|
||||
name: "invalid query syntax",
|
||||
query: "SELECT * FROM nonexistent_table WHERE",
|
||||
wantErr: true,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
s.Run(tt.name, func() {
|
||||
stmt, err := s.store.GetStmt(tt.query)
|
||||
if tt.wantErr {
|
||||
s.Error(err)
|
||||
s.Nil(stmt)
|
||||
} else {
|
||||
s.NoError(err)
|
||||
s.NotNil(stmt)
|
||||
|
||||
// Second call should return cached statement
|
||||
stmt2, err := s.store.GetStmt(tt.query)
|
||||
s.NoError(err)
|
||||
s.Same(stmt, stmt2)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// TestExecContext tests query execution.
|
||||
func (s *StoreSuite) TestExecContext() {
|
||||
ctx := context.Background()
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
query string
|
||||
args []interface{}
|
||||
wantErr bool
|
||||
wantAffected int64
|
||||
}{
|
||||
{
|
||||
name: "insert session",
|
||||
query: `INSERT INTO sdk_sessions (claude_session_id, sdk_session_id, project, started_at, started_at_epoch, status)
|
||||
VALUES (?, ?, ?, datetime('now'), strftime('%s', 'now') * 1000, 'active')`,
|
||||
args: []interface{}{"claude-1", "sdk-1", "test-project"},
|
||||
wantErr: false,
|
||||
wantAffected: 1,
|
||||
},
|
||||
{
|
||||
name: "invalid query",
|
||||
query: "INSERT INTO nonexistent_table VALUES (?)",
|
||||
args: []interface{}{"test"},
|
||||
wantErr: true,
|
||||
wantAffected: 0,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
s.Run(tt.name, func() {
|
||||
result, err := s.store.ExecContext(ctx, tt.query, tt.args...)
|
||||
if tt.wantErr {
|
||||
s.Error(err)
|
||||
} else {
|
||||
s.NoError(err)
|
||||
affected, _ := result.RowsAffected()
|
||||
s.Equal(tt.wantAffected, affected)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// TestQueryContext tests query execution that returns rows.
|
||||
func (s *StoreSuite) TestQueryContext() {
|
||||
ctx := context.Background()
|
||||
|
||||
// Seed data
|
||||
seedSession(s.T(), s.db, "claude-1", "sdk-1", "project-a")
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
query string
|
||||
args []interface{}
|
||||
wantErr bool
|
||||
wantRows int
|
||||
setupFunc func()
|
||||
assertFunc func(rows *sql.Rows)
|
||||
}{
|
||||
{
|
||||
name: "query existing session",
|
||||
query: "SELECT id, project FROM sdk_sessions WHERE claude_session_id = ?",
|
||||
args: []interface{}{"claude-1"},
|
||||
wantErr: false,
|
||||
wantRows: 1,
|
||||
},
|
||||
{
|
||||
name: "query non-existent session",
|
||||
query: "SELECT id, project FROM sdk_sessions WHERE claude_session_id = ?",
|
||||
args: []interface{}{"nonexistent"},
|
||||
wantErr: false,
|
||||
wantRows: 0,
|
||||
},
|
||||
{
|
||||
name: "query all sessions",
|
||||
query: "SELECT id, project FROM sdk_sessions",
|
||||
args: nil,
|
||||
wantErr: false,
|
||||
wantRows: 1,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
s.Run(tt.name, func() {
|
||||
rows, err := s.store.QueryContext(ctx, tt.query, tt.args...)
|
||||
if tt.wantErr {
|
||||
s.Error(err)
|
||||
return
|
||||
}
|
||||
|
||||
s.NoError(err)
|
||||
defer rows.Close()
|
||||
|
||||
count := 0
|
||||
for rows.Next() {
|
||||
count++
|
||||
}
|
||||
s.Equal(tt.wantRows, count)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// TestQueryRowContext tests single row query execution.
|
||||
func (s *StoreSuite) TestQueryRowContext() {
|
||||
ctx := context.Background()
|
||||
|
||||
// Seed data
|
||||
seedSession(s.T(), s.db, "claude-1", "sdk-1", "project-a")
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
query string
|
||||
args []interface{}
|
||||
wantErr bool
|
||||
}{
|
||||
{
|
||||
name: "query existing session",
|
||||
query: "SELECT id FROM sdk_sessions WHERE claude_session_id = ?",
|
||||
args: []interface{}{"claude-1"},
|
||||
wantErr: false,
|
||||
},
|
||||
{
|
||||
name: "query non-existent session",
|
||||
query: "SELECT id FROM sdk_sessions WHERE claude_session_id = ?",
|
||||
args: []interface{}{"nonexistent"},
|
||||
wantErr: true, // sql.ErrNoRows
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
s.Run(tt.name, func() {
|
||||
row := s.store.QueryRowContext(ctx, tt.query, tt.args...)
|
||||
var id int64
|
||||
err := row.Scan(&id)
|
||||
if tt.wantErr {
|
||||
s.Error(err)
|
||||
} else {
|
||||
s.NoError(err)
|
||||
s.Greater(id, int64(0))
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// TestPing tests database connection health check.
|
||||
func (s *StoreSuite) TestPing() {
|
||||
err := s.store.Ping()
|
||||
s.NoError(err)
|
||||
}
|
||||
|
||||
// TestDB tests getting the underlying database connection.
|
||||
func (s *StoreSuite) TestDB() {
|
||||
db := s.store.DB()
|
||||
s.NotNil(db)
|
||||
s.Same(s.db, db)
|
||||
}
|
||||
|
||||
// TestClose tests closing the store.
|
||||
func (s *StoreSuite) TestClose() {
|
||||
// Create a separate store for close test
|
||||
db, _, cleanup := testDB(s.T())
|
||||
defer cleanup()
|
||||
|
||||
store := newStoreFromDB(db)
|
||||
|
||||
// Cache a statement first
|
||||
_, err := store.GetStmt("SELECT 1")
|
||||
s.NoError(err)
|
||||
|
||||
// Close should not error
|
||||
err = store.Close()
|
||||
s.NoError(err)
|
||||
|
||||
// Operations after close should fail
|
||||
err = store.Ping()
|
||||
s.Error(err)
|
||||
}
|
||||
|
||||
// TestConcurrentStmtCache tests concurrent access to statement cache.
|
||||
func (s *StoreSuite) TestConcurrentStmtCache() {
|
||||
ctx := context.Background()
|
||||
queries := []string{
|
||||
"SELECT 1",
|
||||
"SELECT 2",
|
||||
"SELECT id FROM sdk_sessions",
|
||||
"SELECT project FROM sdk_sessions",
|
||||
}
|
||||
|
||||
done := make(chan struct{})
|
||||
for i := 0; i < 10; i++ {
|
||||
go func(i int) {
|
||||
query := queries[i%len(queries)]
|
||||
_, _ = s.store.GetStmt(query)
|
||||
_, _ = s.store.ExecContext(ctx, "SELECT 1")
|
||||
done <- struct{}{}
|
||||
}(i)
|
||||
}
|
||||
|
||||
for i := 0; i < 10; i++ {
|
||||
<-done
|
||||
}
|
||||
}
|
||||
|
||||
// HelpersSuite tests helper functions.
|
||||
type HelpersSuite struct {
|
||||
suite.Suite
|
||||
}
|
||||
|
||||
func TestHelpersSuite(t *testing.T) {
|
||||
suite.Run(t, new(HelpersSuite))
|
||||
}
|
||||
|
||||
func (s *HelpersSuite) TestNullString() {
|
||||
tests := []struct {
|
||||
name string
|
||||
input string
|
||||
wantStr string
|
||||
wantBool bool
|
||||
}{
|
||||
{
|
||||
name: "empty string",
|
||||
input: "",
|
||||
wantStr: "",
|
||||
wantBool: false,
|
||||
},
|
||||
{
|
||||
name: "non-empty string",
|
||||
input: "test",
|
||||
wantStr: "test",
|
||||
wantBool: true,
|
||||
},
|
||||
{
|
||||
name: "whitespace string",
|
||||
input: " ",
|
||||
wantStr: " ",
|
||||
wantBool: true,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
s.Run(tt.name, func() {
|
||||
result := nullString(tt.input)
|
||||
s.Equal(tt.wantStr, result.String)
|
||||
s.Equal(tt.wantBool, result.Valid)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func (s *HelpersSuite) TestNullInt() {
|
||||
tests := []struct {
|
||||
name string
|
||||
input int
|
||||
wantInt int64
|
||||
wantBool bool
|
||||
}{
|
||||
{
|
||||
name: "zero",
|
||||
input: 0,
|
||||
wantInt: 0,
|
||||
wantBool: false,
|
||||
},
|
||||
{
|
||||
name: "negative",
|
||||
input: -1,
|
||||
wantInt: -1,
|
||||
wantBool: false,
|
||||
},
|
||||
{
|
||||
name: "positive",
|
||||
input: 42,
|
||||
wantInt: 42,
|
||||
wantBool: true,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
s.Run(tt.name, func() {
|
||||
result := nullInt(tt.input)
|
||||
s.Equal(tt.wantInt, result.Int64)
|
||||
s.Equal(tt.wantBool, result.Valid)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func (s *HelpersSuite) TestRepeatPlaceholders() {
|
||||
tests := []struct {
|
||||
name string
|
||||
input int
|
||||
expected string
|
||||
}{
|
||||
{
|
||||
name: "zero",
|
||||
input: 0,
|
||||
expected: "",
|
||||
},
|
||||
{
|
||||
name: "negative",
|
||||
input: -1,
|
||||
expected: "",
|
||||
},
|
||||
{
|
||||
name: "one",
|
||||
input: 1,
|
||||
expected: ", ?",
|
||||
},
|
||||
{
|
||||
name: "three",
|
||||
input: 3,
|
||||
expected: ", ?, ?, ?",
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
s.Run(tt.name, func() {
|
||||
result := repeatPlaceholders(tt.input)
|
||||
s.Equal(tt.expected, result)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func (s *HelpersSuite) TestInt64SliceToInterface() {
|
||||
tests := []struct {
|
||||
name string
|
||||
input []int64
|
||||
expected int
|
||||
}{
|
||||
{
|
||||
name: "empty slice",
|
||||
input: []int64{},
|
||||
expected: 0,
|
||||
},
|
||||
{
|
||||
name: "single element",
|
||||
input: []int64{42},
|
||||
expected: 1,
|
||||
},
|
||||
{
|
||||
name: "multiple elements",
|
||||
input: []int64{1, 2, 3, 4, 5},
|
||||
expected: 5,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
s.Run(tt.name, func() {
|
||||
result := int64SliceToInterface(tt.input)
|
||||
s.Len(result, tt.expected)
|
||||
for i, v := range result {
|
||||
s.Equal(tt.input[i], v)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// TestBuildGetByIDsQuery tests the shared query builder.
|
||||
func TestBuildGetByIDsQuery(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
baseQuery string
|
||||
ids []int64
|
||||
orderBy string
|
||||
limit int
|
||||
wantQuery string
|
||||
wantArgs int
|
||||
}{
|
||||
{
|
||||
name: "single id, no limit, desc order",
|
||||
baseQuery: "SELECT * FROM test",
|
||||
ids: []int64{1},
|
||||
orderBy: "date_desc",
|
||||
limit: 0,
|
||||
wantQuery: "SELECT * FROM test WHERE id IN (?)\n\t\tORDER BY created_at_epoch DESC",
|
||||
wantArgs: 1,
|
||||
},
|
||||
{
|
||||
name: "multiple ids with limit and asc order",
|
||||
baseQuery: "SELECT * FROM test",
|
||||
ids: []int64{1, 2, 3},
|
||||
orderBy: "date_asc",
|
||||
limit: 10,
|
||||
wantQuery: "SELECT * FROM test WHERE id IN (?, ?, ?)\n\t\tORDER BY created_at_epoch ASC LIMIT ?",
|
||||
wantArgs: 4,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
query, args := BuildGetByIDsQuery(tt.baseQuery, tt.ids, tt.orderBy, tt.limit)
|
||||
assert.Contains(t, query, "WHERE id IN")
|
||||
assert.Len(t, args, tt.wantArgs)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// TestEnsureSessionExists tests session auto-creation.
|
||||
func TestEnsureSessionExists(t *testing.T) {
|
||||
db, _, cleanup := testDB(t)
|
||||
defer cleanup()
|
||||
createBaseTables(t, db)
|
||||
|
||||
store := newStoreFromDB(db)
|
||||
ctx := context.Background()
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
sdkSessionID string
|
||||
project string
|
||||
setup func()
|
||||
wantErr bool
|
||||
}{
|
||||
{
|
||||
name: "create new session",
|
||||
sdkSessionID: "sdk-new",
|
||||
project: "project-a",
|
||||
wantErr: false,
|
||||
},
|
||||
{
|
||||
name: "session already exists",
|
||||
sdkSessionID: "sdk-existing",
|
||||
project: "project-b",
|
||||
setup: func() {
|
||||
seedSession(t, db, "sdk-existing", "sdk-existing", "project-b")
|
||||
},
|
||||
wantErr: false,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
if tt.setup != nil {
|
||||
tt.setup()
|
||||
}
|
||||
|
||||
err := EnsureSessionExists(ctx, store, tt.sdkSessionID, tt.project)
|
||||
if tt.wantErr {
|
||||
require.Error(t, err)
|
||||
} else {
|
||||
require.NoError(t, err)
|
||||
|
||||
// Verify session exists
|
||||
var id int64
|
||||
err := db.QueryRow("SELECT id FROM sdk_sessions WHERE sdk_session_id = ?", tt.sdkSessionID).Scan(&id)
|
||||
require.NoError(t, err)
|
||||
assert.Greater(t, id, int64(0))
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
@@ -1,136 +0,0 @@
|
||||
// Package sqlite provides SQLite database operations for claude-mnemonic.
|
||||
package sqlite
|
||||
|
||||
import (
|
||||
"context"
|
||||
"time"
|
||||
|
||||
"github.com/lukaszraczylo/claude-mnemonic/pkg/models"
|
||||
)
|
||||
|
||||
// SummaryStore provides summary-related database operations.
|
||||
type SummaryStore struct {
|
||||
store *Store
|
||||
}
|
||||
|
||||
// NewSummaryStore creates a new summary store.
|
||||
func NewSummaryStore(store *Store) *SummaryStore {
|
||||
return &SummaryStore{store: store}
|
||||
}
|
||||
|
||||
// StoreSummary stores a new session summary.
|
||||
func (s *SummaryStore) StoreSummary(ctx context.Context, sdkSessionID, project string, summary *models.ParsedSummary, promptNumber int, discoveryTokens int64) (int64, int64, error) {
|
||||
now := time.Now()
|
||||
nowEpoch := now.UnixMilli()
|
||||
|
||||
// Ensure session exists (auto-create if missing)
|
||||
if err := s.ensureSessionExists(ctx, sdkSessionID, project); err != nil {
|
||||
return 0, 0, err
|
||||
}
|
||||
|
||||
const query = `
|
||||
INSERT INTO session_summaries
|
||||
(sdk_session_id, project, request, investigated, learned, completed,
|
||||
next_steps, notes, prompt_number, discovery_tokens, created_at, created_at_epoch)
|
||||
VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)
|
||||
`
|
||||
|
||||
result, err := s.store.ExecContext(ctx, query,
|
||||
sdkSessionID, project,
|
||||
nullString(summary.Request), nullString(summary.Investigated),
|
||||
nullString(summary.Learned), nullString(summary.Completed),
|
||||
nullString(summary.NextSteps), nullString(summary.Notes),
|
||||
nullInt(promptNumber), discoveryTokens,
|
||||
now.Format(time.RFC3339), nowEpoch,
|
||||
)
|
||||
if err != nil {
|
||||
return 0, 0, err
|
||||
}
|
||||
|
||||
id, _ := result.LastInsertId()
|
||||
return id, nowEpoch, nil
|
||||
}
|
||||
|
||||
// ensureSessionExists creates a session if it doesn't exist.
|
||||
func (s *SummaryStore) ensureSessionExists(ctx context.Context, sdkSessionID, project string) error {
|
||||
return EnsureSessionExists(ctx, s.store, sdkSessionID, project)
|
||||
}
|
||||
|
||||
// GetSummariesByIDs retrieves summaries by a list of IDs.
|
||||
func (s *SummaryStore) GetSummariesByIDs(ctx context.Context, ids []int64, orderBy string, limit int) ([]*models.SessionSummary, error) {
|
||||
if len(ids) == 0 {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
const baseQuery = `
|
||||
SELECT id, sdk_session_id, project, request, investigated, learned, completed,
|
||||
next_steps, notes, prompt_number, discovery_tokens, created_at, created_at_epoch
|
||||
FROM session_summaries`
|
||||
|
||||
query, args := BuildGetByIDsQuery(baseQuery, ids, orderBy, limit)
|
||||
|
||||
rows, err := s.store.db.QueryContext(ctx, query, args...)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer rows.Close()
|
||||
|
||||
return scanSummaryRows(rows)
|
||||
}
|
||||
|
||||
// GetRecentSummaries retrieves recent summaries for a project.
|
||||
func (s *SummaryStore) GetRecentSummaries(ctx context.Context, project string, limit int) ([]*models.SessionSummary, error) {
|
||||
const query = `
|
||||
SELECT id, sdk_session_id, project, request, investigated, learned, completed,
|
||||
next_steps, notes, prompt_number, discovery_tokens, created_at, created_at_epoch
|
||||
FROM session_summaries
|
||||
WHERE project = ?
|
||||
ORDER BY created_at_epoch DESC
|
||||
LIMIT ?
|
||||
`
|
||||
|
||||
rows, err := s.store.QueryContext(ctx, query, project, limit)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer rows.Close()
|
||||
|
||||
return scanSummaryRows(rows)
|
||||
}
|
||||
|
||||
// GetAllRecentSummaries retrieves recent summaries across all projects.
|
||||
func (s *SummaryStore) GetAllRecentSummaries(ctx context.Context, limit int) ([]*models.SessionSummary, error) {
|
||||
const query = `
|
||||
SELECT id, sdk_session_id, project, request, investigated, learned, completed,
|
||||
next_steps, notes, prompt_number, discovery_tokens, created_at, created_at_epoch
|
||||
FROM session_summaries
|
||||
ORDER BY created_at_epoch DESC
|
||||
LIMIT ?
|
||||
`
|
||||
|
||||
rows, err := s.store.QueryContext(ctx, query, limit)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer rows.Close()
|
||||
|
||||
return scanSummaryRows(rows)
|
||||
}
|
||||
|
||||
// GetAllSummaries retrieves all summaries (for vector rebuild).
|
||||
func (s *SummaryStore) GetAllSummaries(ctx context.Context) ([]*models.SessionSummary, error) {
|
||||
const query = `
|
||||
SELECT id, sdk_session_id, project, request, investigated, learned, completed,
|
||||
next_steps, notes, prompt_number, discovery_tokens, created_at, created_at_epoch
|
||||
FROM session_summaries
|
||||
ORDER BY id
|
||||
`
|
||||
|
||||
rows, err := s.store.QueryContext(ctx, query)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer rows.Close()
|
||||
|
||||
return scanSummaryRows(rows)
|
||||
}
|
||||
@@ -1,242 +0,0 @@
|
||||
package sqlite
|
||||
|
||||
import (
|
||||
"context"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/lukaszraczylo/claude-mnemonic/pkg/models"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func testSummaryStore(t *testing.T) (*SummaryStore, *Store, func()) {
|
||||
t.Helper()
|
||||
|
||||
db, _, cleanup := testDB(t)
|
||||
createAllTables(t, db)
|
||||
|
||||
store := newStoreFromDB(db)
|
||||
summaryStore := NewSummaryStore(store)
|
||||
|
||||
return summaryStore, store, cleanup
|
||||
}
|
||||
|
||||
func TestSummaryStore_StoreSummary(t *testing.T) {
|
||||
summaryStore, store, cleanup := testSummaryStore(t)
|
||||
defer cleanup()
|
||||
|
||||
ctx := context.Background()
|
||||
|
||||
// Create a session first
|
||||
seedSession(t, storeDB(store), "claude-1", "sdk-1", "test-project")
|
||||
|
||||
summary := &models.ParsedSummary{
|
||||
Request: "Add new feature",
|
||||
Investigated: "Looked at existing code",
|
||||
Learned: "Found the pattern to follow",
|
||||
Completed: "Implemented the feature",
|
||||
NextSteps: "Add tests",
|
||||
Notes: "Some additional notes",
|
||||
}
|
||||
|
||||
id, epoch, err := summaryStore.StoreSummary(ctx, "sdk-1", "test-project", summary, 1, 100)
|
||||
require.NoError(t, err)
|
||||
assert.Greater(t, id, int64(0))
|
||||
assert.Greater(t, epoch, int64(0))
|
||||
|
||||
// Verify it was saved
|
||||
var count int
|
||||
err = storeDB(store).QueryRow("SELECT COUNT(*) FROM session_summaries WHERE id = ?", id).Scan(&count)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, 1, count)
|
||||
}
|
||||
|
||||
func TestSummaryStore_StoreSummary_AutoCreateSession(t *testing.T) {
|
||||
summaryStore, store, cleanup := testSummaryStore(t)
|
||||
defer cleanup()
|
||||
|
||||
ctx := context.Background()
|
||||
|
||||
// Don't create session beforehand - should be auto-created
|
||||
summary := &models.ParsedSummary{
|
||||
Request: "Test request",
|
||||
}
|
||||
|
||||
id, _, err := summaryStore.StoreSummary(ctx, "auto-session", "test-project", summary, 1, 0)
|
||||
require.NoError(t, err)
|
||||
assert.Greater(t, id, int64(0))
|
||||
|
||||
// Verify session was auto-created
|
||||
var sessionCount int
|
||||
err = storeDB(store).QueryRow("SELECT COUNT(*) FROM sdk_sessions WHERE sdk_session_id = ?", "auto-session").Scan(&sessionCount)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, 1, sessionCount)
|
||||
}
|
||||
|
||||
func TestSummaryStore_GetRecentSummaries(t *testing.T) {
|
||||
summaryStore, store, cleanup := testSummaryStore(t)
|
||||
defer cleanup()
|
||||
|
||||
ctx := context.Background()
|
||||
|
||||
// Create a session
|
||||
seedSession(t, storeDB(store), "claude-1", "sdk-1", "test-project")
|
||||
|
||||
// Store multiple summaries
|
||||
for i := 0; i < 5; i++ {
|
||||
summary := &models.ParsedSummary{
|
||||
Request: "Request " + string(rune('A'+i)),
|
||||
}
|
||||
_, _, err := summaryStore.StoreSummary(ctx, "sdk-1", "test-project", summary, i+1, 0)
|
||||
require.NoError(t, err)
|
||||
time.Sleep(time.Millisecond) // Ensure different timestamps
|
||||
}
|
||||
|
||||
// Get recent summaries with limit
|
||||
summaries, err := summaryStore.GetRecentSummaries(ctx, "test-project", 3)
|
||||
require.NoError(t, err)
|
||||
assert.Len(t, summaries, 3)
|
||||
|
||||
// Should be in descending order
|
||||
assert.Equal(t, int64(5), summaries[0].PromptNumber.Int64)
|
||||
}
|
||||
|
||||
func TestSummaryStore_GetAllRecentSummaries(t *testing.T) {
|
||||
summaryStore, store, cleanup := testSummaryStore(t)
|
||||
defer cleanup()
|
||||
|
||||
ctx := context.Background()
|
||||
|
||||
// Create sessions for different projects
|
||||
seedSession(t, storeDB(store), "claude-1", "sdk-1", "project-a")
|
||||
seedSession(t, storeDB(store), "claude-2", "sdk-2", "project-b")
|
||||
|
||||
// Store summaries for both projects
|
||||
for i := 0; i < 3; i++ {
|
||||
summary := &models.ParsedSummary{Request: "Project A request"}
|
||||
_, _, err := summaryStore.StoreSummary(ctx, "sdk-1", "project-a", summary, i+1, 0)
|
||||
require.NoError(t, err)
|
||||
}
|
||||
for i := 0; i < 2; i++ {
|
||||
summary := &models.ParsedSummary{Request: "Project B request"}
|
||||
_, _, err := summaryStore.StoreSummary(ctx, "sdk-2", "project-b", summary, i+1, 0)
|
||||
require.NoError(t, err)
|
||||
}
|
||||
|
||||
// Get all summaries (should include both projects)
|
||||
summaries, err := summaryStore.GetAllRecentSummaries(ctx, 10)
|
||||
require.NoError(t, err)
|
||||
assert.Len(t, summaries, 5)
|
||||
}
|
||||
|
||||
func TestSummaryStore_GetSummariesByIDs(t *testing.T) {
|
||||
summaryStore, store, cleanup := testSummaryStore(t)
|
||||
defer cleanup()
|
||||
|
||||
ctx := context.Background()
|
||||
|
||||
// Create a session
|
||||
seedSession(t, storeDB(store), "claude-1", "sdk-1", "test-project")
|
||||
|
||||
// Store summaries and collect IDs
|
||||
var ids []int64
|
||||
for i := 0; i < 5; i++ {
|
||||
summary := &models.ParsedSummary{Request: "Request " + string(rune('A'+i))}
|
||||
id, _, err := summaryStore.StoreSummary(ctx, "sdk-1", "test-project", summary, i+1, 0)
|
||||
require.NoError(t, err)
|
||||
ids = append(ids, id)
|
||||
time.Sleep(time.Millisecond)
|
||||
}
|
||||
|
||||
// Get specific summaries by ID
|
||||
summaries, err := summaryStore.GetSummariesByIDs(ctx, ids[:3], "date_desc", 10)
|
||||
require.NoError(t, err)
|
||||
assert.Len(t, summaries, 3)
|
||||
|
||||
// Test with ascending order
|
||||
summaries, err = summaryStore.GetSummariesByIDs(ctx, ids, "date_asc", 2)
|
||||
require.NoError(t, err)
|
||||
assert.Len(t, summaries, 2)
|
||||
assert.Equal(t, int64(1), summaries[0].PromptNumber.Int64)
|
||||
}
|
||||
|
||||
func TestSummaryStore_GetSummariesByIDs_EmptyInput(t *testing.T) {
|
||||
summaryStore, _, cleanup := testSummaryStore(t)
|
||||
defer cleanup()
|
||||
|
||||
ctx := context.Background()
|
||||
|
||||
// Empty IDs should return nil
|
||||
summaries, err := summaryStore.GetSummariesByIDs(ctx, []int64{}, "date_desc", 10)
|
||||
require.NoError(t, err)
|
||||
assert.Nil(t, summaries)
|
||||
}
|
||||
|
||||
func TestSummaryStore_SummaryFields(t *testing.T) {
|
||||
summaryStore, store, cleanup := testSummaryStore(t)
|
||||
defer cleanup()
|
||||
|
||||
ctx := context.Background()
|
||||
|
||||
// Create a session
|
||||
seedSession(t, storeDB(store), "claude-1", "sdk-1", "test-project")
|
||||
|
||||
// Store a summary with all fields
|
||||
summary := &models.ParsedSummary{
|
||||
Request: "Add authentication",
|
||||
Investigated: "Reviewed existing auth code",
|
||||
Learned: "OAuth is preferred",
|
||||
Completed: "Implemented OAuth flow",
|
||||
NextSteps: "Add refresh token support",
|
||||
Notes: "Consider rate limiting",
|
||||
}
|
||||
|
||||
id, _, err := summaryStore.StoreSummary(ctx, "sdk-1", "test-project", summary, 5, 1500)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Retrieve and verify all fields
|
||||
summaries, err := summaryStore.GetSummariesByIDs(ctx, []int64{id}, "date_desc", 1)
|
||||
require.NoError(t, err)
|
||||
require.Len(t, summaries, 1)
|
||||
|
||||
s := summaries[0]
|
||||
assert.Equal(t, id, s.ID)
|
||||
assert.Equal(t, "sdk-1", s.SDKSessionID)
|
||||
assert.Equal(t, "test-project", s.Project)
|
||||
assert.Equal(t, "Add authentication", s.Request.String)
|
||||
assert.Equal(t, "Reviewed existing auth code", s.Investigated.String)
|
||||
assert.Equal(t, "OAuth is preferred", s.Learned.String)
|
||||
assert.Equal(t, "Implemented OAuth flow", s.Completed.String)
|
||||
assert.Equal(t, "Add refresh token support", s.NextSteps.String)
|
||||
assert.Equal(t, "Consider rate limiting", s.Notes.String)
|
||||
assert.Equal(t, int64(5), s.PromptNumber.Int64)
|
||||
assert.Equal(t, int64(1500), s.DiscoveryTokens)
|
||||
}
|
||||
|
||||
func TestSummaryStore_EmptySummary(t *testing.T) {
|
||||
summaryStore, store, cleanup := testSummaryStore(t)
|
||||
defer cleanup()
|
||||
|
||||
ctx := context.Background()
|
||||
|
||||
// Create a session
|
||||
seedSession(t, storeDB(store), "claude-1", "sdk-1", "test-project")
|
||||
|
||||
// Store an empty summary
|
||||
summary := &models.ParsedSummary{}
|
||||
|
||||
id, _, err := summaryStore.StoreSummary(ctx, "sdk-1", "test-project", summary, 0, 0)
|
||||
require.NoError(t, err)
|
||||
assert.Greater(t, id, int64(0))
|
||||
|
||||
// Retrieve and verify null fields
|
||||
summaries, err := summaryStore.GetSummariesByIDs(ctx, []int64{id}, "date_desc", 1)
|
||||
require.NoError(t, err)
|
||||
require.Len(t, summaries, 1)
|
||||
|
||||
s := summaries[0]
|
||||
assert.False(t, s.Request.Valid || s.Request.String != "")
|
||||
assert.False(t, s.Investigated.Valid || s.Investigated.String != "")
|
||||
assert.False(t, s.Learned.Valid || s.Learned.String != "")
|
||||
}
|
||||
@@ -1,367 +0,0 @@
|
||||
package sqlite
|
||||
|
||||
import (
|
||||
"database/sql"
|
||||
"os"
|
||||
"testing"
|
||||
|
||||
_ "github.com/mattn/go-sqlite3"
|
||||
)
|
||||
|
||||
// newStoreFromDB creates a Store from an existing database connection for testing.
|
||||
func newStoreFromDB(db *sql.DB) *Store {
|
||||
return &Store{
|
||||
db: db,
|
||||
stmtCache: make(map[string]*sql.Stmt),
|
||||
}
|
||||
}
|
||||
|
||||
// storeDB returns the underlying database connection from a store for testing.
|
||||
func storeDB(s *Store) *sql.DB {
|
||||
return s.db
|
||||
}
|
||||
|
||||
// testDB creates a temporary SQLite database for testing.
|
||||
// Returns the database, path, and a cleanup function.
|
||||
func testDB(t *testing.T) (*sql.DB, string, func()) {
|
||||
t.Helper()
|
||||
|
||||
tmpDir, err := os.MkdirTemp("", "claude-mnemonic-test-*")
|
||||
if err != nil {
|
||||
t.Fatalf("create temp dir: %v", err)
|
||||
}
|
||||
|
||||
dbPath := tmpDir + "/test.db"
|
||||
connStr := dbPath + "?_journal_mode=WAL&_synchronous=NORMAL&_foreign_keys=ON"
|
||||
|
||||
db, err := sql.Open("sqlite3", connStr)
|
||||
if err != nil {
|
||||
_ = os.RemoveAll(tmpDir)
|
||||
t.Fatalf("open database: %v", err)
|
||||
}
|
||||
|
||||
cleanup := func() {
|
||||
_ = db.Close()
|
||||
_ = os.RemoveAll(tmpDir)
|
||||
}
|
||||
|
||||
return db, dbPath, cleanup
|
||||
}
|
||||
|
||||
// createBaseTables creates the base tables without FTS5 for unit testing.
|
||||
func createBaseTables(t *testing.T, db *sql.DB) {
|
||||
t.Helper()
|
||||
|
||||
_, err := db.Exec(`
|
||||
CREATE TABLE IF NOT EXISTS schema_versions (
|
||||
id INTEGER PRIMARY KEY,
|
||||
version INTEGER UNIQUE NOT NULL,
|
||||
applied_at TEXT NOT NULL
|
||||
)
|
||||
`)
|
||||
if err != nil {
|
||||
t.Fatalf("create schema_versions: %v", err)
|
||||
}
|
||||
|
||||
_, err = db.Exec(`
|
||||
CREATE TABLE IF NOT EXISTS sdk_sessions (
|
||||
id INTEGER PRIMARY KEY AUTOINCREMENT,
|
||||
claude_session_id TEXT UNIQUE NOT NULL,
|
||||
sdk_session_id TEXT UNIQUE,
|
||||
project TEXT NOT NULL,
|
||||
user_prompt TEXT,
|
||||
started_at TEXT NOT NULL,
|
||||
started_at_epoch INTEGER NOT NULL,
|
||||
completed_at TEXT,
|
||||
completed_at_epoch INTEGER,
|
||||
status TEXT CHECK(status IN ('active', 'completed', 'failed')) NOT NULL DEFAULT 'active',
|
||||
worker_port INTEGER,
|
||||
prompt_counter INTEGER DEFAULT 0
|
||||
)
|
||||
`)
|
||||
if err != nil {
|
||||
t.Fatalf("create sdk_sessions: %v", err)
|
||||
}
|
||||
|
||||
_, err = db.Exec(`
|
||||
CREATE TABLE IF NOT EXISTS observations (
|
||||
id INTEGER PRIMARY KEY AUTOINCREMENT,
|
||||
sdk_session_id TEXT NOT NULL,
|
||||
project TEXT NOT NULL,
|
||||
text TEXT,
|
||||
type TEXT NOT NULL CHECK(type IN ('decision', 'bugfix', 'feature', 'refactor', 'discovery', 'change')),
|
||||
title TEXT,
|
||||
subtitle TEXT,
|
||||
facts TEXT,
|
||||
narrative TEXT,
|
||||
concepts TEXT,
|
||||
files_read TEXT,
|
||||
files_modified TEXT,
|
||||
file_mtimes TEXT,
|
||||
scope TEXT DEFAULT 'project' CHECK(scope IN ('project', 'global')),
|
||||
prompt_number INTEGER,
|
||||
discovery_tokens INTEGER DEFAULT 0,
|
||||
created_at TEXT NOT NULL,
|
||||
created_at_epoch INTEGER NOT NULL,
|
||||
importance_score REAL DEFAULT 1.0,
|
||||
user_feedback INTEGER DEFAULT 0,
|
||||
retrieval_count INTEGER DEFAULT 0,
|
||||
last_retrieved_at_epoch INTEGER,
|
||||
score_updated_at_epoch INTEGER,
|
||||
is_superseded INTEGER DEFAULT 0,
|
||||
FOREIGN KEY(sdk_session_id) REFERENCES sdk_sessions(sdk_session_id) ON DELETE CASCADE
|
||||
)
|
||||
`)
|
||||
if err != nil {
|
||||
t.Fatalf("create observations: %v", err)
|
||||
}
|
||||
|
||||
// Create observation_conflicts table for conflict detection
|
||||
_, err = db.Exec(`
|
||||
CREATE TABLE IF NOT EXISTS observation_conflicts (
|
||||
id INTEGER PRIMARY KEY AUTOINCREMENT,
|
||||
newer_obs_id INTEGER NOT NULL,
|
||||
older_obs_id INTEGER NOT NULL,
|
||||
conflict_type TEXT NOT NULL CHECK(conflict_type IN ('superseded', 'contradicts', 'outdated_pattern')),
|
||||
resolution TEXT NOT NULL CHECK(resolution IN ('prefer_newer', 'prefer_older', 'manual')),
|
||||
reason TEXT,
|
||||
detected_at TEXT NOT NULL,
|
||||
detected_at_epoch INTEGER NOT NULL,
|
||||
resolved INTEGER DEFAULT 0,
|
||||
resolved_at TEXT,
|
||||
FOREIGN KEY(newer_obs_id) REFERENCES observations(id) ON DELETE CASCADE,
|
||||
FOREIGN KEY(older_obs_id) REFERENCES observations(id) ON DELETE CASCADE
|
||||
)
|
||||
`)
|
||||
if err != nil {
|
||||
t.Fatalf("create observation_conflicts: %v", err)
|
||||
}
|
||||
|
||||
_, err = db.Exec(`
|
||||
CREATE TABLE IF NOT EXISTS session_summaries (
|
||||
id INTEGER PRIMARY KEY AUTOINCREMENT,
|
||||
sdk_session_id TEXT NOT NULL,
|
||||
project TEXT NOT NULL,
|
||||
request TEXT,
|
||||
investigated TEXT,
|
||||
learned TEXT,
|
||||
completed TEXT,
|
||||
next_steps TEXT,
|
||||
files_read TEXT,
|
||||
files_edited TEXT,
|
||||
notes TEXT,
|
||||
prompt_number INTEGER,
|
||||
discovery_tokens INTEGER DEFAULT 0,
|
||||
created_at TEXT NOT NULL,
|
||||
created_at_epoch INTEGER NOT NULL,
|
||||
FOREIGN KEY(sdk_session_id) REFERENCES sdk_sessions(sdk_session_id) ON DELETE CASCADE
|
||||
)
|
||||
`)
|
||||
if err != nil {
|
||||
t.Fatalf("create session_summaries: %v", err)
|
||||
}
|
||||
|
||||
_, err = db.Exec(`
|
||||
CREATE TABLE IF NOT EXISTS user_prompts (
|
||||
id INTEGER PRIMARY KEY AUTOINCREMENT,
|
||||
claude_session_id TEXT NOT NULL,
|
||||
prompt_number INTEGER NOT NULL,
|
||||
prompt_text TEXT NOT NULL,
|
||||
matched_observations INTEGER DEFAULT 0,
|
||||
created_at TEXT NOT NULL,
|
||||
created_at_epoch INTEGER NOT NULL,
|
||||
FOREIGN KEY(claude_session_id) REFERENCES sdk_sessions(claude_session_id) ON DELETE CASCADE
|
||||
)
|
||||
`)
|
||||
if err != nil {
|
||||
t.Fatalf("create user_prompts: %v", err)
|
||||
}
|
||||
|
||||
_, err = db.Exec(`
|
||||
CREATE TABLE IF NOT EXISTS patterns (
|
||||
id INTEGER PRIMARY KEY AUTOINCREMENT,
|
||||
name TEXT NOT NULL,
|
||||
type TEXT NOT NULL CHECK(type IN ('bug', 'refactor', 'architecture', 'anti-pattern', 'best-practice')),
|
||||
description TEXT,
|
||||
signature TEXT,
|
||||
recommendation TEXT,
|
||||
frequency INTEGER DEFAULT 1,
|
||||
projects TEXT,
|
||||
observation_ids TEXT,
|
||||
status TEXT DEFAULT 'active' CHECK(status IN ('active', 'deprecated', 'merged')),
|
||||
merged_into_id INTEGER,
|
||||
confidence REAL DEFAULT 0.5,
|
||||
last_seen_at TEXT NOT NULL,
|
||||
last_seen_at_epoch INTEGER NOT NULL,
|
||||
created_at TEXT NOT NULL,
|
||||
created_at_epoch INTEGER NOT NULL,
|
||||
FOREIGN KEY(merged_into_id) REFERENCES patterns(id) ON DELETE SET NULL
|
||||
)
|
||||
`)
|
||||
if err != nil {
|
||||
t.Fatalf("create patterns: %v", err)
|
||||
}
|
||||
|
||||
indexes := []string{
|
||||
`CREATE INDEX IF NOT EXISTS idx_sdk_sessions_claude_id ON sdk_sessions(claude_session_id)`,
|
||||
`CREATE INDEX IF NOT EXISTS idx_sdk_sessions_sdk_id ON sdk_sessions(sdk_session_id)`,
|
||||
`CREATE INDEX IF NOT EXISTS idx_sdk_sessions_project ON sdk_sessions(project)`,
|
||||
`CREATE INDEX IF NOT EXISTS idx_observations_sdk_session ON observations(sdk_session_id)`,
|
||||
`CREATE INDEX IF NOT EXISTS idx_observations_project ON observations(project)`,
|
||||
`CREATE INDEX IF NOT EXISTS idx_observations_scope ON observations(scope)`,
|
||||
`CREATE INDEX IF NOT EXISTS idx_observations_created ON observations(created_at_epoch DESC)`,
|
||||
`CREATE INDEX IF NOT EXISTS idx_session_summaries_sdk_session ON session_summaries(sdk_session_id)`,
|
||||
`CREATE INDEX IF NOT EXISTS idx_session_summaries_project ON session_summaries(project)`,
|
||||
`CREATE INDEX IF NOT EXISTS idx_user_prompts_claude_session ON user_prompts(claude_session_id)`,
|
||||
`CREATE INDEX IF NOT EXISTS idx_user_prompts_created ON user_prompts(created_at_epoch DESC)`,
|
||||
}
|
||||
for _, idx := range indexes {
|
||||
if _, err := db.Exec(idx); err != nil {
|
||||
t.Fatalf("create index: %v", err)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// seedSession creates a test session in the database.
|
||||
func seedSession(t *testing.T, db *sql.DB, claudeSessionID, sdkSessionID, project string) {
|
||||
t.Helper()
|
||||
|
||||
_, err := db.Exec(`
|
||||
INSERT INTO sdk_sessions (claude_session_id, sdk_session_id, project, started_at, started_at_epoch, status)
|
||||
VALUES (?, ?, ?, datetime('now'), strftime('%s', 'now') * 1000, 'active')
|
||||
`, claudeSessionID, sdkSessionID, project)
|
||||
if err != nil {
|
||||
t.Fatalf("seed session: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
// hasFTS5 checks if FTS5 is available in the SQLite build.
|
||||
func hasFTS5(db *sql.DB) bool {
|
||||
_, err := db.Exec("CREATE VIRTUAL TABLE IF NOT EXISTS fts5_test USING fts5(content)")
|
||||
if err != nil {
|
||||
return false
|
||||
}
|
||||
_, _ = db.Exec("DROP TABLE IF EXISTS fts5_test")
|
||||
return true
|
||||
}
|
||||
|
||||
// createFTSTables creates FTS5 virtual tables and triggers for full-text search.
|
||||
func createFTSTables(t *testing.T, db *sql.DB) {
|
||||
t.Helper()
|
||||
|
||||
if !hasFTS5(db) {
|
||||
t.Skip("FTS5 not available in this SQLite build")
|
||||
}
|
||||
|
||||
_, err := db.Exec(`
|
||||
CREATE VIRTUAL TABLE IF NOT EXISTS observations_fts USING fts5(
|
||||
title, subtitle, narrative,
|
||||
content='observations',
|
||||
content_rowid='id'
|
||||
)
|
||||
`)
|
||||
if err != nil {
|
||||
t.Fatalf("create observations_fts: %v", err)
|
||||
}
|
||||
|
||||
_, err = db.Exec(`
|
||||
CREATE TRIGGER IF NOT EXISTS observations_ai AFTER INSERT ON observations BEGIN
|
||||
INSERT INTO observations_fts(rowid, title, subtitle, narrative)
|
||||
VALUES (new.id, new.title, new.subtitle, new.narrative);
|
||||
END
|
||||
`)
|
||||
if err != nil {
|
||||
t.Fatalf("create observations_ai trigger: %v", err)
|
||||
}
|
||||
|
||||
_, err = db.Exec(`
|
||||
CREATE TRIGGER IF NOT EXISTS observations_ad AFTER DELETE ON observations BEGIN
|
||||
INSERT INTO observations_fts(observations_fts, rowid, title, subtitle, narrative)
|
||||
VALUES ('delete', old.id, old.title, old.subtitle, old.narrative);
|
||||
END
|
||||
`)
|
||||
if err != nil {
|
||||
t.Fatalf("create observations_ad trigger: %v", err)
|
||||
}
|
||||
|
||||
_, err = db.Exec(`
|
||||
CREATE TRIGGER IF NOT EXISTS observations_au AFTER UPDATE ON observations BEGIN
|
||||
INSERT INTO observations_fts(observations_fts, rowid, title, subtitle, narrative)
|
||||
VALUES ('delete', old.id, old.title, old.subtitle, old.narrative);
|
||||
INSERT INTO observations_fts(rowid, title, subtitle, narrative)
|
||||
VALUES (new.id, new.title, new.subtitle, new.narrative);
|
||||
END
|
||||
`)
|
||||
if err != nil {
|
||||
t.Fatalf("create observations_au trigger: %v", err)
|
||||
}
|
||||
|
||||
_, err = db.Exec(`
|
||||
CREATE VIRTUAL TABLE IF NOT EXISTS session_summaries_fts USING fts5(
|
||||
request, investigated, learned, completed, next_steps, notes,
|
||||
content='session_summaries',
|
||||
content_rowid='id'
|
||||
)
|
||||
`)
|
||||
if err != nil {
|
||||
t.Fatalf("create session_summaries_fts: %v", err)
|
||||
}
|
||||
|
||||
_, err = db.Exec(`
|
||||
CREATE TRIGGER IF NOT EXISTS summaries_ai AFTER INSERT ON session_summaries BEGIN
|
||||
INSERT INTO session_summaries_fts(rowid, request, investigated, learned, completed, next_steps, notes)
|
||||
VALUES (new.id, new.request, new.investigated, new.learned, new.completed, new.next_steps, new.notes);
|
||||
END
|
||||
`)
|
||||
if err != nil {
|
||||
t.Fatalf("create summaries_ai trigger: %v", err)
|
||||
}
|
||||
|
||||
_, err = db.Exec(`
|
||||
CREATE TRIGGER IF NOT EXISTS summaries_ad AFTER DELETE ON session_summaries BEGIN
|
||||
INSERT INTO session_summaries_fts(session_summaries_fts, rowid, request, investigated, learned, completed, next_steps, notes)
|
||||
VALUES ('delete', old.id, old.request, old.investigated, old.learned, old.completed, old.next_steps, old.notes);
|
||||
END
|
||||
`)
|
||||
if err != nil {
|
||||
t.Fatalf("create summaries_ad trigger: %v", err)
|
||||
}
|
||||
|
||||
_, err = db.Exec(`
|
||||
CREATE VIRTUAL TABLE IF NOT EXISTS user_prompts_fts USING fts5(
|
||||
prompt_text,
|
||||
content='user_prompts',
|
||||
content_rowid='id'
|
||||
)
|
||||
`)
|
||||
if err != nil {
|
||||
t.Fatalf("create user_prompts_fts: %v", err)
|
||||
}
|
||||
|
||||
_, err = db.Exec(`
|
||||
CREATE TRIGGER IF NOT EXISTS prompts_ai AFTER INSERT ON user_prompts BEGIN
|
||||
INSERT INTO user_prompts_fts(rowid, prompt_text)
|
||||
VALUES (new.id, new.prompt_text);
|
||||
END
|
||||
`)
|
||||
if err != nil {
|
||||
t.Fatalf("create prompts_ai trigger: %v", err)
|
||||
}
|
||||
|
||||
_, err = db.Exec(`
|
||||
CREATE TRIGGER IF NOT EXISTS prompts_ad AFTER DELETE ON user_prompts BEGIN
|
||||
INSERT INTO user_prompts_fts(user_prompts_fts, rowid, prompt_text)
|
||||
VALUES ('delete', old.id, old.prompt_text);
|
||||
END
|
||||
`)
|
||||
if err != nil {
|
||||
t.Fatalf("create prompts_ad trigger: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
// createAllTables creates all tables including FTS5 for comprehensive testing.
|
||||
func createAllTables(t *testing.T, db *sql.DB) {
|
||||
t.Helper()
|
||||
createBaseTables(t, db)
|
||||
createFTSTables(t, db)
|
||||
}
|
||||
@@ -292,19 +292,19 @@ func (m *bgeModel) computeBatch(sentences []string) ([][]float32, error) {
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("create input_ids tensor: %w", err)
|
||||
}
|
||||
defer inputIdsTensor.Destroy()
|
||||
defer func() { _ = inputIdsTensor.Destroy() }()
|
||||
|
||||
attentionMaskTensor, err := ort.NewTensor(inputShape, attentionMaskData)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("create attention_mask tensor: %w", err)
|
||||
}
|
||||
defer attentionMaskTensor.Destroy()
|
||||
defer func() { _ = attentionMaskTensor.Destroy() }()
|
||||
|
||||
tokenTypeIdsTensor, err := ort.NewTensor(inputShape, tokenTypeIdsData)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("create token_type_ids tensor: %w", err)
|
||||
}
|
||||
defer tokenTypeIdsTensor.Destroy()
|
||||
defer func() { _ = tokenTypeIdsTensor.Destroy() }()
|
||||
|
||||
// Create output tensor based on pooling strategy
|
||||
var outputShape ort.Shape
|
||||
@@ -324,7 +324,7 @@ func (m *bgeModel) computeBatch(sentences []string) ([][]float32, error) {
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("create output tensor: %w", err)
|
||||
}
|
||||
defer outputTensor.Destroy()
|
||||
defer func() { _ = outputTensor.Destroy() }()
|
||||
|
||||
// Run inference
|
||||
inputTensors := []ort.Value{inputIdsTensor, attentionMaskTensor, tokenTypeIdsTensor}
|
||||
|
||||
+120
-5
@@ -9,7 +9,11 @@ import (
|
||||
"io"
|
||||
"os"
|
||||
|
||||
"github.com/lukaszraczylo/claude-mnemonic/internal/db/gorm"
|
||||
"github.com/lukaszraczylo/claude-mnemonic/internal/scoring"
|
||||
"github.com/lukaszraczylo/claude-mnemonic/internal/search"
|
||||
"github.com/lukaszraczylo/claude-mnemonic/internal/vector/sqlitevec"
|
||||
"github.com/lukaszraczylo/claude-mnemonic/pkg/models"
|
||||
"github.com/rs/zerolog/log"
|
||||
)
|
||||
|
||||
@@ -19,15 +23,41 @@ type Server struct {
|
||||
version string
|
||||
stdin io.Reader
|
||||
stdout io.Writer
|
||||
|
||||
// Store dependencies for enhanced tools
|
||||
observationStore *gorm.ObservationStore
|
||||
patternStore *gorm.PatternStore
|
||||
relationStore *gorm.RelationStore
|
||||
sessionStore *gorm.SessionStore
|
||||
vectorClient *sqlitevec.Client
|
||||
scoreCalculator *scoring.Calculator
|
||||
recalculator *scoring.Recalculator
|
||||
}
|
||||
|
||||
// NewServer creates a new MCP server.
|
||||
func NewServer(searchMgr *search.Manager, version string) *Server {
|
||||
func NewServer(
|
||||
searchMgr *search.Manager,
|
||||
version string,
|
||||
observationStore *gorm.ObservationStore,
|
||||
patternStore *gorm.PatternStore,
|
||||
relationStore *gorm.RelationStore,
|
||||
sessionStore *gorm.SessionStore,
|
||||
vectorClient *sqlitevec.Client,
|
||||
scoreCalculator *scoring.Calculator,
|
||||
recalculator *scoring.Recalculator,
|
||||
) *Server {
|
||||
return &Server{
|
||||
searchMgr: searchMgr,
|
||||
version: version,
|
||||
stdin: os.Stdin,
|
||||
stdout: os.Stdout,
|
||||
searchMgr: searchMgr,
|
||||
version: version,
|
||||
stdin: os.Stdin,
|
||||
stdout: os.Stdout,
|
||||
observationStore: observationStore,
|
||||
patternStore: patternStore,
|
||||
relationStore: relationStore,
|
||||
sessionStore: sessionStore,
|
||||
vectorClient: vectorClient,
|
||||
scoreCalculator: scoreCalculator,
|
||||
recalculator: recalculator,
|
||||
}
|
||||
}
|
||||
|
||||
@@ -333,6 +363,19 @@ func (s *Server) handleToolsList(req *Request) *Response {
|
||||
},
|
||||
},
|
||||
},
|
||||
{
|
||||
Name: "find_related_observations",
|
||||
Description: "Find observations related to a given observation ID filtered by confidence threshold. Returns related observations sorted by confidence score. Useful for discovering relevant context.",
|
||||
InputSchema: map[string]any{
|
||||
"type": "object",
|
||||
"required": []string{"id"},
|
||||
"properties": map[string]any{
|
||||
"id": map[string]any{"type": "number", "description": "Observation ID"},
|
||||
"min_confidence": map[string]any{"type": "number", "default": 0.5, "minimum": 0.0, "maximum": 1.0, "description": "Minimum confidence threshold"},
|
||||
"limit": map[string]any{"type": "number", "default": 20, "minimum": 1, "maximum": 100},
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
return &Response{
|
||||
@@ -388,6 +431,12 @@ func (s *Server) handleToolsCall(ctx context.Context, req *Request) *Response {
|
||||
|
||||
// callTool dispatches to the appropriate tool handler.
|
||||
func (s *Server) callTool(ctx context.Context, name string, args json.RawMessage) (string, error) {
|
||||
// Relation discovery tool
|
||||
if name == "find_related_observations" {
|
||||
return s.handleFindRelatedObservations(ctx, args)
|
||||
}
|
||||
|
||||
// Original search-based tools
|
||||
var params search.SearchParams
|
||||
if err := json.Unmarshal(args, ¶ms); err != nil {
|
||||
return "", fmt.Errorf("invalid arguments: %w", err)
|
||||
@@ -537,6 +586,72 @@ func (s *Server) handleTimelineByQuery(ctx context.Context, args json.RawMessage
|
||||
return s.handleTimeline(ctx, args)
|
||||
}
|
||||
|
||||
// handleFindRelatedObservations finds observations related to a given observation ID.
|
||||
func (s *Server) handleFindRelatedObservations(ctx context.Context, args json.RawMessage) (string, error) {
|
||||
var params struct {
|
||||
ID int64 `json:"id"`
|
||||
MinConfidence float64 `json:"min_confidence"`
|
||||
Limit int `json:"limit"`
|
||||
}
|
||||
if err := json.Unmarshal(args, ¶ms); err != nil {
|
||||
return "", fmt.Errorf("invalid arguments: %w", err)
|
||||
}
|
||||
|
||||
if params.ID == 0 {
|
||||
return "", fmt.Errorf("id is required")
|
||||
}
|
||||
|
||||
if params.MinConfidence == 0 {
|
||||
params.MinConfidence = 0.5
|
||||
}
|
||||
|
||||
if params.Limit == 0 {
|
||||
params.Limit = 20
|
||||
}
|
||||
if params.Limit > 100 {
|
||||
params.Limit = 100
|
||||
}
|
||||
|
||||
// Get related observation IDs with confidence filter
|
||||
relatedIDs, err := s.relationStore.GetRelatedObservationIDs(ctx, params.ID, params.MinConfidence)
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("failed to get related observations: %w", err)
|
||||
}
|
||||
|
||||
if relatedIDs == nil {
|
||||
relatedIDs = []int64{}
|
||||
}
|
||||
|
||||
// Limit results
|
||||
if len(relatedIDs) > params.Limit {
|
||||
relatedIDs = relatedIDs[:params.Limit]
|
||||
}
|
||||
|
||||
// Fetch full observations
|
||||
observations := make([]*models.Observation, 0, len(relatedIDs))
|
||||
for _, id := range relatedIDs {
|
||||
obs, err := s.observationStore.GetObservationByID(ctx, id)
|
||||
if err != nil {
|
||||
continue // Skip errors for individual observations
|
||||
}
|
||||
if obs != nil {
|
||||
observations = append(observations, obs)
|
||||
}
|
||||
}
|
||||
|
||||
response := map[string]any{
|
||||
"observations": observations,
|
||||
"count": len(observations),
|
||||
}
|
||||
|
||||
output, err := json.Marshal(response)
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("marshal response: %w", err)
|
||||
}
|
||||
|
||||
return string(output), nil
|
||||
}
|
||||
|
||||
// sendResponse sends a JSON-RPC response.
|
||||
func (s *Server) sendResponse(resp *Response) {
|
||||
data, err := json.Marshal(resp)
|
||||
|
||||
+20
-20
@@ -24,7 +24,7 @@ func TestServerSuite(t *testing.T) {
|
||||
|
||||
// TestNewServer tests server creation.
|
||||
func (s *ServerSuite) TestNewServer() {
|
||||
server := NewServer(nil, "1.0.0")
|
||||
server := NewServer(nil, "1.0.0", nil, nil, nil, nil, nil, nil, nil)
|
||||
s.NotNil(server)
|
||||
s.Nil(server.searchMgr)
|
||||
s.Equal("1.0.0", server.version)
|
||||
@@ -293,7 +293,7 @@ func TestTimelineParams(t *testing.T) {
|
||||
|
||||
// TestHandleInitialize tests the initialize handler.
|
||||
func TestHandleInitialize(t *testing.T) {
|
||||
server := NewServer(nil, "1.2.3")
|
||||
server := NewServer(nil, "1.2.3", nil, nil, nil, nil, nil, nil, nil)
|
||||
|
||||
req := &Request{
|
||||
JSONRPC: "2.0",
|
||||
@@ -320,7 +320,7 @@ func TestHandleInitialize(t *testing.T) {
|
||||
|
||||
// TestHandleToolsList tests the tools/list handler.
|
||||
func TestHandleToolsList(t *testing.T) {
|
||||
server := NewServer(nil, "1.0.0")
|
||||
server := NewServer(nil, "1.0.0", nil, nil, nil, nil, nil, nil, nil)
|
||||
|
||||
req := &Request{
|
||||
JSONRPC: "2.0",
|
||||
@@ -361,7 +361,7 @@ func TestHandleToolsList(t *testing.T) {
|
||||
|
||||
// TestHandleRequest tests request routing.
|
||||
func TestHandleRequest(t *testing.T) {
|
||||
server := NewServer(nil, "1.0.0")
|
||||
server := NewServer(nil, "1.0.0", nil, nil, nil, nil, nil, nil, nil)
|
||||
ctx := context.Background()
|
||||
|
||||
tests := []struct {
|
||||
@@ -423,7 +423,7 @@ func TestHandleRequest(t *testing.T) {
|
||||
|
||||
// TestHandleToolsCall_InvalidParams tests tools/call with invalid params.
|
||||
func TestHandleToolsCall_InvalidParams(t *testing.T) {
|
||||
server := NewServer(nil, "1.0.0")
|
||||
server := NewServer(nil, "1.0.0", nil, nil, nil, nil, nil, nil, nil)
|
||||
ctx := context.Background()
|
||||
|
||||
req := &Request{
|
||||
@@ -442,7 +442,7 @@ func TestHandleToolsCall_InvalidParams(t *testing.T) {
|
||||
|
||||
// TestCallTool_UnknownTool tests callTool with unknown tool name.
|
||||
func TestCallTool_UnknownTool(t *testing.T) {
|
||||
server := NewServer(nil, "1.0.0")
|
||||
server := NewServer(nil, "1.0.0", nil, nil, nil, nil, nil, nil, nil)
|
||||
ctx := context.Background()
|
||||
|
||||
_, err := server.callTool(ctx, "nonexistent_tool", json.RawMessage(`{}`))
|
||||
@@ -452,7 +452,7 @@ func TestCallTool_UnknownTool(t *testing.T) {
|
||||
|
||||
// TestCallTool_InvalidArgs tests callTool with invalid arguments.
|
||||
func TestCallTool_InvalidArgs(t *testing.T) {
|
||||
server := NewServer(nil, "1.0.0")
|
||||
server := NewServer(nil, "1.0.0", nil, nil, nil, nil, nil, nil, nil)
|
||||
ctx := context.Background()
|
||||
|
||||
_, err := server.callTool(ctx, "search", json.RawMessage(`invalid json`))
|
||||
@@ -574,7 +574,7 @@ func TestJSONRPCErrorCodes(t *testing.T) {
|
||||
|
||||
// TestToolListContainsExpectedSchemas tests that tool schemas are valid.
|
||||
func TestToolListContainsExpectedSchemas(t *testing.T) {
|
||||
server := NewServer(nil, "1.0.0")
|
||||
server := NewServer(nil, "1.0.0", nil, nil, nil, nil, nil, nil, nil)
|
||||
|
||||
req := &Request{
|
||||
JSONRPC: "2.0",
|
||||
@@ -600,7 +600,7 @@ func TestToolListContainsExpectedSchemas(t *testing.T) {
|
||||
|
||||
// TestHandleToolsCall_UnknownTool tests tools/call with unknown tool name.
|
||||
func TestHandleToolsCall_UnknownTool(t *testing.T) {
|
||||
server := NewServer(nil, "1.0.0")
|
||||
server := NewServer(nil, "1.0.0", nil, nil, nil, nil, nil, nil, nil)
|
||||
ctx := context.Background()
|
||||
|
||||
req := &Request{
|
||||
@@ -620,7 +620,7 @@ func TestHandleToolsCall_UnknownTool(t *testing.T) {
|
||||
func TestCallTool_ToolNameRecognition(t *testing.T) {
|
||||
// Note: This test verifies tool routing logic, not execution (which requires searchMgr)
|
||||
// All valid tool names should be in the handleToolsList response
|
||||
server := NewServer(nil, "1.0.0")
|
||||
server := NewServer(nil, "1.0.0", nil, nil, nil, nil, nil, nil, nil)
|
||||
|
||||
req := &Request{
|
||||
JSONRPC: "2.0",
|
||||
@@ -782,7 +782,7 @@ func TestResponseIDTypes(t *testing.T) {
|
||||
|
||||
// TestHandleTimelineByQuery_EmptyQuery tests timeline by query with empty query.
|
||||
func TestHandleTimelineByQuery_EmptyQuery(t *testing.T) {
|
||||
server := NewServer(nil, "1.0.0")
|
||||
server := NewServer(nil, "1.0.0", nil, nil, nil, nil, nil, nil, nil)
|
||||
ctx := context.Background()
|
||||
|
||||
// Empty query should error
|
||||
@@ -793,7 +793,7 @@ func TestHandleTimelineByQuery_EmptyQuery(t *testing.T) {
|
||||
|
||||
// TestHandleTimeline_InvalidJSON tests timeline with invalid JSON.
|
||||
func TestHandleTimeline_InvalidJSON(t *testing.T) {
|
||||
server := NewServer(nil, "1.0.0")
|
||||
server := NewServer(nil, "1.0.0", nil, nil, nil, nil, nil, nil, nil)
|
||||
ctx := context.Background()
|
||||
|
||||
_, err := server.handleTimeline(ctx, json.RawMessage(`{invalid`))
|
||||
@@ -803,7 +803,7 @@ func TestHandleTimeline_InvalidJSON(t *testing.T) {
|
||||
|
||||
// TestHandleTimelineByQuery_InvalidJSON tests timeline by query with invalid JSON.
|
||||
func TestHandleTimelineByQuery_InvalidJSON(t *testing.T) {
|
||||
server := NewServer(nil, "1.0.0")
|
||||
server := NewServer(nil, "1.0.0", nil, nil, nil, nil, nil, nil, nil)
|
||||
ctx := context.Background()
|
||||
|
||||
_, err := server.handleTimelineByQuery(ctx, json.RawMessage(`{invalid`))
|
||||
@@ -813,7 +813,7 @@ func TestHandleTimelineByQuery_InvalidJSON(t *testing.T) {
|
||||
|
||||
// TestHandleTimeline_NoAnchorNoQuery tests timeline with no anchor and no query.
|
||||
func TestHandleTimeline_NoAnchorNoQuery(t *testing.T) {
|
||||
server := NewServer(nil, "1.0.0")
|
||||
server := NewServer(nil, "1.0.0", nil, nil, nil, nil, nil, nil, nil)
|
||||
ctx := context.Background()
|
||||
|
||||
// No anchor_id and no query should return empty result
|
||||
@@ -825,7 +825,7 @@ func TestHandleTimeline_NoAnchorNoQuery(t *testing.T) {
|
||||
|
||||
// TestHandleTimeline_WithDefaults tests timeline default values are applied.
|
||||
func TestHandleTimeline_WithDefaults(t *testing.T) {
|
||||
server := NewServer(nil, "1.0.0")
|
||||
server := NewServer(nil, "1.0.0", nil, nil, nil, nil, nil, nil, nil)
|
||||
ctx := context.Background()
|
||||
|
||||
// With anchor_id but no before/after, defaults should be applied
|
||||
@@ -839,7 +839,7 @@ func TestHandleTimeline_WithDefaults(t *testing.T) {
|
||||
|
||||
// TestServerFields tests Server struct fields.
|
||||
func TestServerFields(t *testing.T) {
|
||||
server := NewServer(nil, "2.0.0")
|
||||
server := NewServer(nil, "2.0.0", nil, nil, nil, nil, nil, nil, nil)
|
||||
|
||||
assert.Equal(t, "2.0.0", server.version)
|
||||
assert.Nil(t, server.searchMgr)
|
||||
@@ -891,7 +891,7 @@ func TestErrorWithNilData(t *testing.T) {
|
||||
|
||||
// TestToolInputSchema tests that tool input schemas have required fields.
|
||||
func TestToolInputSchema(t *testing.T) {
|
||||
server := NewServer(nil, "1.0.0")
|
||||
server := NewServer(nil, "1.0.0", nil, nil, nil, nil, nil, nil, nil)
|
||||
|
||||
req := &Request{
|
||||
JSONRPC: "2.0",
|
||||
@@ -960,7 +960,7 @@ func TestToolCallParamsWithComplexArgs(t *testing.T) {
|
||||
|
||||
// TestCallTool_UnknownToolName tests callTool with various unknown tool names.
|
||||
func TestCallTool_UnknownToolName(t *testing.T) {
|
||||
server := NewServer(nil, "1.0.0")
|
||||
server := NewServer(nil, "1.0.0", nil, nil, nil, nil, nil, nil, nil)
|
||||
ctx := context.Background()
|
||||
|
||||
unknownTools := []string{
|
||||
@@ -1009,7 +1009,7 @@ func TestTimelineParams_Validation(t *testing.T) {
|
||||
|
||||
// TestHandleToolsCall_UnknownToolNameError tests tools/call with unknown tool returns error.
|
||||
func TestHandleToolsCall_UnknownToolNameError(t *testing.T) {
|
||||
server := NewServer(nil, "1.0.0")
|
||||
server := NewServer(nil, "1.0.0", nil, nil, nil, nil, nil, nil, nil)
|
||||
ctx := context.Background()
|
||||
|
||||
req := &Request{
|
||||
@@ -1031,7 +1031,7 @@ func TestHandleToolsCall_UnknownToolNameError(t *testing.T) {
|
||||
|
||||
// TestHandleToolsCall_EmptyParams tests tools/call with empty params.
|
||||
func TestHandleToolsCall_EmptyParams(t *testing.T) {
|
||||
server := NewServer(nil, "1.0.0")
|
||||
server := NewServer(nil, "1.0.0", nil, nil, nil, nil, nil, nil, nil)
|
||||
ctx := context.Background()
|
||||
|
||||
req := &Request{
|
||||
|
||||
@@ -6,7 +6,7 @@ import (
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/lukaszraczylo/claude-mnemonic/internal/db/sqlite"
|
||||
"github.com/lukaszraczylo/claude-mnemonic/internal/db/gorm"
|
||||
"github.com/lukaszraczylo/claude-mnemonic/pkg/models"
|
||||
"github.com/rs/zerolog/log"
|
||||
)
|
||||
@@ -39,8 +39,8 @@ type PatternSyncFunc func(pattern *models.Pattern)
|
||||
// Detector detects and tracks recurring patterns across observations.
|
||||
type Detector struct {
|
||||
config DetectorConfig
|
||||
patternStore *sqlite.PatternStore
|
||||
observationStore *sqlite.ObservationStore
|
||||
patternStore *gorm.PatternStore
|
||||
observationStore *gorm.ObservationStore
|
||||
|
||||
// Vector sync callback
|
||||
syncFunc PatternSyncFunc
|
||||
@@ -71,7 +71,7 @@ type candidatePattern struct {
|
||||
}
|
||||
|
||||
// NewDetector creates a new pattern detector.
|
||||
func NewDetector(patternStore *sqlite.PatternStore, observationStore *sqlite.ObservationStore, config DetectorConfig) *Detector {
|
||||
func NewDetector(patternStore *gorm.PatternStore, observationStore *gorm.ObservationStore, config DetectorConfig) *Detector {
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
return &Detector{
|
||||
config: config,
|
||||
|
||||
@@ -7,7 +7,7 @@ import (
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/lukaszraczylo/claude-mnemonic/internal/db/sqlite"
|
||||
"github.com/lukaszraczylo/claude-mnemonic/internal/db/gorm"
|
||||
"github.com/lukaszraczylo/claude-mnemonic/pkg/models"
|
||||
)
|
||||
|
||||
@@ -15,8 +15,8 @@ func TestNewDetector(t *testing.T) {
|
||||
store := setupTestStore(t)
|
||||
defer store.Close()
|
||||
|
||||
patternStore := sqlite.NewPatternStore(store)
|
||||
observationStore := sqlite.NewObservationStore(store)
|
||||
patternStore := gorm.NewPatternStore(store)
|
||||
observationStore := gorm.NewObservationStore(store, nil, nil, nil)
|
||||
config := DefaultConfig()
|
||||
|
||||
detector := NewDetector(patternStore, observationStore, config)
|
||||
@@ -34,8 +34,8 @@ func TestDetector_StartStop(t *testing.T) {
|
||||
store := setupTestStore(t)
|
||||
defer store.Close()
|
||||
|
||||
patternStore := sqlite.NewPatternStore(store)
|
||||
observationStore := sqlite.NewObservationStore(store)
|
||||
patternStore := gorm.NewPatternStore(store)
|
||||
observationStore := gorm.NewObservationStore(store, nil, nil, nil)
|
||||
config := DefaultConfig()
|
||||
config.AnalysisInterval = 100 * time.Millisecond // Short interval for testing
|
||||
|
||||
@@ -58,8 +58,8 @@ func TestDetector_AnalyzeObservation_NewCandidate(t *testing.T) {
|
||||
store := setupTestStore(t)
|
||||
defer store.Close()
|
||||
|
||||
patternStore := sqlite.NewPatternStore(store)
|
||||
observationStore := sqlite.NewObservationStore(store)
|
||||
patternStore := gorm.NewPatternStore(store)
|
||||
observationStore := gorm.NewObservationStore(store, nil, nil, nil)
|
||||
config := DefaultConfig()
|
||||
config.MinFrequencyForPattern = 2
|
||||
|
||||
@@ -88,8 +88,8 @@ func TestDetector_AnalyzeObservation_PromoteToPattern(t *testing.T) {
|
||||
store := setupTestStore(t)
|
||||
defer store.Close()
|
||||
|
||||
patternStore := sqlite.NewPatternStore(store)
|
||||
observationStore := sqlite.NewObservationStore(store)
|
||||
patternStore := gorm.NewPatternStore(store)
|
||||
observationStore := gorm.NewObservationStore(store, nil, nil, nil)
|
||||
config := DefaultConfig()
|
||||
config.MinFrequencyForPattern = 2
|
||||
|
||||
@@ -127,8 +127,8 @@ func TestDetector_AnalyzeObservation_MatchExisting(t *testing.T) {
|
||||
store := setupTestStore(t)
|
||||
defer store.Close()
|
||||
|
||||
patternStore := sqlite.NewPatternStore(store)
|
||||
observationStore := sqlite.NewObservationStore(store)
|
||||
patternStore := gorm.NewPatternStore(store)
|
||||
observationStore := gorm.NewObservationStore(store, nil, nil, nil)
|
||||
config := DefaultConfig()
|
||||
|
||||
detector := NewDetector(patternStore, observationStore, config)
|
||||
@@ -149,7 +149,7 @@ func TestDetector_AnalyzeObservation_MatchExisting(t *testing.T) {
|
||||
CreatedAt: time.Now().Format(time.RFC3339),
|
||||
CreatedAtEpoch: time.Now().UnixMilli(),
|
||||
}
|
||||
patternStore.StorePattern(ctx, pattern)
|
||||
_, _ = patternStore.StorePattern(ctx, pattern)
|
||||
|
||||
// Create observation with similar signature
|
||||
obs := createTestObservation(10, "Nil check", []string{"nil", "error-handling"})
|
||||
@@ -175,8 +175,8 @@ func TestDetector_AnalyzeObservation_NoMatch(t *testing.T) {
|
||||
store := setupTestStore(t)
|
||||
defer store.Close()
|
||||
|
||||
patternStore := sqlite.NewPatternStore(store)
|
||||
observationStore := sqlite.NewObservationStore(store)
|
||||
patternStore := gorm.NewPatternStore(store)
|
||||
observationStore := gorm.NewObservationStore(store, nil, nil, nil)
|
||||
config := DefaultConfig()
|
||||
config.MinMatchScore = 0.5 // Higher threshold
|
||||
|
||||
@@ -198,7 +198,7 @@ func TestDetector_AnalyzeObservation_NoMatch(t *testing.T) {
|
||||
CreatedAt: time.Now().Format(time.RFC3339),
|
||||
CreatedAtEpoch: time.Now().UnixMilli(),
|
||||
}
|
||||
patternStore.StorePattern(ctx, pattern)
|
||||
_, _ = patternStore.StorePattern(ctx, pattern)
|
||||
|
||||
// Create observation with completely different signature
|
||||
obs := createTestObservation(10, "UI Component", []string{"frontend", "react", "component"})
|
||||
@@ -218,8 +218,8 @@ func TestDetector_CandidateCleanup(t *testing.T) {
|
||||
store := setupTestStore(t)
|
||||
defer store.Close()
|
||||
|
||||
patternStore := sqlite.NewPatternStore(store)
|
||||
observationStore := sqlite.NewObservationStore(store)
|
||||
patternStore := gorm.NewPatternStore(store)
|
||||
observationStore := gorm.NewObservationStore(store, nil, nil, nil)
|
||||
config := DefaultConfig()
|
||||
config.MinFrequencyForPattern = 3 // Higher threshold
|
||||
|
||||
@@ -265,8 +265,8 @@ func TestDetector_GetPatternInsight(t *testing.T) {
|
||||
store := setupTestStore(t)
|
||||
defer store.Close()
|
||||
|
||||
patternStore := sqlite.NewPatternStore(store)
|
||||
observationStore := sqlite.NewObservationStore(store)
|
||||
patternStore := gorm.NewPatternStore(store)
|
||||
observationStore := gorm.NewObservationStore(store, nil, nil, nil)
|
||||
config := DefaultConfig()
|
||||
|
||||
detector := NewDetector(patternStore, observationStore, config)
|
||||
@@ -388,7 +388,7 @@ func TestFormatPatternInsight(t *testing.T) {
|
||||
|
||||
// Helper functions
|
||||
|
||||
func setupTestStore(t *testing.T) *sqlite.Store {
|
||||
func setupTestStore(t *testing.T) *gorm.Store {
|
||||
t.Helper()
|
||||
|
||||
// Create temp database file
|
||||
@@ -402,10 +402,9 @@ func setupTestStore(t *testing.T) *sqlite.Store {
|
||||
os.Remove(tmpFile.Name())
|
||||
})
|
||||
|
||||
store, err := sqlite.NewStore(sqlite.StoreConfig{
|
||||
store, err := gorm.NewStore(gorm.Config{
|
||||
Path: tmpFile.Name(),
|
||||
MaxConns: 1,
|
||||
WALMode: true,
|
||||
})
|
||||
if err != nil {
|
||||
// Check if this is an FTS5 related error
|
||||
|
||||
@@ -297,19 +297,19 @@ func (s *Service) scoreAll(query string, candidates []Candidate) ([]float64, err
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("create input_ids tensor: %w", err)
|
||||
}
|
||||
defer inputIdsTensor.Destroy()
|
||||
defer func() { _ = inputIdsTensor.Destroy() }()
|
||||
|
||||
attentionMaskTensor, err := ort.NewTensor(inputShape, attentionMaskData)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("create attention_mask tensor: %w", err)
|
||||
}
|
||||
defer attentionMaskTensor.Destroy()
|
||||
defer func() { _ = attentionMaskTensor.Destroy() }()
|
||||
|
||||
tokenTypeIdsTensor, err := ort.NewTensor(inputShape, tokenTypeIdsData)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("create token_type_ids tensor: %w", err)
|
||||
}
|
||||
defer tokenTypeIdsTensor.Destroy()
|
||||
defer func() { _ = tokenTypeIdsTensor.Destroy() }()
|
||||
|
||||
// Cross-encoder outputs [batch, 1] logits
|
||||
outputShape := ort.NewShape(int64(batchSize), 1)
|
||||
@@ -317,7 +317,7 @@ func (s *Service) scoreAll(query string, candidates []Candidate) ([]float64, err
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("create output tensor: %w", err)
|
||||
}
|
||||
defer outputTensor.Destroy()
|
||||
defer func() { _ = outputTensor.Destroy() }()
|
||||
|
||||
// Run inference
|
||||
inputTensors := []ort.Value{inputIdsTensor, attentionMaskTensor, tokenTypeIdsTensor}
|
||||
|
||||
@@ -8,7 +8,7 @@ import (
|
||||
|
||||
"github.com/rs/zerolog"
|
||||
|
||||
"github.com/lukaszraczylo/claude-mnemonic/internal/db/sqlite"
|
||||
"github.com/lukaszraczylo/claude-mnemonic/internal/db/gorm"
|
||||
"github.com/lukaszraczylo/claude-mnemonic/pkg/models"
|
||||
)
|
||||
|
||||
@@ -183,4 +183,4 @@ func (r *Recalculator) GetStats() Stats {
|
||||
}
|
||||
|
||||
// Ensure ObservationStore satisfies the interface
|
||||
var _ ObservationStore = (*sqlite.ObservationStore)(nil)
|
||||
var _ ObservationStore = (*gorm.ObservationStore)(nil)
|
||||
|
||||
@@ -7,7 +7,7 @@ import (
|
||||
"os"
|
||||
"testing"
|
||||
|
||||
"github.com/lukaszraczylo/claude-mnemonic/internal/db/sqlite"
|
||||
"github.com/lukaszraczylo/claude-mnemonic/internal/db/gorm"
|
||||
"github.com/lukaszraczylo/claude-mnemonic/pkg/models"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
@@ -45,8 +45,8 @@ func hasFTS5(t *testing.T) bool {
|
||||
return true
|
||||
}
|
||||
|
||||
// testStore creates a sqlite.Store with a temporary database for testing.
|
||||
func testStore(t *testing.T) (*sqlite.Store, func()) {
|
||||
// testStore creates a gorm.Store with a temporary database for testing.
|
||||
func testStore(t *testing.T) (*gorm.Store, func()) {
|
||||
t.Helper()
|
||||
|
||||
if !hasFTS5(t) {
|
||||
@@ -58,10 +58,9 @@ func testStore(t *testing.T) (*sqlite.Store, func()) {
|
||||
|
||||
dbPath := tmpDir + "/test.db"
|
||||
|
||||
store, err := sqlite.NewStore(sqlite.StoreConfig{
|
||||
store, err := gorm.NewStore(gorm.Config{
|
||||
Path: dbPath,
|
||||
MaxConns: 1,
|
||||
WALMode: true,
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
@@ -76,12 +75,12 @@ func testStore(t *testing.T) (*sqlite.Store, func()) {
|
||||
// SearchIntegrationSuite tests search with real SQLite stores.
|
||||
type SearchIntegrationSuite struct {
|
||||
suite.Suite
|
||||
store *sqlite.Store
|
||||
store *gorm.Store
|
||||
cleanup func()
|
||||
manager *Manager
|
||||
obsStore *sqlite.ObservationStore
|
||||
sumStore *sqlite.SummaryStore
|
||||
prmStore *sqlite.PromptStore
|
||||
obsStore *gorm.ObservationStore
|
||||
sumStore *gorm.SummaryStore
|
||||
prmStore *gorm.PromptStore
|
||||
}
|
||||
|
||||
func (s *SearchIntegrationSuite) SetupTest() {
|
||||
@@ -92,9 +91,9 @@ func (s *SearchIntegrationSuite) SetupTest() {
|
||||
s.store, s.cleanup = testStore(s.T())
|
||||
|
||||
// Create real stores backed by SQLite
|
||||
s.obsStore = sqlite.NewObservationStore(s.store)
|
||||
s.sumStore = sqlite.NewSummaryStore(s.store)
|
||||
s.prmStore = sqlite.NewPromptStore(s.store)
|
||||
s.obsStore = gorm.NewObservationStore(s.store, nil, nil, nil)
|
||||
s.sumStore = gorm.NewSummaryStore(s.store)
|
||||
s.prmStore = gorm.NewPromptStore(s.store, nil)
|
||||
|
||||
// Create search manager with real stores (no vector client for now)
|
||||
s.manager = NewManager(s.obsStore, s.sumStore, s.prmStore, nil)
|
||||
@@ -491,7 +490,7 @@ func (s *SearchIntegrationSuite) TestSummaryToResult_FullFormat() {
|
||||
func (s *SearchIntegrationSuite) TestPromptToResult_FullFormat() {
|
||||
// First create a session
|
||||
ctx := context.Background()
|
||||
sessionStore := sqlite.NewSessionStore(s.store)
|
||||
sessionStore := gorm.NewSessionStore(s.store)
|
||||
_, err := sessionStore.CreateSDKSession(ctx, "sdk-prompt-test", "test-project", "initial prompt")
|
||||
s.Require().NoError(err)
|
||||
|
||||
|
||||
@@ -5,24 +5,24 @@ import (
|
||||
"context"
|
||||
"strings"
|
||||
|
||||
"github.com/lukaszraczylo/claude-mnemonic/internal/db/sqlite"
|
||||
"github.com/lukaszraczylo/claude-mnemonic/internal/db/gorm"
|
||||
"github.com/lukaszraczylo/claude-mnemonic/internal/vector/sqlitevec"
|
||||
"github.com/lukaszraczylo/claude-mnemonic/pkg/models"
|
||||
)
|
||||
|
||||
// Manager provides unified search across SQLite and sqlite-vec.
|
||||
type Manager struct {
|
||||
observationStore *sqlite.ObservationStore
|
||||
summaryStore *sqlite.SummaryStore
|
||||
promptStore *sqlite.PromptStore
|
||||
observationStore *gorm.ObservationStore
|
||||
summaryStore *gorm.SummaryStore
|
||||
promptStore *gorm.PromptStore
|
||||
vectorClient *sqlitevec.Client
|
||||
}
|
||||
|
||||
// NewManager creates a new search manager.
|
||||
func NewManager(
|
||||
observationStore *sqlite.ObservationStore,
|
||||
summaryStore *sqlite.SummaryStore,
|
||||
promptStore *sqlite.PromptStore,
|
||||
observationStore *gorm.ObservationStore,
|
||||
summaryStore *gorm.SummaryStore,
|
||||
promptStore *gorm.PromptStore,
|
||||
vectorClient *sqlitevec.Client,
|
||||
) *Manager {
|
||||
return &Manager{
|
||||
|
||||
@@ -501,3 +501,254 @@ func TestClient_DeleteDocuments_NonExistent(t *testing.T) {
|
||||
err = client.DeleteDocuments(context.Background(), []string{"non-existent-id"})
|
||||
require.NoError(t, err)
|
||||
}
|
||||
|
||||
func TestClient_Count_Empty(t *testing.T) {
|
||||
db, dbCleanup := testDB(t)
|
||||
defer dbCleanup()
|
||||
|
||||
embedSvc, embedCleanup := testEmbeddingService(t)
|
||||
defer embedCleanup()
|
||||
|
||||
client, err := NewClient(Config{DB: db}, embedSvc)
|
||||
require.NoError(t, err)
|
||||
|
||||
count, err := client.Count(context.Background())
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, int64(0), count)
|
||||
}
|
||||
|
||||
func TestClient_Count_WithVectors(t *testing.T) {
|
||||
db, dbCleanup := testDB(t)
|
||||
defer dbCleanup()
|
||||
|
||||
embedSvc, embedCleanup := testEmbeddingService(t)
|
||||
defer embedCleanup()
|
||||
|
||||
client, err := NewClient(Config{DB: db}, embedSvc)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Add some documents
|
||||
docs := []Document{
|
||||
{ID: "doc-1", Content: "test content 1"},
|
||||
{ID: "doc-2", Content: "test content 2"},
|
||||
{ID: "doc-3", Content: "test content 3"},
|
||||
}
|
||||
err = client.AddDocuments(context.Background(), docs)
|
||||
require.NoError(t, err)
|
||||
|
||||
count, err := client.Count(context.Background())
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, int64(3), count)
|
||||
}
|
||||
|
||||
func TestClient_ModelVersion(t *testing.T) {
|
||||
db, dbCleanup := testDB(t)
|
||||
defer dbCleanup()
|
||||
|
||||
embedSvc, embedCleanup := testEmbeddingService(t)
|
||||
defer embedCleanup()
|
||||
|
||||
client, err := NewClient(Config{DB: db}, embedSvc)
|
||||
require.NoError(t, err)
|
||||
|
||||
version := client.ModelVersion()
|
||||
assert.NotEmpty(t, version)
|
||||
// Should match the embedding service version
|
||||
assert.Equal(t, embedSvc.Version(), version)
|
||||
}
|
||||
|
||||
func TestClient_NeedsRebuild_EmptyDatabase(t *testing.T) {
|
||||
db, dbCleanup := testDB(t)
|
||||
defer dbCleanup()
|
||||
|
||||
embedSvc, embedCleanup := testEmbeddingService(t)
|
||||
defer embedCleanup()
|
||||
|
||||
client, err := NewClient(Config{DB: db}, embedSvc)
|
||||
require.NoError(t, err)
|
||||
|
||||
needsRebuild, reason := client.NeedsRebuild(context.Background())
|
||||
assert.True(t, needsRebuild)
|
||||
assert.Equal(t, "empty", reason)
|
||||
}
|
||||
|
||||
func TestClient_NeedsRebuild_ModelMismatch(t *testing.T) {
|
||||
db, dbCleanup := testDB(t)
|
||||
defer dbCleanup()
|
||||
|
||||
embedSvc, embedCleanup := testEmbeddingService(t)
|
||||
defer embedCleanup()
|
||||
|
||||
client, err := NewClient(Config{DB: db}, embedSvc)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Insert vectors with wrong model version
|
||||
embedding := make([]float32, 384)
|
||||
for i := range embedding {
|
||||
embedding[i] = 0.1
|
||||
}
|
||||
embeddingBytes, err := sqlite_vec.SerializeFloat32(embedding)
|
||||
require.NoError(t, err)
|
||||
|
||||
_, err = db.Exec(`
|
||||
INSERT INTO vectors (doc_id, embedding, model_version, sqlite_id, doc_type, field_type, project, scope)
|
||||
VALUES (?, ?, ?, ?, ?, ?, ?, ?)
|
||||
`, "doc-1", embeddingBytes, "old-model-v1", 1, "observation", "content", "test", "project")
|
||||
require.NoError(t, err)
|
||||
|
||||
_, err = db.Exec(`
|
||||
INSERT INTO vectors (doc_id, embedding, model_version, sqlite_id, doc_type, field_type, project, scope)
|
||||
VALUES (?, ?, ?, ?, ?, ?, ?, ?)
|
||||
`, "doc-2", embeddingBytes, "old-model-v1", 2, "observation", "content", "test", "project")
|
||||
require.NoError(t, err)
|
||||
|
||||
needsRebuild, reason := client.NeedsRebuild(context.Background())
|
||||
assert.True(t, needsRebuild)
|
||||
assert.Contains(t, reason, "model_mismatch:2")
|
||||
}
|
||||
|
||||
func TestClient_NeedsRebuild_CurrentModel(t *testing.T) {
|
||||
db, dbCleanup := testDB(t)
|
||||
defer dbCleanup()
|
||||
|
||||
embedSvc, embedCleanup := testEmbeddingService(t)
|
||||
defer embedCleanup()
|
||||
|
||||
client, err := NewClient(Config{DB: db}, embedSvc)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Add documents with current model version
|
||||
docs := []Document{
|
||||
{ID: "doc-1", Content: "test content 1"},
|
||||
{ID: "doc-2", Content: "test content 2"},
|
||||
}
|
||||
err = client.AddDocuments(context.Background(), docs)
|
||||
require.NoError(t, err)
|
||||
|
||||
needsRebuild, reason := client.NeedsRebuild(context.Background())
|
||||
assert.False(t, needsRebuild)
|
||||
assert.Empty(t, reason)
|
||||
}
|
||||
|
||||
func TestClient_GetStaleVectors_Empty(t *testing.T) {
|
||||
db, dbCleanup := testDB(t)
|
||||
defer dbCleanup()
|
||||
|
||||
embedSvc, embedCleanup := testEmbeddingService(t)
|
||||
defer embedCleanup()
|
||||
|
||||
client, err := NewClient(Config{DB: db}, embedSvc)
|
||||
require.NoError(t, err)
|
||||
|
||||
stale, err := client.GetStaleVectors(context.Background())
|
||||
require.NoError(t, err)
|
||||
assert.Empty(t, stale)
|
||||
}
|
||||
|
||||
func TestClient_GetStaleVectors_WithMismatch(t *testing.T) {
|
||||
db, dbCleanup := testDB(t)
|
||||
defer dbCleanup()
|
||||
|
||||
embedSvc, embedCleanup := testEmbeddingService(t)
|
||||
defer embedCleanup()
|
||||
|
||||
client, err := NewClient(Config{DB: db}, embedSvc)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Insert vectors with wrong model version
|
||||
embedding := make([]float32, 384)
|
||||
embeddingBytes, err := sqlite_vec.SerializeFloat32(embedding)
|
||||
require.NoError(t, err)
|
||||
|
||||
_, err = db.Exec(`
|
||||
INSERT INTO vectors (doc_id, embedding, model_version, sqlite_id, doc_type, field_type, project, scope)
|
||||
VALUES (?, ?, ?, ?, ?, ?, ?, ?)
|
||||
`, "doc-1", embeddingBytes, "old-model", 1, "observation", "content", "project-1", "project")
|
||||
require.NoError(t, err)
|
||||
|
||||
_, err = db.Exec(`
|
||||
INSERT INTO vectors (doc_id, embedding, model_version, sqlite_id, doc_type, field_type, project, scope)
|
||||
VALUES (?, ?, ?, ?, ?, ?, ?, ?)
|
||||
`, "doc-2", embeddingBytes, embedSvc.Version(), 2, "observation", "title", "project-1", "project")
|
||||
require.NoError(t, err)
|
||||
|
||||
stale, err := client.GetStaleVectors(context.Background())
|
||||
require.NoError(t, err)
|
||||
assert.Len(t, stale, 1)
|
||||
assert.Equal(t, "doc-1", stale[0].DocID)
|
||||
assert.Equal(t, int64(1), stale[0].SQLiteID)
|
||||
assert.Equal(t, "observation", stale[0].DocType)
|
||||
assert.Equal(t, "content", stale[0].FieldType)
|
||||
assert.Equal(t, "project-1", stale[0].Project)
|
||||
assert.Equal(t, "project", stale[0].Scope)
|
||||
}
|
||||
|
||||
func TestClient_DeleteVectorsByDocIDs_Empty(t *testing.T) {
|
||||
db, dbCleanup := testDB(t)
|
||||
defer dbCleanup()
|
||||
|
||||
embedSvc, embedCleanup := testEmbeddingService(t)
|
||||
defer embedCleanup()
|
||||
|
||||
client, err := NewClient(Config{DB: db}, embedSvc)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Deleting empty slice should not error
|
||||
err = client.DeleteVectorsByDocIDs(context.Background(), []string{})
|
||||
require.NoError(t, err)
|
||||
}
|
||||
|
||||
func TestClient_DeleteVectorsByDocIDs_Success(t *testing.T) {
|
||||
db, dbCleanup := testDB(t)
|
||||
defer dbCleanup()
|
||||
|
||||
embedSvc, embedCleanup := testEmbeddingService(t)
|
||||
defer embedCleanup()
|
||||
|
||||
client, err := NewClient(Config{DB: db}, embedSvc)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Add documents
|
||||
docs := []Document{
|
||||
{ID: "doc-1", Content: "test 1"},
|
||||
{ID: "doc-2", Content: "test 2"},
|
||||
{ID: "doc-3", Content: "test 3"},
|
||||
}
|
||||
err = client.AddDocuments(context.Background(), docs)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Verify 3 documents exist
|
||||
count, err := client.Count(context.Background())
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, int64(3), count)
|
||||
|
||||
// Delete doc-1 and doc-3
|
||||
err = client.DeleteVectorsByDocIDs(context.Background(), []string{"doc-1", "doc-3"})
|
||||
require.NoError(t, err)
|
||||
|
||||
// Should have 1 document remaining
|
||||
count, err = client.Count(context.Background())
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, int64(1), count)
|
||||
|
||||
// Verify doc-2 still exists
|
||||
var exists int
|
||||
err = db.QueryRow("SELECT COUNT(*) FROM vectors WHERE doc_id = ?", "doc-2").Scan(&exists)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, 1, exists)
|
||||
}
|
||||
|
||||
func TestClient_DeleteVectorsByDocIDs_NonExistent(t *testing.T) {
|
||||
db, dbCleanup := testDB(t)
|
||||
defer dbCleanup()
|
||||
|
||||
embedSvc, embedCleanup := testEmbeddingService(t)
|
||||
defer embedCleanup()
|
||||
|
||||
client, err := NewClient(Config{DB: db}, embedSvc)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Deleting non-existent IDs should not error
|
||||
err = client.DeleteVectorsByDocIDs(context.Background(), []string{"non-existent-1", "non-existent-2"})
|
||||
require.NoError(t, err)
|
||||
}
|
||||
|
||||
@@ -572,3 +572,119 @@ func TestExtractedIDs_Empty(t *testing.T) {
|
||||
assert.Nil(t, ids.SummaryIDs)
|
||||
assert.Nil(t, ids.PromptIDs)
|
||||
}
|
||||
|
||||
func TestFilterByThreshold(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
results []QueryResult
|
||||
threshold float64
|
||||
maxResults int
|
||||
expectedLen int
|
||||
expectedIDs []string
|
||||
}{
|
||||
{
|
||||
name: "empty_results",
|
||||
results: []QueryResult{},
|
||||
threshold: 0.5,
|
||||
maxResults: 0,
|
||||
expectedLen: 0,
|
||||
},
|
||||
{
|
||||
name: "all_above_threshold",
|
||||
results: []QueryResult{
|
||||
{ID: "doc-1", Similarity: 0.9},
|
||||
{ID: "doc-2", Similarity: 0.8},
|
||||
{ID: "doc-3", Similarity: 0.7},
|
||||
},
|
||||
threshold: 0.5,
|
||||
maxResults: 0,
|
||||
expectedLen: 3,
|
||||
expectedIDs: []string{"doc-1", "doc-2", "doc-3"},
|
||||
},
|
||||
{
|
||||
name: "some_below_threshold",
|
||||
results: []QueryResult{
|
||||
{ID: "doc-1", Similarity: 0.9},
|
||||
{ID: "doc-2", Similarity: 0.4},
|
||||
{ID: "doc-3", Similarity: 0.7},
|
||||
{ID: "doc-4", Similarity: 0.3},
|
||||
},
|
||||
threshold: 0.5,
|
||||
maxResults: 0,
|
||||
expectedLen: 2,
|
||||
expectedIDs: []string{"doc-1", "doc-3"},
|
||||
},
|
||||
{
|
||||
name: "all_below_threshold",
|
||||
results: []QueryResult{
|
||||
{ID: "doc-1", Similarity: 0.3},
|
||||
{ID: "doc-2", Similarity: 0.2},
|
||||
},
|
||||
threshold: 0.5,
|
||||
maxResults: 0,
|
||||
expectedLen: 0,
|
||||
},
|
||||
{
|
||||
name: "max_results_limit",
|
||||
results: []QueryResult{
|
||||
{ID: "doc-1", Similarity: 0.9},
|
||||
{ID: "doc-2", Similarity: 0.8},
|
||||
{ID: "doc-3", Similarity: 0.7},
|
||||
{ID: "doc-4", Similarity: 0.6},
|
||||
},
|
||||
threshold: 0.5,
|
||||
maxResults: 2,
|
||||
expectedLen: 2,
|
||||
expectedIDs: []string{"doc-1", "doc-2"},
|
||||
},
|
||||
{
|
||||
name: "max_results_with_threshold",
|
||||
results: []QueryResult{
|
||||
{ID: "doc-1", Similarity: 0.9},
|
||||
{ID: "doc-2", Similarity: 0.3},
|
||||
{ID: "doc-3", Similarity: 0.8},
|
||||
{ID: "doc-4", Similarity: 0.2},
|
||||
{ID: "doc-5", Similarity: 0.7},
|
||||
},
|
||||
threshold: 0.5,
|
||||
maxResults: 2,
|
||||
expectedLen: 2,
|
||||
expectedIDs: []string{"doc-1", "doc-3"},
|
||||
},
|
||||
{
|
||||
name: "exact_threshold_included",
|
||||
results: []QueryResult{
|
||||
{ID: "doc-1", Similarity: 0.5},
|
||||
{ID: "doc-2", Similarity: 0.4},
|
||||
},
|
||||
threshold: 0.5,
|
||||
maxResults: 0,
|
||||
expectedLen: 1,
|
||||
expectedIDs: []string{"doc-1"},
|
||||
},
|
||||
{
|
||||
name: "zero_threshold",
|
||||
results: []QueryResult{
|
||||
{ID: "doc-1", Similarity: 0.1},
|
||||
{ID: "doc-2", Similarity: 0.0},
|
||||
},
|
||||
threshold: 0.0,
|
||||
maxResults: 0,
|
||||
expectedLen: 2,
|
||||
expectedIDs: []string{"doc-1", "doc-2"},
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
filtered := FilterByThreshold(tt.results, tt.threshold, tt.maxResults)
|
||||
assert.Len(t, filtered, tt.expectedLen)
|
||||
|
||||
if tt.expectedLen > 0 {
|
||||
for i, id := range tt.expectedIDs {
|
||||
assert.Equal(t, id, filtered[i].ID)
|
||||
}
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
@@ -11,7 +11,7 @@ import (
|
||||
"time"
|
||||
|
||||
"github.com/go-chi/chi/v5"
|
||||
"github.com/lukaszraczylo/claude-mnemonic/internal/db/sqlite"
|
||||
"github.com/lukaszraczylo/claude-mnemonic/internal/db/gorm"
|
||||
"github.com/lukaszraczylo/claude-mnemonic/internal/embedding"
|
||||
"github.com/lukaszraczylo/claude-mnemonic/internal/privacy"
|
||||
"github.com/lukaszraczylo/claude-mnemonic/internal/reranking"
|
||||
@@ -486,7 +486,7 @@ func (s *Service) handleSummarize(w http.ResponseWriter, r *http.Request) {
|
||||
// handleGetObservations returns recent observations.
|
||||
// Supports optional query parameter for semantic search via sqlite-vec.
|
||||
func (s *Service) handleGetObservations(w http.ResponseWriter, r *http.Request) {
|
||||
limit := sqlite.ParseLimitParam(r, DefaultObservationsLimit)
|
||||
limit := gorm.ParseLimitParam(r, DefaultObservationsLimit)
|
||||
project := r.URL.Query().Get("project")
|
||||
query := r.URL.Query().Get("query")
|
||||
|
||||
@@ -535,7 +535,7 @@ func (s *Service) handleGetObservations(w http.ResponseWriter, r *http.Request)
|
||||
// handleGetSummaries returns recent summaries.
|
||||
// Supports optional query parameter for semantic search via sqlite-vec.
|
||||
func (s *Service) handleGetSummaries(w http.ResponseWriter, r *http.Request) {
|
||||
limit := sqlite.ParseLimitParam(r, DefaultSummariesLimit)
|
||||
limit := gorm.ParseLimitParam(r, DefaultSummariesLimit)
|
||||
project := r.URL.Query().Get("project")
|
||||
query := r.URL.Query().Get("query")
|
||||
|
||||
@@ -582,7 +582,7 @@ func (s *Service) handleGetSummaries(w http.ResponseWriter, r *http.Request) {
|
||||
// handleGetPrompts returns recent user prompts.
|
||||
// Supports optional query parameter for semantic search via sqlite-vec.
|
||||
func (s *Service) handleGetPrompts(w http.ResponseWriter, r *http.Request) {
|
||||
limit := sqlite.ParseLimitParam(r, DefaultPromptsLimit)
|
||||
limit := gorm.ParseLimitParam(r, DefaultPromptsLimit)
|
||||
project := r.URL.Query().Get("project")
|
||||
query := r.URL.Query().Get("query")
|
||||
|
||||
@@ -743,7 +743,7 @@ func (s *Service) handleSearchByPrompt(w http.ResponseWriter, r *http.Request) {
|
||||
return
|
||||
}
|
||||
|
||||
limit := sqlite.ParseLimitParam(r, DefaultSearchLimit)
|
||||
limit := gorm.ParseLimitParam(r, DefaultSearchLimit)
|
||||
|
||||
var observations []*models.Observation
|
||||
var err error
|
||||
|
||||
@@ -68,6 +68,7 @@ func (s *Service) handleObservationFeedback(w http.ResponseWriter, r *http.Reque
|
||||
if err := observationStore.UpdateImportanceScore(r.Context(), id, newScore); err != nil {
|
||||
// Log but don't fail - feedback was recorded
|
||||
// Score will be updated on next recalculation cycle
|
||||
_ = err // Explicitly ignore - non-critical operation
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -263,6 +264,7 @@ func (s *Service) handleUpdateConceptWeight(w http.ResponseWriter, r *http.Reque
|
||||
if recalculator != nil {
|
||||
if err := recalculator.RefreshConceptWeights(r.Context()); err != nil {
|
||||
// Log but don't fail - weight was saved
|
||||
_ = err // Explicitly ignore - non-critical operation
|
||||
}
|
||||
}
|
||||
|
||||
@@ -310,6 +312,7 @@ func (s *Service) handleTriggerRecalculation(w http.ResponseWriter, r *http.Requ
|
||||
go func() {
|
||||
if err := recalculator.RecalculateNow(r.Context()); err != nil {
|
||||
// Log error but don't block response
|
||||
_ = err // Explicitly ignore - background operation
|
||||
}
|
||||
}()
|
||||
|
||||
@@ -349,6 +352,7 @@ func (s *Service) incrementRetrievalCounts(ids []int64) {
|
||||
|
||||
if err := store.IncrementRetrievalCount(ctx, ids); err != nil {
|
||||
// Log but don't fail - this is a background operation
|
||||
_ = err // Explicitly ignore - background operation
|
||||
}
|
||||
}()
|
||||
}
|
||||
|
||||
@@ -15,7 +15,7 @@ import (
|
||||
|
||||
"github.com/go-chi/chi/v5"
|
||||
"github.com/lukaszraczylo/claude-mnemonic/internal/config"
|
||||
"github.com/lukaszraczylo/claude-mnemonic/internal/db/sqlite"
|
||||
"github.com/lukaszraczylo/claude-mnemonic/internal/db/gorm"
|
||||
"github.com/lukaszraczylo/claude-mnemonic/internal/update"
|
||||
"github.com/lukaszraczylo/claude-mnemonic/internal/worker/session"
|
||||
"github.com/lukaszraczylo/claude-mnemonic/internal/worker/sse"
|
||||
@@ -32,10 +32,10 @@ func testService(t *testing.T) (*Service, func()) {
|
||||
store, dbCleanup := testStore(t)
|
||||
|
||||
// Create store wrappers
|
||||
sessionStore := sqlite.NewSessionStore(store)
|
||||
observationStore := sqlite.NewObservationStore(store)
|
||||
summaryStore := sqlite.NewSummaryStore(store)
|
||||
promptStore := sqlite.NewPromptStore(store)
|
||||
sessionStore := gorm.NewSessionStore(store)
|
||||
observationStore := gorm.NewObservationStore(store, nil, nil, nil)
|
||||
summaryStore := gorm.NewSummaryStore(store)
|
||||
promptStore := gorm.NewPromptStore(store, nil)
|
||||
|
||||
// Create domain services
|
||||
sessionManager := session.NewManager(sessionStore)
|
||||
@@ -83,7 +83,7 @@ func testService(t *testing.T) (*Service, func()) {
|
||||
}
|
||||
|
||||
// createTestObservation creates a test observation in the database.
|
||||
func createTestObservation(t *testing.T, store *sqlite.ObservationStore, project, title, narrative string, concepts []string) int64 {
|
||||
func createTestObservation(t *testing.T, store *gorm.ObservationStore, project, title, narrative string, concepts []string) int64 {
|
||||
t.Helper()
|
||||
|
||||
obs := &models.ParsedObservation{
|
||||
@@ -530,7 +530,7 @@ func TestRequireReadyMiddleware_Blocks(t *testing.T) {
|
||||
|
||||
handler := svc.requireReady(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
w.WriteHeader(http.StatusOK)
|
||||
w.Write([]byte("success"))
|
||||
_, _ = w.Write([]byte("success"))
|
||||
}))
|
||||
|
||||
req := httptest.NewRequest(http.MethodGet, "/test", nil)
|
||||
@@ -549,7 +549,7 @@ func TestRequireReadyMiddleware_Allows(t *testing.T) {
|
||||
|
||||
handler := svc.requireReady(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
w.WriteHeader(http.StatusOK)
|
||||
w.Write([]byte("success"))
|
||||
_, _ = w.Write([]byte("success"))
|
||||
}))
|
||||
|
||||
req := httptest.NewRequest(http.MethodGet, "/test", nil)
|
||||
@@ -669,9 +669,9 @@ func TestHandleGetProjects(t *testing.T) {
|
||||
|
||||
// Create sessions for different projects
|
||||
ctx := context.Background()
|
||||
svc.sessionStore.CreateSDKSession(ctx, "session-1", "project-alpha", "")
|
||||
svc.sessionStore.CreateSDKSession(ctx, "session-2", "project-beta", "")
|
||||
svc.sessionStore.CreateSDKSession(ctx, "session-3", "project-gamma", "")
|
||||
_, _ = svc.sessionStore.CreateSDKSession(ctx, "session-1", "project-alpha", "")
|
||||
_, _ = svc.sessionStore.CreateSDKSession(ctx, "session-2", "project-beta", "")
|
||||
_, _ = svc.sessionStore.CreateSDKSession(ctx, "session-3", "project-gamma", "")
|
||||
|
||||
req := httptest.NewRequest(http.MethodGet, "/api/projects", nil)
|
||||
rec := httptest.NewRecorder()
|
||||
@@ -761,7 +761,7 @@ func TestHandleGetPrompts(t *testing.T) {
|
||||
|
||||
// Create sessions and prompts
|
||||
ctx := context.Background()
|
||||
svc.sessionStore.CreateSDKSession(ctx, "claude-test", "project-x", "")
|
||||
_, _ = svc.sessionStore.CreateSDKSession(ctx, "claude-test", "project-x", "")
|
||||
|
||||
// Save prompts
|
||||
for i := 0; i < 5; i++ {
|
||||
@@ -958,7 +958,7 @@ func TestHandleSessionInit_DuplicatePrompt(t *testing.T) {
|
||||
|
||||
assert.Equal(t, http.StatusOK, rec1.Code)
|
||||
var resp1 SessionInitResponse
|
||||
json.Unmarshal(rec1.Body.Bytes(), &resp1)
|
||||
_ = json.Unmarshal(rec1.Body.Bytes(), &resp1)
|
||||
|
||||
// Second request with same prompt (duplicate)
|
||||
body2, _ := json.Marshal(reqBody)
|
||||
@@ -969,7 +969,7 @@ func TestHandleSessionInit_DuplicatePrompt(t *testing.T) {
|
||||
|
||||
assert.Equal(t, http.StatusOK, rec2.Code)
|
||||
var resp2 SessionInitResponse
|
||||
json.Unmarshal(rec2.Body.Bytes(), &resp2)
|
||||
_ = json.Unmarshal(rec2.Body.Bytes(), &resp2)
|
||||
|
||||
// Should return same prompt number (duplicate detected)
|
||||
assert.Equal(t, resp1.PromptNumber, resp2.PromptNumber)
|
||||
@@ -1095,7 +1095,7 @@ func TestHandleObservation_WithExistingSession(t *testing.T) {
|
||||
|
||||
// Create a session first
|
||||
ctx := context.Background()
|
||||
svc.sessionStore.CreateSDKSession(ctx, "claude-obs-test", "test-project", "test prompt")
|
||||
_, _ = svc.sessionStore.CreateSDKSession(ctx, "claude-obs-test", "test-project", "test prompt")
|
||||
|
||||
reqBody := ObservationRequest{
|
||||
ClaudeSessionID: "claude-obs-test",
|
||||
@@ -1190,7 +1190,7 @@ func TestHandleGetSummaries_DefaultLimit(t *testing.T) {
|
||||
// Create more than default limit
|
||||
for i := 0; i < 60; i++ {
|
||||
parsed := &models.ParsedSummary{Request: "Request " + strconv.Itoa(i)}
|
||||
svc.summaryStore.StoreSummary(ctx, "sdk-"+strconv.Itoa(i), "project-sum", parsed, i+1, 100)
|
||||
_, _, _ = svc.summaryStore.StoreSummary(ctx, "sdk-"+strconv.Itoa(i), "project-sum", parsed, i+1, 100)
|
||||
}
|
||||
|
||||
req := httptest.NewRequest(http.MethodGet, "/api/summaries", nil)
|
||||
@@ -1212,11 +1212,11 @@ func TestHandleGetPrompts_DefaultLimit(t *testing.T) {
|
||||
defer cleanup()
|
||||
|
||||
ctx := context.Background()
|
||||
svc.sessionStore.CreateSDKSession(ctx, "claude-prompts", "project-prompts", "")
|
||||
_, _ = svc.sessionStore.CreateSDKSession(ctx, "claude-prompts", "project-prompts", "")
|
||||
|
||||
// Create more than default limit
|
||||
for i := 0; i < 120; i++ {
|
||||
svc.promptStore.SaveUserPromptWithMatches(ctx, "claude-prompts", i+1, "Prompt "+strconv.Itoa(i), 0)
|
||||
_, _ = svc.promptStore.SaveUserPromptWithMatches(ctx, "claude-prompts", i+1, "Prompt "+strconv.Itoa(i), 0)
|
||||
}
|
||||
|
||||
req := httptest.NewRequest(http.MethodGet, "/api/prompts", nil)
|
||||
@@ -1475,7 +1475,7 @@ func TestHandleGetSessionByClaudeID(t *testing.T) {
|
||||
|
||||
// Create a session
|
||||
ctx := context.Background()
|
||||
svc.sessionStore.CreateSDKSession(ctx, "claude-test-123", "project-a", "prompt 1")
|
||||
_, _ = svc.sessionStore.CreateSDKSession(ctx, "claude-test-123", "project-a", "prompt 1")
|
||||
|
||||
// Test with valid claudeSessionId
|
||||
req := httptest.NewRequest(http.MethodGet, "/api/sessions?claudeSessionId=claude-test-123", nil)
|
||||
@@ -1870,7 +1870,7 @@ func TestHandleSubagentComplete_WithSession(t *testing.T) {
|
||||
|
||||
// Create a session first
|
||||
ctx := context.Background()
|
||||
svc.sessionStore.CreateSDKSession(ctx, "subagent-claude-123", "test-project", "test prompt")
|
||||
_, _ = svc.sessionStore.CreateSDKSession(ctx, "subagent-claude-123", "test-project", "test prompt")
|
||||
|
||||
reqBody := SubagentCompleteRequest{
|
||||
ClaudeSessionID: "subagent-claude-123",
|
||||
@@ -2063,7 +2063,7 @@ func TestHandleObservation_WithFullData(t *testing.T) {
|
||||
|
||||
// Create a session first
|
||||
ctx := context.Background()
|
||||
svc.sessionStore.CreateSDKSession(ctx, "obs-full-test", "test-project", "test prompt")
|
||||
_, _ = svc.sessionStore.CreateSDKSession(ctx, "obs-full-test", "test-project", "test prompt")
|
||||
|
||||
reqBody := ObservationRequest{
|
||||
ClaudeSessionID: "obs-full-test",
|
||||
@@ -2118,7 +2118,7 @@ func TestHandleGetSummaries_NoProject(t *testing.T) {
|
||||
ctx := context.Background()
|
||||
for i := 0; i < 3; i++ {
|
||||
parsed := &models.ParsedSummary{Request: "Request " + string(rune('A'+i))}
|
||||
svc.summaryStore.StoreSummary(ctx, "sdk-"+string(rune('a'+i)), "project-"+string(rune('a'+i)), parsed, i+1, 100)
|
||||
_, _, _ = svc.summaryStore.StoreSummary(ctx, "sdk-"+string(rune('a'+i)), "project-"+string(rune('a'+i)), parsed, i+1, 100)
|
||||
}
|
||||
|
||||
req := httptest.NewRequest(http.MethodGet, "/api/summaries", nil)
|
||||
@@ -2143,11 +2143,11 @@ func TestHandleGetPrompts_NoProject(t *testing.T) {
|
||||
|
||||
// Create sessions and prompts in different projects
|
||||
ctx := context.Background()
|
||||
svc.sessionStore.CreateSDKSession(ctx, "claude-prompts-a", "project-a", "")
|
||||
svc.sessionStore.CreateSDKSession(ctx, "claude-prompts-b", "project-b", "")
|
||||
_, _ = svc.sessionStore.CreateSDKSession(ctx, "claude-prompts-a", "project-a", "")
|
||||
_, _ = svc.sessionStore.CreateSDKSession(ctx, "claude-prompts-b", "project-b", "")
|
||||
|
||||
svc.promptStore.SaveUserPromptWithMatches(ctx, "claude-prompts-a", 1, "Prompt A", 0)
|
||||
svc.promptStore.SaveUserPromptWithMatches(ctx, "claude-prompts-b", 1, "Prompt B", 0)
|
||||
_, _ = svc.promptStore.SaveUserPromptWithMatches(ctx, "claude-prompts-a", 1, "Prompt A", 0)
|
||||
_, _ = svc.promptStore.SaveUserPromptWithMatches(ctx, "claude-prompts-b", 1, "Prompt B", 0)
|
||||
|
||||
req := httptest.NewRequest(http.MethodGet, "/api/prompts", nil)
|
||||
rec := httptest.NewRecorder()
|
||||
@@ -2798,13 +2798,17 @@ func TestHandleUpdateApply_NoUpdateAvailable(t *testing.T) {
|
||||
|
||||
svc.router.ServeHTTP(rec, req)
|
||||
|
||||
assert.Equal(t, http.StatusOK, rec.Code)
|
||||
|
||||
var response map[string]interface{}
|
||||
err := json.Unmarshal(rec.Body.Bytes(), &response)
|
||||
require.NoError(t, err)
|
||||
// Update check may succeed or fail - both are valid behaviors
|
||||
assert.NotNil(t, response)
|
||||
// Update check may succeed (200) or fail (500) depending on network/GitHub availability
|
||||
// Both are valid in test environment
|
||||
if rec.Code == http.StatusOK {
|
||||
var response map[string]interface{}
|
||||
err := json.Unmarshal(rec.Body.Bytes(), &response)
|
||||
require.NoError(t, err)
|
||||
assert.NotNil(t, response)
|
||||
} else {
|
||||
// If it fails, that's also acceptable in test environment (no network/GitHub access)
|
||||
assert.Equal(t, http.StatusInternalServerError, rec.Code)
|
||||
}
|
||||
}
|
||||
|
||||
// TestHandleGetObservations_WithQuery tests observations with query parameter.
|
||||
|
||||
@@ -14,7 +14,7 @@ import (
|
||||
json "github.com/goccy/go-json"
|
||||
|
||||
"github.com/lukaszraczylo/claude-mnemonic/internal/config"
|
||||
"github.com/lukaszraczylo/claude-mnemonic/internal/db/sqlite"
|
||||
"github.com/lukaszraczylo/claude-mnemonic/internal/db/gorm"
|
||||
"github.com/lukaszraczylo/claude-mnemonic/pkg/models"
|
||||
"github.com/lukaszraczylo/claude-mnemonic/pkg/similarity"
|
||||
"github.com/rs/zerolog/log"
|
||||
@@ -33,8 +33,8 @@ type SyncSummaryFunc func(summary *models.SessionSummary)
|
||||
type Processor struct {
|
||||
claudePath string
|
||||
model string
|
||||
observationStore *sqlite.ObservationStore
|
||||
summaryStore *sqlite.SummaryStore
|
||||
observationStore *gorm.ObservationStore
|
||||
summaryStore *gorm.SummaryStore
|
||||
broadcastFunc BroadcastFunc
|
||||
syncObservationFunc SyncObservationFunc
|
||||
syncSummaryFunc SyncSummaryFunc
|
||||
@@ -69,7 +69,7 @@ func (p *Processor) broadcast(event map[string]interface{}) {
|
||||
const MaxConcurrentCLICalls = 4
|
||||
|
||||
// NewProcessor creates a new SDK processor.
|
||||
func NewProcessor(observationStore *sqlite.ObservationStore, summaryStore *sqlite.SummaryStore) (*Processor, error) {
|
||||
func NewProcessor(observationStore *gorm.ObservationStore, summaryStore *gorm.SummaryStore) (*Processor, error) {
|
||||
cfg := config.Get()
|
||||
|
||||
// Find Claude Code CLI
|
||||
|
||||
+37
-45
@@ -13,7 +13,7 @@ import (
|
||||
"github.com/go-chi/chi/v5"
|
||||
"github.com/go-chi/chi/v5/middleware"
|
||||
"github.com/lukaszraczylo/claude-mnemonic/internal/config"
|
||||
"github.com/lukaszraczylo/claude-mnemonic/internal/db/sqlite"
|
||||
"github.com/lukaszraczylo/claude-mnemonic/internal/db/gorm"
|
||||
"github.com/lukaszraczylo/claude-mnemonic/internal/embedding"
|
||||
"github.com/lukaszraczylo/claude-mnemonic/internal/pattern"
|
||||
"github.com/lukaszraczylo/claude-mnemonic/internal/reranking"
|
||||
@@ -63,14 +63,14 @@ type Service struct {
|
||||
config *config.Config
|
||||
|
||||
// Database
|
||||
store *sqlite.Store
|
||||
sessionStore *sqlite.SessionStore
|
||||
observationStore *sqlite.ObservationStore
|
||||
summaryStore *sqlite.SummaryStore
|
||||
promptStore *sqlite.PromptStore
|
||||
conflictStore *sqlite.ConflictStore
|
||||
patternStore *sqlite.PatternStore
|
||||
relationStore *sqlite.RelationStore
|
||||
store *gorm.Store
|
||||
sessionStore *gorm.SessionStore
|
||||
observationStore *gorm.ObservationStore
|
||||
summaryStore *gorm.SummaryStore
|
||||
promptStore *gorm.PromptStore
|
||||
conflictStore *gorm.ConflictStore
|
||||
patternStore *gorm.PatternStore
|
||||
relationStore *gorm.RelationStore
|
||||
|
||||
// Pattern detection
|
||||
patternDetector *pattern.Detector
|
||||
@@ -182,10 +182,10 @@ func (s *Service) initializeAsync() {
|
||||
}
|
||||
|
||||
// Initialize database (this includes migrations - can be slow)
|
||||
store, err := sqlite.NewStore(sqlite.StoreConfig{
|
||||
store, err := gorm.NewStore(gorm.Config{
|
||||
Path: s.config.DBPath,
|
||||
MaxConns: s.config.MaxConns,
|
||||
WALMode: true,
|
||||
// WALMode is enabled automatically by GORM
|
||||
})
|
||||
if err != nil {
|
||||
s.setInitError(fmt.Errorf("init database: %w", err))
|
||||
@@ -193,19 +193,15 @@ func (s *Service) initializeAsync() {
|
||||
}
|
||||
|
||||
// Create store wrappers
|
||||
sessionStore := sqlite.NewSessionStore(store)
|
||||
observationStore := sqlite.NewObservationStore(store)
|
||||
summaryStore := sqlite.NewSummaryStore(store)
|
||||
promptStore := sqlite.NewPromptStore(store)
|
||||
conflictStore := sqlite.NewConflictStore(store)
|
||||
patternStore := sqlite.NewPatternStore(store)
|
||||
relationStore := sqlite.NewRelationStore(store)
|
||||
sessionStore := gorm.NewSessionStore(store)
|
||||
summaryStore := gorm.NewSummaryStore(store)
|
||||
promptStore := gorm.NewPromptStore(store, nil)
|
||||
conflictStore := gorm.NewConflictStore(store)
|
||||
patternStore := gorm.NewPatternStore(store)
|
||||
relationStore := gorm.NewRelationStore(store)
|
||||
|
||||
// Enable conflict detection by linking stores
|
||||
observationStore.SetConflictStore(conflictStore)
|
||||
|
||||
// Enable relation detection by linking stores
|
||||
observationStore.SetRelationStore(relationStore)
|
||||
// Create observation store with conflict and relation stores for automatic detection
|
||||
observationStore := gorm.NewObservationStore(store, nil, conflictStore, relationStore)
|
||||
|
||||
// Create session manager
|
||||
sessionManager := session.NewManager(sessionStore)
|
||||
@@ -224,7 +220,7 @@ func (s *Service) initializeAsync() {
|
||||
embedSvc = emb
|
||||
// Create sqlite-vec client using the same DB connection
|
||||
client, err := sqlitevec.NewClient(sqlitevec.Config{
|
||||
DB: store.DB(),
|
||||
DB: store.GetRawDB(),
|
||||
}, embedSvc)
|
||||
if err != nil {
|
||||
log.Warn().Err(err).Msg("sqlite-vec client creation failed - vector search disabled")
|
||||
@@ -519,10 +515,10 @@ func (s *Service) reinitializeDatabase() {
|
||||
}
|
||||
|
||||
// Create new database
|
||||
store, err := sqlite.NewStore(sqlite.StoreConfig{
|
||||
store, err := gorm.NewStore(gorm.Config{
|
||||
Path: s.config.DBPath,
|
||||
MaxConns: s.config.MaxConns,
|
||||
WALMode: true,
|
||||
// WALMode is enabled automatically by GORM
|
||||
})
|
||||
if err != nil {
|
||||
s.setInitError(fmt.Errorf("reinit database: %w", err))
|
||||
@@ -530,19 +526,15 @@ func (s *Service) reinitializeDatabase() {
|
||||
}
|
||||
|
||||
// Create new store wrappers
|
||||
sessionStore := sqlite.NewSessionStore(store)
|
||||
observationStore := sqlite.NewObservationStore(store)
|
||||
summaryStore := sqlite.NewSummaryStore(store)
|
||||
promptStore := sqlite.NewPromptStore(store)
|
||||
conflictStore := sqlite.NewConflictStore(store)
|
||||
patternStore := sqlite.NewPatternStore(store)
|
||||
relationStore := sqlite.NewRelationStore(store)
|
||||
sessionStore := gorm.NewSessionStore(store)
|
||||
summaryStore := gorm.NewSummaryStore(store)
|
||||
promptStore := gorm.NewPromptStore(store, nil)
|
||||
conflictStore := gorm.NewConflictStore(store)
|
||||
patternStore := gorm.NewPatternStore(store)
|
||||
relationStore := gorm.NewRelationStore(store)
|
||||
|
||||
// Enable conflict detection by linking stores
|
||||
observationStore.SetConflictStore(conflictStore)
|
||||
|
||||
// Enable relation detection by linking stores
|
||||
observationStore.SetRelationStore(relationStore)
|
||||
// Create observation store with conflict and relation stores for automatic detection
|
||||
observationStore := gorm.NewObservationStore(store, nil, conflictStore, relationStore)
|
||||
|
||||
// Create new session manager
|
||||
sessionManager := session.NewManager(sessionStore)
|
||||
@@ -560,7 +552,7 @@ func (s *Service) reinitializeDatabase() {
|
||||
} else {
|
||||
embedSvc = emb
|
||||
client, err := sqlitevec.NewClient(sqlitevec.Config{
|
||||
DB: store.DB(),
|
||||
DB: store.GetRawDB(),
|
||||
}, embedSvc)
|
||||
if err != nil {
|
||||
log.Warn().Err(err).Msg("sqlite-vec client creation failed after reinit")
|
||||
@@ -805,9 +797,9 @@ func (s *Service) processStaleQueue() {
|
||||
// rebuildAllVectors rebuilds all vectors from observations, summaries, and prompts.
|
||||
// Called when the vectors table is empty (e.g., after migration 20 drops all vectors).
|
||||
func (s *Service) rebuildAllVectors(
|
||||
observationStore *sqlite.ObservationStore,
|
||||
summaryStore *sqlite.SummaryStore,
|
||||
promptStore *sqlite.PromptStore,
|
||||
observationStore *gorm.ObservationStore,
|
||||
summaryStore *gorm.SummaryStore,
|
||||
promptStore *gorm.PromptStore,
|
||||
vectorSync *sqlitevec.Sync,
|
||||
) {
|
||||
defer s.wg.Done()
|
||||
@@ -877,9 +869,9 @@ func (s *Service) rebuildAllVectors(
|
||||
// rebuildStaleVectors rebuilds only vectors with mismatched or unknown model versions.
|
||||
// This is more efficient than rebuilding all vectors when only some need updating.
|
||||
func (s *Service) rebuildStaleVectors(
|
||||
observationStore *sqlite.ObservationStore,
|
||||
summaryStore *sqlite.SummaryStore,
|
||||
promptStore *sqlite.PromptStore,
|
||||
observationStore *gorm.ObservationStore,
|
||||
summaryStore *gorm.SummaryStore,
|
||||
promptStore *gorm.PromptStore,
|
||||
vectorClient *sqlitevec.Client,
|
||||
vectorSync *sqlitevec.Sync,
|
||||
) {
|
||||
|
||||
@@ -7,7 +7,7 @@ import (
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/lukaszraczylo/claude-mnemonic/internal/db/sqlite"
|
||||
"github.com/lukaszraczylo/claude-mnemonic/internal/db/gorm"
|
||||
"github.com/stretchr/testify/require"
|
||||
"github.com/stretchr/testify/suite"
|
||||
|
||||
@@ -25,10 +25,9 @@ func hasFTS5(t *testing.T) bool {
|
||||
}
|
||||
defer func() { _ = os.RemoveAll(tmpDir) }()
|
||||
|
||||
store, err := sqlite.NewStore(sqlite.StoreConfig{
|
||||
store, err := gorm.NewStore(gorm.Config{
|
||||
Path: tmpDir + "/check.db",
|
||||
MaxConns: 1,
|
||||
WALMode: true,
|
||||
})
|
||||
if err != nil {
|
||||
return false
|
||||
@@ -37,8 +36,8 @@ func hasFTS5(t *testing.T) bool {
|
||||
return true
|
||||
}
|
||||
|
||||
// testStore creates a sqlite.Store with a temporary database for testing.
|
||||
func testStore(t *testing.T) (*sqlite.Store, func()) {
|
||||
// testStore creates a gorm.Store with a temporary database for testing.
|
||||
func testStore(t *testing.T) (*gorm.Store, func()) {
|
||||
t.Helper()
|
||||
|
||||
if !hasFTS5(t) {
|
||||
@@ -50,10 +49,9 @@ func testStore(t *testing.T) (*sqlite.Store, func()) {
|
||||
|
||||
dbPath := tmpDir + "/test.db"
|
||||
|
||||
store, err := sqlite.NewStore(sqlite.StoreConfig{
|
||||
store, err := gorm.NewStore(gorm.Config{
|
||||
Path: dbPath,
|
||||
MaxConns: 1,
|
||||
WALMode: true,
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
@@ -68,8 +66,8 @@ func testStore(t *testing.T) (*sqlite.Store, func()) {
|
||||
// SessionIntegrationSuite tests session manager with real SQLite stores.
|
||||
type SessionIntegrationSuite struct {
|
||||
suite.Suite
|
||||
store *sqlite.Store
|
||||
sessionStore *sqlite.SessionStore
|
||||
store *gorm.Store
|
||||
sessionStore *gorm.SessionStore
|
||||
cleanup func()
|
||||
manager *Manager
|
||||
}
|
||||
@@ -80,7 +78,7 @@ func (s *SessionIntegrationSuite) SetupTest() {
|
||||
}
|
||||
|
||||
s.store, s.cleanup = testStore(s.T())
|
||||
s.sessionStore = sqlite.NewSessionStore(s.store)
|
||||
s.sessionStore = gorm.NewSessionStore(s.store)
|
||||
s.manager = NewManager(s.sessionStore)
|
||||
}
|
||||
|
||||
|
||||
@@ -7,7 +7,7 @@ import (
|
||||
"sync/atomic"
|
||||
"time"
|
||||
|
||||
"github.com/lukaszraczylo/claude-mnemonic/internal/db/sqlite"
|
||||
"github.com/lukaszraczylo/claude-mnemonic/internal/db/gorm"
|
||||
"github.com/rs/zerolog/log"
|
||||
)
|
||||
|
||||
@@ -70,7 +70,7 @@ const CleanupInterval = 5 * time.Minute
|
||||
|
||||
// Manager manages active session lifecycles.
|
||||
type Manager struct {
|
||||
sessionStore *sqlite.SessionStore
|
||||
sessionStore *gorm.SessionStore
|
||||
sessions map[int64]*ActiveSession
|
||||
mu sync.RWMutex
|
||||
onCreated func(int64)
|
||||
@@ -82,7 +82,7 @@ type Manager struct {
|
||||
}
|
||||
|
||||
// NewManager creates a new session manager.
|
||||
func NewManager(sessionStore *sqlite.SessionStore) *Manager {
|
||||
func NewManager(sessionStore *gorm.SessionStore) *Manager {
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
m := &Manager{
|
||||
sessionStore: sessionStore,
|
||||
|
||||
@@ -5,14 +5,14 @@ import (
|
||||
"os"
|
||||
"testing"
|
||||
|
||||
"github.com/lukaszraczylo/claude-mnemonic/internal/db/sqlite"
|
||||
"github.com/lukaszraczylo/claude-mnemonic/internal/db/gorm"
|
||||
_ "github.com/mattn/go-sqlite3"
|
||||
)
|
||||
|
||||
// testStore creates a sqlite.Store with a temporary database for testing.
|
||||
// Uses sqlite.NewStore which runs migrations (requires FTS5).
|
||||
// testStore creates a gorm.Store with a temporary database for testing.
|
||||
// Uses gorm.NewStore which runs migrations (requires FTS5).
|
||||
// Skips the test if FTS5 is not available.
|
||||
func testStore(t *testing.T) (*sqlite.Store, func()) {
|
||||
func testStore(t *testing.T) (*gorm.Store, func()) {
|
||||
t.Helper()
|
||||
|
||||
// First check if FTS5 is available
|
||||
@@ -27,10 +27,9 @@ func testStore(t *testing.T) (*sqlite.Store, func()) {
|
||||
|
||||
dbPath := tmpDir + "/test.db"
|
||||
|
||||
store, err := sqlite.NewStore(sqlite.StoreConfig{
|
||||
store, err := gorm.NewStore(gorm.Config{
|
||||
Path: dbPath,
|
||||
MaxConns: 1,
|
||||
WALMode: true,
|
||||
})
|
||||
if err != nil {
|
||||
_ = os.RemoveAll(tmpDir)
|
||||
|
||||
@@ -143,3 +143,38 @@ func RunHook[T any](hookName string, handler HookHandler[T]) {
|
||||
|
||||
WriteResponse(hookName, true)
|
||||
}
|
||||
|
||||
// StatuslineHandler is a function that handles statusline-specific logic.
|
||||
// It receives input and port, returns formatted status string.
|
||||
// No context injection or worker startup - just display.
|
||||
type StatuslineHandler[T any] func(input *T, port int) string
|
||||
|
||||
// RunStatuslineHook executes a statusline hook with minimal overhead.
|
||||
// Unlike RunHook, this:
|
||||
// - Does NOT check CLAUDE_MNEMONIC_INTERNAL (statuslines always run)
|
||||
// - Uses GetWorkerPort() instead of EnsureWorkerRunning() (no startup)
|
||||
// - Prints output directly to stdout (no JSON wrapping)
|
||||
// This keeps statusline fast (<100ms requirement).
|
||||
func RunStatuslineHook[T any](handler StatuslineHandler[T]) {
|
||||
// Read input from stdin
|
||||
inputData, err := io.ReadAll(os.Stdin)
|
||||
if err != nil {
|
||||
// On error, handler receives nil and should return offline status
|
||||
fmt.Println(handler(nil, 0))
|
||||
return
|
||||
}
|
||||
|
||||
// Parse input
|
||||
var input T
|
||||
if err := json.Unmarshal(inputData, &input); err != nil {
|
||||
// On parse error, handler receives nil and should return offline status
|
||||
fmt.Println(handler(nil, 0))
|
||||
return
|
||||
}
|
||||
|
||||
// Get worker port (does NOT start worker)
|
||||
port := GetWorkerPort()
|
||||
|
||||
// Run handler and print result
|
||||
fmt.Println(handler(&input, port))
|
||||
}
|
||||
|
||||
+12
-12
@@ -34,7 +34,7 @@ func TestIsWorkerRunning(t *testing.T) {
|
||||
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
if r.URL.Path == "/api/health" {
|
||||
w.WriteHeader(http.StatusOK)
|
||||
json.NewEncoder(w).Encode(map[string]string{"status": "ready"})
|
||||
_ = json.NewEncoder(w).Encode(map[string]string{"status": "ready"})
|
||||
} else {
|
||||
w.WriteHeader(http.StatusNotFound)
|
||||
}
|
||||
@@ -68,7 +68,7 @@ func TestGetWorkerVersion(t *testing.T) {
|
||||
name: "returns version from server",
|
||||
serverResponse: func(w http.ResponseWriter, r *http.Request) {
|
||||
if r.URL.Path == "/api/version" {
|
||||
json.NewEncoder(w).Encode(map[string]string{"version": "1.2.3"})
|
||||
_ = json.NewEncoder(w).Encode(map[string]string{"version": "1.2.3"})
|
||||
}
|
||||
},
|
||||
expectedResult: "1.2.3",
|
||||
@@ -83,7 +83,7 @@ func TestGetWorkerVersion(t *testing.T) {
|
||||
{
|
||||
name: "returns empty on invalid JSON",
|
||||
serverResponse: func(w http.ResponseWriter, r *http.Request) {
|
||||
w.Write([]byte("not json"))
|
||||
_, _ = w.Write([]byte("not json"))
|
||||
},
|
||||
expectedResult: "",
|
||||
},
|
||||
@@ -332,7 +332,7 @@ func TestPOST(t *testing.T) {
|
||||
assert.Equal(t, http.MethodPost, r.Method)
|
||||
assert.Equal(t, "application/json", r.Header.Get("Content-Type"))
|
||||
w.WriteHeader(http.StatusOK)
|
||||
json.NewEncoder(w).Encode(map[string]interface{}{"status": "ok"})
|
||||
_ = json.NewEncoder(w).Encode(map[string]interface{}{"status": "ok"})
|
||||
},
|
||||
body: map[string]string{"key": "value"},
|
||||
expectError: false,
|
||||
@@ -358,7 +358,7 @@ func TestPOST(t *testing.T) {
|
||||
name: "POST with non-JSON response",
|
||||
serverHandler: func(w http.ResponseWriter, r *http.Request) {
|
||||
w.WriteHeader(http.StatusOK)
|
||||
w.Write([]byte("not json"))
|
||||
_, _ = w.Write([]byte("not json"))
|
||||
},
|
||||
body: map[string]string{"key": "value"},
|
||||
expectError: false,
|
||||
@@ -403,7 +403,7 @@ func TestGET(t *testing.T) {
|
||||
serverHandler: func(w http.ResponseWriter, r *http.Request) {
|
||||
assert.Equal(t, http.MethodGet, r.Method)
|
||||
w.WriteHeader(http.StatusOK)
|
||||
json.NewEncoder(w).Encode(map[string]interface{}{"data": "test"})
|
||||
_ = json.NewEncoder(w).Encode(map[string]interface{}{"data": "test"})
|
||||
},
|
||||
expectError: false,
|
||||
expectedResult: map[string]interface{}{"data": "test"},
|
||||
@@ -419,7 +419,7 @@ func TestGET(t *testing.T) {
|
||||
name: "GET with invalid JSON",
|
||||
serverHandler: func(w http.ResponseWriter, r *http.Request) {
|
||||
w.WriteHeader(http.StatusOK)
|
||||
w.Write([]byte("not valid json"))
|
||||
_, _ = w.Write([]byte("not valid json"))
|
||||
},
|
||||
expectError: true,
|
||||
},
|
||||
@@ -666,7 +666,7 @@ func TestGetWorkerVersion_WithServer(t *testing.T) {
|
||||
serverHandler: func(w http.ResponseWriter, r *http.Request) {
|
||||
if r.URL.Path == "/api/version" {
|
||||
w.WriteHeader(http.StatusOK)
|
||||
json.NewEncoder(w).Encode(map[string]string{"version": "v1.2.3"})
|
||||
_ = json.NewEncoder(w).Encode(map[string]string{"version": "v1.2.3"})
|
||||
}
|
||||
},
|
||||
expectedResult: "v1.2.3",
|
||||
@@ -682,7 +682,7 @@ func TestGetWorkerVersion_WithServer(t *testing.T) {
|
||||
name: "returns empty on invalid JSON",
|
||||
serverHandler: func(w http.ResponseWriter, r *http.Request) {
|
||||
w.WriteHeader(http.StatusOK)
|
||||
w.Write([]byte("not json"))
|
||||
_, _ = w.Write([]byte("not json"))
|
||||
},
|
||||
expectedResult: "",
|
||||
},
|
||||
@@ -690,7 +690,7 @@ func TestGetWorkerVersion_WithServer(t *testing.T) {
|
||||
name: "returns empty on missing version field",
|
||||
serverHandler: func(w http.ResponseWriter, r *http.Request) {
|
||||
w.WriteHeader(http.StatusOK)
|
||||
json.NewEncoder(w).Encode(map[string]string{"other": "field"})
|
||||
_ = json.NewEncoder(w).Encode(map[string]string{"other": "field"})
|
||||
},
|
||||
expectedResult: "",
|
||||
},
|
||||
@@ -1082,7 +1082,7 @@ func TestPOST_EmptyBody(t *testing.T) {
|
||||
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
assert.Equal(t, http.MethodPost, r.Method)
|
||||
w.WriteHeader(http.StatusOK)
|
||||
json.NewEncoder(w).Encode(map[string]string{"status": "ok"})
|
||||
_ = json.NewEncoder(w).Encode(map[string]string{"status": "ok"})
|
||||
}))
|
||||
defer server.Close()
|
||||
|
||||
@@ -1101,7 +1101,7 @@ func TestGET_WithQueryParams(t *testing.T) {
|
||||
assert.Equal(t, http.MethodGet, r.Method)
|
||||
assert.Equal(t, "/test?foo=bar", r.URL.String())
|
||||
w.WriteHeader(http.StatusOK)
|
||||
json.NewEncoder(w).Encode(map[string]string{"status": "ok"})
|
||||
_ = json.NewEncoder(w).Encode(map[string]string{"status": "ok"})
|
||||
}))
|
||||
defer server.Close()
|
||||
|
||||
|
||||
@@ -1,22 +0,0 @@
|
||||
{
|
||||
"$schema": "https://anthropic.com/claude-code/marketplace.schema.json",
|
||||
"name": "claude-mnemonic",
|
||||
"version": "1.0.0",
|
||||
"description": "Persistent memory system for Claude Code - stores observations, session summaries, and user prompts with semantic search",
|
||||
"owner": {
|
||||
"name": "lukaszraczylo",
|
||||
"email": "lukaszraczylo@users.noreply.github.com"
|
||||
},
|
||||
"plugins": [
|
||||
{
|
||||
"name": "claude-mnemonic",
|
||||
"description": "Persistent memory system for Claude Code - Go implementation with SQLite and ChromaDB",
|
||||
"version": "1.0.0",
|
||||
"author": {
|
||||
"name": "lukaszraczylo"
|
||||
},
|
||||
"source": "./",
|
||||
"category": "productivity"
|
||||
}
|
||||
]
|
||||
}
|
||||
@@ -1,9 +0,0 @@
|
||||
{
|
||||
"name": "claude-mnemonic",
|
||||
"version": "1.0.0",
|
||||
"description": "Persistent memory system for Claude Code - Go implementation with SQLite and ChromaDB",
|
||||
"author": {
|
||||
"name": "lukaszraczylo",
|
||||
"email": "lukaszraczylo@users.noreply.github.com"
|
||||
}
|
||||
}
|
||||
@@ -9,9 +9,14 @@ if [ -n "$GORELEASER_CURRENT_TAG" ]; then
|
||||
VERSION="${GORELEASER_CURRENT_TAG#v}"
|
||||
echo "Using version from GORELEASER_CURRENT_TAG: $VERSION"
|
||||
else
|
||||
# Fallback for local testing
|
||||
VERSION="0.0.0-dev"
|
||||
echo "GORELEASER_CURRENT_TAG not set, using fallback version: $VERSION"
|
||||
# Fallback: Use latest git tag instead of 0.0.0-dev
|
||||
# This prevents version mismatch when Claude installs from GitHub
|
||||
LATEST_TAG=$(git tag --sort=-v:refname | head -1 || echo "v0.0.0-dev")
|
||||
if [ -z "$LATEST_TAG" ]; then
|
||||
LATEST_TAG="v0.0.0-dev"
|
||||
fi
|
||||
VERSION="${LATEST_TAG#v}"
|
||||
echo "GORELEASER_CURRENT_TAG not set, using latest git tag: $VERSION"
|
||||
fi
|
||||
|
||||
# Source and destination directories
|
||||
|
||||
@@ -186,6 +186,34 @@ function Register-Plugin {
|
||||
$Marketplaces | Add-Member -NotePropertyName "claude-mnemonic" -NotePropertyValue $MarketplaceEntry -Force
|
||||
$Marketplaces | ConvertTo-Json -Depth 10 | Out-File -Encoding UTF8 $MarketplacesFile
|
||||
Write-Success "Marketplace registered in known_marketplaces.json"
|
||||
|
||||
# Register MCP server in settings.json
|
||||
$McpBinary = Join-Path $InstallDir "mcp-server.exe"
|
||||
if (Test-Path $McpBinary) {
|
||||
Write-Info "Registering MCP server in settings.json..."
|
||||
|
||||
# Reload settings to include any previous updates
|
||||
$Settings = Get-Content $SettingsFile -Raw | ConvertFrom-Json
|
||||
|
||||
# Ensure mcpServers object exists
|
||||
if (-not $Settings.mcpServers) {
|
||||
$Settings | Add-Member -NotePropertyName "mcpServers" -NotePropertyValue @{} -Force
|
||||
}
|
||||
|
||||
# Add MCP server entry
|
||||
$McpEntry = @{
|
||||
command = $McpBinary
|
||||
args = @("--project", "`${CLAUDE_PROJECT}")
|
||||
env = @{}
|
||||
}
|
||||
|
||||
$Settings.mcpServers | Add-Member -NotePropertyName "claude-mnemonic" -NotePropertyValue $McpEntry -Force
|
||||
|
||||
$Settings | ConvertTo-Json -Depth 10 | Out-File -Encoding UTF8 $SettingsFile
|
||||
Write-Success "MCP server registered successfully"
|
||||
} else {
|
||||
Write-Warn "MCP server binary not found at $McpBinary, skipping MCP registration"
|
||||
}
|
||||
} catch {
|
||||
Write-Warn "Plugin registration encountered an error: $_"
|
||||
}
|
||||
@@ -282,6 +310,10 @@ function Uninstall-ClaudeMnemonic {
|
||||
if ($Settings.statusLine -and $Settings.statusLine.command -match "claude-mnemonic") {
|
||||
$Settings.PSObject.Properties.Remove("statusLine")
|
||||
}
|
||||
# Remove MCP server entry
|
||||
if ($Settings.mcpServers) {
|
||||
$Settings.mcpServers.PSObject.Properties.Remove("claude-mnemonic")
|
||||
}
|
||||
$Settings | ConvertTo-Json -Depth 10 | Out-File -Encoding UTF8 $SettingsFile
|
||||
}
|
||||
if (Test-Path $MarketplacesFile) {
|
||||
|
||||
+35
-2
@@ -297,6 +297,37 @@ EOF
|
||||
&& mv "${MARKETPLACES_FILE}.tmp" "$MARKETPLACES_FILE"
|
||||
|
||||
success "Marketplace registered in known_marketplaces.json"
|
||||
|
||||
# Register MCP server in settings.json
|
||||
local mcp_binary="$INSTALL_DIR/mcp-server"
|
||||
if [[ -f "$mcp_binary" ]]; then
|
||||
info "Registering MCP server in settings.json..."
|
||||
|
||||
# MCP server entry - note the escaped ${CLAUDE_PROJECT}
|
||||
local mcp_entry
|
||||
mcp_entry=$(cat <<'EOF'
|
||||
{
|
||||
"command": "MCP_BINARY_PLACEHOLDER",
|
||||
"args": ["--project", "${CLAUDE_PROJECT}"],
|
||||
"env": {}
|
||||
}
|
||||
EOF
|
||||
)
|
||||
# Replace placeholder with actual path
|
||||
mcp_entry=$(echo "$mcp_entry" | sed "s|MCP_BINARY_PLACEHOLDER|$mcp_binary|g")
|
||||
|
||||
# Add or update mcpServers field
|
||||
if jq --arg key "claude-mnemonic" --argjson entry "$mcp_entry" \
|
||||
'.mcpServers //= {} | .mcpServers[$key] = $entry' "$SETTINGS_FILE" > "${SETTINGS_FILE}.tmp"; then
|
||||
mv "${SETTINGS_FILE}.tmp" "$SETTINGS_FILE"
|
||||
success "MCP server registered successfully"
|
||||
else
|
||||
warn "Failed to register MCP server (jq error)"
|
||||
rm -f "${SETTINGS_FILE}.tmp"
|
||||
fi
|
||||
else
|
||||
warn "MCP server binary not found at $mcp_binary, skipping MCP registration"
|
||||
fi
|
||||
}
|
||||
|
||||
# Start the worker service
|
||||
@@ -479,8 +510,10 @@ if [[ "${1:-}" == "--uninstall" ]]; then
|
||||
jq 'del(.plugins["'"$PLUGIN_KEY"'"])' "$PLUGINS_FILE" > "${PLUGINS_FILE}.tmp" && mv "${PLUGINS_FILE}.tmp" "$PLUGINS_FILE"
|
||||
fi
|
||||
if [[ -f "$SETTINGS_FILE" ]]; then
|
||||
# Remove plugin from enabled plugins and remove statusline if it's ours
|
||||
jq 'del(.enabledPlugins["'"$PLUGIN_KEY"'"]) | if .statusLine.command | test("claude-mnemonic") then del(.statusLine) else . end' "$SETTINGS_FILE" > "${SETTINGS_FILE}.tmp" && mv "${SETTINGS_FILE}.tmp" "$SETTINGS_FILE"
|
||||
# Remove plugin from enabled plugins, remove statusline if it's ours, and remove MCP server entry
|
||||
jq 'del(.enabledPlugins["'"$PLUGIN_KEY"'"]) |
|
||||
if .statusLine.command | test("claude-mnemonic") then del(.statusLine) else . end |
|
||||
del(.mcpServers["claude-mnemonic"])' "$SETTINGS_FILE" > "${SETTINGS_FILE}.tmp" && mv "${SETTINGS_FILE}.tmp" "$SETTINGS_FILE"
|
||||
fi
|
||||
if [[ -f "$MARKETPLACES_FILE" ]]; then
|
||||
jq 'del(.["claude-mnemonic"])' "$MARKETPLACES_FILE" > "${MARKETPLACES_FILE}.tmp" && mv "${MARKETPLACES_FILE}.tmp" "$MARKETPLACES_FILE"
|
||||
|
||||
@@ -107,6 +107,37 @@ EOF
|
||||
&& mv "${MARKETPLACES_FILE}.tmp" "$MARKETPLACES_FILE"
|
||||
|
||||
echo "Marketplace registered in known_marketplaces.json"
|
||||
|
||||
# Register MCP server in settings.json
|
||||
MCP_BINARY="$MARKETPLACE_PATH/mcp-server"
|
||||
if [ -f "$MCP_BINARY" ]; then
|
||||
echo "Registering MCP server in settings.json..."
|
||||
|
||||
# MCP server entry - note the escaped ${CLAUDE_PROJECT}
|
||||
MCP_ENTRY=$(cat <<'EOF'
|
||||
{
|
||||
"command": "MCP_BINARY_PLACEHOLDER",
|
||||
"args": ["--project", "${CLAUDE_PROJECT}"],
|
||||
"env": {}
|
||||
}
|
||||
EOF
|
||||
)
|
||||
# Replace placeholder with actual path
|
||||
MCP_ENTRY=$(echo "$MCP_ENTRY" | sed "s|MCP_BINARY_PLACEHOLDER|$MCP_BINARY|g")
|
||||
|
||||
# Add or update mcpServers field
|
||||
if jq --arg key "claude-mnemonic" --argjson entry "$MCP_ENTRY" \
|
||||
'.mcpServers //= {} | .mcpServers[$key] = $entry' "$SETTINGS_FILE" > "${SETTINGS_FILE}.tmp"; then
|
||||
mv "${SETTINGS_FILE}.tmp" "$SETTINGS_FILE"
|
||||
echo "MCP server registered successfully"
|
||||
else
|
||||
echo "Warning: Failed to register MCP server (jq error)"
|
||||
rm -f "${SETTINGS_FILE}.tmp"
|
||||
fi
|
||||
else
|
||||
echo "MCP server binary not found at $MCP_BINARY, skipping MCP registration"
|
||||
fi
|
||||
|
||||
echo "Plugin registered successfully using jq"
|
||||
else
|
||||
echo "ERROR: jq is required for plugin registration"
|
||||
|
||||
Generated
+2
-2
@@ -1,12 +1,12 @@
|
||||
{
|
||||
"name": "claude-mnemonic-dashboard",
|
||||
"version": "0ddacaa-dirty",
|
||||
"version": "8fe9ea5-dirty",
|
||||
"lockfileVersion": 3,
|
||||
"requires": true,
|
||||
"packages": {
|
||||
"": {
|
||||
"name": "claude-mnemonic-dashboard",
|
||||
"version": "0ddacaa-dirty",
|
||||
"version": "8fe9ea5-dirty",
|
||||
"dependencies": {
|
||||
"vis-data": "^7.1.9",
|
||||
"vis-network": "^9.1.9",
|
||||
|
||||
+1
-1
@@ -1,6 +1,6 @@
|
||||
{
|
||||
"name": "claude-mnemonic-dashboard",
|
||||
"version": "0ddacaa-dirty",
|
||||
"version": "8fe9ea5-dirty",
|
||||
"private": true,
|
||||
"type": "module",
|
||||
"scripts": {
|
||||
|
||||
Reference in New Issue
Block a user