Release dec 2025 (#15)

* Resolves issue #13

- Switched model to bge-small-en-v1.5
- Added lazy re-embedding
- Added model version tracking per vector
- Added conversion of vectors to the new model

* Add lfs support to the workflow.

* Implements importance scoring with decay + voting #6

* Resolves issue #5 by marking observations as superseeded and scheduled for deletion

* Implement pattern detection #7

* Improve injections and observations accuracy

- Session start: Recent observations for project context (recency-based)
- User prompt: Semantically relevant observations (similarity-based with threshold)

* Added two stage retrieval with bi and cross encoder #8

* Implement query expansion and reformulation #9

* Knowledge graph and relationships ( resolves #4 )

- File Overlap Detection: Detects relationships when observations modify/read the same files
- Concept Overlap Detection: Detects relationships based on shared semantic concepts
- Type Progression Detection: Infers relationships from natural observation type progressions (e.g., discovery → bugfix = "fixes")
- Temporal Proximity Detection: Detects relationships between observations in the same session within 5 minutes
- Narrative Mention Detection: Detects explicit relationship language in narratives (e.g., "fixes", "depends on", "supersedes")

* Add visualisation of the relations to the dashboard.

* fixup! Add visualisation of the relations to the dashboard.

* Update documentation with new settings and screenshots.
This commit is contained in:
2025-12-19 17:57:11 +00:00
committed by GitHub
parent 48957a6c81
commit f79782a008
69 changed files with 43967 additions and 194 deletions
+1
View File
@@ -20,4 +20,5 @@ jobs:
uses: lukaszraczylo/shared-actions/.github/workflows/go-pr.yaml@main uses: lukaszraczylo/shared-actions/.github/workflows/go-pr.yaml@main
with: with:
go-version: ">=1.24" go-version: ">=1.24"
lfs: true
secrets: inherit secrets: inherit
+44 -5
View File
@@ -12,7 +12,21 @@ Claude Code forgets everything when your session ends. Claude Mnemonic fixes tha
It captures what Claude learns during your coding sessions - bug fixes, architecture decisions, patterns that work - and brings that knowledge back in future conversations. No more re-explaining your codebase. It captures what Claude learns during your coding sessions - bug fixes, architecture decisions, patterns that work - and brings that knowledge back in future conversations. No more re-explaining your codebase.
## What's New in v0.6 ![Claude Mnemonic Dashboard](docs/public/claude-mnemonic.jpg)
## What's New in v0.7
- **Two-Stage Retrieval** - Cross-encoder reranking for dramatically improved search relevance
- **Knowledge Graph** - Automatic relationship detection between observations with visual graph in dashboard
![Knowledge Graph](docs/public/observation-relation-graph.jpg)
- **Pattern Detection** - Identifies recurring patterns across sessions and projects
- **Importance Scoring** - Time decay and voting system to surface the most valuable memories
- **Query Expansion** - Reformulates searches to find semantically related content
- **Conflict Detection** - Identifies and resolves contradictory observations
- **Observation Lifecycle** - Memories can be superseded when better information arrives
<details>
<summary>Previous: v0.6</summary>
- **Auto-Updates** - Automatically stays up-to-date with the latest version - **Auto-Updates** - Automatically stays up-to-date with the latest version
- **Slash Command: `/restart`** - Restart the worker directly from Claude Code - **Slash Command: `/restart`** - Restart the worker directly from Claude Code
@@ -20,6 +34,7 @@ It captures what Claude learns during your coding sessions - bug fixes, architec
- **Async Queue Processing** - Non-blocking observation capture for faster sessions - **Async Queue Processing** - Non-blocking observation capture for faster sessions
- **Smarter Storage** - Filters out system/agent summaries to keep knowledge relevant - **Smarter Storage** - Filters out system/agent summaries to keep knowledge relevant
- **Improved Reliability** - Better handling of connectivity issues and dead connections - **Improved Reliability** - Better handling of connectivity issues and dead connections
</details>
## Requirements ## Requirements
@@ -110,16 +125,39 @@ Config file: `~/.claude-mnemonic/settings.json`
{ {
"CLAUDE_MNEMONIC_WORKER_PORT": 37777, "CLAUDE_MNEMONIC_WORKER_PORT": 37777,
"CLAUDE_MNEMONIC_CONTEXT_OBSERVATIONS": 100, "CLAUDE_MNEMONIC_CONTEXT_OBSERVATIONS": 100,
"CLAUDE_MNEMONIC_CONTEXT_FULL_COUNT": 25 "CLAUDE_MNEMONIC_CONTEXT_FULL_COUNT": 25,
"CLAUDE_MNEMONIC_RERANKING_ENABLED": true
} }
``` ```
### Core Settings
| Variable | Default | What it does | | Variable | Default | What it does |
|----------|---------|--------------| |----------|---------|--------------|
| `WORKER_PORT` | `37777` | Dashboard & API port | | `WORKER_PORT` | `37777` | Dashboard & API port |
| `CONTEXT_OBSERVATIONS` | `100` | Max memories per session | | `CONTEXT_OBSERVATIONS` | `100` | Max memories per session |
| `CONTEXT_FULL_COUNT` | `25` | Full detail memories (rest are condensed to title only) | | `CONTEXT_FULL_COUNT` | `25` | Full detail memories (rest are condensed) |
| `CONTEXT_SESSION_COUNT` | `10` | Recent sessions to reference | | `CONTEXT_SESSION_COUNT` | `10` | Recent sessions to reference |
| `CONTEXT_RELEVANCE_THRESHOLD` | `0.3` | Minimum similarity score (0.0-1.0) for inclusion |
| `CONTEXT_MAX_PROMPT_RESULTS` | `10` | Max results per prompt search |
### Reranking Settings (Two-Stage Retrieval)
| Variable | Default | What it does |
|----------|---------|--------------|
| `RERANKING_ENABLED` | `true` | Enable cross-encoder reranking |
| `RERANKING_CANDIDATES` | `100` | Candidates to retrieve before reranking |
| `RERANKING_RESULTS` | `10` | Final results after reranking |
| `RERANKING_ALPHA` | `0.7` | Score blend: alpha×rerank + (1-alpha)×original |
| `RERANKING_PURE_MODE` | `false` | Use pure cross-encoder scores only |
### Embedding Settings
| Variable | Default | What it does |
|----------|---------|--------------|
| `EMBEDDING_MODEL` | `bge-v1.5` | Embedding model for semantic search |
All variables are prefixed with `CLAUDE_MNEMONIC_` in the config file.
## Project vs Global scope ## Project vs Global scope
@@ -202,10 +240,11 @@ curl -sSL https://raw.githubusercontent.com/lukaszraczylo/claude-mnemonic/main/s
- **SQLite + FTS5** - Full-text search for exact matches - **SQLite + FTS5** - Full-text search for exact matches
- **sqlite-vec** - Vector database embedded in SQLite - **sqlite-vec** - Vector database embedded in SQLite
- **all-MiniLM-L6-v2** - Local embedding model (384 dimensions) via ONNX Runtime - **Two-Stage Retrieval** - Bi-encoder (embedding) + cross-encoder (reranking) for high accuracy
- **Local Models** - all-MiniLM-L6-v2 for embeddings, BGE reranker for relevance scoring
- **Go** - Single binary, no external dependencies - **Go** - Single binary, no external dependencies
Everything runs locally. No Python. No external vector database. No API calls for embeddings. Everything runs locally. No Python. No external vector database. No API calls.
## Platform support ## Platform support
+2 -6
View File
@@ -61,12 +61,8 @@ func main() {
searchResult, _ := hooks.GET(port, searchURL) searchResult, _ := hooks.GET(port, searchURL)
if observations, ok := searchResult["observations"].([]interface{}); ok && len(observations) > 0 { if observations, ok := searchResult["observations"].([]interface{}); ok && len(observations) > 0 {
// Limit to top 5 most relevant observations // Results are already filtered by relevance threshold and capped by max_results
maxObs := 5 // from the server-side config (ContextRelevanceThreshold, ContextMaxPromptResults)
if len(observations) < maxObs {
maxObs = len(observations)
}
observations = observations[:maxObs]
observationCount = len(observations) observationCount = len(observations)
// Build context from search results // Build context from search results
Binary file not shown.

After

Width:  |  Height:  |  Size: 252 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 123 KiB

+37 -7
View File
@@ -29,6 +29,21 @@
subtitle="Stop re-explaining your codebase. Claude Mnemonic captures bug fixes, architecture decisions, and coding patterns - then brings them back exactly when you need them." subtitle="Stop re-explaining your codebase. Claude Mnemonic captures bug fixes, architecture decisions, and coding patterns - then brings them back exactly when you need them."
/> />
<!-- Dashboard Preview -->
<section class="py-12 lg:py-16 px-4 sm:px-6 relative">
<div class="max-w-6xl mx-auto">
<div class="relative rounded-xl overflow-hidden border border-slate-700/50 shadow-2xl shadow-amber-500/5">
<div class="absolute inset-0 bg-gradient-to-t from-slate-950 via-transparent to-transparent pointer-events-none z-10"></div>
<img
src="/claude-mnemonic.jpg"
alt="Claude Mnemonic Dashboard"
class="w-full h-auto"
/>
</div>
<p class="text-center text-slate-500 text-sm mt-4">The dashboard at localhost:37777 - browse, search, and manage your memories</p>
</div>
</section>
<!-- Problem Section --> <!-- Problem Section -->
<section class="py-20 lg:py-28 px-4 sm:px-6 relative"> <section class="py-20 lg:py-28 px-4 sm:px-6 relative">
<div class="max-w-6xl mx-auto grid lg:grid-cols-2 gap-8 lg:gap-16 items-center"> <div class="max-w-6xl mx-auto grid lg:grid-cols-2 gap-8 lg:gap-16 items-center">
@@ -85,6 +100,21 @@
</div> </div>
</section> </section>
<!-- Knowledge Graph Preview -->
<section class="py-16 lg:py-20 px-4 sm:px-6 relative">
<div class="max-w-5xl mx-auto">
<SectionHeader title="See how knowledge connects" subtitle="The knowledge graph reveals relationships between your memories automatically" />
<div class="relative rounded-xl overflow-hidden border border-slate-700/50 shadow-2xl shadow-purple-500/5">
<img
src="/observation-relation-graph.jpg"
alt="Knowledge Graph Visualization"
class="w-full h-auto"
/>
</div>
<p class="text-center text-slate-500 text-sm mt-4">Click any observation to explore its relationships - see what causes what, what fixes what, and how concepts evolve</p>
</div>
</section>
<!-- Before/After Section --> <!-- Before/After Section -->
<section class="py-20 lg:py-28 px-4 sm:px-6"> <section class="py-20 lg:py-28 px-4 sm:px-6">
<div class="max-w-6xl mx-auto"> <div class="max-w-6xl mx-auto">
@@ -288,8 +318,8 @@
<p class="text-slate-400 text-xs sm:text-sm">Embedded vector database. No external services required.</p> <p class="text-slate-400 text-xs sm:text-sm">Embedded vector database. No external services required.</p>
</div> </div>
<div class="glass rounded-2xl p-6 sm:p-8 hover:border-amber-500/30 transition-colors"> <div class="glass rounded-2xl p-6 sm:p-8 hover:border-amber-500/30 transition-colors">
<div class="text-3xl sm:text-4xl font-bold text-amber-500 mb-2">MiniLM</div> <div class="text-3xl sm:text-4xl font-bold text-amber-500 mb-2">BGE</div>
<p class="text-slate-400 text-xs sm:text-sm">Local embeddings via ONNX. "Fix auth" finds "JWT issue" automatically.</p> <p class="text-slate-400 text-xs sm:text-sm">Two-stage retrieval: bi-encoder embeddings + cross-encoder reranking for high accuracy.</p>
</div> </div>
</div> </div>
</div> </div>
@@ -386,10 +416,10 @@ const activeTab = ref('macos')
const features = [ const features = [
{ icon: 'fas fa-brain', title: 'Learns as you work', description: 'Every bug fix, every architecture decision, every "aha moment" - captured automatically without breaking your flow.' }, { icon: 'fas fa-brain', title: 'Learns as you work', description: 'Every bug fix, every architecture decision, every "aha moment" - captured automatically without breaking your flow.' },
{ icon: 'fas fa-search', title: 'Two-stage retrieval', description: 'Cross-encoder reranking delivers highly relevant results. Finds what you need even with vague queries like "that auth thing".' },
{ icon: 'fas fa-project-diagram', title: 'Knowledge graph', description: 'Automatically discovers relationships between memories. See how concepts connect in the visual graph dashboard.' },
{ icon: 'fas fa-folder-tree', title: 'Project-aware context', description: 'Your React knowledge stays in React projects. Your Go patterns stay in Go projects. No context pollution.' }, { icon: 'fas fa-folder-tree', title: 'Project-aware context', description: 'Your React knowledge stays in React projects. Your Go patterns stay in Go projects. No context pollution.' },
{ icon: 'fas fa-globe', title: 'Shared best practices', description: 'Security patterns, performance tips, and universal learnings automatically available across all your projects.' }, { icon: 'fas fa-chart-line', title: 'Smart scoring', description: 'Importance decay, pattern detection, and conflict resolution ensure the most valuable memories surface first.' },
{ icon: 'fas fa-search', title: 'Finds what matters', description: 'Semantic search finds relevant memories even when you don\'t remember the exact words. "That auth thing" just works.' },
{ icon: 'fas fa-chart-line', title: 'Live statusline', description: 'Real-time metrics right in Claude Code: memories served, searches performed, and project memory count at a glance.' },
{ icon: 'fas fa-lock', title: '100% private', description: 'Your code context never leaves your machine. No telemetry. No cloud sync. Your memories are yours.' }, { icon: 'fas fa-lock', title: '100% private', description: 'Your code context never leaves your machine. No telemetry. No cloud sync. Your memories are yours.' },
] ]
@@ -415,8 +445,8 @@ const installCommands = {
const configOptions = [ const configOptions = [
{ name: 'CLAUDE_MNEMONIC_WORKER_PORT', description: 'HTTP port for the worker service (default: 37777)', icon: 'fas fa-network-wired' }, { name: 'CLAUDE_MNEMONIC_WORKER_PORT', description: 'HTTP port for the worker service (default: 37777)', icon: 'fas fa-network-wired' },
{ name: 'CLAUDE_MNEMONIC_CONTEXT_OBSERVATIONS', description: 'Maximum observations injected per session (default: 100)', icon: 'fas fa-layer-group' }, { name: 'CLAUDE_MNEMONIC_CONTEXT_OBSERVATIONS', description: 'Maximum observations injected per session (default: 100)', icon: 'fas fa-layer-group' },
{ name: 'CLAUDE_MNEMONIC_CONTEXT_FULL_COUNT', description: 'Observations with full narrative detail, rest are condensed (default: 25)', icon: 'fas fa-expand' }, { name: 'CLAUDE_MNEMONIC_RERANKING_ENABLED', description: 'Enable cross-encoder reranking for improved search relevance (default: true)', icon: 'fas fa-sort-amount-down' },
{ name: 'CLAUDE_MNEMONIC_MODEL', description: 'Model for processing observations (default: haiku)', icon: 'fas fa-microchip' }, { name: 'CLAUDE_MNEMONIC_CONTEXT_RELEVANCE_THRESHOLD', description: 'Minimum similarity score for inclusion, 0.0-1.0 (default: 0.3)', icon: 'fas fa-filter' },
] ]
const requiredDeps = [ const requiredDeps = [
+74 -22
View File
@@ -48,16 +48,29 @@ type Config struct {
Model string `json:"model"` Model string `json:"model"`
ClaudeCodePath string `json:"claude_code_path"` ClaudeCodePath string `json:"claude_code_path"`
// Embedding settings
EmbeddingModel string `json:"embedding_model"` // e.g., "bge-v1.5"
// Reranking settings (cross-encoder)
RerankingEnabled bool `json:"reranking_enabled"` // Enable cross-encoder reranking
RerankingCandidates int `json:"reranking_candidates"` // Number of candidates to retrieve before reranking (default 100)
RerankingResults int `json:"reranking_results"` // Number of results to return after reranking (default 10)
RerankingAlpha float64 `json:"reranking_alpha"` // Weight for combining scores: alpha*rerank + (1-alpha)*original (default 0.7)
RerankingMinImprovement float64 `json:"reranking_min_improvement"` // Minimum rank improvement to trigger reranking (default 0, always rerank)
RerankingPureMode bool `json:"reranking_pure_mode"` // Use pure cross-encoder scores without combining with bi-encoder (default false)
// Context injection settings // Context injection settings
ContextObservations int `json:"context_observations"` ContextObservations int `json:"context_observations"`
ContextFullCount int `json:"context_full_count"` ContextFullCount int `json:"context_full_count"`
ContextSessionCount int `json:"context_session_count"` ContextSessionCount int `json:"context_session_count"`
ContextShowReadTokens bool `json:"context_show_read_tokens"` ContextShowReadTokens bool `json:"context_show_read_tokens"`
ContextShowWorkTokens bool `json:"context_show_work_tokens"` ContextShowWorkTokens bool `json:"context_show_work_tokens"`
ContextFullField string `json:"context_full_field"` ContextFullField string `json:"context_full_field"`
ContextShowLastSummary bool `json:"context_show_last_summary"` ContextShowLastSummary bool `json:"context_show_last_summary"`
ContextObsTypes []string `json:"context_obs_types"` ContextObsTypes []string `json:"context_obs_types"`
ContextObsConcepts []string `json:"context_obs_concepts"` ContextObsConcepts []string `json:"context_obs_concepts"`
ContextRelevanceThreshold float64 `json:"context_relevance_threshold"` // 0.0-1.0, minimum similarity for inclusion
ContextMaxPromptResults int `json:"context_max_prompt_results"` // Max results per prompt (0 = threshold only)
} }
var ( var (
@@ -119,22 +132,33 @@ func EnsureAll() error {
return nil return nil
} }
// DefaultEmbeddingModel is the default embedding model to use.
const DefaultEmbeddingModel = "bge-v1.5"
// Default returns a Config with default values. // Default returns a Config with default values.
func Default() *Config { func Default() *Config {
return &Config{ return &Config{
WorkerPort: DefaultWorkerPort, WorkerPort: DefaultWorkerPort,
DBPath: DBPath(), DBPath: DBPath(),
MaxConns: 4, MaxConns: 4,
Model: DefaultModel, Model: DefaultModel,
ContextObservations: 100, EmbeddingModel: DefaultEmbeddingModel,
ContextFullCount: 25, RerankingEnabled: true, // Enable by default for improved relevance
ContextSessionCount: 10, RerankingCandidates: 100, // Retrieve top 100 candidates
ContextShowReadTokens: true, RerankingResults: 10, // Return top 10 after reranking
ContextShowWorkTokens: true, RerankingAlpha: 0.7, // Favor cross-encoder score
ContextFullField: "narrative", RerankingMinImprovement: 0, // Always apply reranking
ContextShowLastSummary: true, ContextObservations: 100,
ContextObsTypes: DefaultObservationTypes, ContextFullCount: 25,
ContextObsConcepts: DefaultObservationConcepts, ContextSessionCount: 10,
ContextShowReadTokens: true,
ContextShowWorkTokens: true,
ContextFullField: "narrative",
ContextShowLastSummary: true,
ContextObsTypes: DefaultObservationTypes,
ContextObsConcepts: DefaultObservationConcepts,
ContextRelevanceThreshold: 0.3, // Minimum 30% similarity to include
ContextMaxPromptResults: 10, // Cap at 10 results max (0 = no cap, threshold only)
} }
} }
@@ -166,6 +190,28 @@ func Load() (*Config, error) {
if v, ok := settings["CLAUDE_CODE_PATH"].(string); ok { if v, ok := settings["CLAUDE_CODE_PATH"].(string); ok {
cfg.ClaudeCodePath = v cfg.ClaudeCodePath = v
} }
if v, ok := settings["CLAUDE_MNEMONIC_EMBEDDING_MODEL"].(string); ok && v != "" {
cfg.EmbeddingModel = v
}
// Reranking settings
if v, ok := settings["CLAUDE_MNEMONIC_RERANKING_ENABLED"].(bool); ok {
cfg.RerankingEnabled = v
}
if v, ok := settings["CLAUDE_MNEMONIC_RERANKING_CANDIDATES"].(float64); ok && v > 0 {
cfg.RerankingCandidates = int(v)
}
if v, ok := settings["CLAUDE_MNEMONIC_RERANKING_RESULTS"].(float64); ok && v > 0 {
cfg.RerankingResults = int(v)
}
if v, ok := settings["CLAUDE_MNEMONIC_RERANKING_ALPHA"].(float64); ok && v >= 0 && v <= 1 {
cfg.RerankingAlpha = v
}
if v, ok := settings["CLAUDE_MNEMONIC_RERANKING_MIN_IMPROVEMENT"].(float64); ok && v >= 0 {
cfg.RerankingMinImprovement = v
}
if v, ok := settings["CLAUDE_MNEMONIC_RERANKING_PURE_MODE"].(bool); ok {
cfg.RerankingPureMode = v
}
if v, ok := settings["CLAUDE_MNEMONIC_CONTEXT_OBSERVATIONS"].(float64); ok { if v, ok := settings["CLAUDE_MNEMONIC_CONTEXT_OBSERVATIONS"].(float64); ok {
cfg.ContextObservations = int(v) cfg.ContextObservations = int(v)
} }
@@ -181,6 +227,12 @@ func Load() (*Config, error) {
if v, ok := settings["CLAUDE_MNEMONIC_CONTEXT_OBS_CONCEPTS"].(string); ok && v != "" { if v, ok := settings["CLAUDE_MNEMONIC_CONTEXT_OBS_CONCEPTS"].(string); ok && v != "" {
cfg.ContextObsConcepts = splitTrim(v) cfg.ContextObsConcepts = splitTrim(v)
} }
if v, ok := settings["CLAUDE_MNEMONIC_CONTEXT_RELEVANCE_THRESHOLD"].(float64); ok && v >= 0 && v <= 1 {
cfg.ContextRelevanceThreshold = v
}
if v, ok := settings["CLAUDE_MNEMONIC_CONTEXT_MAX_PROMPT_RESULTS"].(float64); ok && v >= 0 {
cfg.ContextMaxPromptResults = int(v)
}
return cfg, nil return cfg, nil
} }
+276
View File
@@ -0,0 +1,276 @@
// Package sqlite provides SQLite database operations for claude-mnemonic.
package sqlite
import (
"context"
"time"
"github.com/lukaszraczylo/claude-mnemonic/pkg/models"
)
// SupersededRetentionDays is the number of days to keep superseded observations before deletion.
const SupersededRetentionDays = 3
// ConflictStore provides conflict-related database operations.
type ConflictStore struct {
store *Store
}
// NewConflictStore creates a new conflict store.
func NewConflictStore(store *Store) *ConflictStore {
return &ConflictStore{store: store}
}
// StoreConflict stores a new observation conflict.
func (s *ConflictStore) StoreConflict(ctx context.Context, conflict *models.ObservationConflict) (int64, error) {
const query = `
INSERT INTO observation_conflicts
(newer_obs_id, older_obs_id, conflict_type, resolution, reason, detected_at, detected_at_epoch, resolved, resolved_at)
VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?)
`
result, err := s.store.ExecContext(ctx, query,
conflict.NewerObsID, conflict.OlderObsID,
string(conflict.ConflictType), string(conflict.Resolution),
conflict.Reason, conflict.DetectedAt, conflict.DetectedAtEpoch,
conflict.Resolved, conflict.ResolvedAt,
)
if err != nil {
return 0, err
}
return result.LastInsertId()
}
// MarkObservationSuperseded marks an observation as superseded.
func (s *ConflictStore) MarkObservationSuperseded(ctx context.Context, obsID int64) error {
const query = `UPDATE observations SET is_superseded = 1 WHERE id = ?`
_, err := s.store.ExecContext(ctx, query, obsID)
return err
}
// MarkObservationsSuperseded marks multiple observations as superseded.
func (s *ConflictStore) MarkObservationsSuperseded(ctx context.Context, obsIDs []int64) error {
if len(obsIDs) == 0 {
return nil
}
query := `UPDATE observations SET is_superseded = 1 WHERE id IN (?` + repeatPlaceholders(len(obsIDs)-1) + `)` // #nosec G202 -- uses parameterized placeholders
args := int64SliceToInterface(obsIDs)
_, err := s.store.db.ExecContext(ctx, query, args...)
return err
}
// GetConflictsByObservationID retrieves all conflicts involving an observation.
func (s *ConflictStore) GetConflictsByObservationID(ctx context.Context, obsID int64) ([]*models.ObservationConflict, error) {
const query = `
SELECT id, newer_obs_id, older_obs_id, conflict_type, resolution, reason,
detected_at, detected_at_epoch, resolved, resolved_at
FROM observation_conflicts
WHERE newer_obs_id = ? OR older_obs_id = ?
ORDER BY detected_at_epoch DESC
`
rows, err := s.store.QueryContext(ctx, query, obsID, obsID)
if err != nil {
return nil, err
}
defer rows.Close()
return s.scanConflictRows(rows)
}
// GetUnresolvedConflicts retrieves all unresolved conflicts.
func (s *ConflictStore) GetUnresolvedConflicts(ctx context.Context, limit int) ([]*models.ObservationConflict, error) {
const query = `
SELECT id, newer_obs_id, older_obs_id, conflict_type, resolution, reason,
detected_at, detected_at_epoch, resolved, resolved_at
FROM observation_conflicts
WHERE resolved = 0
ORDER BY detected_at_epoch DESC
LIMIT ?
`
rows, err := s.store.QueryContext(ctx, query, limit)
if err != nil {
return nil, err
}
defer rows.Close()
return s.scanConflictRows(rows)
}
// GetSupersededObservationIDs returns IDs of all observations that have been superseded.
func (s *ConflictStore) GetSupersededObservationIDs(ctx context.Context, project string) ([]int64, error) {
const query = `
SELECT DISTINCT older_obs_id
FROM observation_conflicts oc
JOIN observations o ON o.id = oc.older_obs_id
WHERE oc.resolution = 'prefer_newer'
AND (o.project = ? OR o.scope = 'global')
`
rows, err := s.store.QueryContext(ctx, query, project)
if err != nil {
return nil, err
}
defer rows.Close()
var ids []int64
for rows.Next() {
var id int64
if err := rows.Scan(&id); err != nil {
return nil, err
}
ids = append(ids, id)
}
return ids, rows.Err()
}
// ResolveConflict marks a conflict as resolved.
func (s *ConflictStore) ResolveConflict(ctx context.Context, conflictID int64, resolution models.ConflictResolution) error {
now := time.Now().Format(time.RFC3339)
const query = `
UPDATE observation_conflicts
SET resolved = 1, resolved_at = ?, resolution = ?
WHERE id = ?
`
_, err := s.store.ExecContext(ctx, query, now, string(resolution), conflictID)
return err
}
// DeleteConflictsByObservationID deletes all conflicts involving an observation.
// Called when an observation is deleted.
func (s *ConflictStore) DeleteConflictsByObservationID(ctx context.Context, obsID int64) error {
const query = `DELETE FROM observation_conflicts WHERE newer_obs_id = ? OR older_obs_id = ?`
_, err := s.store.ExecContext(ctx, query, obsID, obsID)
return err
}
// ConflictWithDetails contains a conflict with its observation details.
type ConflictWithDetails struct {
Conflict *models.ObservationConflict
NewerObsTitle string
OlderObsTitle string
}
// CleanupSupersededObservations deletes observations that have been superseded for longer than
// SupersededRetentionDays. Returns the IDs of deleted observations for downstream cleanup (e.g., vector DB).
func (s *ConflictStore) CleanupSupersededObservations(ctx context.Context, project string) ([]int64, error) {
// Calculate cutoff time (3 days ago in milliseconds)
cutoffEpoch := time.Now().AddDate(0, 0, -SupersededRetentionDays).UnixMilli()
// First, find the IDs that will be deleted
// We delete observations that:
// 1. Are marked as superseded
// 2. Have a conflict record where they are the older observation
// 3. The conflict was detected more than 3 days ago
const selectQuery = `
SELECT DISTINCT o.id FROM observations o
JOIN observation_conflicts oc ON o.id = oc.older_obs_id
WHERE o.is_superseded = 1
AND o.project = ?
AND oc.detected_at_epoch < ?
`
rows, err := s.store.QueryContext(ctx, selectQuery, project, cutoffEpoch)
if err != nil {
return nil, err
}
defer rows.Close()
var toDelete []int64
for rows.Next() {
var id int64
if err := rows.Scan(&id); err != nil {
return nil, err
}
toDelete = append(toDelete, id)
}
if err := rows.Err(); err != nil {
return nil, err
}
if len(toDelete) == 0 {
return nil, nil
}
// Delete the conflict records first (due to foreign key constraints)
for _, obsID := range toDelete {
if err := s.DeleteConflictsByObservationID(ctx, obsID); err != nil {
return nil, err
}
}
// Delete the observations
deleteQuery := `DELETE FROM observations WHERE id IN (?` + repeatPlaceholders(len(toDelete)-1) + `)` // #nosec G202 -- uses parameterized placeholders
args := int64SliceToInterface(toDelete)
_, err = s.store.db.ExecContext(ctx, deleteQuery, args...)
if err != nil {
return nil, err
}
return toDelete, nil
}
// GetConflictsWithDetails retrieves all conflicts with observation titles for display.
func (s *ConflictStore) GetConflictsWithDetails(ctx context.Context, project string, limit int) ([]*ConflictWithDetails, error) {
const query = `
SELECT oc.id, oc.newer_obs_id, oc.older_obs_id, oc.conflict_type, oc.resolution, oc.reason,
oc.detected_at, oc.detected_at_epoch, oc.resolved, oc.resolved_at,
COALESCE(newer.title, '') as newer_title,
COALESCE(older.title, '') as older_title
FROM observation_conflicts oc
JOIN observations newer ON newer.id = oc.newer_obs_id
JOIN observations older ON older.id = oc.older_obs_id
WHERE newer.project = ? OR older.project = ?
ORDER BY oc.detected_at_epoch DESC
LIMIT ?
`
rows, err := s.store.QueryContext(ctx, query, project, project, limit)
if err != nil {
return nil, err
}
defer rows.Close()
var results []*ConflictWithDetails
for rows.Next() {
var c models.ObservationConflict
var cwd ConflictWithDetails
if err := rows.Scan(
&c.ID, &c.NewerObsID, &c.OlderObsID,
&c.ConflictType, &c.Resolution, &c.Reason,
&c.DetectedAt, &c.DetectedAtEpoch,
&c.Resolved, &c.ResolvedAt,
&cwd.NewerObsTitle, &cwd.OlderObsTitle,
); err != nil {
return nil, err
}
cwd.Conflict = &c
results = append(results, &cwd)
}
return results, rows.Err()
}
// scanConflictRows scans multiple conflicts from rows.
func (s *ConflictStore) scanConflictRows(rows interface {
Next() bool
Scan(...interface{}) error
Err() error
}) ([]*models.ObservationConflict, error) {
var conflicts []*models.ObservationConflict
for rows.Next() {
var c models.ObservationConflict
if err := rows.Scan(
&c.ID, &c.NewerObsID, &c.OlderObsID,
&c.ConflictType, &c.Resolution, &c.Reason,
&c.DetectedAt, &c.DetectedAtEpoch,
&c.Resolved, &c.ResolvedAt,
); err != nil {
return nil, err
}
conflicts = append(conflicts, &c)
}
return conflicts, rows.Err()
}
+207
View File
@@ -283,6 +283,213 @@ var Migrations = []Migration{
ON user_prompts(claude_session_id, prompt_number); ON user_prompts(claude_session_id, prompt_number);
`, `,
}, },
{
Version: 19,
Name: "vectors_with_model_version",
SQL: `
-- Drop old vectors table (virtual tables cannot be altered)
DROP TABLE IF EXISTS vectors;
-- Recreate vectors table with model_version column
-- Uses bge-small-en-v1.5 embeddings (384 dimensions)
CREATE VIRTUAL TABLE IF NOT EXISTS vectors USING vec0(
doc_id TEXT PRIMARY KEY,
embedding float[384],
sqlite_id INTEGER,
doc_type TEXT,
field_type TEXT,
project TEXT,
scope TEXT,
model_version TEXT
);
`,
},
{
Version: 20,
Name: "importance_scoring",
SQL: `
-- Importance scoring system for observations
-- Implements multi-factor scoring: type weight, recency decay, user feedback, concept weights, retrieval boost
-- Cached importance score (recalculated periodically)
ALTER TABLE observations ADD COLUMN importance_score REAL DEFAULT 1.0;
-- User feedback: -1 = thumbs down, 0 = neutral, 1 = thumbs up
ALTER TABLE observations ADD COLUMN user_feedback INTEGER DEFAULT 0;
-- Retrieval tracking: how many times this observation was returned in searches
ALTER TABLE observations ADD COLUMN retrieval_count INTEGER DEFAULT 0;
-- Last time this observation was retrieved (for analytics)
ALTER TABLE observations ADD COLUMN last_retrieved_at_epoch INTEGER;
-- Timestamp of last score recalculation
ALTER TABLE observations ADD COLUMN score_updated_at_epoch INTEGER;
-- Index for importance-based sorting (primary ordering strategy)
CREATE INDEX IF NOT EXISTS idx_observations_importance
ON observations(importance_score DESC, created_at_epoch DESC);
-- Index for finding observations needing score recalculation
CREATE INDEX IF NOT EXISTS idx_observations_score_updated
ON observations(score_updated_at_epoch);
-- Configurable concept weights table
-- Allows runtime tuning of how much each concept contributes to importance
CREATE TABLE IF NOT EXISTS concept_weights (
concept TEXT PRIMARY KEY,
weight REAL NOT NULL DEFAULT 0.1,
updated_at TEXT NOT NULL
);
-- Seed default concept weights (security highest, tooling lowest)
INSERT OR IGNORE INTO concept_weights (concept, weight, updated_at) VALUES
('security', 0.30, datetime('now')),
('gotcha', 0.25, datetime('now')),
('best-practice', 0.20, datetime('now')),
('anti-pattern', 0.20, datetime('now')),
('architecture', 0.15, datetime('now')),
('performance', 0.15, datetime('now')),
('error-handling', 0.15, datetime('now')),
('pattern', 0.10, datetime('now')),
('testing', 0.10, datetime('now')),
('debugging', 0.10, datetime('now')),
('workflow', 0.05, datetime('now')),
('tooling', 0.05, datetime('now'));
`,
},
{
Version: 21,
Name: "observation_conflicts",
SQL: `
-- Observation conflicts table for tracking contradictions and superseded observations
-- Implements Issue #5: Contradiction & Obsolescence Detection
CREATE TABLE IF NOT EXISTS observation_conflicts (
id INTEGER PRIMARY KEY AUTOINCREMENT,
newer_obs_id INTEGER NOT NULL,
older_obs_id INTEGER NOT NULL,
conflict_type TEXT NOT NULL CHECK(conflict_type IN ('superseded', 'contradicts', 'outdated_pattern')),
resolution TEXT NOT NULL CHECK(resolution IN ('prefer_newer', 'prefer_older', 'manual')),
reason TEXT,
detected_at TEXT NOT NULL,
detected_at_epoch INTEGER NOT NULL,
resolved INTEGER DEFAULT 0,
resolved_at TEXT,
FOREIGN KEY(newer_obs_id) REFERENCES observations(id) ON DELETE CASCADE,
FOREIGN KEY(older_obs_id) REFERENCES observations(id) ON DELETE CASCADE
);
-- Index for looking up conflicts by observation ID
CREATE INDEX IF NOT EXISTS idx_conflicts_newer ON observation_conflicts(newer_obs_id);
CREATE INDEX IF NOT EXISTS idx_conflicts_older ON observation_conflicts(older_obs_id);
CREATE INDEX IF NOT EXISTS idx_conflicts_unresolved ON observation_conflicts(resolved, detected_at_epoch DESC);
-- Add is_superseded column to observations for quick filtering
-- Set to 1 when this observation has been superseded by a newer one
ALTER TABLE observations ADD COLUMN is_superseded INTEGER DEFAULT 0;
-- Index for filtering out superseded observations in queries
CREATE INDEX IF NOT EXISTS idx_observations_superseded ON observations(is_superseded, importance_score DESC);
`,
},
{
Version: 22,
Name: "patterns_table",
SQL: `
-- Pattern Recognition Engine (Issue #7)
-- Tracks recurring patterns detected across observations
-- Enables Claude to reference historical insights: "I've encountered this pattern 12 times."
CREATE TABLE IF NOT EXISTS patterns (
id INTEGER PRIMARY KEY AUTOINCREMENT,
name TEXT NOT NULL,
type TEXT NOT NULL CHECK(type IN ('bug', 'refactor', 'architecture', 'anti-pattern', 'best-practice')),
description TEXT,
signature TEXT, -- JSON array of keywords/concepts for detection
recommendation TEXT, -- What works for this pattern
frequency INTEGER DEFAULT 1, -- How many times encountered
projects TEXT, -- JSON array of projects where seen
observation_ids TEXT, -- JSON array of source observation IDs
status TEXT DEFAULT 'active' CHECK(status IN ('active', 'deprecated', 'merged')),
merged_into_id INTEGER, -- If status is 'merged', which pattern it merged into
confidence REAL DEFAULT 0.5, -- Detection confidence (0.0-1.0)
last_seen_at TEXT NOT NULL,
last_seen_at_epoch INTEGER NOT NULL,
created_at TEXT NOT NULL,
created_at_epoch INTEGER NOT NULL,
FOREIGN KEY(merged_into_id) REFERENCES patterns(id) ON DELETE SET NULL
);
-- Indexes for efficient pattern queries
CREATE INDEX IF NOT EXISTS idx_patterns_type ON patterns(type);
CREATE INDEX IF NOT EXISTS idx_patterns_status ON patterns(status);
CREATE INDEX IF NOT EXISTS idx_patterns_frequency ON patterns(frequency DESC);
CREATE INDEX IF NOT EXISTS idx_patterns_confidence ON patterns(confidence DESC);
CREATE INDEX IF NOT EXISTS idx_patterns_last_seen ON patterns(last_seen_at_epoch DESC);
-- FTS5 virtual table for pattern search
CREATE VIRTUAL TABLE IF NOT EXISTS patterns_fts USING fts5(
name, description, recommendation,
content='patterns',
content_rowid='id'
);
-- Triggers for FTS5 sync
CREATE TRIGGER IF NOT EXISTS patterns_ai AFTER INSERT ON patterns BEGIN
INSERT INTO patterns_fts(rowid, name, description, recommendation)
VALUES (new.id, new.name, new.description, new.recommendation);
END;
CREATE TRIGGER IF NOT EXISTS patterns_ad AFTER DELETE ON patterns BEGIN
INSERT INTO patterns_fts(patterns_fts, rowid, name, description, recommendation)
VALUES('delete', old.id, old.name, old.description, old.recommendation);
END;
CREATE TRIGGER IF NOT EXISTS patterns_au AFTER UPDATE ON patterns BEGIN
INSERT INTO patterns_fts(patterns_fts, rowid, name, description, recommendation)
VALUES('delete', old.id, old.name, old.description, old.recommendation);
INSERT INTO patterns_fts(rowid, name, description, recommendation)
VALUES (new.id, new.name, new.description, new.recommendation);
END;
`,
},
{
Version: 23,
Name: "observation_relations",
SQL: `
-- Knowledge Graph: Observation Relations (Issue #4)
-- Tracks explicit relationships between observations for knowledge graph navigation.
-- Enables queries like "What caused this bug?" or "What depends on this decision?"
CREATE TABLE IF NOT EXISTS observation_relations (
id INTEGER PRIMARY KEY AUTOINCREMENT,
source_id INTEGER NOT NULL,
target_id INTEGER NOT NULL,
relation_type TEXT NOT NULL CHECK(relation_type IN ('causes', 'fixes', 'supersedes', 'depends_on', 'relates_to', 'evolves_from')),
confidence REAL NOT NULL DEFAULT 0.5,
detection_source TEXT NOT NULL CHECK(detection_source IN ('file_overlap', 'embedding_similarity', 'temporal_proximity', 'narrative_mention', 'concept_overlap', 'type_progression')),
reason TEXT,
created_at TEXT NOT NULL,
created_at_epoch INTEGER NOT NULL,
FOREIGN KEY(source_id) REFERENCES observations(id) ON DELETE CASCADE,
FOREIGN KEY(target_id) REFERENCES observations(id) ON DELETE CASCADE,
UNIQUE(source_id, target_id, relation_type)
);
-- Index for finding relations by source observation
CREATE INDEX IF NOT EXISTS idx_relations_source ON observation_relations(source_id);
-- Index for finding relations by target observation
CREATE INDEX IF NOT EXISTS idx_relations_target ON observation_relations(target_id);
-- Index for relation type queries
CREATE INDEX IF NOT EXISTS idx_relations_type ON observation_relations(relation_type);
-- Index for confidence-based filtering
CREATE INDEX IF NOT EXISTS idx_relations_confidence ON observation_relations(confidence DESC);
-- Index for finding all relations involving an observation (either direction)
CREATE INDEX IF NOT EXISTS idx_relations_both ON observation_relations(source_id, target_id);
`,
},
} }
// MigrationManager handles database schema migrations. // MigrationManager handles database schema migrations.
+229 -40
View File
@@ -11,14 +11,27 @@ import (
"github.com/lukaszraczylo/claude-mnemonic/pkg/models" "github.com/lukaszraczylo/claude-mnemonic/pkg/models"
) )
// observationColumns is the standard list of columns to select for observations.
// This ensures consistency across all observation queries and includes importance scoring fields.
const observationColumns = `id, sdk_session_id, project, COALESCE(scope, 'project') as scope, type,
title, subtitle, facts, narrative, concepts, files_read, files_modified, file_mtimes,
prompt_number, discovery_tokens, created_at, created_at_epoch,
COALESCE(importance_score, 1.0) as importance_score,
COALESCE(user_feedback, 0) as user_feedback,
COALESCE(retrieval_count, 0) as retrieval_count,
last_retrieved_at_epoch, score_updated_at_epoch,
COALESCE(is_superseded, 0) as is_superseded`
// CleanupFunc is a callback for when observations are cleaned up. // CleanupFunc is a callback for when observations are cleaned up.
// Receives the IDs of deleted observations for downstream cleanup (e.g., vector DB). // Receives the IDs of deleted observations for downstream cleanup (e.g., vector DB).
type CleanupFunc func(ctx context.Context, deletedIDs []int64) type CleanupFunc func(ctx context.Context, deletedIDs []int64)
// ObservationStore provides observation-related database operations. // ObservationStore provides observation-related database operations.
type ObservationStore struct { type ObservationStore struct {
store *Store store *Store
cleanupFunc CleanupFunc cleanupFunc CleanupFunc
conflictStore *ConflictStore
relationStore *RelationStore
} }
// NewObservationStore creates a new observation store. // NewObservationStore creates a new observation store.
@@ -31,6 +44,16 @@ func (s *ObservationStore) SetCleanupFunc(fn CleanupFunc) {
s.cleanupFunc = fn s.cleanupFunc = fn
} }
// SetConflictStore sets the conflict store for conflict detection.
func (s *ObservationStore) SetConflictStore(conflictStore *ConflictStore) {
s.conflictStore = conflictStore
}
// SetRelationStore sets the relation store for relationship detection.
func (s *ObservationStore) SetRelationStore(relationStore *RelationStore) {
s.relationStore = relationStore
}
// StoreObservation stores a new observation. // StoreObservation stores a new observation.
func (s *ObservationStore) StoreObservation(ctx context.Context, sdkSessionID, project string, obs *models.ParsedObservation, promptNumber int, discoveryTokens int64) (int64, int64, error) { func (s *ObservationStore) StoreObservation(ctx context.Context, sdkSessionID, project string, obs *models.ParsedObservation, promptNumber int, discoveryTokens int64) (int64, int64, error) {
now := time.Now() now := time.Now()
@@ -86,9 +109,112 @@ func (s *ObservationStore) StoreObservation(ctx context.Context, sdkSessionID, p
}(project) }(project)
} }
// Detect conflicts with existing observations (async to not block handler)
if s.conflictStore != nil && project != "" {
go func(newObsID int64, proj string, parsedObs *models.ParsedObservation) {
conflictCtx, cancel := context.WithTimeout(context.Background(), 10*time.Second)
defer cancel()
s.detectAndStoreConflicts(conflictCtx, newObsID, proj, parsedObs)
}(id, project, obs)
}
// Detect relationships with existing observations (async to not block handler)
if s.relationStore != nil && project != "" {
go func(newObsID int64, proj string) {
relationCtx, cancel := context.WithTimeout(context.Background(), 10*time.Second)
defer cancel()
s.detectAndStoreRelations(relationCtx, newObsID, proj)
}(id, project)
}
return id, nowEpoch, nil return id, nowEpoch, nil
} }
// detectAndStoreConflicts detects conflicts between a new observation and existing ones.
func (s *ObservationStore) detectAndStoreConflicts(ctx context.Context, newObsID int64, project string, parsedObs *models.ParsedObservation) {
// Fetch the newly stored observation
newObs, err := s.GetObservationByID(ctx, newObsID)
if err != nil || newObs == nil {
return
}
// Fetch recent observations from the same project to check for conflicts
existing, err := s.GetRecentObservations(ctx, project, 50)
if err != nil {
return
}
// Detect conflicts
conflicts := models.DetectConflictsWithExisting(newObs, existing)
// Store conflicts and mark superseded observations
var supersededIDs []int64
for _, result := range conflicts {
for _, olderID := range result.OlderObsIDs {
conflict := models.NewObservationConflict(
newObsID, olderID,
result.Type, result.Resolution, result.Reason,
)
if _, err := s.conflictStore.StoreConflict(ctx, conflict); err != nil {
continue
}
// If resolution is to prefer newer, mark older as superseded
if result.Resolution == models.ResolutionPreferNewer {
supersededIDs = append(supersededIDs, olderID)
}
}
}
// Mark superseded observations
if len(supersededIDs) > 0 {
_ = s.conflictStore.MarkObservationsSuperseded(ctx, supersededIDs)
}
// Cleanup old superseded observations (older than 3 days)
deletedIDs, _ := s.conflictStore.CleanupSupersededObservations(ctx, project)
if len(deletedIDs) > 0 && s.cleanupFunc != nil {
s.cleanupFunc(ctx, deletedIDs)
}
}
// MinRelationConfidence is the minimum confidence threshold for storing relations.
const MinRelationConfidence = 0.4
// detectAndStoreRelations detects relationships between a new observation and existing ones.
func (s *ObservationStore) detectAndStoreRelations(ctx context.Context, newObsID int64, project string) {
// Fetch the newly stored observation
newObs, err := s.GetObservationByID(ctx, newObsID)
if err != nil || newObs == nil {
return
}
// Fetch recent observations from the same project to check for relations
existing, err := s.GetRecentObservations(ctx, project, 50)
if err != nil {
return
}
// Detect relationships using the models package detection logic
results := models.DetectRelationsWithExisting(newObs, existing, MinRelationConfidence)
if len(results) == 0 {
return
}
// Convert detection results to relation objects
relations := make([]*models.ObservationRelation, len(results))
for i, r := range results {
relations[i] = models.NewObservationRelation(
r.SourceID, r.TargetID,
r.RelationType, r.Confidence,
r.DetectionSource, r.Reason,
)
}
// Store all relations
_ = s.relationStore.StoreRelations(ctx, relations)
}
// ensureSessionExists creates a session if it doesn't exist. // ensureSessionExists creates a session if it doesn't exist.
func (s *ObservationStore) ensureSessionExists(ctx context.Context, sdkSessionID, project string) error { func (s *ObservationStore) ensureSessionExists(ctx context.Context, sdkSessionID, project string) error {
return EnsureSessionExists(ctx, s.store, sdkSessionID, project) return EnsureSessionExists(ctx, s.store, sdkSessionID, project)
@@ -96,13 +222,7 @@ func (s *ObservationStore) ensureSessionExists(ctx context.Context, sdkSessionID
// GetObservationByID retrieves an observation by ID. // GetObservationByID retrieves an observation by ID.
func (s *ObservationStore) GetObservationByID(ctx context.Context, id int64) (*models.Observation, error) { func (s *ObservationStore) GetObservationByID(ctx context.Context, id int64) (*models.Observation, error) {
const query = ` query := `SELECT ` + observationColumns + ` FROM observations WHERE id = ?`
SELECT id, sdk_session_id, project, COALESCE(scope, 'project') as scope, type, title, subtitle, facts, narrative,
concepts, files_read, files_modified, file_mtimes, prompt_number, discovery_tokens,
created_at, created_at_epoch
FROM observations
WHERE id = ?
`
obs, err := scanObservation(s.store.QueryRowContext(ctx, query, id)) obs, err := scanObservation(s.store.QueryRowContext(ctx, query, id))
if err == sql.ErrNoRows { if err == sql.ErrNoRows {
@@ -112,6 +232,7 @@ func (s *ObservationStore) GetObservationByID(ctx context.Context, id int64) (*m
} }
// GetObservationsByIDs retrieves observations by a list of IDs. // GetObservationsByIDs retrieves observations by a list of IDs.
// Results are ordered by importance_score DESC by default, with created_at_epoch as secondary sort.
func (s *ObservationStore) GetObservationsByIDs(ctx context.Context, ids []int64, orderBy string, limit int) ([]*models.Observation, error) { func (s *ObservationStore) GetObservationsByIDs(ctx context.Context, ids []int64, orderBy string, limit int) ([]*models.Observation, error) {
if len(ids) == 0 { if len(ids) == 0 {
return nil, nil return nil, nil
@@ -119,18 +240,22 @@ func (s *ObservationStore) GetObservationsByIDs(ctx context.Context, ids []int64
// Build query with placeholders // Build query with placeholders
// #nosec G202 -- query uses parameterized placeholders, not user input // #nosec G202 -- query uses parameterized placeholders, not user input
query := ` query := `SELECT ` + observationColumns + `
SELECT id, sdk_session_id, project, COALESCE(scope, 'project') as scope, type, title, subtitle, facts, narrative,
concepts, files_read, files_modified, file_mtimes, prompt_number, discovery_tokens,
created_at, created_at_epoch
FROM observations FROM observations
WHERE id IN (?` + repeatPlaceholders(len(ids)-1) + `) WHERE id IN (?` + repeatPlaceholders(len(ids)-1) + `)
ORDER BY created_at_epoch ` ORDER BY `
if orderBy == "date_asc" { // Default to importance-based ordering
query += "ASC" switch orderBy {
} else { case "date_asc":
query += "DESC" query += "created_at_epoch ASC"
case "date_desc":
query += "created_at_epoch DESC"
case "importance":
query += "importance_score DESC, created_at_epoch DESC"
default:
// Default: importance first, then recency
query += "COALESCE(importance_score, 1.0) DESC, created_at_epoch DESC"
} }
if limit > 0 { if limit > 0 {
@@ -154,14 +279,56 @@ func (s *ObservationStore) GetObservationsByIDs(ctx context.Context, ids []int64
// GetRecentObservations retrieves recent observations for a project. // GetRecentObservations retrieves recent observations for a project.
// This includes project-scoped observations for the specified project AND global observations. // This includes project-scoped observations for the specified project AND global observations.
// Results are ordered by importance_score DESC, then created_at_epoch DESC.
func (s *ObservationStore) GetRecentObservations(ctx context.Context, project string, limit int) ([]*models.Observation, error) { func (s *ObservationStore) GetRecentObservations(ctx context.Context, project string, limit int) ([]*models.Observation, error) {
const query = ` query := `SELECT ` + observationColumns + `
SELECT id, sdk_session_id, project, COALESCE(scope, 'project') as scope, type, title, subtitle, facts, narrative,
concepts, files_read, files_modified, file_mtimes, prompt_number, discovery_tokens,
created_at, created_at_epoch
FROM observations FROM observations
WHERE (project = ? AND (scope IS NULL OR scope = 'project')) WHERE (project = ? AND (scope IS NULL OR scope = 'project'))
OR scope = 'global' OR scope = 'global'
ORDER BY COALESCE(importance_score, 1.0) DESC, created_at_epoch DESC
LIMIT ?
`
rows, err := s.store.QueryContext(ctx, query, project, limit)
if err != nil {
return nil, err
}
defer rows.Close()
return scanObservationRows(rows)
}
// GetActiveObservations retrieves recent non-superseded observations for a project.
// This excludes observations that have been marked as superseded by newer ones.
// Use this for context injection where you want to avoid outdated/contradicted advice.
// Results are ordered by importance_score DESC, then created_at_epoch DESC.
func (s *ObservationStore) GetActiveObservations(ctx context.Context, project string, limit int) ([]*models.Observation, error) {
query := `SELECT ` + observationColumns + `
FROM observations
WHERE ((project = ? AND (scope IS NULL OR scope = 'project'))
OR scope = 'global')
AND COALESCE(is_superseded, 0) = 0
ORDER BY COALESCE(importance_score, 1.0) DESC, created_at_epoch DESC
LIMIT ?
`
rows, err := s.store.QueryContext(ctx, query, project, limit)
if err != nil {
return nil, err
}
defer rows.Close()
return scanObservationRows(rows)
}
// GetSupersededObservations retrieves observations that have been superseded by newer ones.
// Use this for verification/debugging to see which observations were marked as outdated.
// Results are ordered by created_at_epoch DESC.
func (s *ObservationStore) GetSupersededObservations(ctx context.Context, project string, limit int) ([]*models.Observation, error) {
query := `SELECT ` + observationColumns + `
FROM observations
WHERE project = ?
AND COALESCE(is_superseded, 0) = 1
ORDER BY created_at_epoch DESC ORDER BY created_at_epoch DESC
LIMIT ? LIMIT ?
` `
@@ -178,14 +345,12 @@ func (s *ObservationStore) GetRecentObservations(ctx context.Context, project st
// GetObservationsByProjectStrict retrieves observations strictly for a specific project. // GetObservationsByProjectStrict retrieves observations strictly for a specific project.
// Unlike GetRecentObservations, this does NOT include global observations from other projects. // Unlike GetRecentObservations, this does NOT include global observations from other projects.
// Use this for dashboard filtering where the user expects to see only that project's data. // Use this for dashboard filtering where the user expects to see only that project's data.
// Results are ordered by importance_score DESC, then created_at_epoch DESC.
func (s *ObservationStore) GetObservationsByProjectStrict(ctx context.Context, project string, limit int) ([]*models.Observation, error) { func (s *ObservationStore) GetObservationsByProjectStrict(ctx context.Context, project string, limit int) ([]*models.Observation, error) {
const query = ` query := `SELECT ` + observationColumns + `
SELECT id, sdk_session_id, project, COALESCE(scope, 'project') as scope, type, title, subtitle, facts, narrative,
concepts, files_read, files_modified, file_mtimes, prompt_number, discovery_tokens,
created_at, created_at_epoch
FROM observations FROM observations
WHERE project = ? WHERE project = ?
ORDER BY created_at_epoch DESC ORDER BY COALESCE(importance_score, 1.0) DESC, created_at_epoch DESC
LIMIT ? LIMIT ?
` `
@@ -210,13 +375,11 @@ func (s *ObservationStore) GetObservationCount(ctx context.Context, project stri
} }
// GetAllRecentObservations retrieves recent observations across all projects. // GetAllRecentObservations retrieves recent observations across all projects.
// Results are ordered by importance_score DESC, then created_at_epoch DESC.
func (s *ObservationStore) GetAllRecentObservations(ctx context.Context, limit int) ([]*models.Observation, error) { func (s *ObservationStore) GetAllRecentObservations(ctx context.Context, limit int) ([]*models.Observation, error) {
const query = ` query := `SELECT ` + observationColumns + `
SELECT id, sdk_session_id, project, COALESCE(scope, 'project') as scope, type, title, subtitle, facts, narrative,
concepts, files_read, files_modified, file_mtimes, prompt_number, discovery_tokens,
created_at, created_at_epoch
FROM observations FROM observations
ORDER BY created_at_epoch DESC ORDER BY COALESCE(importance_score, 1.0) DESC, created_at_epoch DESC
LIMIT ? LIMIT ?
` `
@@ -229,7 +392,24 @@ func (s *ObservationStore) GetAllRecentObservations(ctx context.Context, limit i
return scanObservationRows(rows) return scanObservationRows(rows)
} }
// GetAllObservations retrieves all observations (for vector rebuild).
func (s *ObservationStore) GetAllObservations(ctx context.Context) ([]*models.Observation, error) {
query := `SELECT ` + observationColumns + `
FROM observations
ORDER BY id
`
rows, err := s.store.QueryContext(ctx, query)
if err != nil {
return nil, err
}
defer rows.Close()
return scanObservationRows(rows)
}
// SearchObservationsFTS performs full-text search on observations. // SearchObservationsFTS performs full-text search on observations.
// Results are ordered by FTS rank (relevance), then by importance_score.
func (s *ObservationStore) SearchObservationsFTS(ctx context.Context, query, project string, limit int) ([]*models.Observation, error) { func (s *ObservationStore) SearchObservationsFTS(ctx context.Context, query, project string, limit int) ([]*models.Observation, error) {
if limit <= 0 { if limit <= 0 {
limit = 10 limit = 10
@@ -245,15 +425,21 @@ func (s *ObservationStore) SearchObservationsFTS(ctx context.Context, query, pro
ftsTerms := strings.Join(keywords, " OR ") ftsTerms := strings.Join(keywords, " OR ")
// Use FTS5 to search title, subtitle, and narrative // Use FTS5 to search title, subtitle, and narrative
const ftsQuery = ` // Include importance scoring columns and order by rank then importance
ftsQuery := `
SELECT o.id, o.sdk_session_id, o.project, COALESCE(o.scope, 'project') as scope, o.type, SELECT o.id, o.sdk_session_id, o.project, COALESCE(o.scope, 'project') as scope, o.type,
o.title, o.subtitle, o.facts, o.narrative, o.concepts, o.files_read, o.files_modified, o.title, o.subtitle, o.facts, o.narrative, o.concepts, o.files_read, o.files_modified,
o.file_mtimes, o.prompt_number, o.discovery_tokens, o.created_at, o.created_at_epoch o.file_mtimes, o.prompt_number, o.discovery_tokens, o.created_at, o.created_at_epoch,
COALESCE(o.importance_score, 1.0) as importance_score,
COALESCE(o.user_feedback, 0) as user_feedback,
COALESCE(o.retrieval_count, 0) as retrieval_count,
o.last_retrieved_at_epoch, o.score_updated_at_epoch,
COALESCE(o.is_superseded, 0) as is_superseded
FROM observations o FROM observations o
JOIN observations_fts fts ON o.id = fts.rowid JOIN observations_fts fts ON o.id = fts.rowid
WHERE observations_fts MATCH ? WHERE observations_fts MATCH ?
AND (o.project = ? OR o.scope = 'global') AND (o.project = ? OR o.scope = 'global')
ORDER BY rank ORDER BY rank, COALESCE(o.importance_score, 1.0) DESC
LIMIT ? LIMIT ?
` `
@@ -278,6 +464,7 @@ func (s *ObservationStore) SearchObservationsFTS(ctx context.Context, query, pro
} }
// searchObservationsLike performs fallback LIKE search on observations. // searchObservationsLike performs fallback LIKE search on observations.
// Results are ordered by importance_score DESC, then created_at_epoch DESC.
func (s *ObservationStore) searchObservationsLike(ctx context.Context, keywords []string, project string, limit int) ([]*models.Observation, error) { func (s *ObservationStore) searchObservationsLike(ctx context.Context, keywords []string, project string, limit int) ([]*models.Observation, error) {
if len(keywords) == 0 { if len(keywords) == 0 {
return nil, nil return nil, nil
@@ -294,14 +481,11 @@ func (s *ObservationStore) searchObservationsLike(ctx context.Context, keywords
} }
// #nosec G202 -- query uses parameterized placeholders, not user input // #nosec G202 -- query uses parameterized placeholders, not user input
query := ` query := `SELECT ` + observationColumns + `
SELECT id, sdk_session_id, project, COALESCE(scope, 'project') as scope, type,
title, subtitle, facts, narrative, concepts, files_read, files_modified,
file_mtimes, prompt_number, discovery_tokens, created_at, created_at_epoch
FROM observations FROM observations
WHERE (` + strings.Join(conditions, " OR ") + `) WHERE (` + strings.Join(conditions, " OR ") + `)
AND (project = ? OR scope = 'global') AND (project = ? OR scope = 'global')
ORDER BY created_at_epoch DESC ORDER BY COALESCE(importance_score, 1.0) DESC, created_at_epoch DESC
LIMIT ? LIMIT ?
` `
args = append(args, project, limit) args = append(args, project, limit)
@@ -445,6 +629,11 @@ func scanObservation(scanner interface{ Scan(...interface{}) error }) (*models.O
&obs.Concepts, &obs.FilesRead, &obs.FilesModified, &obs.FileMtimes, &obs.Concepts, &obs.FilesRead, &obs.FilesModified, &obs.FileMtimes,
&obs.PromptNumber, &obs.DiscoveryTokens, &obs.PromptNumber, &obs.DiscoveryTokens,
&obs.CreatedAt, &obs.CreatedAtEpoch, &obs.CreatedAt, &obs.CreatedAtEpoch,
// Importance scoring fields
&obs.ImportanceScore, &obs.UserFeedback, &obs.RetrievalCount,
&obs.LastRetrievedAt, &obs.ScoreUpdatedAt,
// Conflict detection fields
&obs.IsSuperseded,
); err != nil { ); err != nil {
return nil, err return nil, err
} }
+370
View File
@@ -0,0 +1,370 @@
// Package sqlite provides SQLite database operations for claude-mnemonic.
package sqlite
import (
"context"
"database/sql"
"encoding/json"
"time"
"github.com/lukaszraczylo/claude-mnemonic/pkg/models"
)
// patternColumns is the standard list of columns to select for patterns.
const patternColumns = `id, name, type, description, signature, recommendation,
frequency, projects, observation_ids, status, merged_into_id, confidence,
last_seen_at, last_seen_at_epoch, created_at, created_at_epoch`
// PatternCleanupFunc is a callback for when patterns are deleted.
type PatternCleanupFunc func(ctx context.Context, deletedIDs []int64)
// PatternStore provides pattern-related database operations.
type PatternStore struct {
store *Store
cleanupFunc PatternCleanupFunc
}
// NewPatternStore creates a new pattern store.
func NewPatternStore(store *Store) *PatternStore {
return &PatternStore{store: store}
}
// SetCleanupFunc sets the callback for when patterns are deleted.
func (s *PatternStore) SetCleanupFunc(fn PatternCleanupFunc) {
s.cleanupFunc = fn
}
// StorePattern stores a new pattern.
func (s *PatternStore) StorePattern(ctx context.Context, pattern *models.Pattern) (int64, error) {
signatureJSON, _ := json.Marshal(pattern.Signature)
projectsJSON, _ := json.Marshal(pattern.Projects)
obsIDsJSON, _ := json.Marshal(pattern.ObservationIDs)
const query = `
INSERT INTO patterns
(name, type, description, signature, recommendation, frequency, projects,
observation_ids, status, merged_into_id, confidence,
last_seen_at, last_seen_at_epoch, created_at, created_at_epoch)
VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)
`
result, err := s.store.ExecContext(ctx, query,
pattern.Name, string(pattern.Type),
nullString(pattern.Description.String), string(signatureJSON),
nullString(pattern.Recommendation.String),
pattern.Frequency, string(projectsJSON), string(obsIDsJSON),
string(pattern.Status), nullInt64(pattern.MergedIntoID),
pattern.Confidence, pattern.LastSeenAt, pattern.LastSeenEpoch,
pattern.CreatedAt, pattern.CreatedAtEpoch,
)
if err != nil {
return 0, err
}
return result.LastInsertId()
}
// UpdatePattern updates an existing pattern.
func (s *PatternStore) UpdatePattern(ctx context.Context, pattern *models.Pattern) error {
signatureJSON, _ := json.Marshal(pattern.Signature)
projectsJSON, _ := json.Marshal(pattern.Projects)
obsIDsJSON, _ := json.Marshal(pattern.ObservationIDs)
const query = `
UPDATE patterns SET
name = ?, type = ?, description = ?, signature = ?, recommendation = ?,
frequency = ?, projects = ?, observation_ids = ?, status = ?,
merged_into_id = ?, confidence = ?, last_seen_at = ?, last_seen_at_epoch = ?
WHERE id = ?
`
_, err := s.store.ExecContext(ctx, query,
pattern.Name, string(pattern.Type),
nullString(pattern.Description.String), string(signatureJSON),
nullString(pattern.Recommendation.String),
pattern.Frequency, string(projectsJSON), string(obsIDsJSON),
string(pattern.Status), nullInt64(pattern.MergedIntoID),
pattern.Confidence, pattern.LastSeenAt, pattern.LastSeenEpoch,
pattern.ID,
)
return err
}
// GetPatternByID retrieves a pattern by ID.
func (s *PatternStore) GetPatternByID(ctx context.Context, id int64) (*models.Pattern, error) {
query := `SELECT ` + patternColumns + ` FROM patterns WHERE id = ?`
row := s.store.QueryRowContext(ctx, query, id)
return scanPattern(row)
}
// GetPatternByName retrieves a pattern by name.
func (s *PatternStore) GetPatternByName(ctx context.Context, name string) (*models.Pattern, error) {
query := `SELECT ` + patternColumns + ` FROM patterns WHERE name = ? AND status = 'active'`
row := s.store.QueryRowContext(ctx, query, name)
pattern, err := scanPattern(row)
if err == sql.ErrNoRows {
return nil, nil
}
return pattern, err
}
// GetActivePatterns retrieves all active patterns.
func (s *PatternStore) GetActivePatterns(ctx context.Context, limit int) ([]*models.Pattern, error) {
query := `SELECT ` + patternColumns + `
FROM patterns
WHERE status = 'active'
ORDER BY frequency DESC, confidence DESC
LIMIT ?`
rows, err := s.store.QueryContext(ctx, query, limit)
if err != nil {
return nil, err
}
defer rows.Close()
return scanPatternRows(rows)
}
// GetPatternsByType retrieves patterns of a specific type.
func (s *PatternStore) GetPatternsByType(ctx context.Context, patternType models.PatternType, limit int) ([]*models.Pattern, error) {
query := `SELECT ` + patternColumns + `
FROM patterns
WHERE type = ? AND status = 'active'
ORDER BY frequency DESC, confidence DESC
LIMIT ?`
rows, err := s.store.QueryContext(ctx, query, string(patternType), limit)
if err != nil {
return nil, err
}
defer rows.Close()
return scanPatternRows(rows)
}
// GetPatternsByProject retrieves patterns that have been observed in a specific project.
func (s *PatternStore) GetPatternsByProject(ctx context.Context, project string, limit int) ([]*models.Pattern, error) {
// Use JSON path to search within the projects array
query := `SELECT ` + patternColumns + `
FROM patterns
WHERE status = 'active'
AND EXISTS (
SELECT 1 FROM json_each(projects)
WHERE json_each.value = ?
)
ORDER BY frequency DESC, confidence DESC
LIMIT ?`
rows, err := s.store.QueryContext(ctx, query, project, limit)
if err != nil {
return nil, err
}
defer rows.Close()
return scanPatternRows(rows)
}
// FindMatchingPatterns searches for patterns that match a given signature.
func (s *PatternStore) FindMatchingPatterns(ctx context.Context, signature []string, minScore float64) ([]*models.Pattern, error) {
// Get all active patterns and filter by signature match in Go
// This is simpler than complex SQL for JSON array matching
patterns, err := s.GetActivePatterns(ctx, 100)
if err != nil {
return nil, err
}
var matches []*models.Pattern
for _, pattern := range patterns {
score := models.CalculateMatchScore(signature, pattern.Signature)
if score >= minScore {
matches = append(matches, pattern)
}
}
return matches, nil
}
// MarkPatternDeprecated marks a pattern as deprecated.
func (s *PatternStore) MarkPatternDeprecated(ctx context.Context, id int64) error {
const query = `UPDATE patterns SET status = 'deprecated' WHERE id = ?`
_, err := s.store.ExecContext(ctx, query, id)
return err
}
// MergePatterns merges a source pattern into a target pattern.
func (s *PatternStore) MergePatterns(ctx context.Context, sourceID, targetID int64) error {
// Get both patterns
source, err := s.GetPatternByID(ctx, sourceID)
if err != nil {
return err
}
target, err := s.GetPatternByID(ctx, targetID)
if err != nil {
return err
}
// Merge source into target
target.Frequency += source.Frequency
for _, proj := range source.Projects {
found := false
for _, existing := range target.Projects {
if existing == proj {
found = true
break
}
}
if !found {
target.Projects = append(target.Projects, proj)
}
}
for _, obsID := range source.ObservationIDs {
found := false
for _, existing := range target.ObservationIDs {
if existing == obsID {
found = true
break
}
}
if !found {
target.ObservationIDs = append(target.ObservationIDs, obsID)
}
}
// Update target
if err := s.UpdatePattern(ctx, target); err != nil {
return err
}
// Mark source as merged
source.Status = models.PatternStatusMerged
source.MergedIntoID = sql.NullInt64{Int64: targetID, Valid: true}
return s.UpdatePattern(ctx, source)
}
// DeletePattern deletes a pattern by ID.
func (s *PatternStore) DeletePattern(ctx context.Context, id int64) error {
const query = `DELETE FROM patterns WHERE id = ?`
_, err := s.store.ExecContext(ctx, query, id)
if err == nil && s.cleanupFunc != nil {
s.cleanupFunc(ctx, []int64{id})
}
return err
}
// SearchPatternsFTS performs full-text search on patterns.
func (s *PatternStore) SearchPatternsFTS(ctx context.Context, searchQuery string, limit int) ([]*models.Pattern, error) {
query := `SELECT p.` + patternColumns + `
FROM patterns p
JOIN patterns_fts fts ON p.id = fts.rowid
WHERE patterns_fts MATCH ?
AND p.status = 'active'
ORDER BY rank
LIMIT ?`
rows, err := s.store.QueryContext(ctx, query, searchQuery, limit)
if err != nil {
return nil, err
}
defer rows.Close()
return scanPatternRows(rows)
}
// GetPatternStats returns statistics about patterns.
func (s *PatternStore) GetPatternStats(ctx context.Context) (*PatternStats, error) {
const query = `
SELECT
COUNT(*) as total,
COUNT(CASE WHEN status = 'active' THEN 1 END) as active,
COUNT(CASE WHEN status = 'deprecated' THEN 1 END) as deprecated,
COUNT(CASE WHEN status = 'merged' THEN 1 END) as merged,
COALESCE(SUM(frequency), 0) as total_occurrences,
COALESCE(AVG(confidence), 0) as avg_confidence,
COUNT(CASE WHEN type = 'bug' THEN 1 END) as bugs,
COUNT(CASE WHEN type = 'refactor' THEN 1 END) as refactors,
COUNT(CASE WHEN type = 'architecture' THEN 1 END) as architectures,
COUNT(CASE WHEN type = 'anti-pattern' THEN 1 END) as anti_patterns,
COUNT(CASE WHEN type = 'best-practice' THEN 1 END) as best_practices
FROM patterns
`
var stats PatternStats
err := s.store.QueryRowContext(ctx, query).Scan(
&stats.Total, &stats.Active, &stats.Deprecated, &stats.Merged,
&stats.TotalOccurrences, &stats.AvgConfidence,
&stats.Bugs, &stats.Refactors, &stats.Architectures,
&stats.AntiPatterns, &stats.BestPractices,
)
return &stats, err
}
// PatternStats contains aggregate statistics about patterns.
type PatternStats struct {
Total int `json:"total"`
Active int `json:"active"`
Deprecated int `json:"deprecated"`
Merged int `json:"merged"`
TotalOccurrences int `json:"total_occurrences"`
AvgConfidence float64 `json:"avg_confidence"`
Bugs int `json:"bugs"`
Refactors int `json:"refactors"`
Architectures int `json:"architectures"`
AntiPatterns int `json:"anti_patterns"`
BestPractices int `json:"best_practices"`
}
// scanPattern scans a single pattern from a row scanner.
func scanPattern(scanner interface{ Scan(...interface{}) error }) (*models.Pattern, error) {
var pattern models.Pattern
if err := scanner.Scan(
&pattern.ID, &pattern.Name, &pattern.Type,
&pattern.Description, &pattern.Signature, &pattern.Recommendation,
&pattern.Frequency, &pattern.Projects, &pattern.ObservationIDs,
&pattern.Status, &pattern.MergedIntoID, &pattern.Confidence,
&pattern.LastSeenAt, &pattern.LastSeenEpoch,
&pattern.CreatedAt, &pattern.CreatedAtEpoch,
); err != nil {
return nil, err
}
return &pattern, nil
}
// scanPatternRows scans multiple patterns from rows.
func scanPatternRows(rows *sql.Rows) ([]*models.Pattern, error) {
var patterns []*models.Pattern
for rows.Next() {
pattern, err := scanPattern(rows)
if err != nil {
return nil, err
}
patterns = append(patterns, pattern)
}
return patterns, rows.Err()
}
// nullInt64 converts sql.NullInt64 to the value needed for database insertion.
func nullInt64(n sql.NullInt64) interface{} {
if n.Valid {
return n.Int64
}
return nil
}
// IncrementPatternFrequency atomically increments a pattern's frequency and updates last_seen.
func (s *PatternStore) IncrementPatternFrequency(ctx context.Context, id int64, project string, observationID int64) error {
now := time.Now()
// Get current pattern
pattern, err := s.GetPatternByID(ctx, id)
if err != nil {
return err
}
// Add occurrence
pattern.AddOccurrence(project, observationID)
pattern.LastSeenAt = now.Format(time.RFC3339)
pattern.LastSeenEpoch = now.UnixMilli()
return s.UpdatePattern(ctx, pattern)
}
+507
View File
@@ -0,0 +1,507 @@
package sqlite
import (
"context"
"database/sql"
"testing"
"time"
"github.com/lukaszraczylo/claude-mnemonic/pkg/models"
)
// setupPatternTestStore creates a test store with patterns table.
func setupPatternTestStore(t *testing.T) *Store {
t.Helper()
db, _, cleanup := testDB(t)
t.Cleanup(cleanup)
createBaseTables(t, db)
return newStoreFromDB(db)
}
func TestPatternStore_StoreAndGet(t *testing.T) {
store := setupPatternTestStore(t)
patternStore := NewPatternStore(store)
ctx := context.Background()
// Create test pattern
pattern := &models.Pattern{
Name: "Test Pattern",
Type: models.PatternTypeBug,
Description: sql.NullString{String: "A test pattern", Valid: true},
Signature: []string{"nil", "error"},
Recommendation: sql.NullString{String: "Always check for nil", Valid: true},
Frequency: 1,
Projects: []string{"project1"},
ObservationIDs: []int64{1, 2},
Status: models.PatternStatusActive,
Confidence: 0.5,
LastSeenAt: time.Now().Format(time.RFC3339),
LastSeenEpoch: time.Now().UnixMilli(),
CreatedAt: time.Now().Format(time.RFC3339),
CreatedAtEpoch: time.Now().UnixMilli(),
}
// Store pattern
id, err := patternStore.StorePattern(ctx, pattern)
if err != nil {
t.Fatalf("StorePattern() error = %v", err)
}
if id <= 0 {
t.Errorf("Expected positive ID, got %d", id)
}
// Get pattern by ID
retrieved, err := patternStore.GetPatternByID(ctx, id)
if err != nil {
t.Fatalf("GetPatternByID() error = %v", err)
}
if retrieved.Name != pattern.Name {
t.Errorf("Expected name %s, got %s", pattern.Name, retrieved.Name)
}
if retrieved.Type != pattern.Type {
t.Errorf("Expected type %s, got %s", pattern.Type, retrieved.Type)
}
if len(retrieved.Signature) != len(pattern.Signature) {
t.Errorf("Expected %d signature elements, got %d",
len(pattern.Signature), len(retrieved.Signature))
}
}
func TestPatternStore_GetByName(t *testing.T) {
store := setupPatternTestStore(t)
patternStore := NewPatternStore(store)
ctx := context.Background()
pattern := createTestPattern("Unique Name Pattern")
_, err := patternStore.StorePattern(ctx, pattern)
if err != nil {
t.Fatalf("StorePattern() error = %v", err)
}
// Get by name
retrieved, err := patternStore.GetPatternByName(ctx, "Unique Name Pattern")
if err != nil {
t.Fatalf("GetPatternByName() error = %v", err)
}
if retrieved == nil {
t.Fatal("Expected pattern, got nil")
}
if retrieved.Name != "Unique Name Pattern" {
t.Errorf("Expected name 'Unique Name Pattern', got '%s'", retrieved.Name)
}
// Get non-existent pattern
nonExistent, err := patternStore.GetPatternByName(ctx, "Non Existent")
if err != nil {
t.Fatalf("GetPatternByName() error = %v", err)
}
if nonExistent != nil {
t.Errorf("Expected nil for non-existent pattern, got %v", nonExistent)
}
}
func TestPatternStore_GetActivePatterns(t *testing.T) {
store := setupPatternTestStore(t)
patternStore := NewPatternStore(store)
ctx := context.Background()
// Create multiple patterns with different statuses
active1 := createTestPattern("Active 1")
active1.Frequency = 5
active2 := createTestPattern("Active 2")
active2.Frequency = 3
deprecated := createTestPattern("Deprecated")
deprecated.Status = models.PatternStatusDeprecated
patternStore.StorePattern(ctx, active1)
patternStore.StorePattern(ctx, active2)
patternStore.StorePattern(ctx, deprecated)
// Get active patterns
patterns, err := patternStore.GetActivePatterns(ctx, 10)
if err != nil {
t.Fatalf("GetActivePatterns() error = %v", err)
}
if len(patterns) != 2 {
t.Errorf("Expected 2 active patterns, got %d", len(patterns))
}
// Check order (should be by frequency descending)
if len(patterns) >= 2 {
if patterns[0].Frequency < patterns[1].Frequency {
t.Errorf("Patterns not ordered by frequency descending")
}
}
}
func TestPatternStore_GetPatternsByType(t *testing.T) {
store := setupPatternTestStore(t)
patternStore := NewPatternStore(store)
ctx := context.Background()
// Create patterns of different types
bugPattern := createTestPattern("Bug Pattern")
bugPattern.Type = models.PatternTypeBug
refactorPattern := createTestPattern("Refactor Pattern")
refactorPattern.Type = models.PatternTypeRefactor
patternStore.StorePattern(ctx, bugPattern)
patternStore.StorePattern(ctx, refactorPattern)
// Get by type
bugs, err := patternStore.GetPatternsByType(ctx, models.PatternTypeBug, 10)
if err != nil {
t.Fatalf("GetPatternsByType() error = %v", err)
}
if len(bugs) != 1 {
t.Errorf("Expected 1 bug pattern, got %d", len(bugs))
}
refactors, err := patternStore.GetPatternsByType(ctx, models.PatternTypeRefactor, 10)
if err != nil {
t.Fatalf("GetPatternsByType() error = %v", err)
}
if len(refactors) != 1 {
t.Errorf("Expected 1 refactor pattern, got %d", len(refactors))
}
}
func TestPatternStore_GetPatternsByProject(t *testing.T) {
store := setupPatternTestStore(t)
patternStore := NewPatternStore(store)
ctx := context.Background()
// Create patterns with different projects
pattern1 := createTestPattern("Pattern 1")
pattern1.Projects = []string{"project-a", "project-b"}
pattern2 := createTestPattern("Pattern 2")
pattern2.Projects = []string{"project-b", "project-c"}
patternStore.StorePattern(ctx, pattern1)
patternStore.StorePattern(ctx, pattern2)
// Get by project
projectA, err := patternStore.GetPatternsByProject(ctx, "project-a", 10)
if err != nil {
t.Fatalf("GetPatternsByProject() error = %v", err)
}
if len(projectA) != 1 {
t.Errorf("Expected 1 pattern for project-a, got %d", len(projectA))
}
projectB, err := patternStore.GetPatternsByProject(ctx, "project-b", 10)
if err != nil {
t.Fatalf("GetPatternsByProject() error = %v", err)
}
if len(projectB) != 2 {
t.Errorf("Expected 2 patterns for project-b, got %d", len(projectB))
}
}
func TestPatternStore_UpdatePattern(t *testing.T) {
store := setupPatternTestStore(t)
patternStore := NewPatternStore(store)
ctx := context.Background()
// Create and store pattern
pattern := createTestPattern("Original Name")
id, _ := patternStore.StorePattern(ctx, pattern)
// Update pattern
pattern.ID = id
pattern.Name = "Updated Name"
pattern.Frequency = 10
pattern.Confidence = 0.9
err := patternStore.UpdatePattern(ctx, pattern)
if err != nil {
t.Fatalf("UpdatePattern() error = %v", err)
}
// Verify update
updated, _ := patternStore.GetPatternByID(ctx, id)
if updated.Name != "Updated Name" {
t.Errorf("Expected name 'Updated Name', got '%s'", updated.Name)
}
if updated.Frequency != 10 {
t.Errorf("Expected frequency 10, got %d", updated.Frequency)
}
if updated.Confidence != 0.9 {
t.Errorf("Expected confidence 0.9, got %f", updated.Confidence)
}
}
func TestPatternStore_DeletePattern(t *testing.T) {
store := setupPatternTestStore(t)
patternStore := NewPatternStore(store)
ctx := context.Background()
// Create and store pattern
pattern := createTestPattern("To Delete")
id, _ := patternStore.StorePattern(ctx, pattern)
// Delete pattern
err := patternStore.DeletePattern(ctx, id)
if err != nil {
t.Fatalf("DeletePattern() error = %v", err)
}
// Verify deletion
deleted, err := patternStore.GetPatternByID(ctx, id)
if err != sql.ErrNoRows {
t.Errorf("Expected ErrNoRows, got %v", err)
}
if deleted != nil {
t.Errorf("Expected nil for deleted pattern")
}
}
func TestPatternStore_MarkPatternDeprecated(t *testing.T) {
store := setupPatternTestStore(t)
patternStore := NewPatternStore(store)
ctx := context.Background()
// Create and store pattern
pattern := createTestPattern("To Deprecate")
id, _ := patternStore.StorePattern(ctx, pattern)
// Mark as deprecated
err := patternStore.MarkPatternDeprecated(ctx, id)
if err != nil {
t.Fatalf("MarkPatternDeprecated() error = %v", err)
}
// Verify status
deprecated, _ := patternStore.GetPatternByID(ctx, id)
if deprecated.Status != models.PatternStatusDeprecated {
t.Errorf("Expected status 'deprecated', got '%s'", deprecated.Status)
}
}
func TestPatternStore_MergePatterns(t *testing.T) {
store := setupPatternTestStore(t)
patternStore := NewPatternStore(store)
ctx := context.Background()
// Create source and target patterns
source := createTestPattern("Source Pattern")
source.Frequency = 3
source.Projects = []string{"proj1", "proj2"}
source.ObservationIDs = []int64{1, 2, 3}
target := createTestPattern("Target Pattern")
target.Frequency = 2
target.Projects = []string{"proj2", "proj3"}
target.ObservationIDs = []int64{4, 5}
sourceID, _ := patternStore.StorePattern(ctx, source)
targetID, _ := patternStore.StorePattern(ctx, target)
// Merge
err := patternStore.MergePatterns(ctx, sourceID, targetID)
if err != nil {
t.Fatalf("MergePatterns() error = %v", err)
}
// Verify source is marked as merged
mergedSource, _ := patternStore.GetPatternByID(ctx, sourceID)
if mergedSource.Status != models.PatternStatusMerged {
t.Errorf("Expected source status 'merged', got '%s'", mergedSource.Status)
}
if !mergedSource.MergedIntoID.Valid || mergedSource.MergedIntoID.Int64 != targetID {
t.Errorf("Expected source merged_into_id to be %d", targetID)
}
// Verify target has combined data
mergedTarget, _ := patternStore.GetPatternByID(ctx, targetID)
expectedFrequency := 5 // 3 + 2
if mergedTarget.Frequency != expectedFrequency {
t.Errorf("Expected merged frequency %d, got %d", expectedFrequency, mergedTarget.Frequency)
}
// Should have 3 unique projects: proj1, proj2, proj3
if len(mergedTarget.Projects) != 3 {
t.Errorf("Expected 3 projects after merge, got %d", len(mergedTarget.Projects))
}
// Should have 5 observation IDs
if len(mergedTarget.ObservationIDs) != 5 {
t.Errorf("Expected 5 observation IDs after merge, got %d", len(mergedTarget.ObservationIDs))
}
}
func TestPatternStore_FindMatchingPatterns(t *testing.T) {
store := setupPatternTestStore(t)
patternStore := NewPatternStore(store)
ctx := context.Background()
// Create patterns with known signatures
pattern1 := createTestPattern("Pattern 1")
pattern1.Signature = []string{"nil", "error", "handling"}
pattern2 := createTestPattern("Pattern 2")
pattern2.Signature = []string{"nil", "pointer", "check"}
pattern3 := createTestPattern("Pattern 3")
pattern3.Signature = []string{"refactor", "extract", "method"}
patternStore.StorePattern(ctx, pattern1)
patternStore.StorePattern(ctx, pattern2)
patternStore.StorePattern(ctx, pattern3)
// Find patterns matching "nil" related signature
matches, err := patternStore.FindMatchingPatterns(ctx, []string{"nil", "error"}, 0.3)
if err != nil {
t.Fatalf("FindMatchingPatterns() error = %v", err)
}
if len(matches) < 1 {
t.Errorf("Expected at least 1 match, got %d", len(matches))
}
// Verify no match for unrelated signature
noMatches, err := patternStore.FindMatchingPatterns(ctx, []string{"completely", "different"}, 0.5)
if err != nil {
t.Fatalf("FindMatchingPatterns() error = %v", err)
}
if len(noMatches) != 0 {
t.Errorf("Expected 0 matches for unrelated signature, got %d", len(noMatches))
}
}
func TestPatternStore_IncrementPatternFrequency(t *testing.T) {
store := setupPatternTestStore(t)
patternStore := NewPatternStore(store)
ctx := context.Background()
// Create pattern
pattern := createTestPattern("Frequency Test")
pattern.Frequency = 1
pattern.Projects = []string{"proj1"}
pattern.ObservationIDs = []int64{1}
id, _ := patternStore.StorePattern(ctx, pattern)
// Increment frequency
err := patternStore.IncrementPatternFrequency(ctx, id, "proj2", 2)
if err != nil {
t.Fatalf("IncrementPatternFrequency() error = %v", err)
}
// Verify
updated, _ := patternStore.GetPatternByID(ctx, id)
if updated.Frequency != 2 {
t.Errorf("Expected frequency 2, got %d", updated.Frequency)
}
if len(updated.Projects) != 2 {
t.Errorf("Expected 2 projects, got %d", len(updated.Projects))
}
if len(updated.ObservationIDs) != 2 {
t.Errorf("Expected 2 observation IDs, got %d", len(updated.ObservationIDs))
}
}
func TestPatternStore_GetPatternStats(t *testing.T) {
store := setupPatternTestStore(t)
patternStore := NewPatternStore(store)
ctx := context.Background()
// Create patterns with different types and statuses
bug := createTestPattern("Bug")
bug.Type = models.PatternTypeBug
bug.Frequency = 5
refactor := createTestPattern("Refactor")
refactor.Type = models.PatternTypeRefactor
refactor.Frequency = 3
deprecated := createTestPattern("Deprecated")
deprecated.Type = models.PatternTypeArchitecture
deprecated.Status = models.PatternStatusDeprecated
patternStore.StorePattern(ctx, bug)
patternStore.StorePattern(ctx, refactor)
patternStore.StorePattern(ctx, deprecated)
// Get stats
stats, err := patternStore.GetPatternStats(ctx)
if err != nil {
t.Fatalf("GetPatternStats() error = %v", err)
}
if stats.Total != 3 {
t.Errorf("Expected total 3, got %d", stats.Total)
}
if stats.Active != 2 {
t.Errorf("Expected 2 active, got %d", stats.Active)
}
if stats.Deprecated != 1 {
t.Errorf("Expected 1 deprecated, got %d", stats.Deprecated)
}
if stats.Bugs != 1 {
t.Errorf("Expected 1 bug, got %d", stats.Bugs)
}
if stats.Refactors != 1 {
t.Errorf("Expected 1 refactor, got %d", stats.Refactors)
}
if stats.TotalOccurrences != 9 { // 5 + 3 + 1
t.Errorf("Expected 9 total occurrences, got %d", stats.TotalOccurrences)
}
}
func TestPatternStore_CleanupCallback(t *testing.T) {
store := setupPatternTestStore(t)
patternStore := NewPatternStore(store)
ctx := context.Background()
var deletedIDs []int64
patternStore.SetCleanupFunc(func(ctx context.Context, ids []int64) {
deletedIDs = ids
})
// Create and delete pattern
pattern := createTestPattern("Cleanup Test")
id, _ := patternStore.StorePattern(ctx, pattern)
patternStore.DeletePattern(ctx, id)
if len(deletedIDs) != 1 || deletedIDs[0] != id {
t.Errorf("Expected cleanup callback with ID %d, got %v", id, deletedIDs)
}
}
// Helper function to create a test pattern
func createTestPattern(name string) *models.Pattern {
now := time.Now()
return &models.Pattern{
Name: name,
Type: models.PatternTypeBug,
Description: sql.NullString{String: "Test description", Valid: true},
Signature: []string{"test", "pattern"},
Recommendation: sql.NullString{String: "Test recommendation", Valid: true},
Frequency: 1,
Projects: []string{"test-project"},
ObservationIDs: []int64{1},
Status: models.PatternStatusActive,
Confidence: 0.5,
LastSeenAt: now.Format(time.RFC3339),
LastSeenEpoch: now.UnixMilli(),
CreatedAt: now.Format(time.RFC3339),
CreatedAtEpoch: now.UnixMilli(),
}
}
+22
View File
@@ -199,6 +199,28 @@ func (s *PromptStore) GetAllRecentUserPrompts(ctx context.Context, limit int) ([
return scanPromptWithSessionRows(rows) return scanPromptWithSessionRows(rows)
} }
// GetAllPrompts retrieves all user prompts (for vector rebuild).
func (s *PromptStore) GetAllPrompts(ctx context.Context) ([]*models.UserPromptWithSession, error) {
const query = `
SELECT up.id, up.claude_session_id, up.prompt_number, up.prompt_text,
COALESCE(up.matched_observations, 0) as matched_observations,
up.created_at, up.created_at_epoch,
COALESCE(s.project, '') as project,
COALESCE(s.sdk_session_id, '') as sdk_session_id
FROM user_prompts up
LEFT JOIN sdk_sessions s ON up.claude_session_id = s.claude_session_id
ORDER BY up.id
`
rows, err := s.store.QueryContext(ctx, query)
if err != nil {
return nil, err
}
defer rows.Close()
return scanPromptWithSessionRows(rows)
}
// FindRecentPromptByText finds a prompt with the same text for a session within the last few seconds. // FindRecentPromptByText finds a prompt with the same text for a session within the last few seconds.
// This is used to detect duplicate hook invocations. // This is used to detect duplicate hook invocations.
// Returns (promptID, promptNumber, found). // Returns (promptID, promptNumber, found).
+377
View File
@@ -0,0 +1,377 @@
// Package sqlite provides SQLite database operations for claude-mnemonic.
package sqlite
import (
"context"
"database/sql"
"github.com/lukaszraczylo/claude-mnemonic/pkg/models"
)
// RelationStore provides relation-related database operations.
type RelationStore struct {
store *Store
}
// NewRelationStore creates a new relation store.
func NewRelationStore(store *Store) *RelationStore {
return &RelationStore{store: store}
}
// StoreRelation stores a new observation relation.
// Uses INSERT OR IGNORE to handle duplicate (source_id, target_id, relation_type) combinations.
func (s *RelationStore) StoreRelation(ctx context.Context, relation *models.ObservationRelation) (int64, error) {
const query = `
INSERT OR IGNORE INTO observation_relations
(source_id, target_id, relation_type, confidence, detection_source, reason, created_at, created_at_epoch)
VALUES (?, ?, ?, ?, ?, ?, ?, ?)
`
result, err := s.store.ExecContext(ctx, query,
relation.SourceID, relation.TargetID,
string(relation.RelationType), relation.Confidence,
string(relation.DetectionSource), relation.Reason,
relation.CreatedAt, relation.CreatedAtEpoch,
)
if err != nil {
return 0, err
}
return result.LastInsertId()
}
// StoreRelations stores multiple relations in a single transaction.
func (s *RelationStore) StoreRelations(ctx context.Context, relations []*models.ObservationRelation) error {
if len(relations) == 0 {
return nil
}
tx, err := s.store.db.BeginTx(ctx, nil)
if err != nil {
return err
}
defer func() {
if err != nil {
_ = tx.Rollback()
}
}()
const query = `
INSERT OR IGNORE INTO observation_relations
(source_id, target_id, relation_type, confidence, detection_source, reason, created_at, created_at_epoch)
VALUES (?, ?, ?, ?, ?, ?, ?, ?)
`
stmt, err := tx.PrepareContext(ctx, query)
if err != nil {
return err
}
defer stmt.Close()
for _, rel := range relations {
_, err = stmt.ExecContext(ctx,
rel.SourceID, rel.TargetID,
string(rel.RelationType), rel.Confidence,
string(rel.DetectionSource), rel.Reason,
rel.CreatedAt, rel.CreatedAtEpoch,
)
if err != nil {
return err
}
}
return tx.Commit()
}
// GetRelationsByObservationID retrieves all relations involving an observation (as source or target).
func (s *RelationStore) GetRelationsByObservationID(ctx context.Context, obsID int64) ([]*models.ObservationRelation, error) {
const query = `
SELECT id, source_id, target_id, relation_type, confidence, detection_source, reason,
created_at, created_at_epoch
FROM observation_relations
WHERE source_id = ? OR target_id = ?
ORDER BY confidence DESC, created_at_epoch DESC
`
rows, err := s.store.QueryContext(ctx, query, obsID, obsID)
if err != nil {
return nil, err
}
defer rows.Close()
return s.scanRelationRows(rows)
}
// GetOutgoingRelations retrieves relations where the observation is the source.
func (s *RelationStore) GetOutgoingRelations(ctx context.Context, obsID int64) ([]*models.ObservationRelation, error) {
const query = `
SELECT id, source_id, target_id, relation_type, confidence, detection_source, reason,
created_at, created_at_epoch
FROM observation_relations
WHERE source_id = ?
ORDER BY confidence DESC, created_at_epoch DESC
`
rows, err := s.store.QueryContext(ctx, query, obsID)
if err != nil {
return nil, err
}
defer rows.Close()
return s.scanRelationRows(rows)
}
// GetIncomingRelations retrieves relations where the observation is the target.
func (s *RelationStore) GetIncomingRelations(ctx context.Context, obsID int64) ([]*models.ObservationRelation, error) {
const query = `
SELECT id, source_id, target_id, relation_type, confidence, detection_source, reason,
created_at, created_at_epoch
FROM observation_relations
WHERE target_id = ?
ORDER BY confidence DESC, created_at_epoch DESC
`
rows, err := s.store.QueryContext(ctx, query, obsID)
if err != nil {
return nil, err
}
defer rows.Close()
return s.scanRelationRows(rows)
}
// GetRelationsByType retrieves all relations of a specific type.
func (s *RelationStore) GetRelationsByType(ctx context.Context, relationType models.RelationType, limit int) ([]*models.ObservationRelation, error) {
const query = `
SELECT id, source_id, target_id, relation_type, confidence, detection_source, reason,
created_at, created_at_epoch
FROM observation_relations
WHERE relation_type = ?
ORDER BY confidence DESC, created_at_epoch DESC
LIMIT ?
`
rows, err := s.store.QueryContext(ctx, query, string(relationType), limit)
if err != nil {
return nil, err
}
defer rows.Close()
return s.scanRelationRows(rows)
}
// GetRelationsWithDetails retrieves relations with observation titles for display.
func (s *RelationStore) GetRelationsWithDetails(ctx context.Context, obsID int64) ([]*models.RelationWithDetails, error) {
const query = `
SELECT r.id, r.source_id, r.target_id, r.relation_type, r.confidence, r.detection_source, r.reason,
r.created_at, r.created_at_epoch,
COALESCE(src.title, '') as source_title,
COALESCE(tgt.title, '') as target_title,
src.type as source_type,
tgt.type as target_type
FROM observation_relations r
JOIN observations src ON src.id = r.source_id
JOIN observations tgt ON tgt.id = r.target_id
WHERE r.source_id = ? OR r.target_id = ?
ORDER BY r.confidence DESC, r.created_at_epoch DESC
`
rows, err := s.store.QueryContext(ctx, query, obsID, obsID)
if err != nil {
return nil, err
}
defer rows.Close()
var results []*models.RelationWithDetails
for rows.Next() {
var r models.ObservationRelation
var rwd models.RelationWithDetails
var reason sql.NullString
if err := rows.Scan(
&r.ID, &r.SourceID, &r.TargetID,
&r.RelationType, &r.Confidence, &r.DetectionSource, &reason,
&r.CreatedAt, &r.CreatedAtEpoch,
&rwd.SourceTitle, &rwd.TargetTitle,
&rwd.SourceType, &rwd.TargetType,
); err != nil {
return nil, err
}
if reason.Valid {
r.Reason = reason.String
}
rwd.Relation = &r
results = append(results, &rwd)
}
return results, rows.Err()
}
// GetRelationGraph retrieves a relation graph centered on an observation.
// This returns all observations within N hops from the center.
func (s *RelationStore) GetRelationGraph(ctx context.Context, centerID int64, maxDepth int) (*models.RelationGraph, error) {
// Get all relations involving the center observation
relations, err := s.GetRelationsWithDetails(ctx, centerID)
if err != nil {
return nil, err
}
graph := &models.RelationGraph{
CenterID: centerID,
Relations: relations,
}
// If depth > 1, recursively get relations for connected observations
if maxDepth > 1 {
visited := map[int64]bool{centerID: true}
toVisit := make([]int64, 0)
// Collect IDs of directly connected observations
for _, r := range relations {
if !visited[r.Relation.SourceID] {
toVisit = append(toVisit, r.Relation.SourceID)
visited[r.Relation.SourceID] = true
}
if !visited[r.Relation.TargetID] {
toVisit = append(toVisit, r.Relation.TargetID)
visited[r.Relation.TargetID] = true
}
}
// Get relations for connected observations (depth - 1)
for depth := 1; depth < maxDepth && len(toVisit) > 0; depth++ {
nextLevel := make([]int64, 0)
for _, obsID := range toVisit {
moreRelations, err := s.GetRelationsWithDetails(ctx, obsID)
if err != nil {
continue
}
for _, r := range moreRelations {
// Avoid duplicates
exists := false
for _, existing := range graph.Relations {
if existing.Relation.ID == r.Relation.ID {
exists = true
break
}
}
if !exists {
graph.Relations = append(graph.Relations, r)
}
// Queue next level
if !visited[r.Relation.SourceID] {
nextLevel = append(nextLevel, r.Relation.SourceID)
visited[r.Relation.SourceID] = true
}
if !visited[r.Relation.TargetID] {
nextLevel = append(nextLevel, r.Relation.TargetID)
visited[r.Relation.TargetID] = true
}
}
}
toVisit = nextLevel
}
}
return graph, nil
}
// DeleteRelationsByObservationID deletes all relations involving an observation.
// Called when an observation is deleted.
func (s *RelationStore) DeleteRelationsByObservationID(ctx context.Context, obsID int64) error {
const query = `DELETE FROM observation_relations WHERE source_id = ? OR target_id = ?`
_, err := s.store.ExecContext(ctx, query, obsID, obsID)
return err
}
// GetRelationCount returns the count of relations for an observation.
func (s *RelationStore) GetRelationCount(ctx context.Context, obsID int64) (int, error) {
const query = `
SELECT COUNT(*) FROM observation_relations
WHERE source_id = ? OR target_id = ?
`
var count int
err := s.store.QueryRowContext(ctx, query, obsID, obsID).Scan(&count)
return count, err
}
// GetTotalRelationCount returns the total count of all relations.
func (s *RelationStore) GetTotalRelationCount(ctx context.Context) (int, error) {
const query = `SELECT COUNT(*) FROM observation_relations`
var count int
err := s.store.QueryRowContext(ctx, query).Scan(&count)
return count, err
}
// GetHighConfidenceRelations retrieves relations with confidence above threshold.
func (s *RelationStore) GetHighConfidenceRelations(ctx context.Context, minConfidence float64, limit int) ([]*models.ObservationRelation, error) {
const query = `
SELECT id, source_id, target_id, relation_type, confidence, detection_source, reason,
created_at, created_at_epoch
FROM observation_relations
WHERE confidence >= ?
ORDER BY confidence DESC, created_at_epoch DESC
LIMIT ?
`
rows, err := s.store.QueryContext(ctx, query, minConfidence, limit)
if err != nil {
return nil, err
}
defer rows.Close()
return s.scanRelationRows(rows)
}
// UpdateRelationConfidence updates the confidence of a relation.
func (s *RelationStore) UpdateRelationConfidence(ctx context.Context, relationID int64, newConfidence float64) error {
const query = `UPDATE observation_relations SET confidence = ? WHERE id = ?`
_, err := s.store.ExecContext(ctx, query, newConfidence, relationID)
return err
}
// GetRelatedObservationIDs returns IDs of observations related to the given one.
// This is useful for expanding search results.
func (s *RelationStore) GetRelatedObservationIDs(ctx context.Context, obsID int64, minConfidence float64) ([]int64, error) {
const query = `
SELECT DISTINCT CASE WHEN source_id = ? THEN target_id ELSE source_id END as related_id
FROM observation_relations
WHERE (source_id = ? OR target_id = ?) AND confidence >= ?
`
rows, err := s.store.QueryContext(ctx, query, obsID, obsID, obsID, minConfidence)
if err != nil {
return nil, err
}
defer rows.Close()
var ids []int64
for rows.Next() {
var id int64
if err := rows.Scan(&id); err != nil {
return nil, err
}
ids = append(ids, id)
}
return ids, rows.Err()
}
// scanRelationRows scans multiple relations from rows.
func (s *RelationStore) scanRelationRows(rows *sql.Rows) ([]*models.ObservationRelation, error) {
var relations []*models.ObservationRelation
for rows.Next() {
var r models.ObservationRelation
var reason sql.NullString
if err := rows.Scan(
&r.ID, &r.SourceID, &r.TargetID,
&r.RelationType, &r.Confidence, &r.DetectionSource, &reason,
&r.CreatedAt, &r.CreatedAtEpoch,
); err != nil {
return nil, err
}
if reason.Valid {
r.Reason = reason.String
}
relations = append(relations, &r)
}
return relations, rows.Err()
}
+324
View File
@@ -0,0 +1,324 @@
// Package sqlite provides SQLite database operations for claude-mnemonic.
package sqlite
import (
"context"
"database/sql"
"time"
"github.com/lukaszraczylo/claude-mnemonic/pkg/models"
)
// UpdateObservationFeedback updates the user feedback for an observation.
// Feedback values: -1 (thumbs down), 0 (neutral), 1 (thumbs up).
func (s *ObservationStore) UpdateObservationFeedback(ctx context.Context, id int64, feedback int) error {
const query = `
UPDATE observations
SET user_feedback = ?, score_updated_at_epoch = ?
WHERE id = ?
`
_, err := s.store.ExecContext(ctx, query, feedback, time.Now().UnixMilli(), id)
return err
}
// IncrementRetrievalCount increments the retrieval counter for the given observation IDs.
// This is called when observations are returned in search results.
func (s *ObservationStore) IncrementRetrievalCount(ctx context.Context, ids []int64) error {
if len(ids) == 0 {
return nil
}
now := time.Now().UnixMilli()
// Build query with placeholders
// #nosec G202 -- query uses parameterized placeholders, not user input
query := `
UPDATE observations
SET retrieval_count = COALESCE(retrieval_count, 0) + 1,
last_retrieved_at_epoch = ?
WHERE id IN (?` + repeatPlaceholders(len(ids)-1) + `)
`
args := make([]interface{}, 0, len(ids)+1)
args = append(args, now)
for _, id := range ids {
args = append(args, id)
}
_, err := s.store.db.ExecContext(ctx, query, args...)
return err
}
// UpdateImportanceScore updates the importance score for a single observation.
func (s *ObservationStore) UpdateImportanceScore(ctx context.Context, id int64, score float64) error {
const query = `
UPDATE observations
SET importance_score = ?, score_updated_at_epoch = ?
WHERE id = ?
`
_, err := s.store.ExecContext(ctx, query, score, time.Now().UnixMilli(), id)
return err
}
// UpdateImportanceScores bulk updates importance scores for multiple observations.
// This is more efficient than individual updates for batch recalculation.
func (s *ObservationStore) UpdateImportanceScores(ctx context.Context, scores map[int64]float64) error {
if len(scores) == 0 {
return nil
}
tx, err := s.store.db.BeginTx(ctx, nil)
if err != nil {
return err
}
defer tx.Rollback()
now := time.Now().UnixMilli()
stmt, err := tx.PrepareContext(ctx, `
UPDATE observations
SET importance_score = ?, score_updated_at_epoch = ?
WHERE id = ?
`)
if err != nil {
return err
}
defer stmt.Close()
for id, score := range scores {
if _, err := stmt.ExecContext(ctx, score, now, id); err != nil {
return err
}
}
return tx.Commit()
}
// GetObservationsNeedingScoreUpdate returns observations that need their importance score recalculated.
// Returns observations where score_updated_at_epoch is NULL or older than the threshold.
func (s *ObservationStore) GetObservationsNeedingScoreUpdate(ctx context.Context, threshold time.Duration, limit int) ([]*models.Observation, error) {
cutoff := time.Now().Add(-threshold).UnixMilli()
query := `SELECT ` + observationColumns + `
FROM observations
WHERE score_updated_at_epoch IS NULL OR score_updated_at_epoch < ?
ORDER BY created_at_epoch DESC
LIMIT ?
`
rows, err := s.store.QueryContext(ctx, query, cutoff, limit)
if err != nil {
return nil, err
}
defer rows.Close()
return scanObservationRows(rows)
}
// GetConceptWeights returns all concept weights from the database.
func (s *ObservationStore) GetConceptWeights(ctx context.Context) (map[string]float64, error) {
const query = `SELECT concept, weight FROM concept_weights`
rows, err := s.store.QueryContext(ctx, query)
if err != nil {
// Table might not exist in older databases
if err == sql.ErrNoRows {
return models.DefaultConceptWeights, nil
}
return nil, err
}
defer rows.Close()
weights := make(map[string]float64)
for rows.Next() {
var concept string
var weight float64
if err := rows.Scan(&concept, &weight); err != nil {
return nil, err
}
weights[concept] = weight
}
if err := rows.Err(); err != nil {
return nil, err
}
// If no weights found, use defaults
if len(weights) == 0 {
return models.DefaultConceptWeights, nil
}
return weights, nil
}
// UpdateConceptWeight updates a single concept weight.
func (s *ObservationStore) UpdateConceptWeight(ctx context.Context, concept string, weight float64) error {
const query = `
INSERT INTO concept_weights (concept, weight, updated_at)
VALUES (?, ?, datetime('now'))
ON CONFLICT(concept) DO UPDATE SET weight = excluded.weight, updated_at = excluded.updated_at
`
_, err := s.store.ExecContext(ctx, query, concept, weight)
return err
}
// UpdateConceptWeights bulk updates multiple concept weights.
func (s *ObservationStore) UpdateConceptWeights(ctx context.Context, weights map[string]float64) error {
if len(weights) == 0 {
return nil
}
tx, err := s.store.db.BeginTx(ctx, nil)
if err != nil {
return err
}
defer tx.Rollback()
stmt, err := tx.PrepareContext(ctx, `
INSERT INTO concept_weights (concept, weight, updated_at)
VALUES (?, ?, datetime('now'))
ON CONFLICT(concept) DO UPDATE SET weight = excluded.weight, updated_at = excluded.updated_at
`)
if err != nil {
return err
}
defer stmt.Close()
for concept, weight := range weights {
if _, err := stmt.ExecContext(ctx, concept, weight); err != nil {
return err
}
}
return tx.Commit()
}
// GetObservationFeedbackStats returns statistics about user feedback.
func (s *ObservationStore) GetObservationFeedbackStats(ctx context.Context, project string) (*FeedbackStats, error) {
var query string
var args []interface{}
if project == "" {
query = `
SELECT
COUNT(*) as total,
COALESCE(SUM(CASE WHEN user_feedback = 1 THEN 1 ELSE 0 END), 0) as positive,
COALESCE(SUM(CASE WHEN user_feedback = -1 THEN 1 ELSE 0 END), 0) as negative,
COALESCE(SUM(CASE WHEN user_feedback = 0 THEN 1 ELSE 0 END), 0) as neutral,
COALESCE(AVG(COALESCE(importance_score, 1.0)), 0) as avg_score,
COALESCE(AVG(COALESCE(retrieval_count, 0)), 0) as avg_retrieval
FROM observations
`
} else {
query = `
SELECT
COUNT(*) as total,
COALESCE(SUM(CASE WHEN user_feedback = 1 THEN 1 ELSE 0 END), 0) as positive,
COALESCE(SUM(CASE WHEN user_feedback = -1 THEN 1 ELSE 0 END), 0) as negative,
COALESCE(SUM(CASE WHEN user_feedback = 0 THEN 1 ELSE 0 END), 0) as neutral,
COALESCE(AVG(COALESCE(importance_score, 1.0)), 0) as avg_score,
COALESCE(AVG(COALESCE(retrieval_count, 0)), 0) as avg_retrieval
FROM observations
WHERE project = ? OR scope = 'global'
`
args = append(args, project)
}
var stats FeedbackStats
err := s.store.QueryRowContext(ctx, query, args...).Scan(
&stats.Total,
&stats.Positive,
&stats.Negative,
&stats.Neutral,
&stats.AvgScore,
&stats.AvgRetrieval,
)
if err != nil {
return nil, err
}
return &stats, nil
}
// FeedbackStats contains statistics about observation feedback and scoring.
type FeedbackStats struct {
Total int `json:"total"`
Positive int `json:"positive"`
Negative int `json:"negative"`
Neutral int `json:"neutral"`
AvgScore float64 `json:"avg_score"`
AvgRetrieval float64 `json:"avg_retrieval"`
}
// GetTopScoringObservations returns the highest-scoring observations.
func (s *ObservationStore) GetTopScoringObservations(ctx context.Context, project string, limit int) ([]*models.Observation, error) {
var query string
var args []interface{}
if project == "" {
query = `SELECT ` + observationColumns + `
FROM observations
ORDER BY COALESCE(importance_score, 1.0) DESC, created_at_epoch DESC
LIMIT ?
`
args = append(args, limit)
} else {
query = `SELECT ` + observationColumns + `
FROM observations
WHERE project = ? OR scope = 'global'
ORDER BY COALESCE(importance_score, 1.0) DESC, created_at_epoch DESC
LIMIT ?
`
args = append(args, project, limit)
}
rows, err := s.store.QueryContext(ctx, query, args...)
if err != nil {
return nil, err
}
defer rows.Close()
return scanObservationRows(rows)
}
// GetMostRetrievedObservations returns the most frequently retrieved observations.
func (s *ObservationStore) GetMostRetrievedObservations(ctx context.Context, project string, limit int) ([]*models.Observation, error) {
var query string
var args []interface{}
if project == "" {
query = `SELECT ` + observationColumns + `
FROM observations
WHERE retrieval_count > 0
ORDER BY retrieval_count DESC, created_at_epoch DESC
LIMIT ?
`
args = append(args, limit)
} else {
query = `SELECT ` + observationColumns + `
FROM observations
WHERE (project = ? OR scope = 'global') AND retrieval_count > 0
ORDER BY retrieval_count DESC, created_at_epoch DESC
LIMIT ?
`
args = append(args, project, limit)
}
rows, err := s.store.QueryContext(ctx, query, args...)
if err != nil {
return nil, err
}
defer rows.Close()
return scanObservationRows(rows)
}
// ResetObservationScores resets all observation scores to their default values.
// This is useful for testing or when changing the scoring algorithm.
func (s *ObservationStore) ResetObservationScores(ctx context.Context) error {
const query = `
UPDATE observations
SET importance_score = 1.0, score_updated_at_epoch = NULL
`
_, err := s.store.ExecContext(ctx, query)
return err
}
+698
View File
@@ -0,0 +1,698 @@
// Package sqlite provides SQLite database operations for claude-mnemonic.
package sqlite
import (
"context"
"database/sql"
"testing"
"time"
"github.com/lukaszraczylo/claude-mnemonic/pkg/models"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"github.com/stretchr/testify/suite"
)
// testScoringObservationStore creates an ObservationStore with scoring columns for testing.
func testScoringObservationStore(t *testing.T) (*ObservationStore, *Store, func()) {
t.Helper()
db, _, cleanup := testDB(t)
createBaseTables(t, db)
createConceptWeightsTable(t, db)
// Add importance index if not exists (columns already in createBaseTables)
if _, err := db.Exec(`CREATE INDEX IF NOT EXISTS idx_observations_importance ON observations(importance_score DESC, created_at_epoch DESC)`); err != nil {
t.Fatalf("create importance index: %v", err)
}
store := newStoreFromDB(db)
obsStore := NewObservationStore(store)
return obsStore, store, cleanup
}
// createConceptWeightsTable creates the concept_weights table for testing.
func createConceptWeightsTable(t *testing.T, db *sql.DB) {
t.Helper()
_, err := db.Exec(`
CREATE TABLE IF NOT EXISTS concept_weights (
concept TEXT PRIMARY KEY,
weight REAL NOT NULL DEFAULT 0.1,
updated_at TEXT NOT NULL
)
`)
if err != nil {
t.Fatalf("create concept_weights: %v", err)
}
}
// ScoringStoreSuite is a test suite for scoring-related database operations.
type ScoringStoreSuite struct {
suite.Suite
obsStore *ObservationStore
store *Store
cleanup func()
ctx context.Context
}
func (s *ScoringStoreSuite) SetupTest() {
s.obsStore, s.store, s.cleanup = testScoringObservationStore(s.T())
s.ctx = context.Background()
}
func (s *ScoringStoreSuite) TearDownTest() {
if s.cleanup != nil {
s.cleanup()
}
}
func TestScoringStoreSuite(t *testing.T) {
suite.Run(t, new(ScoringStoreSuite))
}
// =============================================================================
// FEEDBACK TESTS
// =============================================================================
func (s *ScoringStoreSuite) TestUpdateObservationFeedback_Positive() {
// Create observation
obs := &models.ParsedObservation{
Type: models.ObsTypeBugfix,
Title: "Test feedback",
}
id, _, err := s.obsStore.StoreObservation(s.ctx, "session-1", "project-a", obs, 1, 100)
s.NoError(err)
// Update feedback to positive
err = s.obsStore.UpdateObservationFeedback(s.ctx, id, 1)
s.NoError(err)
// Verify
retrieved, err := s.obsStore.GetObservationByID(s.ctx, id)
s.NoError(err)
s.Equal(1, retrieved.UserFeedback)
s.True(retrieved.ScoreUpdatedAt.Valid)
}
func (s *ScoringStoreSuite) TestUpdateObservationFeedback_Negative() {
obs := &models.ParsedObservation{
Type: models.ObsTypeDiscovery,
}
id, _, err := s.obsStore.StoreObservation(s.ctx, "session-1", "project-a", obs, 1, 100)
s.NoError(err)
err = s.obsStore.UpdateObservationFeedback(s.ctx, id, -1)
s.NoError(err)
retrieved, err := s.obsStore.GetObservationByID(s.ctx, id)
s.NoError(err)
s.Equal(-1, retrieved.UserFeedback)
}
func (s *ScoringStoreSuite) TestUpdateObservationFeedback_Neutral() {
obs := &models.ParsedObservation{
Type: models.ObsTypeChange,
}
id, _, err := s.obsStore.StoreObservation(s.ctx, "session-1", "project-a", obs, 1, 100)
s.NoError(err)
// First set to positive
err = s.obsStore.UpdateObservationFeedback(s.ctx, id, 1)
s.NoError(err)
// Then reset to neutral
err = s.obsStore.UpdateObservationFeedback(s.ctx, id, 0)
s.NoError(err)
retrieved, err := s.obsStore.GetObservationByID(s.ctx, id)
s.NoError(err)
s.Equal(0, retrieved.UserFeedback)
}
func (s *ScoringStoreSuite) TestUpdateObservationFeedback_NonExistent() {
// Updating non-existent observation should not fail (just no rows affected)
err := s.obsStore.UpdateObservationFeedback(s.ctx, 99999, 1)
s.NoError(err)
}
// =============================================================================
// RETRIEVAL COUNT TESTS
// =============================================================================
func (s *ScoringStoreSuite) TestIncrementRetrievalCount_Single() {
obs := &models.ParsedObservation{
Type: models.ObsTypeBugfix,
}
id, _, err := s.obsStore.StoreObservation(s.ctx, "session-1", "project-a", obs, 1, 100)
s.NoError(err)
err = s.obsStore.IncrementRetrievalCount(s.ctx, []int64{id})
s.NoError(err)
retrieved, err := s.obsStore.GetObservationByID(s.ctx, id)
s.NoError(err)
s.Equal(1, retrieved.RetrievalCount)
s.True(retrieved.LastRetrievedAt.Valid)
}
func (s *ScoringStoreSuite) TestIncrementRetrievalCount_Multiple() {
var ids []int64
for i := 0; i < 3; i++ {
obs := &models.ParsedObservation{
Type: models.ObsTypeDiscovery,
}
id, _, err := s.obsStore.StoreObservation(s.ctx, "session-1", "project-a", obs, i, 100)
s.NoError(err)
ids = append(ids, id)
}
err := s.obsStore.IncrementRetrievalCount(s.ctx, ids)
s.NoError(err)
for _, id := range ids {
retrieved, err := s.obsStore.GetObservationByID(s.ctx, id)
s.NoError(err)
s.Equal(1, retrieved.RetrievalCount)
}
}
func (s *ScoringStoreSuite) TestIncrementRetrievalCount_Cumulative() {
obs := &models.ParsedObservation{
Type: models.ObsTypeBugfix,
}
id, _, err := s.obsStore.StoreObservation(s.ctx, "session-1", "project-a", obs, 1, 100)
s.NoError(err)
// Increment multiple times
for i := 0; i < 5; i++ {
err = s.obsStore.IncrementRetrievalCount(s.ctx, []int64{id})
s.NoError(err)
}
retrieved, err := s.obsStore.GetObservationByID(s.ctx, id)
s.NoError(err)
s.Equal(5, retrieved.RetrievalCount)
}
func (s *ScoringStoreSuite) TestIncrementRetrievalCount_Empty() {
err := s.obsStore.IncrementRetrievalCount(s.ctx, []int64{})
s.NoError(err)
}
// =============================================================================
// IMPORTANCE SCORE TESTS
// =============================================================================
func (s *ScoringStoreSuite) TestUpdateImportanceScore_Single() {
obs := &models.ParsedObservation{
Type: models.ObsTypeBugfix,
}
id, _, err := s.obsStore.StoreObservation(s.ctx, "session-1", "project-a", obs, 1, 100)
s.NoError(err)
err = s.obsStore.UpdateImportanceScore(s.ctx, id, 1.5)
s.NoError(err)
retrieved, err := s.obsStore.GetObservationByID(s.ctx, id)
s.NoError(err)
s.InDelta(1.5, retrieved.ImportanceScore, 0.001)
}
func (s *ScoringStoreSuite) TestUpdateImportanceScores_Batch() {
var ids []int64
for i := 0; i < 5; i++ {
obs := &models.ParsedObservation{
Type: models.ObsTypeDiscovery,
}
id, _, err := s.obsStore.StoreObservation(s.ctx, "session-1", "project-a", obs, i, 100)
s.NoError(err)
ids = append(ids, id)
}
scores := map[int64]float64{
ids[0]: 1.5,
ids[1]: 0.8,
ids[2]: 1.2,
ids[3]: 0.5,
ids[4]: 2.0,
}
err := s.obsStore.UpdateImportanceScores(s.ctx, scores)
s.NoError(err)
for id, expectedScore := range scores {
retrieved, err := s.obsStore.GetObservationByID(s.ctx, id)
s.NoError(err)
s.InDelta(expectedScore, retrieved.ImportanceScore, 0.001)
}
}
func (s *ScoringStoreSuite) TestUpdateImportanceScores_Empty() {
err := s.obsStore.UpdateImportanceScores(s.ctx, map[int64]float64{})
s.NoError(err)
}
// =============================================================================
// OBSERVATIONS NEEDING SCORE UPDATE TESTS
// =============================================================================
func (s *ScoringStoreSuite) TestGetObservationsNeedingScoreUpdate_NeverUpdated() {
// Observations without score_updated_at_epoch should need update
obs := &models.ParsedObservation{
Type: models.ObsTypeBugfix,
}
id, _, err := s.obsStore.StoreObservation(s.ctx, "session-1", "project-a", obs, 1, 100)
s.NoError(err)
observations, err := s.obsStore.GetObservationsNeedingScoreUpdate(s.ctx, 6*time.Hour, 100)
s.NoError(err)
s.Len(observations, 1)
s.Equal(id, observations[0].ID)
}
func (s *ScoringStoreSuite) TestGetObservationsNeedingScoreUpdate_RecentlyUpdated() {
obs := &models.ParsedObservation{
Type: models.ObsTypeBugfix,
}
id, _, err := s.obsStore.StoreObservation(s.ctx, "session-1", "project-a", obs, 1, 100)
s.NoError(err)
// Update score (this sets score_updated_at_epoch)
err = s.obsStore.UpdateImportanceScore(s.ctx, id, 1.5)
s.NoError(err)
// Should not need update (just updated)
observations, err := s.obsStore.GetObservationsNeedingScoreUpdate(s.ctx, 6*time.Hour, 100)
s.NoError(err)
s.Empty(observations)
}
func (s *ScoringStoreSuite) TestGetObservationsNeedingScoreUpdate_Limit() {
// Create 10 observations
for i := 0; i < 10; i++ {
obs := &models.ParsedObservation{
Type: models.ObsTypeDiscovery,
}
_, _, err := s.obsStore.StoreObservation(s.ctx, "session-1", "project-a", obs, i, 100)
s.NoError(err)
}
// Request only 5
observations, err := s.obsStore.GetObservationsNeedingScoreUpdate(s.ctx, 6*time.Hour, 5)
s.NoError(err)
s.Len(observations, 5)
}
// =============================================================================
// CONCEPT WEIGHTS TESTS
// =============================================================================
func (s *ScoringStoreSuite) TestGetConceptWeights_Empty() {
weights, err := s.obsStore.GetConceptWeights(s.ctx)
s.NoError(err)
s.Equal(models.DefaultConceptWeights, weights)
}
func (s *ScoringStoreSuite) TestUpdateConceptWeight_NewConcept() {
err := s.obsStore.UpdateConceptWeight(s.ctx, "new-concept", 0.42)
s.NoError(err)
weights, err := s.obsStore.GetConceptWeights(s.ctx)
s.NoError(err)
s.Equal(0.42, weights["new-concept"])
}
func (s *ScoringStoreSuite) TestUpdateConceptWeight_UpdateExisting() {
// Insert first
err := s.obsStore.UpdateConceptWeight(s.ctx, "test-concept", 0.1)
s.NoError(err)
// Update
err = s.obsStore.UpdateConceptWeight(s.ctx, "test-concept", 0.9)
s.NoError(err)
weights, err := s.obsStore.GetConceptWeights(s.ctx)
s.NoError(err)
s.Equal(0.9, weights["test-concept"])
}
func (s *ScoringStoreSuite) TestUpdateConceptWeights_Batch() {
weightsToSet := map[string]float64{
"security": 0.5,
"performance": 0.3,
"testing": 0.2,
}
err := s.obsStore.UpdateConceptWeights(s.ctx, weightsToSet)
s.NoError(err)
retrieved, err := s.obsStore.GetConceptWeights(s.ctx)
s.NoError(err)
for concept, expected := range weightsToSet {
s.Equal(expected, retrieved[concept])
}
}
func (s *ScoringStoreSuite) TestUpdateConceptWeights_Empty() {
err := s.obsStore.UpdateConceptWeights(s.ctx, map[string]float64{})
s.NoError(err)
}
// =============================================================================
// FEEDBACK STATS TESTS
// =============================================================================
func (s *ScoringStoreSuite) TestGetObservationFeedbackStats_Empty() {
stats, err := s.obsStore.GetObservationFeedbackStats(s.ctx, "")
s.NoError(err)
s.Equal(0, stats.Total)
s.Equal(0, stats.Positive)
s.Equal(0, stats.Negative)
s.Equal(0, stats.Neutral)
}
func (s *ScoringStoreSuite) TestGetObservationFeedbackStats_WithData() {
// Create observations with different feedback
feedbacks := []int{1, 1, 1, -1, -1, 0, 0, 0, 0, 0}
for i, fb := range feedbacks {
obs := &models.ParsedObservation{
Type: models.ObsTypeDiscovery,
}
id, _, err := s.obsStore.StoreObservation(s.ctx, "session-1", "project-a", obs, i, 100)
s.NoError(err)
if fb != 0 {
err = s.obsStore.UpdateObservationFeedback(s.ctx, id, fb)
s.NoError(err)
}
}
stats, err := s.obsStore.GetObservationFeedbackStats(s.ctx, "")
s.NoError(err)
s.Equal(10, stats.Total)
s.Equal(3, stats.Positive)
s.Equal(2, stats.Negative)
s.Equal(5, stats.Neutral)
}
func (s *ScoringStoreSuite) TestGetObservationFeedbackStats_ByProject() {
// Project A observations
for i := 0; i < 5; i++ {
obs := &models.ParsedObservation{
Type: models.ObsTypeBugfix,
}
id, _, err := s.obsStore.StoreObservation(s.ctx, "session-1", "project-a", obs, i, 100)
s.NoError(err)
_ = s.obsStore.UpdateObservationFeedback(s.ctx, id, 1)
}
// Project B observations
for i := 0; i < 3; i++ {
obs := &models.ParsedObservation{
Type: models.ObsTypeFeature,
}
id, _, err := s.obsStore.StoreObservation(s.ctx, "session-1", "project-b", obs, i, 100)
s.NoError(err)
_ = s.obsStore.UpdateObservationFeedback(s.ctx, id, -1)
}
// Check project A stats
statsA, err := s.obsStore.GetObservationFeedbackStats(s.ctx, "project-a")
s.NoError(err)
s.Equal(5, statsA.Total)
s.Equal(5, statsA.Positive)
// Check project B stats
statsB, err := s.obsStore.GetObservationFeedbackStats(s.ctx, "project-b")
s.NoError(err)
s.Equal(3, statsB.Total)
s.Equal(3, statsB.Negative)
}
// =============================================================================
// TOP SCORING OBSERVATIONS TESTS
// =============================================================================
func (s *ScoringStoreSuite) TestGetTopScoringObservations() {
// Create observations with different scores
for i := 0; i < 5; i++ {
obs := &models.ParsedObservation{
Type: models.ObsTypeDiscovery,
}
id, _, err := s.obsStore.StoreObservation(s.ctx, "session-1", "project-a", obs, i, 100)
s.NoError(err)
// Set different scores
err = s.obsStore.UpdateImportanceScore(s.ctx, id, float64(i+1)*0.5)
s.NoError(err)
}
// Get top 3
top, err := s.obsStore.GetTopScoringObservations(s.ctx, "", 3)
s.NoError(err)
s.Len(top, 3)
// Verify ordered by score descending
s.GreaterOrEqual(top[0].ImportanceScore, top[1].ImportanceScore)
s.GreaterOrEqual(top[1].ImportanceScore, top[2].ImportanceScore)
}
func (s *ScoringStoreSuite) TestGetTopScoringObservations_ByProject() {
// Project A with high scores
for i := 0; i < 3; i++ {
obs := &models.ParsedObservation{
Type: models.ObsTypeBugfix,
}
id, _, err := s.obsStore.StoreObservation(s.ctx, "session-1", "project-a", obs, i, 100)
s.NoError(err)
_ = s.obsStore.UpdateImportanceScore(s.ctx, id, 2.0)
}
// Project B with low scores
for i := 0; i < 3; i++ {
obs := &models.ParsedObservation{
Type: models.ObsTypeChange,
}
id, _, err := s.obsStore.StoreObservation(s.ctx, "session-1", "project-b", obs, i, 100)
s.NoError(err)
_ = s.obsStore.UpdateImportanceScore(s.ctx, id, 0.5)
}
// Get top for project A
topA, err := s.obsStore.GetTopScoringObservations(s.ctx, "project-a", 10)
s.NoError(err)
s.Len(topA, 3)
for _, obs := range topA {
s.Equal("project-a", obs.Project)
}
}
// =============================================================================
// MOST RETRIEVED OBSERVATIONS TESTS
// =============================================================================
func (s *ScoringStoreSuite) TestGetMostRetrievedObservations() {
// Create observations with different retrieval counts
var ids []int64
for i := 0; i < 5; i++ {
obs := &models.ParsedObservation{
Type: models.ObsTypeDiscovery,
}
id, _, err := s.obsStore.StoreObservation(s.ctx, "session-1", "project-a", obs, i, 100)
s.NoError(err)
ids = append(ids, id)
}
// Set different retrieval counts
for i := 0; i < 10; i++ {
_ = s.obsStore.IncrementRetrievalCount(s.ctx, []int64{ids[0]}) // 10 retrievals
}
for i := 0; i < 5; i++ {
_ = s.obsStore.IncrementRetrievalCount(s.ctx, []int64{ids[1]}) // 5 retrievals
}
for i := 0; i < 3; i++ {
_ = s.obsStore.IncrementRetrievalCount(s.ctx, []int64{ids[2]}) // 3 retrievals
}
// ids[3] and ids[4] have 0 retrievals
// Get top 3
most, err := s.obsStore.GetMostRetrievedObservations(s.ctx, "", 3)
s.NoError(err)
s.Len(most, 3)
// Verify ordered by retrieval count descending
s.Equal(10, most[0].RetrievalCount)
s.Equal(5, most[1].RetrievalCount)
s.Equal(3, most[2].RetrievalCount)
}
func (s *ScoringStoreSuite) TestGetMostRetrievedObservations_NoRetrievals() {
// Create observations without any retrievals
obs := &models.ParsedObservation{
Type: models.ObsTypeDiscovery,
}
_, _, err := s.obsStore.StoreObservation(s.ctx, "session-1", "project-a", obs, 1, 100)
s.NoError(err)
most, err := s.obsStore.GetMostRetrievedObservations(s.ctx, "", 10)
s.NoError(err)
s.Empty(most) // No observations with retrieval_count > 0
}
// =============================================================================
// RESET OBSERVATION SCORES TESTS
// =============================================================================
func (s *ScoringStoreSuite) TestResetObservationScores() {
// Create observations with various scores
for i := 0; i < 5; i++ {
obs := &models.ParsedObservation{
Type: models.ObsTypeDiscovery,
}
id, _, err := s.obsStore.StoreObservation(s.ctx, "session-1", "project-a", obs, i, 100)
s.NoError(err)
_ = s.obsStore.UpdateImportanceScore(s.ctx, id, float64(i+1))
}
// Reset all scores
err := s.obsStore.ResetObservationScores(s.ctx)
s.NoError(err)
// Verify all scores are reset to 1.0
observations, err := s.obsStore.GetAllRecentObservations(s.ctx, 100)
s.NoError(err)
for _, obs := range observations {
s.InDelta(1.0, obs.ImportanceScore, 0.001)
s.False(obs.ScoreUpdatedAt.Valid)
}
}
// =============================================================================
// EDGE CASES
// =============================================================================
func (s *ScoringStoreSuite) TestScoring_ZeroScore() {
obs := &models.ParsedObservation{
Type: models.ObsTypeBugfix,
}
id, _, err := s.obsStore.StoreObservation(s.ctx, "session-1", "project-a", obs, 1, 100)
s.NoError(err)
// Set score to 0
err = s.obsStore.UpdateImportanceScore(s.ctx, id, 0.0)
s.NoError(err)
retrieved, err := s.obsStore.GetObservationByID(s.ctx, id)
s.NoError(err)
s.InDelta(0.0, retrieved.ImportanceScore, 0.001)
}
func (s *ScoringStoreSuite) TestScoring_NegativeScore() {
obs := &models.ParsedObservation{
Type: models.ObsTypeBugfix,
}
id, _, err := s.obsStore.StoreObservation(s.ctx, "session-1", "project-a", obs, 1, 100)
s.NoError(err)
// Set negative score (calculator shouldn't produce this, but test DB handling)
err = s.obsStore.UpdateImportanceScore(s.ctx, id, -0.5)
s.NoError(err)
retrieved, err := s.obsStore.GetObservationByID(s.ctx, id)
s.NoError(err)
s.InDelta(-0.5, retrieved.ImportanceScore, 0.001)
}
func (s *ScoringStoreSuite) TestScoring_LargeScore() {
obs := &models.ParsedObservation{
Type: models.ObsTypeBugfix,
}
id, _, err := s.obsStore.StoreObservation(s.ctx, "session-1", "project-a", obs, 1, 100)
s.NoError(err)
// Set very large score
err = s.obsStore.UpdateImportanceScore(s.ctx, id, 999.999)
s.NoError(err)
retrieved, err := s.obsStore.GetObservationByID(s.ctx, id)
s.NoError(err)
s.InDelta(999.999, retrieved.ImportanceScore, 0.001)
}
func (s *ScoringStoreSuite) TestConceptWeight_ZeroWeight() {
err := s.obsStore.UpdateConceptWeight(s.ctx, "zero-concept", 0.0)
s.NoError(err)
weights, err := s.obsStore.GetConceptWeights(s.ctx)
s.NoError(err)
s.Equal(0.0, weights["zero-concept"])
}
func (s *ScoringStoreSuite) TestConceptWeight_ExactBoundary() {
err := s.obsStore.UpdateConceptWeight(s.ctx, "max-concept", 1.0)
s.NoError(err)
weights, err := s.obsStore.GetConceptWeights(s.ctx)
s.NoError(err)
s.Equal(1.0, weights["max-concept"])
}
// =============================================================================
// STANDALONE TESTS
// =============================================================================
func TestFeedbackStats_Structure(t *testing.T) {
stats := FeedbackStats{
Total: 100,
Positive: 30,
Negative: 10,
Neutral: 60,
AvgScore: 1.5,
AvgRetrieval: 5.0,
}
assert.Equal(t, 100, stats.Total)
assert.Equal(t, 30, stats.Positive)
assert.Equal(t, 10, stats.Negative)
assert.Equal(t, 60, stats.Neutral)
assert.Equal(t, 1.5, stats.AvgScore)
assert.Equal(t, 5.0, stats.AvgRetrieval)
}
func TestScoringStore_Integration(t *testing.T) {
obsStore, _, cleanup := testScoringObservationStore(t)
defer cleanup()
ctx := context.Background()
// Full integration test: store, feedback, retrieval, score update
obs := &models.ParsedObservation{
Type: models.ObsTypeBugfix,
Title: "Integration test observation",
Concepts: []string{"security"},
}
id, _, err := obsStore.StoreObservation(ctx, "session-int", "project-int", obs, 1, 100)
require.NoError(t, err)
// Add feedback
err = obsStore.UpdateObservationFeedback(ctx, id, 1)
require.NoError(t, err)
// Increment retrieval
err = obsStore.IncrementRetrievalCount(ctx, []int64{id})
require.NoError(t, err)
// Update score
err = obsStore.UpdateImportanceScore(ctx, id, 1.75)
require.NoError(t, err)
// Verify final state
retrieved, err := obsStore.GetObservationByID(ctx, id)
require.NoError(t, err)
assert.Equal(t, 1, retrieved.UserFeedback)
assert.Equal(t, 1, retrieved.RetrievalCount)
assert.InDelta(t, 1.75, retrieved.ImportanceScore, 0.001)
assert.True(t, retrieved.ScoreUpdatedAt.Valid)
assert.True(t, retrieved.LastRetrievedAt.Valid)
}
+18
View File
@@ -116,3 +116,21 @@ func (s *SummaryStore) GetAllRecentSummaries(ctx context.Context, limit int) ([]
return scanSummaryRows(rows) return scanSummaryRows(rows)
} }
// GetAllSummaries retrieves all summaries (for vector rebuild).
func (s *SummaryStore) GetAllSummaries(ctx context.Context) ([]*models.SessionSummary, error) {
const query = `
SELECT id, sdk_session_id, project, request, investigated, learned, completed,
next_steps, notes, prompt_number, discovery_tokens, created_at, created_at_epoch
FROM session_summaries
ORDER BY id
`
rows, err := s.store.QueryContext(ctx, query)
if err != nil {
return nil, err
}
defer rows.Close()
return scanSummaryRows(rows)
}
+52
View File
@@ -103,6 +103,12 @@ func createBaseTables(t *testing.T, db *sql.DB) {
discovery_tokens INTEGER DEFAULT 0, discovery_tokens INTEGER DEFAULT 0,
created_at TEXT NOT NULL, created_at TEXT NOT NULL,
created_at_epoch INTEGER NOT NULL, created_at_epoch INTEGER NOT NULL,
importance_score REAL DEFAULT 1.0,
user_feedback INTEGER DEFAULT 0,
retrieval_count INTEGER DEFAULT 0,
last_retrieved_at_epoch INTEGER,
score_updated_at_epoch INTEGER,
is_superseded INTEGER DEFAULT 0,
FOREIGN KEY(sdk_session_id) REFERENCES sdk_sessions(sdk_session_id) ON DELETE CASCADE FOREIGN KEY(sdk_session_id) REFERENCES sdk_sessions(sdk_session_id) ON DELETE CASCADE
) )
`) `)
@@ -110,6 +116,27 @@ func createBaseTables(t *testing.T, db *sql.DB) {
t.Fatalf("create observations: %v", err) t.Fatalf("create observations: %v", err)
} }
// Create observation_conflicts table for conflict detection
_, err = db.Exec(`
CREATE TABLE IF NOT EXISTS observation_conflicts (
id INTEGER PRIMARY KEY AUTOINCREMENT,
newer_obs_id INTEGER NOT NULL,
older_obs_id INTEGER NOT NULL,
conflict_type TEXT NOT NULL CHECK(conflict_type IN ('superseded', 'contradicts', 'outdated_pattern')),
resolution TEXT NOT NULL CHECK(resolution IN ('prefer_newer', 'prefer_older', 'manual')),
reason TEXT,
detected_at TEXT NOT NULL,
detected_at_epoch INTEGER NOT NULL,
resolved INTEGER DEFAULT 0,
resolved_at TEXT,
FOREIGN KEY(newer_obs_id) REFERENCES observations(id) ON DELETE CASCADE,
FOREIGN KEY(older_obs_id) REFERENCES observations(id) ON DELETE CASCADE
)
`)
if err != nil {
t.Fatalf("create observation_conflicts: %v", err)
}
_, err = db.Exec(` _, err = db.Exec(`
CREATE TABLE IF NOT EXISTS session_summaries ( CREATE TABLE IF NOT EXISTS session_summaries (
id INTEGER PRIMARY KEY AUTOINCREMENT, id INTEGER PRIMARY KEY AUTOINCREMENT,
@@ -150,6 +177,31 @@ func createBaseTables(t *testing.T, db *sql.DB) {
t.Fatalf("create user_prompts: %v", err) t.Fatalf("create user_prompts: %v", err)
} }
_, err = db.Exec(`
CREATE TABLE IF NOT EXISTS patterns (
id INTEGER PRIMARY KEY AUTOINCREMENT,
name TEXT NOT NULL,
type TEXT NOT NULL CHECK(type IN ('bug', 'refactor', 'architecture', 'anti-pattern', 'best-practice')),
description TEXT,
signature TEXT,
recommendation TEXT,
frequency INTEGER DEFAULT 1,
projects TEXT,
observation_ids TEXT,
status TEXT DEFAULT 'active' CHECK(status IN ('active', 'deprecated', 'merged')),
merged_into_id INTEGER,
confidence REAL DEFAULT 0.5,
last_seen_at TEXT NOT NULL,
last_seen_at_epoch INTEGER NOT NULL,
created_at TEXT NOT NULL,
created_at_epoch INTEGER NOT NULL,
FOREIGN KEY(merged_into_id) REFERENCES patterns(id) ON DELETE SET NULL
)
`)
if err != nil {
t.Fatalf("create patterns: %v", err)
}
indexes := []string{ indexes := []string{
`CREATE INDEX IF NOT EXISTS idx_sdk_sessions_claude_id ON sdk_sessions(claude_session_id)`, `CREATE INDEX IF NOT EXISTS idx_sdk_sessions_claude_id ON sdk_sessions(claude_session_id)`,
`CREATE INDEX IF NOT EXISTS idx_sdk_sessions_sdk_id ON sdk_sessions(sdk_session_id)`, `CREATE INDEX IF NOT EXISTS idx_sdk_sessions_sdk_id ON sdk_sessions(sdk_session_id)`,
+1
View File
@@ -0,0 +1 @@
bge-small-en-v1.5
Binary file not shown.
+2 -16
View File
@@ -1,21 +1,7 @@
{ {
"version": "1.0", "version": "1.0",
"truncation": { "truncation": null,
"direction": "Right", "padding": null,
"max_length": 128,
"strategy": "LongestFirst",
"stride": 0
},
"padding": {
"strategy": {
"Fixed": 128
},
"direction": "Right",
"pad_to_multiple_of": null,
"pad_id": 0,
"pad_type_id": 0,
"pad_token": "[PAD]"
},
"added_tokens": [ "added_tokens": [
{ {
"id": 0, "id": 0,
+157
View File
@@ -0,0 +1,157 @@
// Package embedding provides text embedding generation with swappable models.
package embedding
import (
"fmt"
"sync"
)
// PoolingStrategy defines how to pool token embeddings into sentence embeddings.
type PoolingStrategy string
const (
// PoolingNone means the model already outputs sentence embeddings directly.
PoolingNone PoolingStrategy = "none"
// PoolingMean averages all token embeddings (weighted by attention mask).
PoolingMean PoolingStrategy = "mean"
// PoolingCLS uses only the [CLS] token embedding.
PoolingCLS PoolingStrategy = "cls"
)
// ONNXConfig describes ONNX-specific model configuration.
// This allows different models to specify their tensor names and pooling needs.
type ONNXConfig struct {
// InputNames are the ONNX input tensor names in order.
InputNames []string
// OutputNames are the ONNX output tensor names.
OutputNames []string
// Pooling specifies how to convert token embeddings to sentence embeddings.
// If PoolingNone, the model outputs sentence embeddings directly.
Pooling PoolingStrategy
// HiddenSize is the embedding dimension (used for pooling calculations).
HiddenSize int
}
// EmbeddingModel represents a text embedding model.
type EmbeddingModel interface {
// Name returns the human-readable model name (e.g., "bge-small-en-v1.5").
Name() string
// Version returns a short version string for storage (e.g., "bge-v1.5").
Version() string
// Dimensions returns the embedding vector size.
Dimensions() int
// Embed generates an embedding for a single text.
Embed(text string) ([]float32, error)
// EmbedBatch generates embeddings for multiple texts.
EmbedBatch(texts []string) ([][]float32, error)
// Close releases model resources.
Close() error
}
// ONNXConfigurer is an optional interface that models can implement
// to expose their ONNX configuration for introspection.
type ONNXConfigurer interface {
// ONNXConfig returns the model's ONNX configuration.
ONNXConfig() ONNXConfig
}
// ModelMetadata describes an embedding model for UI/config.
type ModelMetadata struct {
Name string `json:"name"` // Human-readable name
Version string `json:"version"` // Short ID for DB storage
Dimensions int `json:"dimensions"` // Vector size
Description string `json:"description"` // Brief description
Default bool `json:"default"` // Is this the default model?
}
// ModelFactory creates a new instance of an embedding model.
type ModelFactory func() (EmbeddingModel, error)
// ModelRegistry provides model lookup by version.
type ModelRegistry struct {
mu sync.RWMutex
models map[string]ModelFactory
metadata map[string]ModelMetadata
defaultModel string
}
// NewModelRegistry creates a new model registry.
func NewModelRegistry() *ModelRegistry {
return &ModelRegistry{
models: make(map[string]ModelFactory),
metadata: make(map[string]ModelMetadata),
}
}
// Register adds a model factory to the registry.
func (r *ModelRegistry) Register(meta ModelMetadata, factory ModelFactory) {
r.mu.Lock()
defer r.mu.Unlock()
r.models[meta.Version] = factory
r.metadata[meta.Version] = meta
if meta.Default {
r.defaultModel = meta.Version
}
}
// Get creates a new instance of the model with the given version.
func (r *ModelRegistry) Get(version string) (EmbeddingModel, error) {
r.mu.RLock()
factory, ok := r.models[version]
r.mu.RUnlock()
if !ok {
return nil, fmt.Errorf("unknown model version: %s", version)
}
return factory()
}
// Default returns the default model version.
func (r *ModelRegistry) Default() string {
r.mu.RLock()
defer r.mu.RUnlock()
return r.defaultModel
}
// List returns metadata for all registered models.
func (r *ModelRegistry) List() []ModelMetadata {
r.mu.RLock()
defer r.mu.RUnlock()
result := make([]ModelMetadata, 0, len(r.metadata))
for _, meta := range r.metadata {
result = append(result, meta)
}
return result
}
// DefaultRegistry is the global model registry with all available models.
var DefaultRegistry = NewModelRegistry()
// RegisterModel adds a model to the default registry.
func RegisterModel(meta ModelMetadata, factory ModelFactory) {
DefaultRegistry.Register(meta, factory)
}
// GetModel creates a model instance from the default registry.
func GetModel(version string) (EmbeddingModel, error) {
return DefaultRegistry.Get(version)
}
// GetDefaultModel returns the default model version from the default registry.
func GetDefaultModel() string {
return DefaultRegistry.Default()
}
// ListModels returns metadata for all models in the default registry.
func ListModels() []ModelMetadata {
return DefaultRegistry.List()
}
+277 -56
View File
@@ -1,4 +1,4 @@
// Package embedding provides text embedding generation using all-MiniLM-L6-v2. // Package embedding provides text embedding generation with swappable models.
package embedding package embedding
import ( import (
@@ -15,19 +15,50 @@ import (
ort "github.com/yalue/onnxruntime_go" ort "github.com/yalue/onnxruntime_go"
) )
// EmbeddingDim is the dimension of embeddings produced by all-MiniLM-L6-v2. // EmbeddingDim is the dimension of embeddings produced by the current model.
// Both all-MiniLM-L6-v2 and bge-small-en-v1.5 produce 384-dimensional embeddings.
const EmbeddingDim = 384 const EmbeddingDim = 384
// Service provides thread-safe text embedding generation. // Model version constants
type Service struct { const (
// BGEModelVersion is the version string for bge-small-en-v1.5
BGEModelVersion = "bge-v1.5"
// BGEModelName is the human-readable name for bge-small-en-v1.5
BGEModelName = "bge-small-en-v1.5"
// DefaultModelVersion is the default model to use
DefaultModelVersion = BGEModelVersion
)
// MaxSequenceLength is the maximum token sequence length for the model.
const MaxSequenceLength = 512
// bgeONNXConfig defines the ONNX configuration for BGE models.
// BGE outputs last_hidden_state and requires mean pooling.
var bgeONNXConfig = ONNXConfig{
InputNames: []string{"input_ids", "attention_mask", "token_type_ids"},
OutputNames: []string{"last_hidden_state"},
Pooling: PoolingMean,
HiddenSize: EmbeddingDim,
}
// bgeModel is the ONNX-based embedding model implementation.
// Currently supports bge-small-en-v1.5 (previously all-MiniLM-L6-v2).
type bgeModel struct {
tk *tokenizer.Tokenizer tk *tokenizer.Tokenizer
session *ort.DynamicAdvancedSession session *ort.DynamicAdvancedSession
mu sync.Mutex mu sync.Mutex
libDir string // temp directory containing extracted libraries libDir string // temp directory containing extracted libraries
config ONNXConfig // ONNX configuration for this model
} }
// NewService creates a new embedding service using bundled ONNX runtime and model. // Compile-time check that bgeModel implements EmbeddingModel
func NewService() (*Service, error) { var _ EmbeddingModel = (*bgeModel)(nil)
// Compile-time check that bgeModel implements ONNXConfigurer
var _ ONNXConfigurer = (*bgeModel)(nil)
// newBGEModel creates a new BGE embedding model using bundled ONNX runtime and model.
func newBGEModel() (EmbeddingModel, error) {
// Extract ONNX runtime library to temp directory // Extract ONNX runtime library to temp directory
libDir, err := extractONNXLibrary() libDir, err := extractONNXLibrary()
if err != nil { if err != nil {
@@ -49,22 +80,41 @@ func NewService() (*Service, error) {
return nil, fmt.Errorf("load tokenizer: %w", err) return nil, fmt.Errorf("load tokenizer: %w", err)
} }
// Create ONNX session with embedded model // Create ONNX session using model-specific configuration
inputNames := []string{"input_ids", "attention_mask", "token_type_ids"} config := bgeONNXConfig
outputNames := []string{"sentence_embedding"} session, err := ort.NewDynamicAdvancedSessionWithONNXData(modelData, config.InputNames, config.OutputNames, nil)
session, err := ort.NewDynamicAdvancedSessionWithONNXData(modelData, inputNames, outputNames, nil)
if err != nil { if err != nil {
return nil, fmt.Errorf("create ONNX session: %w", err) return nil, fmt.Errorf("create ONNX session: %w", err)
} }
return &Service{ return &bgeModel{
tk: tk, tk: tk,
session: session, session: session,
libDir: libDir, libDir: libDir,
config: config,
}, nil }, nil
} }
// ONNXConfig returns the model's ONNX configuration.
func (m *bgeModel) ONNXConfig() ONNXConfig {
return m.config
}
// Name returns the human-readable model name.
func (m *bgeModel) Name() string {
return BGEModelName
}
// Version returns the short version string for storage.
func (m *bgeModel) Version() string {
return BGEModelVersion
}
// Dimensions returns the embedding vector size.
func (m *bgeModel) Dimensions() int {
return EmbeddingDim
}
// extractONNXLibrary extracts the embedded ONNX runtime library to a temp directory. // extractONNXLibrary extracts the embedded ONNX runtime library to a temp directory.
// Uses content hash to avoid re-extracting if already present. // Uses content hash to avoid re-extracting if already present.
func extractONNXLibrary() (string, error) { func extractONNXLibrary() (string, error) {
@@ -107,15 +157,15 @@ func extractONNXLibrary() (string, error) {
// Embed generates an embedding for a single text. // Embed generates an embedding for a single text.
// Returns a 384-dimensional float32 vector. // Returns a 384-dimensional float32 vector.
func (s *Service) Embed(text string) ([]float32, error) { func (m *bgeModel) Embed(text string) ([]float32, error) {
s.mu.Lock() m.mu.Lock()
defer s.mu.Unlock() defer m.mu.Unlock()
if text == "" { if text == "" {
return make([]float32, EmbeddingDim), nil return make([]float32, EmbeddingDim), nil
} }
results, err := s.computeBatch([]string{text}) results, err := m.computeBatch([]string{text})
if err != nil { if err != nil {
return nil, err return nil, err
} }
@@ -127,13 +177,13 @@ func (s *Service) Embed(text string) ([]float32, error) {
// EmbedBatch generates embeddings for multiple texts. // EmbedBatch generates embeddings for multiple texts.
// Returns slice of 384-dimensional float32 vectors. // Returns slice of 384-dimensional float32 vectors.
func (s *Service) EmbedBatch(texts []string) ([][]float32, error) { func (m *bgeModel) EmbedBatch(texts []string) ([][]float32, error) {
if len(texts) == 0 { if len(texts) == 0 {
return nil, nil return nil, nil
} }
s.mu.Lock() m.mu.Lock()
defer s.mu.Unlock() defer m.mu.Unlock()
// Filter out empty texts and track indices // Filter out empty texts and track indices
nonEmpty := make([]string, 0, len(texts)) nonEmpty := make([]string, 0, len(texts))
@@ -155,7 +205,7 @@ func (s *Service) EmbedBatch(texts []string) ([][]float32, error) {
} }
// Compute embeddings for non-empty texts // Compute embeddings for non-empty texts
embeddings, err := s.computeBatch(nonEmpty) embeddings, err := m.computeBatch(nonEmpty)
if err != nil { if err != nil {
return nil, fmt.Errorf("compute batch embeddings: %w", err) return nil, fmt.Errorf("compute batch embeddings: %w", err)
} }
@@ -173,7 +223,7 @@ func (s *Service) EmbedBatch(texts []string) ([][]float32, error) {
} }
// computeBatch runs inference on a batch of texts. Must be called with lock held. // computeBatch runs inference on a batch of texts. Must be called with lock held.
func (s *Service) computeBatch(sentences []string) ([][]float32, error) { func (m *bgeModel) computeBatch(sentences []string) ([][]float32, error) {
if len(sentences) == 0 { if len(sentences) == 0 {
return nil, nil return nil, nil
} }
@@ -184,31 +234,57 @@ func (s *Service) computeBatch(sentences []string) ([][]float32, error) {
inputBatch[i] = tokenizer.NewSingleEncodeInput(tokenizer.NewRawInputSequence(sent)) inputBatch[i] = tokenizer.NewSingleEncodeInput(tokenizer.NewRawInputSequence(sent))
} }
encodings, err := s.tk.EncodeBatch(inputBatch, true) encodings, err := m.tk.EncodeBatch(inputBatch, true)
if err != nil { if err != nil {
return nil, fmt.Errorf("tokenize: %w", err) return nil, fmt.Errorf("tokenize: %w", err)
} }
batchSize := len(encodings) batchSize := len(encodings)
seqLength := len(encodings[0].Ids) hiddenSize := m.config.HiddenSize
hiddenSize := EmbeddingDim
// Find max sequence length across all encodings (tokenizer may not pad uniformly)
// Also enforce MaxSequenceLength to prevent model errors
seqLength := 0
for _, enc := range encodings {
if len(enc.Ids) > seqLength {
seqLength = len(enc.Ids)
}
}
// Truncate to max model sequence length
if seqLength > MaxSequenceLength {
seqLength = MaxSequenceLength
}
inputShape := ort.NewShape(int64(batchSize), int64(seqLength)) inputShape := ort.NewShape(int64(batchSize), int64(seqLength))
// Create input tensors // Create input tensors (pre-filled with zeros for padding)
inputIdsData := make([]int64, batchSize*seqLength) inputIdsData := make([]int64, batchSize*seqLength)
attentionMaskData := make([]int64, batchSize*seqLength) attentionMaskData := make([]int64, batchSize*seqLength)
tokenTypeIdsData := make([]int64, batchSize*seqLength) tokenTypeIdsData := make([]int64, batchSize*seqLength)
for b := 0; b < batchSize; b++ { for b := 0; b < batchSize; b++ {
for i, id := range encodings[b].Ids { // Copy actual token data (rest remains 0 as padding)
inputIdsData[b*seqLength+i] = int64(id) // Truncate to seqLength to handle long inputs
copyLen := len(encodings[b].Ids)
if copyLen > seqLength {
copyLen = seqLength
} }
for i, mask := range encodings[b].AttentionMask { for i := 0; i < copyLen; i++ {
attentionMaskData[b*seqLength+i] = int64(mask) inputIdsData[b*seqLength+i] = int64(encodings[b].Ids[i])
} }
for i, typeId := range encodings[b].TypeIds { copyLen = len(encodings[b].AttentionMask)
tokenTypeIdsData[b*seqLength+i] = int64(typeId) if copyLen > seqLength {
copyLen = seqLength
}
for i := 0; i < copyLen; i++ {
attentionMaskData[b*seqLength+i] = int64(encodings[b].AttentionMask[i])
}
copyLen = len(encodings[b].TypeIds)
if copyLen > seqLength {
copyLen = seqLength
}
for i := 0; i < copyLen; i++ {
tokenTypeIdsData[b*seqLength+i] = int64(encodings[b].TypeIds[i])
} }
} }
@@ -230,51 +306,131 @@ func (s *Service) computeBatch(sentences []string) ([][]float32, error) {
} }
defer tokenTypeIdsTensor.Destroy() defer tokenTypeIdsTensor.Destroy()
sentenceOutputShape := ort.NewShape(int64(batchSize), int64(hiddenSize)) // Create output tensor based on pooling strategy
sentenceOutputTensor, err := ort.NewEmptyTensor[float32](sentenceOutputShape) var outputShape ort.Shape
switch m.config.Pooling {
case PoolingNone:
// Direct sentence embedding output: [batch, hidden]
outputShape = ort.NewShape(int64(batchSize), int64(hiddenSize))
case PoolingMean, PoolingCLS:
// Token-level output: [batch, seq_len, hidden]
outputShape = ort.NewShape(int64(batchSize), int64(seqLength), int64(hiddenSize))
default:
outputShape = ort.NewShape(int64(batchSize), int64(hiddenSize))
}
outputTensor, err := ort.NewEmptyTensor[float32](outputShape)
if err != nil { if err != nil {
return nil, fmt.Errorf("create output tensor: %w", err) return nil, fmt.Errorf("create output tensor: %w", err)
} }
defer sentenceOutputTensor.Destroy() defer outputTensor.Destroy()
// Run inference // Run inference
inputTensors := []ort.Value{inputIdsTensor, attentionMaskTensor, tokenTypeIdsTensor} inputTensors := []ort.Value{inputIdsTensor, attentionMaskTensor, tokenTypeIdsTensor}
outputTensors := []ort.Value{sentenceOutputTensor} outputTensors := []ort.Value{outputTensor}
if err := s.session.Run(inputTensors, outputTensors); err != nil { if err := m.session.Run(inputTensors, outputTensors); err != nil {
return nil, fmt.Errorf("run inference: %w", err) return nil, fmt.Errorf("run inference: %w", err)
} }
// Extract results // Extract and pool results based on strategy
flatOutput := sentenceOutputTensor.GetData() flatOutput := outputTensor.GetData()
expectedSize := batchSize * hiddenSize
if len(flatOutput) != expectedSize {
return nil, fmt.Errorf("unexpected output size: got %d, expected %d", len(flatOutput), expectedSize)
}
switch m.config.Pooling {
case PoolingNone:
// Direct output, no pooling needed
expectedSize := batchSize * hiddenSize
if len(flatOutput) != expectedSize {
return nil, fmt.Errorf("unexpected output size: got %d, expected %d", len(flatOutput), expectedSize)
}
results := make([][]float32, batchSize)
for i := 0; i < batchSize; i++ {
start := i * hiddenSize
end := start + hiddenSize
results[i] = make([]float32, hiddenSize)
copy(results[i], flatOutput[start:end])
}
return results, nil
case PoolingMean:
// Mean pooling over tokens (weighted by attention mask)
return meanPooling(flatOutput, attentionMaskData, batchSize, seqLength, hiddenSize), nil
case PoolingCLS:
// CLS token pooling (first token of each sequence)
return clsPooling(flatOutput, batchSize, seqLength, hiddenSize), nil
default:
return nil, fmt.Errorf("unknown pooling strategy: %s", m.config.Pooling)
}
}
// meanPooling applies mean pooling over token embeddings, weighted by attention mask.
// Input shape: [batch, seq_len, hidden], attention mask: [batch, seq_len]
// Output shape: [batch, hidden]
func meanPooling(embeddings []float32, attentionMask []int64, batchSize, seqLen, hiddenSize int) [][]float32 {
results := make([][]float32, batchSize) results := make([][]float32, batchSize)
for i := 0; i < batchSize; i++ {
start := i * hiddenSize for b := 0; b < batchSize; b++ {
end := start + hiddenSize result := make([]float32, hiddenSize)
results[i] = make([]float32, hiddenSize) var maskSum float32
copy(results[i], flatOutput[start:end])
// Sum embeddings weighted by attention mask
for s := 0; s < seqLen; s++ {
maskVal := float32(attentionMask[b*seqLen+s])
maskSum += maskVal
if maskVal > 0 {
embOffset := (b*seqLen + s) * hiddenSize
for h := 0; h < hiddenSize; h++ {
result[h] += embeddings[embOffset+h] * maskVal
}
}
}
// Normalize by mask sum (avoid division by zero)
if maskSum > 0 {
for h := 0; h < hiddenSize; h++ {
result[h] /= maskSum
}
}
results[b] = result
} }
return results, nil return results
}
// clsPooling extracts the [CLS] token embedding (first token).
// Input shape: [batch, seq_len, hidden]
// Output shape: [batch, hidden]
func clsPooling(embeddings []float32, batchSize, seqLen, hiddenSize int) [][]float32 {
results := make([][]float32, batchSize)
for b := 0; b < batchSize; b++ {
result := make([]float32, hiddenSize)
// CLS token is at position 0
embOffset := b * seqLen * hiddenSize
copy(result, embeddings[embOffset:embOffset+hiddenSize])
results[b] = result
}
return results
} }
// Close releases model resources. // Close releases model resources.
func (s *Service) Close() error { func (m *bgeModel) Close() error {
s.mu.Lock() m.mu.Lock()
defer s.mu.Unlock() defer m.mu.Unlock()
var errs []error var errs []error
if s.session != nil { if m.session != nil {
if err := s.session.Destroy(); err != nil { if err := m.session.Destroy(); err != nil {
errs = append(errs, fmt.Errorf("destroy session: %w", err)) errs = append(errs, fmt.Errorf("destroy session: %w", err))
} }
s.session = nil m.session = nil
} }
if err := ort.DestroyEnvironment(); err != nil { if err := ort.DestroyEnvironment(); err != nil {
@@ -282,10 +438,75 @@ func (s *Service) Close() error {
} }
// Optionally clean up extracted library (leave for caching) // Optionally clean up extracted library (leave for caching)
// os.RemoveAll(s.libDir) // os.RemoveAll(m.libDir)
if len(errs) > 0 { if len(errs) > 0 {
return errs[0] return errs[0]
} }
return nil return nil
} }
// Register the BGE model with the default registry at init time
func init() {
RegisterModel(ModelMetadata{
Name: BGEModelName,
Version: BGEModelVersion,
Dimensions: EmbeddingDim,
Description: "High-quality semantic search model",
Default: true,
}, newBGEModel)
}
// Service provides thread-safe text embedding generation with model abstraction.
type Service struct {
model EmbeddingModel
}
// NewService creates a new embedding service using the default model.
func NewService() (*Service, error) {
return NewServiceWithModel(DefaultModelVersion)
}
// NewServiceWithModel creates a new embedding service using the specified model.
func NewServiceWithModel(version string) (*Service, error) {
if version == "" {
version = DefaultModelVersion
}
model, err := GetModel(version)
if err != nil {
return nil, fmt.Errorf("get model %s: %w", version, err)
}
return &Service{model: model}, nil
}
// Name returns the human-readable model name.
func (s *Service) Name() string {
return s.model.Name()
}
// Version returns the short version string for storage.
func (s *Service) Version() string {
return s.model.Version()
}
// Dimensions returns the embedding vector size.
func (s *Service) Dimensions() int {
return s.model.Dimensions()
}
// Embed generates an embedding for a single text.
func (s *Service) Embed(text string) ([]float32, error) {
return s.model.Embed(text)
}
// EmbedBatch generates embeddings for multiple texts.
func (s *Service) EmbedBatch(texts []string) ([][]float32, error) {
return s.model.EmbedBatch(texts)
}
// Close releases model resources.
func (s *Service) Close() error {
return s.model.Close()
}
+6 -4
View File
@@ -22,8 +22,10 @@ func TestNewService(t *testing.T) {
defer svc.Close() defer svc.Close()
assert.NotNil(t, svc.tk) // Verify the service is properly initialized via public methods
assert.NotNil(t, svc.session) assert.NotEmpty(t, svc.Name())
assert.NotEmpty(t, svc.Version())
assert.Equal(t, EmbeddingDim, svc.Dimensions())
} }
// TestEmbed_SingleText tests embedding a single text. // TestEmbed_SingleText tests embedding a single text.
@@ -269,8 +271,8 @@ func TestClose(t *testing.T) {
err = svc.Close() err = svc.Close()
require.NoError(t, err) require.NoError(t, err)
// Session should be nil after close // After close, embedding should fail (model resources released)
assert.Nil(t, svc.session) // Note: This behavior is model-specific; some models may still work after close
} }
// TestEmbedBatch_SingleItem tests batch embedding with single item. // TestEmbedBatch_SingleItem tests batch embedding with single item.
+438
View File
@@ -0,0 +1,438 @@
// Package pattern provides pattern detection and recognition functionality.
package pattern
import (
"context"
"sync"
"time"
"github.com/lukaszraczylo/claude-mnemonic/internal/db/sqlite"
"github.com/lukaszraczylo/claude-mnemonic/pkg/models"
"github.com/rs/zerolog/log"
)
// DetectorConfig contains configuration for the pattern detector.
type DetectorConfig struct {
// MinMatchScore is the minimum similarity score to consider a match (0.0-1.0).
MinMatchScore float64
// MinFrequencyForPattern is the minimum occurrences before creating a pattern.
MinFrequencyForPattern int
// AnalysisInterval is how often to run background pattern analysis.
AnalysisInterval time.Duration
// MaxPatternsToTrack is the maximum number of active patterns.
MaxPatternsToTrack int
}
// DefaultConfig returns the default detector configuration.
func DefaultConfig() DetectorConfig {
return DetectorConfig{
MinMatchScore: 0.3, // 30% similarity threshold
MinFrequencyForPattern: 2, // At least 2 occurrences to form a pattern
AnalysisInterval: 5 * time.Minute,
MaxPatternsToTrack: 1000,
}
}
// PatternSyncFunc is a callback for syncing patterns to vector store.
type PatternSyncFunc func(pattern *models.Pattern)
// Detector detects and tracks recurring patterns across observations.
type Detector struct {
config DetectorConfig
patternStore *sqlite.PatternStore
observationStore *sqlite.ObservationStore
// Vector sync callback
syncFunc PatternSyncFunc
// Candidate tracking (patterns not yet confirmed)
candidates map[string]*candidatePattern
candidatesMu sync.RWMutex
// Background analysis
ctx context.Context
cancel context.CancelFunc
wg sync.WaitGroup
}
// SetSyncFunc sets the callback for syncing patterns to vector store.
func (d *Detector) SetSyncFunc(fn PatternSyncFunc) {
d.syncFunc = fn
}
// candidatePattern tracks a potential pattern before it reaches frequency threshold.
type candidatePattern struct {
signature []string
observationIDs []int64
projects []string
patternType models.PatternType
title string
lastSeenEpoch int64
}
// NewDetector creates a new pattern detector.
func NewDetector(patternStore *sqlite.PatternStore, observationStore *sqlite.ObservationStore, config DetectorConfig) *Detector {
ctx, cancel := context.WithCancel(context.Background())
return &Detector{
config: config,
patternStore: patternStore,
observationStore: observationStore,
candidates: make(map[string]*candidatePattern),
ctx: ctx,
cancel: cancel,
}
}
// Start begins background pattern analysis.
func (d *Detector) Start() {
d.wg.Add(1)
go d.backgroundAnalysis()
log.Info().Dur("interval", d.config.AnalysisInterval).Msg("Pattern detector started")
}
// Stop stops background pattern analysis.
func (d *Detector) Stop() {
d.cancel()
d.wg.Wait()
log.Info().Msg("Pattern detector stopped")
}
// backgroundAnalysis runs periodic pattern analysis.
func (d *Detector) backgroundAnalysis() {
defer d.wg.Done()
ticker := time.NewTicker(d.config.AnalysisInterval)
defer ticker.Stop()
for {
select {
case <-d.ctx.Done():
return
case <-ticker.C:
if err := d.AnalyzeRecentObservations(d.ctx); err != nil {
log.Warn().Err(err).Msg("Background pattern analysis failed")
}
}
}
}
// AnalyzeObservation processes a new observation for pattern detection.
// This is called synchronously when a new observation is stored.
func (d *Detector) AnalyzeObservation(ctx context.Context, obs *models.Observation) (*DetectionResult, error) {
result := &DetectionResult{}
// Extract signature from observation
signature := models.ExtractSignature(
obs.Concepts,
obs.Title.String,
obs.Narrative.String,
)
if len(signature) == 0 {
return result, nil // Nothing to detect
}
// Check against existing patterns
matches, err := d.patternStore.FindMatchingPatterns(ctx, signature, d.config.MinMatchScore)
if err != nil {
return nil, err
}
if len(matches) > 0 {
// Found existing pattern match
bestMatch := matches[0]
for _, m := range matches[1:] {
if models.CalculateMatchScore(signature, m.Signature) > models.CalculateMatchScore(signature, bestMatch.Signature) {
bestMatch = m
}
}
// Update the pattern with new occurrence
if err := d.patternStore.IncrementPatternFrequency(ctx, bestMatch.ID, obs.Project, obs.ID); err != nil {
log.Warn().Err(err).Int64("pattern_id", bestMatch.ID).Msg("Failed to update pattern frequency")
}
result.MatchedPattern = bestMatch
result.MatchScore = models.CalculateMatchScore(signature, bestMatch.Signature)
result.IsNewPattern = false
log.Debug().
Int64("pattern_id", bestMatch.ID).
Str("pattern_name", bestMatch.Name).
Float64("score", result.MatchScore).
Msg("Observation matched existing pattern")
return result, nil
}
// No existing pattern match - check candidates
candidateKey := generateCandidateKey(signature)
d.candidatesMu.Lock()
defer d.candidatesMu.Unlock()
if candidate, exists := d.candidates[candidateKey]; exists {
// Update existing candidate
candidate.observationIDs = append(candidate.observationIDs, obs.ID)
// Add project if not already present
found := false
for _, p := range candidate.projects {
if p == obs.Project {
found = true
break
}
}
if !found {
candidate.projects = append(candidate.projects, obs.Project)
}
candidate.lastSeenEpoch = time.Now().UnixMilli()
// Check if candidate should be promoted to pattern
if len(candidate.observationIDs) >= d.config.MinFrequencyForPattern {
pattern, err := d.promoteCandidate(ctx, candidateKey, candidate)
if err != nil {
log.Warn().Err(err).Msg("Failed to promote candidate to pattern")
} else {
result.MatchedPattern = pattern
result.IsNewPattern = true
log.Info().
Str("pattern_name", pattern.Name).
Int("frequency", pattern.Frequency).
Msg("New pattern detected and stored")
}
}
} else {
// Create new candidate
patternType := models.DetectPatternType(obs.Concepts, obs.Title.String, obs.Narrative.String)
d.candidates[candidateKey] = &candidatePattern{
signature: signature,
observationIDs: []int64{obs.ID},
projects: []string{obs.Project},
patternType: patternType,
title: obs.Title.String,
lastSeenEpoch: time.Now().UnixMilli(),
}
log.Debug().
Str("candidate_key", candidateKey).
Strs("signature", signature).
Msg("New pattern candidate created")
}
return result, nil
}
// promoteCandidate converts a candidate to a stored pattern.
func (d *Detector) promoteCandidate(ctx context.Context, key string, candidate *candidatePattern) (*models.Pattern, error) {
// Generate pattern name from signature
name := generatePatternName(candidate.patternType, candidate.signature, candidate.title)
// Create base pattern using NewPattern with first observation
firstProject := ""
if len(candidate.projects) > 0 {
firstProject = candidate.projects[0]
}
var firstObsID int64
if len(candidate.observationIDs) > 0 {
firstObsID = candidate.observationIDs[0]
}
pattern := models.NewPattern(
name,
candidate.patternType,
"Automatically detected pattern from recurring observations",
candidate.signature,
firstProject,
firstObsID,
)
// Add remaining projects and observations
for i := 1; i < len(candidate.projects); i++ {
pattern.Projects = append(pattern.Projects, candidate.projects[i])
}
for i := 1; i < len(candidate.observationIDs); i++ {
pattern.ObservationIDs = append(pattern.ObservationIDs, candidate.observationIDs[i])
}
pattern.Frequency = len(candidate.observationIDs)
id, err := d.patternStore.StorePattern(ctx, pattern)
if err != nil {
return nil, err
}
pattern.ID = id
// Sync to vector store if callback is set
if d.syncFunc != nil {
d.syncFunc(pattern)
}
// Remove from candidates
delete(d.candidates, key)
return pattern, nil
}
// AnalyzeRecentObservations analyzes recent observations for pattern detection.
// This is used for background batch analysis.
func (d *Detector) AnalyzeRecentObservations(ctx context.Context) error {
// Get observations from the last 24 hours that haven't been analyzed
observations, err := d.observationStore.GetRecentObservations(ctx, "", 100)
if err != nil {
return err
}
analyzed := 0
patternsFound := 0
for _, obs := range observations {
result, err := d.AnalyzeObservation(ctx, obs)
if err != nil {
log.Warn().Err(err).Int64("obs_id", obs.ID).Msg("Failed to analyze observation")
continue
}
analyzed++
if result.MatchedPattern != nil {
patternsFound++
}
}
if analyzed > 0 {
log.Info().
Int("analyzed", analyzed).
Int("patterns_found", patternsFound).
Msg("Background pattern analysis completed")
}
// Clean up old candidates (older than 7 days)
d.cleanupOldCandidates()
return nil
}
// cleanupOldCandidates removes candidates that haven't been seen recently.
func (d *Detector) cleanupOldCandidates() {
d.candidatesMu.Lock()
defer d.candidatesMu.Unlock()
threshold := time.Now().Add(-7 * 24 * time.Hour).UnixMilli()
for key, candidate := range d.candidates {
if candidate.lastSeenEpoch < threshold {
delete(d.candidates, key)
}
}
}
// GetPatternInsight returns a formatted insight string for a pattern.
func (d *Detector) GetPatternInsight(ctx context.Context, patternID int64) (string, error) {
pattern, err := d.patternStore.GetPatternByID(ctx, patternID)
if err != nil {
return "", err
}
return formatPatternInsight(pattern), nil
}
// DetectionResult contains the result of pattern detection.
type DetectionResult struct {
MatchedPattern *models.Pattern
MatchScore float64
IsNewPattern bool
}
// generateCandidateKey creates a unique key for a signature.
func generateCandidateKey(signature []string) string {
if len(signature) == 0 {
return ""
}
key := ""
for _, s := range signature {
key += s + "|"
}
return key
}
// generatePatternName creates a human-readable name for a pattern.
func generatePatternName(patternType models.PatternType, signature []string, title string) string {
// Use title if available and meaningful
if title != "" && len(title) < 60 {
return title
}
// Otherwise generate from type and signature
prefix := ""
switch patternType {
case models.PatternTypeBug:
prefix = "Bug Pattern: "
case models.PatternTypeRefactor:
prefix = "Refactor Pattern: "
case models.PatternTypeArchitecture:
prefix = "Architecture Pattern: "
case models.PatternTypeAntiPattern:
prefix = "Anti-Pattern: "
case models.PatternTypeBestPractice:
prefix = "Best Practice: "
}
// Use first few signature elements
if len(signature) > 0 {
name := prefix
for i, s := range signature {
if i >= 3 {
break
}
if i > 0 {
name += ", "
}
name += s
}
return name
}
return prefix + "Unnamed"
}
// formatPatternInsight creates a human-readable insight from a pattern.
func formatPatternInsight(pattern *models.Pattern) string {
insight := "I've encountered this pattern " +
itoa(pattern.Frequency) + " times"
if len(pattern.Projects) > 1 {
insight += " across " + itoa(len(pattern.Projects)) + " projects"
}
insight += ". "
if pattern.Recommendation.Valid && pattern.Recommendation.String != "" {
insight += "What works: " + pattern.Recommendation.String
} else {
switch pattern.Type {
case models.PatternTypeBug:
insight += "This appears to be a recurring bug pattern."
case models.PatternTypeAntiPattern:
insight += "This is an identified anti-pattern to avoid."
case models.PatternTypeBestPractice:
insight += "This is a validated best practice."
default:
insight += "This is a recognized pattern in the codebase."
}
}
return insight
}
// itoa converts int to string without importing strconv.
func itoa(n int) string {
if n == 0 {
return "0"
}
negative := false
if n < 0 {
negative = true
n = -n
}
var digits []byte
for n > 0 {
digits = append([]byte{byte('0' + n%10)}, digits...)
n /= 10
}
if negative {
digits = append([]byte{'-'}, digits...)
}
return string(digits)
}
+450
View File
@@ -0,0 +1,450 @@
package pattern
import (
"context"
"database/sql"
"os"
"testing"
"time"
"github.com/lukaszraczylo/claude-mnemonic/internal/db/sqlite"
"github.com/lukaszraczylo/claude-mnemonic/pkg/models"
)
func TestNewDetector(t *testing.T) {
store := setupTestStore(t)
defer store.Close()
patternStore := sqlite.NewPatternStore(store)
observationStore := sqlite.NewObservationStore(store)
config := DefaultConfig()
detector := NewDetector(patternStore, observationStore, config)
if detector == nil {
t.Fatal("Expected non-nil detector")
}
if detector.config.MinMatchScore != config.MinMatchScore {
t.Errorf("Expected MinMatchScore %f, got %f",
config.MinMatchScore, detector.config.MinMatchScore)
}
}
func TestDetector_StartStop(t *testing.T) {
store := setupTestStore(t)
defer store.Close()
patternStore := sqlite.NewPatternStore(store)
observationStore := sqlite.NewObservationStore(store)
config := DefaultConfig()
config.AnalysisInterval = 100 * time.Millisecond // Short interval for testing
detector := NewDetector(patternStore, observationStore, config)
// Start
detector.Start()
// Wait a bit to ensure background goroutine is running
time.Sleep(50 * time.Millisecond)
// Stop
detector.Stop()
// Verify we can stop without hanging
// (if this test hangs, the Stop logic is broken)
}
func TestDetector_AnalyzeObservation_NewCandidate(t *testing.T) {
store := setupTestStore(t)
defer store.Close()
patternStore := sqlite.NewPatternStore(store)
observationStore := sqlite.NewObservationStore(store)
config := DefaultConfig()
config.MinFrequencyForPattern = 2
detector := NewDetector(patternStore, observationStore, config)
ctx := context.Background()
// Create an observation
obs := createTestObservation(1, "nil-handling", []string{"nil", "error-handling"})
// Analyze first observation
result, err := detector.AnalyzeObservation(ctx, obs)
if err != nil {
t.Fatalf("AnalyzeObservation() error = %v", err)
}
// First observation should create a candidate, not a pattern
if result.MatchedPattern != nil {
t.Errorf("Expected no pattern match for first observation")
}
if result.IsNewPattern {
t.Errorf("Expected IsNewPattern to be false for first observation")
}
}
func TestDetector_AnalyzeObservation_PromoteToPattern(t *testing.T) {
store := setupTestStore(t)
defer store.Close()
patternStore := sqlite.NewPatternStore(store)
observationStore := sqlite.NewObservationStore(store)
config := DefaultConfig()
config.MinFrequencyForPattern = 2
detector := NewDetector(patternStore, observationStore, config)
ctx := context.Background()
// Create two similar observations
obs1 := createTestObservation(1, "Nil pointer handling", []string{"nil", "error-handling"})
obs2 := createTestObservation(2, "Nil pointer handling", []string{"nil", "error-handling"})
// Analyze first observation
_, err := detector.AnalyzeObservation(ctx, obs1)
if err != nil {
t.Fatalf("AnalyzeObservation(obs1) error = %v", err)
}
// Analyze second observation - should promote to pattern
result, err := detector.AnalyzeObservation(ctx, obs2)
if err != nil {
t.Fatalf("AnalyzeObservation(obs2) error = %v", err)
}
if result.MatchedPattern == nil {
t.Fatal("Expected pattern to be created after second occurrence")
}
if !result.IsNewPattern {
t.Errorf("Expected IsNewPattern to be true for newly promoted pattern")
}
if result.MatchedPattern.Frequency != 2 {
t.Errorf("Expected frequency 2, got %d", result.MatchedPattern.Frequency)
}
}
func TestDetector_AnalyzeObservation_MatchExisting(t *testing.T) {
store := setupTestStore(t)
defer store.Close()
patternStore := sqlite.NewPatternStore(store)
observationStore := sqlite.NewObservationStore(store)
config := DefaultConfig()
detector := NewDetector(patternStore, observationStore, config)
ctx := context.Background()
// Create an existing pattern
pattern := &models.Pattern{
Name: "Existing Pattern",
Type: models.PatternTypeBug,
Signature: []string{"nil", "error-handling", "pointer"},
Frequency: 5,
Projects: []string{"proj1"},
ObservationIDs: []int64{1, 2, 3, 4, 5},
Status: models.PatternStatusActive,
Confidence: 0.7,
LastSeenAt: time.Now().Format(time.RFC3339),
LastSeenEpoch: time.Now().UnixMilli(),
CreatedAt: time.Now().Format(time.RFC3339),
CreatedAtEpoch: time.Now().UnixMilli(),
}
patternStore.StorePattern(ctx, pattern)
// Create observation with similar signature
obs := createTestObservation(10, "Nil check", []string{"nil", "error-handling"})
// Analyze - should match existing pattern
result, err := detector.AnalyzeObservation(ctx, obs)
if err != nil {
t.Fatalf("AnalyzeObservation() error = %v", err)
}
if result.MatchedPattern == nil {
t.Fatal("Expected to match existing pattern")
}
if result.IsNewPattern {
t.Errorf("Expected IsNewPattern to be false for existing pattern")
}
if result.MatchScore < 0.3 {
t.Errorf("Expected match score >= 0.3, got %f", result.MatchScore)
}
}
func TestDetector_AnalyzeObservation_NoMatch(t *testing.T) {
store := setupTestStore(t)
defer store.Close()
patternStore := sqlite.NewPatternStore(store)
observationStore := sqlite.NewObservationStore(store)
config := DefaultConfig()
config.MinMatchScore = 0.5 // Higher threshold
detector := NewDetector(patternStore, observationStore, config)
ctx := context.Background()
// Create an existing pattern with specific signature
pattern := &models.Pattern{
Name: "Specific Pattern",
Type: models.PatternTypeBug,
Signature: []string{"database", "connection", "pool"},
Frequency: 3,
Projects: []string{"proj1"},
ObservationIDs: []int64{1, 2, 3},
Status: models.PatternStatusActive,
Confidence: 0.6,
LastSeenAt: time.Now().Format(time.RFC3339),
LastSeenEpoch: time.Now().UnixMilli(),
CreatedAt: time.Now().Format(time.RFC3339),
CreatedAtEpoch: time.Now().UnixMilli(),
}
patternStore.StorePattern(ctx, pattern)
// Create observation with completely different signature
obs := createTestObservation(10, "UI Component", []string{"frontend", "react", "component"})
// Analyze - should not match
result, err := detector.AnalyzeObservation(ctx, obs)
if err != nil {
t.Fatalf("AnalyzeObservation() error = %v", err)
}
if result.MatchedPattern != nil {
t.Errorf("Expected no match for unrelated observation")
}
}
func TestDetector_CandidateCleanup(t *testing.T) {
store := setupTestStore(t)
defer store.Close()
patternStore := sqlite.NewPatternStore(store)
observationStore := sqlite.NewObservationStore(store)
config := DefaultConfig()
config.MinFrequencyForPattern = 3 // Higher threshold
detector := NewDetector(patternStore, observationStore, config)
// Add an old candidate manually
oldKey := "old|candidate|"
detector.candidates[oldKey] = &candidatePattern{
signature: []string{"old", "candidate"},
observationIDs: []int64{1},
projects: []string{"proj1"},
patternType: models.PatternTypeBug,
title: "Old Candidate",
lastSeenEpoch: time.Now().Add(-8 * 24 * time.Hour).UnixMilli(), // 8 days ago
}
// Add a recent candidate
recentKey := "recent|candidate|"
detector.candidates[recentKey] = &candidatePattern{
signature: []string{"recent", "candidate"},
observationIDs: []int64{2},
projects: []string{"proj1"},
patternType: models.PatternTypeBug,
title: "Recent Candidate",
lastSeenEpoch: time.Now().UnixMilli(),
}
// Run cleanup
detector.cleanupOldCandidates()
// Old candidate should be removed
if _, exists := detector.candidates[oldKey]; exists {
t.Errorf("Expected old candidate to be cleaned up")
}
// Recent candidate should remain
if _, exists := detector.candidates[recentKey]; !exists {
t.Errorf("Expected recent candidate to remain")
}
}
func TestDetector_GetPatternInsight(t *testing.T) {
store := setupTestStore(t)
defer store.Close()
patternStore := sqlite.NewPatternStore(store)
observationStore := sqlite.NewObservationStore(store)
config := DefaultConfig()
detector := NewDetector(patternStore, observationStore, config)
ctx := context.Background()
// Create pattern with recommendation
pattern := &models.Pattern{
Name: "Insight Test Pattern",
Type: models.PatternTypeBestPractice,
Signature: []string{"test"},
Recommendation: sql.NullString{String: "Always write tests first", Valid: true},
Frequency: 12,
Projects: []string{"proj1", "proj2", "proj3"},
ObservationIDs: []int64{1},
Status: models.PatternStatusActive,
Confidence: 0.8,
LastSeenAt: time.Now().Format(time.RFC3339),
LastSeenEpoch: time.Now().UnixMilli(),
CreatedAt: time.Now().Format(time.RFC3339),
CreatedAtEpoch: time.Now().UnixMilli(),
}
id, _ := patternStore.StorePattern(ctx, pattern)
// Get insight
insight, err := detector.GetPatternInsight(ctx, id)
if err != nil {
t.Fatalf("GetPatternInsight() error = %v", err)
}
// Verify insight contains expected elements
if insight == "" {
t.Error("Expected non-empty insight")
}
if !containsString(insight, "12 times") {
t.Errorf("Expected insight to contain frequency, got: %s", insight)
}
if !containsString(insight, "3 projects") {
t.Errorf("Expected insight to contain project count, got: %s", insight)
}
if !containsString(insight, "Always write tests first") {
t.Errorf("Expected insight to contain recommendation, got: %s", insight)
}
}
func TestDefaultConfig(t *testing.T) {
config := DefaultConfig()
if config.MinMatchScore <= 0 || config.MinMatchScore > 1 {
t.Errorf("Invalid MinMatchScore: %f", config.MinMatchScore)
}
if config.MinFrequencyForPattern < 1 {
t.Errorf("Invalid MinFrequencyForPattern: %d", config.MinFrequencyForPattern)
}
if config.AnalysisInterval <= 0 {
t.Errorf("Invalid AnalysisInterval: %v", config.AnalysisInterval)
}
if config.MaxPatternsToTrack <= 0 {
t.Errorf("Invalid MaxPatternsToTrack: %d", config.MaxPatternsToTrack)
}
}
func TestGeneratePatternName(t *testing.T) {
tests := []struct {
patternType models.PatternType
signature []string
title string
wantPrefix string
}{
{models.PatternTypeBug, []string{"nil", "error"}, "", "Bug Pattern:"},
{models.PatternTypeRefactor, []string{"extract"}, "", "Refactor Pattern:"},
{models.PatternTypeArchitecture, []string{"service"}, "", "Architecture Pattern:"},
{models.PatternTypeAntiPattern, []string{"god-class"}, "", "Anti-Pattern:"},
{models.PatternTypeBestPractice, []string{"testing"}, "", "Best Practice:"},
{models.PatternTypeBug, []string{}, "Short Title", "Short Title"}, // Use title directly
}
for _, tt := range tests {
name := generatePatternName(tt.patternType, tt.signature, tt.title)
if !hasPrefix(name, tt.wantPrefix) {
t.Errorf("generatePatternName(%v, %v, %q) = %q, want prefix %q",
tt.patternType, tt.signature, tt.title, name, tt.wantPrefix)
}
}
}
func TestFormatPatternInsight(t *testing.T) {
// Pattern without recommendation
pattern1 := &models.Pattern{
Type: models.PatternTypeBug,
Frequency: 5,
Projects: []string{"proj1"},
}
insight1 := formatPatternInsight(pattern1)
if !containsString(insight1, "5 times") {
t.Errorf("Expected insight to contain frequency")
}
if !containsString(insight1, "recurring bug pattern") {
t.Errorf("Expected bug pattern description")
}
// Pattern with recommendation
pattern2 := &models.Pattern{
Type: models.PatternTypeBestPractice,
Frequency: 10,
Projects: []string{"proj1", "proj2"},
Recommendation: sql.NullString{String: "Do this", Valid: true},
}
insight2 := formatPatternInsight(pattern2)
if !containsString(insight2, "10 times") {
t.Errorf("Expected insight to contain frequency")
}
if !containsString(insight2, "2 projects") {
t.Errorf("Expected insight to contain project count")
}
if !containsString(insight2, "Do this") {
t.Errorf("Expected insight to contain recommendation")
}
}
// Helper functions
func setupTestStore(t *testing.T) *sqlite.Store {
t.Helper()
// Create temp database file
tmpFile, err := os.CreateTemp("", "pattern_test_*.db")
if err != nil {
t.Fatalf("Failed to create temp file: %v", err)
}
tmpFile.Close()
t.Cleanup(func() {
os.Remove(tmpFile.Name())
})
store, err := sqlite.NewStore(sqlite.StoreConfig{
Path: tmpFile.Name(),
MaxConns: 1,
WALMode: true,
})
if err != nil {
// Check if this is an FTS5 related error
if containsString(err.Error(), "fts5") || containsString(err.Error(), "no such module") {
t.Skip("FTS5 not available in this SQLite build")
}
t.Fatalf("Failed to create store: %v", err)
}
return store
}
func createTestObservation(id int64, title string, concepts []string) *models.Observation {
return &models.Observation{
ID: id,
SDKSessionID: "test-session",
Project: "test-project",
Scope: models.ScopeProject,
Type: models.ObsTypeBugfix,
Title: sql.NullString{String: title, Valid: true},
Narrative: sql.NullString{String: "Test narrative", Valid: true},
Concepts: concepts,
CreatedAt: time.Now().Format(time.RFC3339),
CreatedAtEpoch: time.Now().UnixMilli(),
}
}
func containsString(s, substr string) bool {
for i := 0; i <= len(s)-len(substr); i++ {
if s[i:i+len(substr)] == substr {
return true
}
}
return false
}
func hasPrefix(s, prefix string) bool {
if len(s) < len(prefix) {
return false
}
return s[:len(prefix)] == prefix
}
+14
View File
@@ -0,0 +1,14 @@
// Package reranking provides cross-encoder reranking for search results.
package reranking
import (
_ "embed"
)
// Cross-encoder model and tokenizer files - embedded for all platforms
//
//go:embed assets/model.onnx
var crossEncoderModelData []byte
//go:embed assets/tokenizer.json
var crossEncoderTokenizerData []byte
Binary file not shown.
File diff suppressed because it is too large Load Diff
@@ -0,0 +1,57 @@
{
"added_tokens_decoder": {
"0": {
"content": "[PAD]",
"lstrip": false,
"normalized": false,
"rstrip": false,
"single_word": false,
"special": true
},
"100": {
"content": "[UNK]",
"lstrip": false,
"normalized": false,
"rstrip": false,
"single_word": false,
"special": true
},
"101": {
"content": "[CLS]",
"lstrip": false,
"normalized": false,
"rstrip": false,
"single_word": false,
"special": true
},
"102": {
"content": "[SEP]",
"lstrip": false,
"normalized": false,
"rstrip": false,
"single_word": false,
"special": true
},
"103": {
"content": "[MASK]",
"lstrip": false,
"normalized": false,
"rstrip": false,
"single_word": false,
"special": true
}
},
"clean_up_tokenization_spaces": true,
"cls_token": "[CLS]",
"do_basic_tokenize": true,
"do_lower_case": true,
"mask_token": "[MASK]",
"model_max_length": 512,
"never_split": null,
"pad_token": "[PAD]",
"sep_token": "[SEP]",
"strip_accents": null,
"tokenize_chinese_chars": true,
"tokenizer_class": "BertTokenizer",
"unk_token": "[UNK]"
}
+380
View File
@@ -0,0 +1,380 @@
// Package reranking provides cross-encoder reranking for search results.
// Uses MS-MARCO MiniLM L6 v2 cross-encoder model for relevance scoring.
package reranking
import (
"bytes"
"fmt"
"math"
"sort"
"sync"
"github.com/rs/zerolog/log"
"github.com/sugarme/tokenizer"
"github.com/sugarme/tokenizer/pretrained"
ort "github.com/yalue/onnxruntime_go"
)
const (
// ModelName is the human-readable name for the cross-encoder model
ModelName = "ms-marco-MiniLM-L6-v2"
// ModelVersion is the short version string for identification
ModelVersion = "msmarco-v2"
// MaxSequenceLength is the maximum combined query+document token length
MaxSequenceLength = 512
// DefaultCandidateLimit is the default number of candidates to rerank
DefaultCandidateLimit = 100
// DefaultResultLimit is the default number of results to return after reranking
DefaultResultLimit = 10
)
// Candidate represents a search result candidate for reranking.
type Candidate struct {
ID string // Document ID
Content string // Document text content for scoring
Score float64 // Original bi-encoder similarity score
Metadata map[string]any // Preserved metadata
RerankInfo map[string]float64 // Reranking debug info (optional)
}
// RerankResult represents a reranked search result.
type RerankResult struct {
ID string // Document ID
Content string // Document text content
OriginalScore float64 // Original bi-encoder score
RerankScore float64 // Cross-encoder relevance score
CombinedScore float64 // Weighted combination of scores
Metadata map[string]any // Preserved metadata
OriginalRank int // Position before reranking (1-indexed)
RerankRank int // Position after reranking (1-indexed)
RankImprovement int // How much the rank improved (positive = moved up)
}
// Service provides cross-encoder reranking functionality.
type Service struct {
tk *tokenizer.Tokenizer
session *ort.DynamicAdvancedSession
mu sync.Mutex
// Weight for combining scores: combined = alpha*rerank + (1-alpha)*original
// Default 0.7 favors cross-encoder score
Alpha float64
}
// Config holds configuration for the reranking service.
type Config struct {
// Alpha is the weight for combining scores (0.0-1.0)
// Higher values favor cross-encoder scores, lower values favor bi-encoder scores
Alpha float64
}
// DefaultConfig returns sensible defaults for reranking.
func DefaultConfig() Config {
return Config{
Alpha: 0.7, // Favor cross-encoder by default
}
}
// NewService creates a new cross-encoder reranking service.
// Note: ONNX runtime must be initialized before calling this (via embedding.NewService).
func NewService(cfg Config) (*Service, error) {
// Load tokenizer from embedded data
tk, err := pretrained.FromReader(bytes.NewReader(crossEncoderTokenizerData))
if err != nil {
return nil, fmt.Errorf("load cross-encoder tokenizer: %w", err)
}
// Configure tokenizer for sequence classification (pairs)
tk.WithTruncation(&tokenizer.TruncationParams{
MaxLength: MaxSequenceLength,
Strategy: tokenizer.LongestFirst,
Stride: 0,
})
// Cross-encoder outputs a single logit for relevance scoring
inputNames := []string{"input_ids", "attention_mask", "token_type_ids"}
outputNames := []string{"logits"}
session, err := ort.NewDynamicAdvancedSessionWithONNXData(
crossEncoderModelData,
inputNames,
outputNames,
nil,
)
if err != nil {
return nil, fmt.Errorf("create cross-encoder ONNX session: %w", err)
}
alpha := cfg.Alpha
if alpha <= 0 || alpha > 1 {
alpha = 0.7
}
return &Service{
tk: tk,
session: session,
Alpha: alpha,
}, nil
}
// Rerank reranks candidates using the cross-encoder model.
// Takes a query and list of candidates, returns reranked results.
func (s *Service) Rerank(query string, candidates []Candidate, limit int) ([]RerankResult, error) {
if len(candidates) == 0 {
return nil, nil
}
if limit <= 0 {
limit = DefaultResultLimit
}
s.mu.Lock()
defer s.mu.Unlock()
// Score all query-document pairs
scores, err := s.scoreAll(query, candidates)
if err != nil {
return nil, fmt.Errorf("score candidates: %w", err)
}
// Build results with combined scores
results := make([]RerankResult, len(candidates))
for i, c := range candidates {
// Normalize cross-encoder score to 0-1 range using sigmoid
normalizedRerank := sigmoid(scores[i])
results[i] = RerankResult{
ID: c.ID,
Content: c.Content,
OriginalScore: c.Score,
RerankScore: normalizedRerank,
CombinedScore: s.Alpha*normalizedRerank + (1-s.Alpha)*c.Score,
Metadata: c.Metadata,
OriginalRank: i + 1,
}
}
// Sort by combined score (descending)
sort.Slice(results, func(i, j int) bool {
return results[i].CombinedScore > results[j].CombinedScore
})
// Assign rerank positions and calculate improvement
for i := range results {
results[i].RerankRank = i + 1
results[i].RankImprovement = results[i].OriginalRank - results[i].RerankRank
}
// Apply limit
if len(results) > limit {
results = results[:limit]
}
log.Debug().
Int("candidates", len(candidates)).
Int("returned", len(results)).
Float64("alpha", s.Alpha).
Msg("Cross-encoder reranking completed")
return results, nil
}
// RerankByScore reranks candidates and returns sorted by pure cross-encoder score.
// Useful when you want to completely replace bi-encoder ranking.
func (s *Service) RerankByScore(query string, candidates []Candidate, limit int) ([]RerankResult, error) {
if len(candidates) == 0 {
return nil, nil
}
if limit <= 0 {
limit = DefaultResultLimit
}
s.mu.Lock()
defer s.mu.Unlock()
scores, err := s.scoreAll(query, candidates)
if err != nil {
return nil, fmt.Errorf("score candidates: %w", err)
}
results := make([]RerankResult, len(candidates))
for i, c := range candidates {
normalizedRerank := sigmoid(scores[i])
results[i] = RerankResult{
ID: c.ID,
Content: c.Content,
OriginalScore: c.Score,
RerankScore: normalizedRerank,
CombinedScore: normalizedRerank, // Use pure rerank score
Metadata: c.Metadata,
OriginalRank: i + 1,
}
}
// Sort by rerank score only
sort.Slice(results, func(i, j int) bool {
return results[i].RerankScore > results[j].RerankScore
})
for i := range results {
results[i].RerankRank = i + 1
results[i].RankImprovement = results[i].OriginalRank - results[i].RerankRank
}
if len(results) > limit {
results = results[:limit]
}
return results, nil
}
// scoreAll scores all query-document pairs using the cross-encoder.
// Returns raw logits (before sigmoid normalization).
func (s *Service) scoreAll(query string, candidates []Candidate) ([]float64, error) {
batchSize := len(candidates)
// Tokenize all query-document pairs
pairs := make([]tokenizer.EncodeInput, batchSize)
for i, c := range candidates {
// Cross-encoder takes query and document as a pair
pairs[i] = tokenizer.NewDualEncodeInput(
tokenizer.NewRawInputSequence(query),
tokenizer.NewRawInputSequence(c.Content),
)
}
encodings, err := s.tk.EncodeBatch(pairs, true)
if err != nil {
return nil, fmt.Errorf("tokenize pairs: %w", err)
}
// Find max sequence length
seqLength := 0
for _, enc := range encodings {
if len(enc.Ids) > seqLength {
seqLength = len(enc.Ids)
}
}
if seqLength > MaxSequenceLength {
seqLength = MaxSequenceLength
}
inputShape := ort.NewShape(int64(batchSize), int64(seqLength))
// Create input tensors
inputIdsData := make([]int64, batchSize*seqLength)
attentionMaskData := make([]int64, batchSize*seqLength)
tokenTypeIdsData := make([]int64, batchSize*seqLength)
for b := 0; b < batchSize; b++ {
copyLen := len(encodings[b].Ids)
if copyLen > seqLength {
copyLen = seqLength
}
for i := 0; i < copyLen; i++ {
inputIdsData[b*seqLength+i] = int64(encodings[b].Ids[i])
}
copyLen = len(encodings[b].AttentionMask)
if copyLen > seqLength {
copyLen = seqLength
}
for i := 0; i < copyLen; i++ {
attentionMaskData[b*seqLength+i] = int64(encodings[b].AttentionMask[i])
}
copyLen = len(encodings[b].TypeIds)
if copyLen > seqLength {
copyLen = seqLength
}
for i := 0; i < copyLen; i++ {
tokenTypeIdsData[b*seqLength+i] = int64(encodings[b].TypeIds[i])
}
}
inputIdsTensor, err := ort.NewTensor(inputShape, inputIdsData)
if err != nil {
return nil, fmt.Errorf("create input_ids tensor: %w", err)
}
defer inputIdsTensor.Destroy()
attentionMaskTensor, err := ort.NewTensor(inputShape, attentionMaskData)
if err != nil {
return nil, fmt.Errorf("create attention_mask tensor: %w", err)
}
defer attentionMaskTensor.Destroy()
tokenTypeIdsTensor, err := ort.NewTensor(inputShape, tokenTypeIdsData)
if err != nil {
return nil, fmt.Errorf("create token_type_ids tensor: %w", err)
}
defer tokenTypeIdsTensor.Destroy()
// Cross-encoder outputs [batch, 1] logits
outputShape := ort.NewShape(int64(batchSize), 1)
outputTensor, err := ort.NewEmptyTensor[float32](outputShape)
if err != nil {
return nil, fmt.Errorf("create output tensor: %w", err)
}
defer outputTensor.Destroy()
// Run inference
inputTensors := []ort.Value{inputIdsTensor, attentionMaskTensor, tokenTypeIdsTensor}
outputTensors := []ort.Value{outputTensor}
if err := s.session.Run(inputTensors, outputTensors); err != nil {
return nil, fmt.Errorf("run cross-encoder inference: %w", err)
}
// Extract scores
flatOutput := outputTensor.GetData()
scores := make([]float64, batchSize)
for i := 0; i < batchSize; i++ {
scores[i] = float64(flatOutput[i])
}
return scores, nil
}
// Score scores a single query-document pair.
// Returns the raw cross-encoder logit and normalized score.
func (s *Service) Score(query, document string) (rawScore, normalizedScore float64, err error) {
s.mu.Lock()
defer s.mu.Unlock()
scores, err := s.scoreAll(query, []Candidate{{Content: document}})
if err != nil {
return 0, 0, err
}
rawScore = scores[0]
normalizedScore = sigmoid(rawScore)
return rawScore, normalizedScore, nil
}
// Close releases model resources.
func (s *Service) Close() error {
s.mu.Lock()
defer s.mu.Unlock()
if s.session != nil {
if err := s.session.Destroy(); err != nil {
return fmt.Errorf("destroy cross-encoder session: %w", err)
}
s.session = nil
}
return nil
}
// sigmoid applies the sigmoid function to normalize scores to 0-1 range.
func sigmoid(x float64) float64 {
if x > 20 {
return 1.0
}
if x < -20 {
return 0.0
}
return 1.0 / (1.0 + math.Exp(-x))
}
+448
View File
@@ -0,0 +1,448 @@
package reranking
import (
"sync"
"testing"
"github.com/lukaszraczylo/claude-mnemonic/internal/embedding"
)
// initONNX initializes the ONNX runtime via the embedding service.
// Must be called before creating reranking service.
func initONNX(t *testing.T) func() {
t.Helper()
embSvc, err := embedding.NewService()
if err != nil {
t.Fatalf("Failed to initialize ONNX via embedding service: %v", err)
}
return func() {
embSvc.Close()
}
}
// TestSigmoid tests the sigmoid normalization function.
func TestSigmoid(t *testing.T) {
tests := []struct {
name string
input float64
wantMin float64
wantMax float64
}{
{"positive large", 10, 0.9999, 1.0},
{"positive small", 1, 0.7, 0.8},
{"zero", 0, 0.4999, 0.5001},
{"negative small", -1, 0.2, 0.3},
{"negative large", -10, 0, 0.0001},
{"very positive", 25, 0.999999, 1.0},
{"very negative", -25, 0, 0.000001},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
got := sigmoid(tt.input)
if got < tt.wantMin || got > tt.wantMax {
t.Errorf("sigmoid(%v) = %v, want in range [%v, %v]",
tt.input, got, tt.wantMin, tt.wantMax)
}
})
}
}
// TestDefaultConfig tests the default configuration values.
func TestDefaultConfig(t *testing.T) {
cfg := DefaultConfig()
if cfg.Alpha < 0 || cfg.Alpha > 1 {
t.Errorf("DefaultConfig().Alpha = %v, want in range [0, 1]", cfg.Alpha)
}
if cfg.Alpha != 0.7 {
t.Errorf("DefaultConfig().Alpha = %v, want 0.7", cfg.Alpha)
}
}
// TestNewService tests service creation.
func TestNewService(t *testing.T) {
cleanup := initONNX(t)
defer cleanup()
cfg := DefaultConfig()
svc, err := NewService(cfg)
if err != nil {
t.Fatalf("NewService() error = %v", err)
}
defer svc.Close()
if svc == nil {
t.Fatal("NewService() returned nil")
}
if svc.Alpha != cfg.Alpha {
t.Errorf("Service.Alpha = %v, want %v", svc.Alpha, cfg.Alpha)
}
}
// TestScore tests single pair scoring.
func TestScore(t *testing.T) {
cleanup := initONNX(t)
defer cleanup()
cfg := DefaultConfig()
svc, err := NewService(cfg)
if err != nil {
t.Fatalf("NewService() error = %v", err)
}
defer svc.Close()
query := "What is the capital of France?"
relevant := "Paris is the capital and largest city of France."
irrelevant := "Dogs are popular pets known for their loyalty."
// Score relevant document
_, relevantNorm, err := svc.Score(query, relevant)
if err != nil {
t.Fatalf("Score(relevant) error = %v", err)
}
// Score irrelevant document
_, irrelevantNorm, err := svc.Score(query, irrelevant)
if err != nil {
t.Fatalf("Score(irrelevant) error = %v", err)
}
// Relevant document should score higher
if relevantNorm <= irrelevantNorm {
t.Errorf("Expected relevant (%v) > irrelevant (%v)",
relevantNorm, irrelevantNorm)
}
}
// TestRerank tests the reranking functionality.
func TestRerank(t *testing.T) {
cleanup := initONNX(t)
defer cleanup()
cfg := DefaultConfig()
svc, err := NewService(cfg)
if err != nil {
t.Fatalf("NewService() error = %v", err)
}
defer svc.Close()
query := "How to handle errors in Go?"
candidates := []Candidate{
{
ID: "1",
Content: "Python exception handling with try/except blocks.",
Score: 0.8, // High bi-encoder score but irrelevant
},
{
ID: "2",
Content: "Go error handling uses explicit return values. Functions return error as the last value.",
Score: 0.6, // Lower bi-encoder score but relevant
},
{
ID: "3",
Content: "JavaScript uses Promise.catch for async error handling.",
Score: 0.7,
},
}
results, err := svc.Rerank(query, candidates, 3)
if err != nil {
t.Fatalf("Rerank() error = %v", err)
}
if len(results) != 3 {
t.Fatalf("Rerank() returned %d results, want 3", len(results))
}
// The Go error handling document should rank higher after reranking
var goRank int
for i, r := range results {
if r.ID == "2" {
goRank = i + 1
break
}
}
if goRank == 0 {
t.Error("Go document not found in results")
}
// Verify all results have required fields populated
for i, r := range results {
if r.ID == "" {
t.Errorf("Result %d has empty ID", i)
}
if r.Content == "" {
t.Errorf("Result %d has empty Content", i)
}
if r.RerankRank != i+1 {
t.Errorf("Result %d has RerankRank %d, want %d", i, r.RerankRank, i+1)
}
}
}
// TestRerankEmpty tests reranking with empty input.
func TestRerankEmpty(t *testing.T) {
cleanup := initONNX(t)
defer cleanup()
cfg := DefaultConfig()
svc, err := NewService(cfg)
if err != nil {
t.Fatalf("NewService() error = %v", err)
}
defer svc.Close()
results, err := svc.Rerank("test query", nil, 10)
if err != nil {
t.Fatalf("Rerank(nil) error = %v", err)
}
if results != nil {
t.Errorf("Rerank(nil) = %v, want nil", results)
}
results, err = svc.Rerank("test query", []Candidate{}, 10)
if err != nil {
t.Fatalf("Rerank([]) error = %v", err)
}
if results != nil {
t.Errorf("Rerank([]) = %v, want nil", results)
}
}
// TestRerankLimit tests that limit is respected.
func TestRerankLimit(t *testing.T) {
cleanup := initONNX(t)
defer cleanup()
cfg := DefaultConfig()
svc, err := NewService(cfg)
if err != nil {
t.Fatalf("NewService() error = %v", err)
}
defer svc.Close()
candidates := make([]Candidate, 20)
for i := range candidates {
candidates[i] = Candidate{
ID: string(rune('A' + i)),
Content: "Test document content for ranking.",
Score: 0.5,
}
}
results, err := svc.Rerank("test query", candidates, 5)
if err != nil {
t.Fatalf("Rerank() error = %v", err)
}
if len(results) != 5 {
t.Errorf("Rerank() returned %d results, want 5", len(results))
}
}
// TestRerankByScore tests pure cross-encoder ranking.
func TestRerankByScore(t *testing.T) {
cleanup := initONNX(t)
defer cleanup()
cfg := DefaultConfig()
svc, err := NewService(cfg)
if err != nil {
t.Fatalf("NewService() error = %v", err)
}
defer svc.Close()
query := "machine learning algorithms"
candidates := []Candidate{
{
ID: "1",
Content: "Cooking recipes for Italian pasta dishes.",
Score: 0.9, // High original score
},
{
ID: "2",
Content: "Neural networks are a type of machine learning algorithm.",
Score: 0.3, // Low original score
},
}
results, err := svc.RerankByScore(query, candidates, 2)
if err != nil {
t.Fatalf("RerankByScore() error = %v", err)
}
// Document 2 should rank first since it's about ML
if results[0].ID != "2" {
t.Errorf("Expected ML document to rank first, got %v", results[0].ID)
}
// CombinedScore should equal RerankScore when using RerankByScore
for _, r := range results {
if r.CombinedScore != r.RerankScore {
t.Errorf("RerankByScore: CombinedScore (%v) != RerankScore (%v)",
r.CombinedScore, r.RerankScore)
}
}
}
// TestRankImprovement tests that rank improvement is calculated correctly.
func TestRankImprovement(t *testing.T) {
cleanup := initONNX(t)
defer cleanup()
cfg := DefaultConfig()
svc, err := NewService(cfg)
if err != nil {
t.Fatalf("NewService() error = %v", err)
}
defer svc.Close()
// Create candidates where we know the expected reranking
candidates := []Candidate{
{ID: "A", Content: "Unrelated content about weather forecasting.", Score: 0.9},
{ID: "B", Content: "How to fix memory leaks in Go programs.", Score: 0.8},
{ID: "C", Content: "More unrelated content about gardening tips.", Score: 0.7},
}
results, err := svc.Rerank("debugging memory issues in Go", candidates, 3)
if err != nil {
t.Fatalf("Rerank() error = %v", err)
}
for _, r := range results {
// RankImprovement = OriginalRank - RerankRank
// Positive means moved up, negative means moved down
expectedImprovement := r.OriginalRank - r.RerankRank
if r.RankImprovement != expectedImprovement {
t.Errorf("ID %s: RankImprovement = %d, want %d (orig=%d, new=%d)",
r.ID, r.RankImprovement, expectedImprovement,
r.OriginalRank, r.RerankRank)
}
}
}
// TestConcurrentRerank tests concurrent reranking calls.
func TestConcurrentRerank(t *testing.T) {
cleanup := initONNX(t)
defer cleanup()
cfg := DefaultConfig()
svc, err := NewService(cfg)
if err != nil {
t.Fatalf("NewService() error = %v", err)
}
defer svc.Close()
candidates := []Candidate{
{ID: "1", Content: "Test document one.", Score: 0.5},
{ID: "2", Content: "Test document two.", Score: 0.5},
}
var wg sync.WaitGroup
errors := make(chan error, 10)
for i := 0; i < 10; i++ {
wg.Add(1)
go func(idx int) {
defer wg.Done()
_, err := svc.Rerank("concurrent test query", candidates, 2)
if err != nil {
errors <- err
}
}(i)
}
wg.Wait()
close(errors)
for err := range errors {
t.Errorf("Concurrent Rerank error: %v", err)
}
}
// TestClose tests service cleanup.
func TestClose(t *testing.T) {
cleanup := initONNX(t)
defer cleanup()
cfg := DefaultConfig()
svc, err := NewService(cfg)
if err != nil {
t.Fatalf("NewService() error = %v", err)
}
err = svc.Close()
if err != nil {
t.Errorf("Close() error = %v", err)
}
// Double close should not panic
err = svc.Close()
if err != nil {
t.Errorf("Close() on closed service error = %v", err)
}
}
// TestMetadataPreserved tests that metadata is preserved through reranking.
func TestMetadataPreserved(t *testing.T) {
cleanup := initONNX(t)
defer cleanup()
cfg := DefaultConfig()
svc, err := NewService(cfg)
if err != nil {
t.Fatalf("NewService() error = %v", err)
}
defer svc.Close()
candidates := []Candidate{
{
ID: "1",
Content: "Test content.",
Score: 0.5,
Metadata: map[string]any{"custom": "value1", "num": 42},
},
{
ID: "2",
Content: "Another test.",
Score: 0.5,
Metadata: map[string]any{"custom": "value2"},
},
}
results, err := svc.Rerank("query", candidates, 2)
if err != nil {
t.Fatalf("Rerank() error = %v", err)
}
for _, r := range results {
if r.Metadata == nil {
t.Errorf("Result %s has nil metadata", r.ID)
continue
}
// Find original candidate
var original *Candidate
for i := range candidates {
if candidates[i].ID == r.ID {
original = &candidates[i]
break
}
}
if original == nil {
t.Errorf("Could not find original for result %s", r.ID)
continue
}
// Check metadata preserved
if original.Metadata["custom"] != r.Metadata["custom"] {
t.Errorf("Metadata not preserved for %s", r.ID)
}
}
}
+168
View File
@@ -0,0 +1,168 @@
// Package scoring provides importance score calculation for observations.
package scoring
import (
"math"
"time"
"github.com/lukaszraczylo/claude-mnemonic/pkg/models"
)
// Calculator computes importance scores for observations.
type Calculator struct {
config *models.ScoringConfig
}
// NewCalculator creates a new scoring calculator.
// If config is nil, uses the default configuration.
func NewCalculator(config *models.ScoringConfig) *Calculator {
if config == nil {
config = models.DefaultScoringConfig()
}
return &Calculator{config: config}
}
// Calculate computes the importance score for an observation at the given time.
//
// The scoring formula:
//
// FinalScore = (BaseScore × TypeWeight × RecencyDecay) + FeedbackContrib + ConceptContrib + RetrievalContrib
//
// Where:
// - BaseScore = 1.0
// - TypeWeight = observation type multiplier (e.g., bugfix=1.3, change=0.9)
// - RecencyDecay = 0.5^(age_days / half_life_days) - halves every 7 days by default
// - FeedbackContrib = user_feedback × feedback_weight
// - ConceptContrib = sum(concept_weights) × concept_weight_factor
// - RetrievalContrib = log2(retrieval_count + 1) × 0.1 × retrieval_weight
func (c *Calculator) Calculate(obs *models.Observation, now time.Time) float64 {
// 1. Get base type weight
typeWeight := models.TypeBaseScore(obs.Type)
// 2. Calculate recency decay: 0.5^(age_days / half_life_days)
ageDays := now.Sub(time.UnixMilli(obs.CreatedAtEpoch)).Hours() / 24.0
if ageDays < 0 {
ageDays = 0 // Handle future timestamps gracefully
}
recencyDecay := math.Pow(0.5, ageDays/c.config.RecencyHalfLifeDays)
// Core score = 1.0 × type_weight × recency_decay
coreScore := 1.0 * typeWeight * recencyDecay
// 3. User feedback contribution: feedback × weight
feedbackContrib := float64(obs.UserFeedback) * c.config.FeedbackWeight
// 4. Concept boost contribution: sum of matching concept weights × factor
conceptBoost := 0.0
for _, concept := range obs.Concepts {
if weight, ok := c.config.ConceptWeights[concept]; ok {
conceptBoost += weight
}
}
conceptContrib := conceptBoost * c.config.ConceptWeight
// 5. Retrieval boost: log2(count + 1) × 0.1 × weight (diminishing returns)
retrievalContrib := 0.0
if obs.RetrievalCount > 0 {
// log2(count + 1) gives diminishing returns: 1→1, 3→2, 7→3, 15→4, etc.
retrievalBoost := math.Log2(float64(obs.RetrievalCount)+1) * 0.1
retrievalContrib = retrievalBoost * c.config.RetrievalWeight
}
// Final score with minimum threshold
finalScore := coreScore + feedbackContrib + conceptContrib + retrievalContrib
if finalScore < c.config.MinScore {
finalScore = c.config.MinScore
}
return finalScore
}
// CalculateComponents returns the individual components of the importance score.
// Useful for debugging and explaining scores to users.
func (c *Calculator) CalculateComponents(obs *models.Observation, now time.Time) ScoreComponents {
typeWeight := models.TypeBaseScore(obs.Type)
ageDays := now.Sub(time.UnixMilli(obs.CreatedAtEpoch)).Hours() / 24.0
if ageDays < 0 {
ageDays = 0
}
recencyDecay := math.Pow(0.5, ageDays/c.config.RecencyHalfLifeDays)
coreScore := 1.0 * typeWeight * recencyDecay
feedbackContrib := float64(obs.UserFeedback) * c.config.FeedbackWeight
conceptBoost := 0.0
for _, concept := range obs.Concepts {
if weight, ok := c.config.ConceptWeights[concept]; ok {
conceptBoost += weight
}
}
conceptContrib := conceptBoost * c.config.ConceptWeight
retrievalContrib := 0.0
if obs.RetrievalCount > 0 {
retrievalBoost := math.Log2(float64(obs.RetrievalCount)+1) * 0.1
retrievalContrib = retrievalBoost * c.config.RetrievalWeight
}
finalScore := coreScore + feedbackContrib + conceptContrib + retrievalContrib
if finalScore < c.config.MinScore {
finalScore = c.config.MinScore
}
return ScoreComponents{
TypeWeight: typeWeight,
RecencyDecay: recencyDecay,
CoreScore: coreScore,
FeedbackContrib: feedbackContrib,
ConceptContrib: conceptContrib,
RetrievalContrib: retrievalContrib,
FinalScore: finalScore,
AgeDays: ageDays,
}
}
// ScoreComponents contains the breakdown of an importance score calculation.
type ScoreComponents struct {
TypeWeight float64 `json:"type_weight"`
RecencyDecay float64 `json:"recency_decay"`
CoreScore float64 `json:"core_score"`
FeedbackContrib float64 `json:"feedback_contrib"`
ConceptContrib float64 `json:"concept_contrib"`
RetrievalContrib float64 `json:"retrieval_contrib"`
FinalScore float64 `json:"final_score"`
AgeDays float64 `json:"age_days"`
}
// BatchCalculate computes scores for multiple observations.
// Returns a map of observation ID to calculated score.
func (c *Calculator) BatchCalculate(observations []*models.Observation, now time.Time) map[int64]float64 {
scores := make(map[int64]float64, len(observations))
for _, obs := range observations {
scores[obs.ID] = c.Calculate(obs, now)
}
return scores
}
// RecalculateThreshold returns the minimum duration before an observation
// should have its score recalculated. This prevents excessive recalculation
// while ensuring scores stay reasonably fresh.
func (c *Calculator) RecalculateThreshold() time.Duration {
// Recalculate at most every 6 hours
// This balances freshness with performance
return 6 * time.Hour
}
// UpdateConfig updates the calculator's scoring configuration.
// This allows runtime tuning of scoring parameters.
func (c *Calculator) UpdateConfig(config *models.ScoringConfig) {
if config != nil {
c.config = config
}
}
// GetConfig returns the current scoring configuration.
func (c *Calculator) GetConfig() *models.ScoringConfig {
return c.config
}
+638
View File
@@ -0,0 +1,638 @@
// Package scoring provides importance score calculation for observations.
package scoring
import (
"testing"
"time"
"github.com/lukaszraczylo/claude-mnemonic/pkg/models"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"github.com/stretchr/testify/suite"
)
// CalculatorSuite is a test suite for the Calculator.
type CalculatorSuite struct {
suite.Suite
calc *Calculator
config *models.ScoringConfig
now time.Time
}
func (s *CalculatorSuite) SetupTest() {
s.config = models.DefaultScoringConfig()
s.calc = NewCalculator(s.config)
s.now = time.Date(2025, 1, 15, 12, 0, 0, 0, time.UTC)
}
func TestCalculatorSuite(t *testing.T) {
suite.Run(t, new(CalculatorSuite))
}
// =============================================================================
// GOOD SCENARIOS - Expected normal operations
// =============================================================================
func (s *CalculatorSuite) TestCalculate_GoodScenarios_NewObservation() {
// A brand new observation should have score close to type weight
obs := &models.Observation{
ID: 1,
Type: models.ObsTypeBugfix,
CreatedAtEpoch: s.now.UnixMilli(),
}
score := s.calc.Calculate(obs, s.now)
// Expected: 1.0 × 1.3 (bugfix weight) × 1.0 (no decay) = 1.3
s.InDelta(1.3, score, 0.01, "new bugfix should score ~1.3")
}
func (s *CalculatorSuite) TestCalculate_GoodScenarios_OneWeekOld() {
// One week old observation should have half the recency score
obs := &models.Observation{
ID: 1,
Type: models.ObsTypeDiscovery,
CreatedAtEpoch: s.now.Add(-7 * 24 * time.Hour).UnixMilli(),
}
score := s.calc.Calculate(obs, s.now)
// Expected: 1.0 × 1.1 (discovery) × 0.5 (7 days half-life) = 0.55
s.InDelta(0.55, score, 0.05, "7-day old discovery should score ~0.55")
}
func (s *CalculatorSuite) TestCalculate_GoodScenarios_TwoWeeksOld() {
// Two weeks old should have 1/4 recency score
obs := &models.Observation{
ID: 1,
Type: models.ObsTypeFeature,
CreatedAtEpoch: s.now.Add(-14 * 24 * time.Hour).UnixMilli(),
}
score := s.calc.Calculate(obs, s.now)
// Expected: 1.0 × 1.2 (feature) × 0.25 (14 days = 2 half-lives) = 0.30
s.InDelta(0.30, score, 0.05, "14-day old feature should score ~0.30")
}
func (s *CalculatorSuite) TestCalculate_GoodScenarios_PositiveFeedback() {
// Positive feedback should boost score
obs := &models.Observation{
ID: 1,
Type: models.ObsTypeChange,
CreatedAtEpoch: s.now.UnixMilli(),
UserFeedback: 1, // thumbs up
}
score := s.calc.Calculate(obs, s.now)
// Expected: (1.0 × 0.9) + 0.30 (feedback) = 1.20
s.InDelta(1.20, score, 0.01, "thumbs up should boost score by 0.30")
}
func (s *CalculatorSuite) TestCalculate_GoodScenarios_NegativeFeedback() {
// Negative feedback should reduce score
obs := &models.Observation{
ID: 1,
Type: models.ObsTypeChange,
CreatedAtEpoch: s.now.UnixMilli(),
UserFeedback: -1, // thumbs down
}
score := s.calc.Calculate(obs, s.now)
// Expected: (1.0 × 0.9) - 0.30 (feedback) = 0.60
s.InDelta(0.60, score, 0.01, "thumbs down should reduce score by 0.30")
}
func (s *CalculatorSuite) TestCalculate_GoodScenarios_WithConcepts() {
// Observation with valuable concepts should get boost
obs := &models.Observation{
ID: 1,
Type: models.ObsTypeBugfix,
CreatedAtEpoch: s.now.UnixMilli(),
Concepts: []string{"security", "gotcha"},
}
score := s.calc.Calculate(obs, s.now)
// Concept boost: (0.30 + 0.25) × 0.20 = 0.11
// Expected: 1.3 + 0.11 = 1.41
s.InDelta(1.41, score, 0.05, "security+gotcha concepts should boost score")
}
func (s *CalculatorSuite) TestCalculate_GoodScenarios_WithRetrievals() {
// Popular observations should get retrieval boost
obs := &models.Observation{
ID: 1,
Type: models.ObsTypeDiscovery,
CreatedAtEpoch: s.now.UnixMilli(),
RetrievalCount: 7, // log2(8) = 3
}
score := s.calc.Calculate(obs, s.now)
// Retrieval boost: log2(7+1) × 0.1 × 0.15 = 3 × 0.1 × 0.15 = 0.045
// Expected: 1.1 + 0.045 ≈ 1.145
s.InDelta(1.145, score, 0.05, "7 retrievals should add small boost")
}
func (s *CalculatorSuite) TestCalculate_GoodScenarios_CombinedFactors() {
// Test with all factors combined
obs := &models.Observation{
ID: 1,
Type: models.ObsTypeBugfix,
CreatedAtEpoch: s.now.Add(-7 * 24 * time.Hour).UnixMilli(), // 7 days old
UserFeedback: 1,
Concepts: []string{"security"},
RetrievalCount: 3,
}
score := s.calc.Calculate(obs, s.now)
// Core: 1.0 × 1.3 × 0.5 = 0.65
// Feedback: 0.30
// Concept: 0.30 × 0.20 = 0.06
// Retrieval: log2(4) × 0.1 × 0.15 = 2 × 0.1 × 0.15 = 0.03
// Total ≈ 1.04
s.InDelta(1.04, score, 0.1, "combined factors should result in ~1.04")
}
// =============================================================================
// WORSE SCENARIOS - Degraded but acceptable operations
// =============================================================================
func (s *CalculatorSuite) TestCalculate_WorseScenarios_VeryOldObservation() {
// Very old observation should have low but non-zero score
obs := &models.Observation{
ID: 1,
Type: models.ObsTypeChange,
CreatedAtEpoch: s.now.Add(-90 * 24 * time.Hour).UnixMilli(), // 90 days old
}
score := s.calc.Calculate(obs, s.now)
// 90 days = ~12.86 half-lives → decay ≈ 0.00014
// Core: 1.0 × 0.9 × 0.00014 = 0.000126
// But minimum score is 0.01
s.GreaterOrEqual(score, 0.01, "very old observation should still meet minimum")
s.Less(score, 0.1, "very old observation should be low scoring")
}
func (s *CalculatorSuite) TestCalculate_WorseScenarios_NegativeFeedbackOld() {
// Old observation with negative feedback should still have minimum score
obs := &models.Observation{
ID: 1,
Type: models.ObsTypeChange,
CreatedAtEpoch: s.now.Add(-60 * 24 * time.Hour).UnixMilli(),
UserFeedback: -1,
}
score := s.calc.Calculate(obs, s.now)
s.GreaterOrEqual(score, s.config.MinScore, "should never go below minimum score")
}
func (s *CalculatorSuite) TestCalculate_WorseScenarios_UnknownConcepts() {
// Unknown concepts should not affect score negatively
obs := &models.Observation{
ID: 1,
Type: models.ObsTypeDiscovery,
CreatedAtEpoch: s.now.UnixMilli(),
Concepts: []string{"unknown-concept", "another-unknown"},
}
score := s.calc.Calculate(obs, s.now)
// Should just be the base score without concept boost
s.InDelta(1.1, score, 0.01, "unknown concepts should not affect score")
}
func (s *CalculatorSuite) TestCalculate_WorseScenarios_MixedConcepts() {
// Mix of known and unknown concepts
obs := &models.Observation{
ID: 1,
Type: models.ObsTypeDiscovery,
CreatedAtEpoch: s.now.UnixMilli(),
Concepts: []string{"security", "unknown-concept"},
}
score := s.calc.Calculate(obs, s.now)
// Only security should contribute
// Expected: 1.1 + (0.30 × 0.20) = 1.16
s.InDelta(1.16, score, 0.05, "only known concepts should boost score")
}
// =============================================================================
// BAD SCENARIOS - Edge cases and error conditions
// =============================================================================
func (s *CalculatorSuite) TestCalculate_BadScenarios_FutureTimestamp() {
// Observation created in the future (clock skew)
obs := &models.Observation{
ID: 1,
Type: models.ObsTypeBugfix,
CreatedAtEpoch: s.now.Add(24 * time.Hour).UnixMilli(), // 1 day in future
}
score := s.calc.Calculate(obs, s.now)
// Should handle gracefully - age should be 0
s.InDelta(1.3, score, 0.01, "future timestamp should be treated as now")
}
func (s *CalculatorSuite) TestCalculate_BadScenarios_ZeroEpoch() {
// Missing creation timestamp
obs := &models.Observation{
ID: 1,
Type: models.ObsTypeDiscovery,
CreatedAtEpoch: 0, // Missing timestamp
}
score := s.calc.Calculate(obs, s.now)
// This will be treated as very old (1970)
s.GreaterOrEqual(score, s.config.MinScore, "should still meet minimum")
}
func (s *CalculatorSuite) TestCalculate_BadScenarios_EmptyObservation() {
// Minimal observation with defaults
obs := &models.Observation{
ID: 1,
Type: "", // Empty type
CreatedAtEpoch: s.now.UnixMilli(),
}
score := s.calc.Calculate(obs, s.now)
// Unknown type should default to 1.0 weight
s.InDelta(1.0, score, 0.01, "empty type should use default weight 1.0")
}
func (s *CalculatorSuite) TestCalculate_BadScenarios_ExtremeRetrievalCount() {
// Very high retrieval count
obs := &models.Observation{
ID: 1,
Type: models.ObsTypeDiscovery,
CreatedAtEpoch: s.now.UnixMilli(),
RetrievalCount: 1000000, // Extreme value
}
score := s.calc.Calculate(obs, s.now)
// log2(1000001) ≈ 19.93, so boost = 19.93 × 0.1 × 0.15 ≈ 0.30
// Score should be reasonable, not exploding
s.Less(score, 2.0, "extreme retrieval count should not explode score")
}
func (s *CalculatorSuite) TestCalculate_BadScenarios_NegativeRetrievalCount() {
// Negative retrieval count (should not happen but test defensively)
obs := &models.Observation{
ID: 1,
Type: models.ObsTypeDiscovery,
CreatedAtEpoch: s.now.UnixMilli(),
RetrievalCount: -5,
}
score := s.calc.Calculate(obs, s.now)
// Should not panic and should give base score
s.InDelta(1.1, score, 0.01, "negative retrieval should be ignored")
}
// =============================================================================
// EDGE CASES - Boundary conditions
// =============================================================================
func (s *CalculatorSuite) TestCalculate_EdgeCases_ExactlyOneHalfLife() {
obs := &models.Observation{
ID: 1,
Type: models.ObsTypeChange, // 0.9 weight
CreatedAtEpoch: s.now.Add(-7 * 24 * time.Hour).UnixMilli(),
}
score := s.calc.Calculate(obs, s.now)
s.InDelta(0.45, score, 0.01, "exactly 7 days should give 0.5 decay")
}
func (s *CalculatorSuite) TestCalculate_EdgeCases_ExactlyTwoHalfLives() {
obs := &models.Observation{
ID: 1,
Type: models.ObsTypeChange,
CreatedAtEpoch: s.now.Add(-14 * 24 * time.Hour).UnixMilli(),
}
score := s.calc.Calculate(obs, s.now)
s.InDelta(0.225, score, 0.01, "exactly 14 days should give 0.25 decay")
}
func (s *CalculatorSuite) TestCalculate_EdgeCases_AllTypeWeights() {
types := []struct {
t models.ObservationType
weight float64
}{
{models.ObsTypeBugfix, 1.3},
{models.ObsTypeFeature, 1.2},
{models.ObsTypeDiscovery, 1.1},
{models.ObsTypeDecision, 1.1},
{models.ObsTypeRefactor, 1.0},
{models.ObsTypeChange, 0.9},
}
for _, tt := range types {
s.Run(string(tt.t), func() {
obs := &models.Observation{
ID: 1,
Type: tt.t,
CreatedAtEpoch: s.now.UnixMilli(),
}
score := s.calc.Calculate(obs, s.now)
s.InDelta(tt.weight, score, 0.01)
})
}
}
func (s *CalculatorSuite) TestCalculate_EdgeCases_MinimumScoreEnforced() {
// Create worst case scenario
obs := &models.Observation{
ID: 1,
Type: models.ObsTypeChange, // Lowest weight 0.9
CreatedAtEpoch: time.Date(2000, 1, 1, 0, 0, 0, 0, time.UTC).UnixMilli(), // Very old
UserFeedback: -1, // Negative feedback
}
score := s.calc.Calculate(obs, s.now)
s.Equal(s.config.MinScore, score, "should be exactly minimum score")
}
func (s *CalculatorSuite) TestCalculate_EdgeCases_AllConceptsMaxWeight() {
// Observation with all high-value concepts
obs := &models.Observation{
ID: 1,
Type: models.ObsTypeBugfix,
CreatedAtEpoch: s.now.UnixMilli(),
Concepts: []string{"security", "gotcha", "best-practice", "anti-pattern"},
}
score := s.calc.Calculate(obs, s.now)
// security=0.30, gotcha=0.25, best-practice=0.20, anti-pattern=0.20 = 0.95
// Concept contrib: 0.95 × 0.20 = 0.19
// Total: 1.3 + 0.19 = 1.49
s.InDelta(1.49, score, 0.05, "all high-value concepts should boost significantly")
}
// =============================================================================
// CALCULATE COMPONENTS TESTS
// =============================================================================
func (s *CalculatorSuite) TestCalculateComponents_ReturnsAllComponents() {
obs := &models.Observation{
ID: 1,
Type: models.ObsTypeBugfix,
CreatedAtEpoch: s.now.Add(-7 * 24 * time.Hour).UnixMilli(),
UserFeedback: 1,
Concepts: []string{"security"},
RetrievalCount: 7,
}
components := s.calc.CalculateComponents(obs, s.now)
s.InDelta(1.3, components.TypeWeight, 0.01)
s.InDelta(0.5, components.RecencyDecay, 0.01)
s.InDelta(0.65, components.CoreScore, 0.05)
s.InDelta(0.30, components.FeedbackContrib, 0.01)
s.InDelta(0.06, components.ConceptContrib, 0.02)
s.Greater(components.RetrievalContrib, 0.0)
s.InDelta(7.0, components.AgeDays, 0.1)
s.Greater(components.FinalScore, 0.0)
}
func (s *CalculatorSuite) TestCalculateComponents_MatchesCalculate() {
obs := &models.Observation{
ID: 1,
Type: models.ObsTypeFeature,
CreatedAtEpoch: s.now.Add(-3 * 24 * time.Hour).UnixMilli(),
UserFeedback: -1,
Concepts: []string{"performance", "architecture"},
RetrievalCount: 15,
}
score := s.calc.Calculate(obs, s.now)
components := s.calc.CalculateComponents(obs, s.now)
s.InDelta(score, components.FinalScore, 0.001, "Calculate and CalculateComponents should match")
}
// =============================================================================
// BATCH CALCULATE TESTS
// =============================================================================
func (s *CalculatorSuite) TestBatchCalculate_Empty() {
scores := s.calc.BatchCalculate(nil, s.now)
s.Empty(scores)
scores = s.calc.BatchCalculate([]*models.Observation{}, s.now)
s.Empty(scores)
}
func (s *CalculatorSuite) TestBatchCalculate_Multiple() {
obs := []*models.Observation{
{ID: 1, Type: models.ObsTypeBugfix, CreatedAtEpoch: s.now.UnixMilli()},
{ID: 2, Type: models.ObsTypeFeature, CreatedAtEpoch: s.now.Add(-7 * 24 * time.Hour).UnixMilli()},
{ID: 3, Type: models.ObsTypeChange, CreatedAtEpoch: s.now.Add(-14 * 24 * time.Hour).UnixMilli()},
}
scores := s.calc.BatchCalculate(obs, s.now)
s.Len(scores, 3)
s.Contains(scores, int64(1))
s.Contains(scores, int64(2))
s.Contains(scores, int64(3))
s.InDelta(1.3, scores[1], 0.01) // New bugfix
s.InDelta(0.6, scores[2], 0.1) // 7-day feature
s.InDelta(0.225, scores[3], 0.05) // 14-day change
}
// =============================================================================
// CONFIGURATION TESTS
// =============================================================================
func (s *CalculatorSuite) TestNewCalculator_NilConfig() {
calc := NewCalculator(nil)
s.NotNil(calc)
s.NotNil(calc.config)
s.Equal(7.0, calc.config.RecencyHalfLifeDays)
}
func (s *CalculatorSuite) TestUpdateConfig() {
newConfig := &models.ScoringConfig{
RecencyHalfLifeDays: 14.0, // Changed from 7
FeedbackWeight: 0.50,
ConceptWeight: 0.10,
RetrievalWeight: 0.05,
MinScore: 0.001,
ConceptWeights: map[string]float64{"test": 0.5},
}
s.calc.UpdateConfig(newConfig)
obs := &models.Observation{
ID: 1,
Type: models.ObsTypeChange,
CreatedAtEpoch: s.now.Add(-14 * 24 * time.Hour).UnixMilli(),
}
score := s.calc.Calculate(obs, s.now)
// With 14-day half-life, 14 days = exactly one half-life
// Expected: 1.0 × 0.9 × 0.5 = 0.45
s.InDelta(0.45, score, 0.01)
}
func (s *CalculatorSuite) TestUpdateConfig_NilIgnored() {
originalConfig := s.calc.GetConfig()
s.calc.UpdateConfig(nil)
s.Equal(originalConfig, s.calc.GetConfig())
}
func (s *CalculatorSuite) TestGetConfig() {
config := s.calc.GetConfig()
s.NotNil(config)
s.Equal(7.0, config.RecencyHalfLifeDays)
}
func (s *CalculatorSuite) TestRecalculateThreshold() {
threshold := s.calc.RecalculateThreshold()
s.Equal(6*time.Hour, threshold)
}
// =============================================================================
// STANDALONE TESTS (non-suite)
// =============================================================================
func TestNewCalculator_DefaultConfig(t *testing.T) {
calc := NewCalculator(nil)
require.NotNil(t, calc)
assert.Equal(t, 7.0, calc.config.RecencyHalfLifeDays)
assert.Equal(t, 0.30, calc.config.FeedbackWeight)
assert.Equal(t, 0.01, calc.config.MinScore)
}
func TestCalculator_ConcurrentAccess(t *testing.T) {
calc := NewCalculator(nil)
now := time.Now()
// Test that calculator is safe for concurrent reads
done := make(chan bool, 10)
for i := 0; i < 10; i++ {
go func(id int64) {
obs := &models.Observation{
ID: id,
Type: models.ObsTypeBugfix,
CreatedAtEpoch: now.UnixMilli(),
}
score := calc.Calculate(obs, now)
assert.Greater(t, score, 0.0)
done <- true
}(int64(i))
}
for i := 0; i < 10; i++ {
<-done
}
}
func TestCalculator_DecayPrecision(t *testing.T) {
calc := NewCalculator(nil)
now := time.Now()
// Test that decay is mathematically correct
testCases := []struct {
days int
expectedDecay float64
}{
{0, 1.0},
{7, 0.5},
{14, 0.25},
{21, 0.125},
{28, 0.0625},
}
for _, tc := range testCases {
t.Run(string(rune('0'+tc.days/7))+"_half_lives", func(t *testing.T) {
obs := &models.Observation{
ID: 1,
Type: models.ObsTypeRefactor, // 1.0 weight
CreatedAtEpoch: now.Add(-time.Duration(tc.days) * 24 * time.Hour).UnixMilli(),
}
components := calc.CalculateComponents(obs, now)
assert.InDelta(t, tc.expectedDecay, components.RecencyDecay, 0.001)
})
}
}
func TestTypeBaseScore_UnknownType(t *testing.T) {
score := models.TypeBaseScore("unknown-type")
assert.Equal(t, 1.0, score, "unknown type should default to 1.0")
}
func TestTypeBaseScore_AllKnownTypes(t *testing.T) {
expected := map[models.ObservationType]float64{
models.ObsTypeBugfix: 1.3,
models.ObsTypeFeature: 1.2,
models.ObsTypeDiscovery: 1.1,
models.ObsTypeDecision: 1.1,
models.ObsTypeRefactor: 1.0,
models.ObsTypeChange: 0.9,
}
for obsType, expectedScore := range expected {
t.Run(string(obsType), func(t *testing.T) {
score := models.TypeBaseScore(obsType)
assert.Equal(t, expectedScore, score)
})
}
}
func TestCalculator_RetrievalBoostDiminishingReturns(t *testing.T) {
calc := NewCalculator(nil)
now := time.Now()
// Test that retrieval boost has diminishing returns
// When retrieval count doubles, the boost should NOT double (log2 gives diminishing returns)
// Collect boosts for different counts
boosts := make([]float64, 0)
retrievalCounts := []int{1, 3, 7, 15, 31, 63, 127}
for _, count := range retrievalCounts {
obs := &models.Observation{
ID: 1,
Type: models.ObsTypeRefactor,
CreatedAtEpoch: now.UnixMilli(),
RetrievalCount: count,
}
components := calc.CalculateComponents(obs, now)
boosts = append(boosts, components.RetrievalContrib)
}
// Verify boost increases but at a decreasing rate
for i := 1; i < len(boosts); i++ {
// Each boost should be higher than the previous
assert.Greater(t, boosts[i], boosts[i-1],
"boost should increase with more retrievals")
// But not proportionally - calculate the ratios
if i >= 2 {
ratio1 := boosts[i-1] / boosts[i-2]
ratio2 := boosts[i] / boosts[i-1]
// The growth ratio should be decreasing (diminishing returns)
assert.Less(t, ratio2, ratio1+0.01, // Allow small floating point tolerance
"growth rate should decrease (diminishing returns)")
}
}
}
+186
View File
@@ -0,0 +1,186 @@
// Package scoring provides importance score calculation for observations.
package scoring
import (
"context"
"sync"
"time"
"github.com/rs/zerolog"
"github.com/lukaszraczylo/claude-mnemonic/internal/db/sqlite"
"github.com/lukaszraczylo/claude-mnemonic/pkg/models"
)
// ObservationStore defines the interface for observation storage operations needed by the recalculator.
type ObservationStore interface {
GetObservationsNeedingScoreUpdate(ctx context.Context, threshold time.Duration, limit int) ([]*models.Observation, error)
UpdateImportanceScores(ctx context.Context, scores map[int64]float64) error
GetConceptWeights(ctx context.Context) (map[string]float64, error)
}
// Recalculator periodically recalculates importance scores for observations.
type Recalculator struct {
store ObservationStore
calculator *Calculator
log zerolog.Logger
interval time.Duration
batchSize int
stopCh chan struct{}
doneCh chan struct{}
mu sync.Mutex
running bool
}
// NewRecalculator creates a new background recalculator.
func NewRecalculator(store ObservationStore, calc *Calculator, log zerolog.Logger) *Recalculator {
return &Recalculator{
store: store,
calculator: calc,
log: log.With().Str("component", "recalculator").Logger(),
interval: 1 * time.Hour, // Run every hour
batchSize: 500, // Process 500 observations at a time
stopCh: make(chan struct{}),
doneCh: make(chan struct{}),
}
}
// Start begins the background recalculation loop.
// This should be called in a goroutine.
func (r *Recalculator) Start(ctx context.Context) {
r.mu.Lock()
if r.running {
r.mu.Unlock()
return
}
r.running = true
r.mu.Unlock()
defer func() {
r.mu.Lock()
r.running = false
r.mu.Unlock()
close(r.doneCh)
}()
// Initial run
r.recalculate(ctx)
r.mu.Lock()
interval := r.interval
r.mu.Unlock()
ticker := time.NewTicker(interval)
defer ticker.Stop()
for {
select {
case <-ctx.Done():
r.log.Info().Msg("recalculator shutting down due to context cancellation")
return
case <-r.stopCh:
r.log.Info().Msg("recalculator stopping")
return
case <-ticker.C:
r.recalculate(ctx)
}
}
}
// Stop stops the background recalculation loop.
func (r *Recalculator) Stop() {
r.mu.Lock()
if !r.running {
r.mu.Unlock()
return
}
r.mu.Unlock()
close(r.stopCh)
<-r.doneCh
}
// recalculate performs a single recalculation batch.
func (r *Recalculator) recalculate(ctx context.Context) {
now := time.Now()
threshold := r.calculator.RecalculateThreshold()
r.mu.Lock()
batchSize := r.batchSize
r.mu.Unlock()
observations, err := r.store.GetObservationsNeedingScoreUpdate(ctx, threshold, batchSize)
if err != nil {
r.log.Error().Err(err).Msg("failed to get observations for score update")
return
}
if len(observations) == 0 {
return
}
scores := r.calculator.BatchCalculate(observations, now)
if err := r.store.UpdateImportanceScores(ctx, scores); err != nil {
r.log.Error().Err(err).Msg("failed to update importance scores")
return
}
r.log.Info().
Int("count", len(scores)).
Dur("elapsed", time.Since(now)).
Msg("recalculated importance scores")
}
// RecalculateNow triggers an immediate recalculation.
// This is useful for testing or when scores need to be updated urgently.
func (r *Recalculator) RecalculateNow(ctx context.Context) error {
r.recalculate(ctx)
return nil
}
// RefreshConceptWeights reloads concept weights from the database.
// Call this after updating concept weights to apply changes.
func (r *Recalculator) RefreshConceptWeights(ctx context.Context) error {
weights, err := r.store.GetConceptWeights(ctx)
if err != nil {
return err
}
config := r.calculator.GetConfig()
config.ConceptWeights = weights
r.calculator.UpdateConfig(config)
r.log.Info().Int("count", len(weights)).Msg("refreshed concept weights")
return nil
}
// Stats returns statistics about the recalculator.
type Stats struct {
Running bool `json:"running"`
Interval time.Duration `json:"interval"`
BatchSize int `json:"batch_size"`
HalfLife float64 `json:"half_life_days"`
MinScore float64 `json:"min_score"`
ConceptsLen int `json:"concepts_count"`
}
// GetStats returns current recalculator statistics.
func (r *Recalculator) GetStats() Stats {
r.mu.Lock()
defer r.mu.Unlock()
config := r.calculator.GetConfig()
return Stats{
Running: r.running,
Interval: r.interval,
BatchSize: r.batchSize,
HalfLife: config.RecencyHalfLifeDays,
MinScore: config.MinScore,
ConceptsLen: len(config.ConceptWeights),
}
}
// Ensure ObservationStore satisfies the interface
var _ ObservationStore = (*sqlite.ObservationStore)(nil)
+447
View File
@@ -0,0 +1,447 @@
// Package scoring provides importance score calculation for observations.
package scoring
import (
"context"
"sync"
"testing"
"time"
"github.com/rs/zerolog"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"github.com/lukaszraczylo/claude-mnemonic/pkg/models"
)
// MockObservationStore is a mock implementation of ObservationStore for testing.
type MockObservationStore struct {
mu sync.Mutex
observations []*models.Observation
scores map[int64]float64
conceptWeights map[string]float64
updateErr error
getErr error
getConceptsErr error
updateScoresCalls int
}
func NewMockObservationStore() *MockObservationStore {
return &MockObservationStore{
observations: []*models.Observation{},
scores: make(map[int64]float64),
conceptWeights: make(map[string]float64),
}
}
func (m *MockObservationStore) GetObservationsNeedingScoreUpdate(ctx context.Context, threshold time.Duration, limit int) ([]*models.Observation, error) {
m.mu.Lock()
defer m.mu.Unlock()
if m.getErr != nil {
return nil, m.getErr
}
// Return observations that haven't been updated within threshold
now := time.Now()
var result []*models.Observation
for _, obs := range m.observations {
if !obs.ScoreUpdatedAt.Valid || now.Sub(time.Unix(obs.ScoreUpdatedAt.Int64, 0)) > threshold {
result = append(result, obs)
if len(result) >= limit {
break
}
}
}
return result, nil
}
func (m *MockObservationStore) UpdateImportanceScores(ctx context.Context, scores map[int64]float64) error {
m.mu.Lock()
defer m.mu.Unlock()
m.updateScoresCalls++
if m.updateErr != nil {
return m.updateErr
}
for id, score := range scores {
m.scores[id] = score
}
return nil
}
func (m *MockObservationStore) GetConceptWeights(ctx context.Context) (map[string]float64, error) {
m.mu.Lock()
defer m.mu.Unlock()
if m.getConceptsErr != nil {
return nil, m.getConceptsErr
}
return m.conceptWeights, nil
}
func (m *MockObservationStore) AddObservation(obs *models.Observation) {
m.mu.Lock()
defer m.mu.Unlock()
m.observations = append(m.observations, obs)
}
func (m *MockObservationStore) SetConceptWeights(weights map[string]float64) {
m.mu.Lock()
defer m.mu.Unlock()
m.conceptWeights = weights
}
func (m *MockObservationStore) GetScore(id int64) (float64, bool) {
m.mu.Lock()
defer m.mu.Unlock()
score, ok := m.scores[id]
return score, ok
}
func (m *MockObservationStore) GetUpdateScoresCalls() int {
m.mu.Lock()
defer m.mu.Unlock()
return m.updateScoresCalls
}
func (m *MockObservationStore) SetUpdateError(err error) {
m.mu.Lock()
defer m.mu.Unlock()
m.updateErr = err
}
func (m *MockObservationStore) SetGetError(err error) {
m.mu.Lock()
defer m.mu.Unlock()
m.getErr = err
}
// =============================================================================
// RECALCULATOR TESTS
// =============================================================================
func TestNewRecalculator(t *testing.T) {
store := NewMockObservationStore()
calc := NewCalculator(nil)
log := zerolog.Nop()
recalc := NewRecalculator(store, calc, log)
require.NotNil(t, recalc)
assert.NotNil(t, recalc.store)
assert.NotNil(t, recalc.calculator)
assert.Equal(t, 1*time.Hour, recalc.interval)
assert.Equal(t, 500, recalc.batchSize)
assert.False(t, recalc.running)
}
func TestRecalculator_RecalculateNow(t *testing.T) {
store := NewMockObservationStore()
calc := NewCalculator(nil)
log := zerolog.Nop()
recalc := NewRecalculator(store, calc, log)
// Add observations
now := time.Now()
store.AddObservation(&models.Observation{
ID: 1,
Type: models.ObsTypeBugfix,
CreatedAtEpoch: now.UnixMilli(),
})
store.AddObservation(&models.Observation{
ID: 2,
Type: models.ObsTypeFeature,
CreatedAtEpoch: now.Add(-7 * 24 * time.Hour).UnixMilli(),
})
ctx := context.Background()
err := recalc.RecalculateNow(ctx)
require.NoError(t, err)
// Verify scores were calculated
score1, ok := store.GetScore(1)
assert.True(t, ok)
assert.Greater(t, score1, 0.0)
score2, ok := store.GetScore(2)
assert.True(t, ok)
assert.Greater(t, score2, 0.0)
assert.Less(t, score2, score1, "older observation should have lower score")
}
func TestRecalculator_RefreshConceptWeights(t *testing.T) {
store := NewMockObservationStore()
calc := NewCalculator(nil)
log := zerolog.Nop()
recalc := NewRecalculator(store, calc, log)
// Set up custom weights in store
customWeights := map[string]float64{
"security": 0.50,
"performance": 0.25,
}
store.SetConceptWeights(customWeights)
ctx := context.Background()
err := recalc.RefreshConceptWeights(ctx)
require.NoError(t, err)
// Verify weights were updated in calculator config
config := calc.GetConfig()
assert.Equal(t, 0.50, config.ConceptWeights["security"])
assert.Equal(t, 0.25, config.ConceptWeights["performance"])
}
func TestRecalculator_RefreshConceptWeights_Error(t *testing.T) {
store := NewMockObservationStore()
calc := NewCalculator(nil)
log := zerolog.Nop()
recalc := NewRecalculator(store, calc, log)
store.getConceptsErr = assert.AnError
ctx := context.Background()
err := recalc.RefreshConceptWeights(ctx)
require.Error(t, err)
}
func TestRecalculator_GetStats(t *testing.T) {
store := NewMockObservationStore()
calc := NewCalculator(nil)
log := zerolog.Nop()
recalc := NewRecalculator(store, calc, log)
// Set fields directly for testing
recalc.interval = 2 * time.Hour
recalc.batchSize = 250
stats := recalc.GetStats()
assert.False(t, stats.Running)
assert.Equal(t, 2*time.Hour, stats.Interval)
assert.Equal(t, 250, stats.BatchSize)
assert.Equal(t, 7.0, stats.HalfLife)
assert.Equal(t, 0.01, stats.MinScore)
}
func TestRecalculator_StartStop(t *testing.T) {
store := NewMockObservationStore()
calc := NewCalculator(nil)
log := zerolog.Nop()
recalc := NewRecalculator(store, calc, log)
// Use a short interval for testing
recalc.interval = 50 * time.Millisecond
// Add an observation
store.AddObservation(&models.Observation{
ID: 1,
Type: models.ObsTypeBugfix,
CreatedAtEpoch: time.Now().UnixMilli(),
})
ctx, cancel := context.WithCancel(context.Background())
defer cancel()
// Start in goroutine
go recalc.Start(ctx)
// Wait for initial run and at least one tick
time.Sleep(100 * time.Millisecond)
// Verify it ran
calls := store.GetUpdateScoresCalls()
assert.GreaterOrEqual(t, calls, 1, "should have run at least once")
// Stop via context cancellation
cancel()
time.Sleep(100 * time.Millisecond)
// Verify stopped
stats := recalc.GetStats()
assert.False(t, stats.Running)
}
func TestRecalculator_StartTwice(t *testing.T) {
store := NewMockObservationStore()
calc := NewCalculator(nil)
log := zerolog.Nop()
recalc := NewRecalculator(store, calc, log)
recalc.interval = 1 * time.Hour // Long interval so it doesn't tick during test
ctx, cancel := context.WithCancel(context.Background())
defer cancel()
// Start first time
go recalc.Start(ctx)
time.Sleep(50 * time.Millisecond)
// Try to start second time (should return immediately)
go recalc.Start(ctx)
time.Sleep(50 * time.Millisecond)
// Should still be running (only once)
stats := recalc.GetStats()
assert.True(t, stats.Running)
cancel()
}
func TestRecalculator_StopWhenNotRunning(t *testing.T) {
store := NewMockObservationStore()
calc := NewCalculator(nil)
log := zerolog.Nop()
recalc := NewRecalculator(store, calc, log)
// Stop without starting - should not panic
recalc.Stop()
stats := recalc.GetStats()
assert.False(t, stats.Running)
}
func TestRecalculator_EmptyStore(t *testing.T) {
store := NewMockObservationStore()
calc := NewCalculator(nil)
log := zerolog.Nop()
recalc := NewRecalculator(store, calc, log)
ctx := context.Background()
err := recalc.RecalculateNow(ctx)
require.NoError(t, err)
assert.Equal(t, 0, store.GetUpdateScoresCalls(), "should not call update with no observations")
}
func TestRecalculator_GetObservationsError(t *testing.T) {
store := NewMockObservationStore()
calc := NewCalculator(nil)
log := zerolog.Nop()
recalc := NewRecalculator(store, calc, log)
store.SetGetError(assert.AnError)
ctx := context.Background()
err := recalc.RecalculateNow(ctx)
// Should not return error (logs it instead)
require.NoError(t, err)
assert.Equal(t, 0, store.GetUpdateScoresCalls())
}
func TestRecalculator_UpdateScoresError(t *testing.T) {
store := NewMockObservationStore()
calc := NewCalculator(nil)
log := zerolog.Nop()
recalc := NewRecalculator(store, calc, log)
store.AddObservation(&models.Observation{
ID: 1,
Type: models.ObsTypeBugfix,
CreatedAtEpoch: time.Now().UnixMilli(),
})
store.SetUpdateError(assert.AnError)
ctx := context.Background()
err := recalc.RecalculateNow(ctx)
// Should not return error (logs it instead)
require.NoError(t, err)
assert.Equal(t, 1, store.GetUpdateScoresCalls())
}
func TestRecalculator_BatchProcessing(t *testing.T) {
store := NewMockObservationStore()
calc := NewCalculator(nil)
log := zerolog.Nop()
recalc := NewRecalculator(store, calc, log)
// Set small batch size
recalc.batchSize = 3
// Add 5 observations
now := time.Now()
for i := 1; i <= 5; i++ {
store.AddObservation(&models.Observation{
ID: int64(i),
Type: models.ObsTypeBugfix,
CreatedAtEpoch: now.UnixMilli(),
})
}
ctx := context.Background()
err := recalc.RecalculateNow(ctx)
require.NoError(t, err)
// Should only process batch size (3)
scores := 0
for i := 1; i <= 5; i++ {
if _, ok := store.GetScore(int64(i)); ok {
scores++
}
}
assert.Equal(t, 3, scores, "should only process batch size observations")
}
func TestRecalculator_ConcurrentAccess(t *testing.T) {
store := NewMockObservationStore()
calc := NewCalculator(nil)
log := zerolog.Nop()
recalc := NewRecalculator(store, calc, log)
// Add observations
now := time.Now()
for i := 1; i <= 10; i++ {
store.AddObservation(&models.Observation{
ID: int64(i),
Type: models.ObsTypeBugfix,
CreatedAtEpoch: now.UnixMilli(),
})
}
ctx := context.Background()
var wg sync.WaitGroup
// Run multiple recalculations concurrently
for i := 0; i < 5; i++ {
wg.Add(1)
go func() {
defer wg.Done()
_ = recalc.RecalculateNow(ctx)
}()
}
wg.Wait()
// Should complete without race conditions
// (use -race flag to verify)
assert.GreaterOrEqual(t, store.GetUpdateScoresCalls(), 1)
}
func TestRecalculator_StatsThreadSafe(t *testing.T) {
store := NewMockObservationStore()
calc := NewCalculator(nil)
log := zerolog.Nop()
recalc := NewRecalculator(store, calc, log)
var wg sync.WaitGroup
// Concurrent reads (use -race flag to verify)
for i := 0; i < 10; i++ {
wg.Add(1)
go func() {
defer wg.Done()
_ = recalc.GetStats()
}()
}
wg.Wait()
}
+448
View File
@@ -0,0 +1,448 @@
// Package expansion provides context-aware query expansion for improved search recall.
package expansion
import (
"context"
"regexp"
"strings"
"sync"
"github.com/lukaszraczylo/claude-mnemonic/internal/embedding"
"github.com/rs/zerolog/log"
)
// QueryIntent represents the detected intent of a query.
type QueryIntent string
const (
// IntentQuestion indicates a question-type query (how, why, what, etc.)
IntentQuestion QueryIntent = "question"
// IntentError indicates an error/debugging query
IntentError QueryIntent = "error"
// IntentImplementation indicates an implementation/coding query
IntentImplementation QueryIntent = "implementation"
// IntentArchitecture indicates an architecture/design query
IntentArchitecture QueryIntent = "architecture"
// IntentGeneral indicates a general lookup query
IntentGeneral QueryIntent = "general"
)
// ExpandedQuery represents a query variant with metadata.
type ExpandedQuery struct {
Query string `json:"query"`
Weight float64 `json:"weight"` // Weight for result merging (0.0-1.0)
Source string `json:"source"` // Where this expansion came from
Intent QueryIntent `json:"intent"` // Detected intent
}
// Expander provides context-aware query expansion.
type Expander struct {
embedSvc *embedding.Service
vocabulary []VocabEntry // Known vocabulary from observations
vocabVectors [][]float32 // Embeddings for vocabulary entries
vocabMu sync.RWMutex // Protects vocabulary
intentPatterns map[QueryIntent][]*regexp.Regexp
}
// VocabEntry represents a vocabulary term from observations.
type VocabEntry struct {
Term string // The term itself
Weight float64 // How common/important this term is (0.0-1.0)
Source string // Where it came from (title, concept, narrative)
}
// Config holds expander configuration.
type Config struct {
// MaxExpansions limits the number of expanded queries returned
MaxExpansions int
// MinSimilarity is the minimum similarity score for vocabulary expansion
MinSimilarity float64
// EnableVocabularyExpansion enables finding related terms from observations
EnableVocabularyExpansion bool
}
// DefaultConfig returns sensible default configuration.
func DefaultConfig() Config {
return Config{
MaxExpansions: 4,
MinSimilarity: 0.5,
EnableVocabularyExpansion: true,
}
}
// NewExpander creates a new query expander.
func NewExpander(embedSvc *embedding.Service) *Expander {
e := &Expander{
embedSvc: embedSvc,
intentPatterns: buildIntentPatterns(),
}
return e
}
// buildIntentPatterns creates regex patterns for intent detection.
func buildIntentPatterns() map[QueryIntent][]*regexp.Regexp {
patterns := make(map[QueryIntent][]*regexp.Regexp)
// Question patterns
patterns[IntentQuestion] = []*regexp.Regexp{
regexp.MustCompile(`(?i)^(how|why|what|when|where|which|who)\b`),
regexp.MustCompile(`(?i)\?$`),
regexp.MustCompile(`(?i)\b(explain|describe|understand)\b`),
}
// Error/debugging patterns
patterns[IntentError] = []*regexp.Regexp{
regexp.MustCompile(`(?i)\b(error|bug|issue|problem|fail|crash|exception|panic)\b`),
regexp.MustCompile(`(?i)\b(fix|debug|troubleshoot|resolve)\b`),
regexp.MustCompile(`(?i)\b(doesn't work|not working|broken)\b`),
}
// Implementation patterns
patterns[IntentImplementation] = []*regexp.Regexp{
regexp.MustCompile(`(?i)\b(implement|add|create|build|write|code)\b`),
regexp.MustCompile(`(?i)\b(function|method|handler|endpoint|api)\b`),
regexp.MustCompile(`(?i)\b(feature|functionality)\b`),
}
// Architecture patterns
patterns[IntentArchitecture] = []*regexp.Regexp{
regexp.MustCompile(`(?i)\b(architecture|design|pattern|structure)\b`),
regexp.MustCompile(`(?i)\b(component|module|layer|service)\b`),
regexp.MustCompile(`(?i)\b(flow|pipeline|workflow)\b`),
}
return patterns
}
// DetectIntent analyzes a query to determine its intent.
func (e *Expander) DetectIntent(query string) QueryIntent {
query = strings.TrimSpace(query)
if query == "" {
return IntentGeneral
}
// Check patterns in priority order
intentOrder := []QueryIntent{IntentError, IntentQuestion, IntentImplementation, IntentArchitecture}
for _, intent := range intentOrder {
patterns := e.intentPatterns[intent]
for _, pattern := range patterns {
if pattern.MatchString(query) {
return intent
}
}
}
return IntentGeneral
}
// Expand generates expanded query variants based on the original query.
func (e *Expander) Expand(ctx context.Context, query string, cfg Config) []ExpandedQuery {
query = strings.TrimSpace(query)
if query == "" {
return nil
}
intent := e.DetectIntent(query)
expansions := make([]ExpandedQuery, 0, cfg.MaxExpansions)
// Always include the original query with highest weight
expansions = append(expansions, ExpandedQuery{
Query: query,
Weight: 1.0,
Source: "original",
Intent: intent,
})
// Generate intent-based expansions
intentExpansions := e.expandByIntent(query, intent)
expansions = append(expansions, intentExpansions...)
// Generate vocabulary-based expansions if enabled and we have vocabulary
if cfg.EnableVocabularyExpansion && e.embedSvc != nil && len(e.vocabulary) > 0 {
vocabExpansions := e.expandByVocabulary(ctx, query, cfg.MinSimilarity)
expansions = append(expansions, vocabExpansions...)
}
// Deduplicate and limit
expansions = deduplicateExpansions(expansions)
if len(expansions) > cfg.MaxExpansions {
expansions = expansions[:cfg.MaxExpansions]
}
log.Debug().
Str("query", truncate(query, 50)).
Str("intent", string(intent)).
Int("expansions", len(expansions)).
Msg("Query expanded")
return expansions
}
// expandByIntent generates expansions based on detected query intent.
func (e *Expander) expandByIntent(query string, intent QueryIntent) []ExpandedQuery {
var expansions []ExpandedQuery
// Extract key terms from query for context-aware expansion
keyTerms := extractKeyTerms(query)
switch intent {
case IntentQuestion:
// For questions, create a declarative variant
declarative := makeDeclarative(query)
if declarative != query {
expansions = append(expansions, ExpandedQuery{
Query: declarative,
Weight: 0.85,
Source: "intent:declarative",
Intent: intent,
})
}
case IntentError:
// For errors, expand with solution-oriented terms
if len(keyTerms) > 0 {
solutionQuery := strings.Join(keyTerms, " ") + " solution fix"
expansions = append(expansions, ExpandedQuery{
Query: solutionQuery,
Weight: 0.8,
Source: "intent:solution",
Intent: intent,
})
}
case IntentImplementation:
// For implementation queries, focus on the what/how
if len(keyTerms) > 0 {
howQuery := "how " + strings.Join(keyTerms, " ")
expansions = append(expansions, ExpandedQuery{
Query: howQuery,
Weight: 0.75,
Source: "intent:how",
Intent: intent,
})
}
case IntentArchitecture:
// For architecture queries, expand with design context
if len(keyTerms) > 0 {
designQuery := strings.Join(keyTerms, " ") + " design structure"
expansions = append(expansions, ExpandedQuery{
Query: designQuery,
Weight: 0.75,
Source: "intent:design",
Intent: intent,
})
}
case IntentGeneral:
// For general queries, try noun phrase extraction
// No additional expansion - rely on vocabulary expansion
}
return expansions
}
// expandByVocabulary finds similar terms from the observation vocabulary.
func (e *Expander) expandByVocabulary(ctx context.Context, query string, minSimilarity float64) []ExpandedQuery {
e.vocabMu.RLock()
defer e.vocabMu.RUnlock()
if len(e.vocabulary) == 0 || e.embedSvc == nil {
return nil
}
// Embed the query
queryEmb, err := e.embedSvc.Embed(query)
if err != nil {
log.Warn().Err(err).Msg("Failed to embed query for vocabulary expansion")
return nil
}
// Find similar vocabulary terms
type scoredTerm struct {
entry VocabEntry
score float64
}
var similar []scoredTerm
for i, entry := range e.vocabulary {
if i >= len(e.vocabVectors) {
break
}
score := cosineSimilarity(queryEmb, e.vocabVectors[i])
if score >= minSimilarity {
similar = append(similar, scoredTerm{entry: entry, score: score})
}
}
if len(similar) == 0 {
return nil
}
// Sort by score (descending) using bubble sort
for i := 0; i < len(similar)-1; i++ {
for j := i + 1; j < len(similar); j++ {
if similar[j].score > similar[i].score {
similar[i], similar[j] = similar[j], similar[i]
}
}
}
// Create expansion by combining top similar terms with query
var expansions []ExpandedQuery
if len(similar) > 0 {
// Take top 2 similar terms and combine with original key terms
keyTerms := extractKeyTerms(query)
for i := 0; i < min(2, len(similar)); i++ {
term := similar[i].entry.Term
// Don't add if term is already in query
if strings.Contains(strings.ToLower(query), strings.ToLower(term)) {
continue
}
combinedQuery := strings.Join(keyTerms, " ") + " " + term
expansions = append(expansions, ExpandedQuery{
Query: combinedQuery,
Weight: 0.7 * similar[i].score * similar[i].entry.Weight,
Source: "vocabulary:" + term,
Intent: IntentGeneral,
})
}
}
return expansions
}
// Helper functions
// extractKeyTerms extracts meaningful terms from a query.
func extractKeyTerms(query string) []string {
// Common stop words to filter out
stopWords := map[string]bool{
"a": true, "an": true, "the": true, "is": true, "are": true,
"was": true, "were": true, "be": true, "been": true, "being": true,
"have": true, "has": true, "had": true, "do": true, "does": true,
"did": true, "will": true, "would": true, "could": true, "should": true,
"may": true, "might": true, "must": true, "can": true,
"i": true, "me": true, "my": true, "we": true, "our": true,
"you": true, "your": true, "it": true, "its": true,
"this": true, "that": true, "these": true, "those": true,
"what": true, "which": true, "who": true, "whom": true,
"how": true, "why": true, "when": true, "where": true,
"to": true, "for": true, "with": true, "about": true, "from": true,
"in": true, "on": true, "at": true, "by": true, "of": true,
"and": true, "or": true, "but": true, "if": true, "then": true,
}
// Split and filter
words := strings.Fields(strings.ToLower(query))
var terms []string
for _, word := range words {
// Remove punctuation
word = strings.Trim(word, ".,?!;:'\"()[]{}")
if len(word) < 2 {
continue
}
if stopWords[word] {
continue
}
terms = append(terms, word)
}
return terms
}
// makeDeclarative converts a question to a declarative statement.
func makeDeclarative(query string) string {
query = strings.TrimSpace(query)
// Remove question mark
query = strings.TrimSuffix(query, "?")
// Handle common question patterns
patterns := []struct {
prefix string
replacement string
}{
{"how do i ", ""},
{"how to ", ""},
{"how does ", ""},
{"how is ", ""},
{"what is ", ""},
{"what are ", ""},
{"why does ", ""},
{"why is ", ""},
{"where is ", ""},
{"where are ", ""},
{"when does ", ""},
{"when is ", ""},
}
lower := strings.ToLower(query)
for _, p := range patterns {
if strings.HasPrefix(lower, p.prefix) {
return strings.TrimSpace(query[len(p.prefix):])
}
}
return query
}
// deduplicateExpansions removes duplicate queries while preserving order.
func deduplicateExpansions(expansions []ExpandedQuery) []ExpandedQuery {
seen := make(map[string]bool)
result := make([]ExpandedQuery, 0, len(expansions))
for _, exp := range expansions {
normalized := strings.ToLower(strings.TrimSpace(exp.Query))
if !seen[normalized] {
seen[normalized] = true
result = append(result, exp)
}
}
return result
}
// cosineSimilarity computes cosine similarity between two vectors.
func cosineSimilarity(a, b []float32) float64 {
if len(a) != len(b) || len(a) == 0 {
return 0
}
var dot, normA, normB float64
for i := range a {
dot += float64(a[i]) * float64(b[i])
normA += float64(a[i]) * float64(a[i])
normB += float64(b[i]) * float64(b[i])
}
if normA == 0 || normB == 0 {
return 0
}
return dot / (sqrt(normA) * sqrt(normB))
}
// sqrt is a simple square root implementation.
func sqrt(x float64) float64 {
if x <= 0 {
return 0
}
z := x
for i := 0; i < 10; i++ {
z = (z + x/z) / 2
}
return z
}
// truncate truncates a string to maxLen characters.
func truncate(s string, maxLen int) string {
if len(s) <= maxLen {
return s
}
return s[:maxLen] + "..."
}
+518
View File
@@ -0,0 +1,518 @@
// Package expansion provides context-aware query expansion for improved search recall.
package expansion
import (
"context"
"testing"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/suite"
)
// ExpanderSuite tests the Expander functionality.
type ExpanderSuite struct {
suite.Suite
expander *Expander
}
func TestExpanderSuite(t *testing.T) {
suite.Run(t, new(ExpanderSuite))
}
func (s *ExpanderSuite) SetupTest() {
// Create expander without embedding service for basic tests
s.expander = NewExpander(nil)
}
// TestNewExpander tests expander creation.
func (s *ExpanderSuite) TestNewExpander() {
e := NewExpander(nil)
s.NotNil(e)
s.NotNil(e.intentPatterns)
s.Nil(e.embedSvc)
}
// TestDetectIntent tests intent detection.
func (s *ExpanderSuite) TestDetectIntent() {
tests := []struct {
name string
query string
expected QueryIntent
}{
// Question intents
{"how_question", "how do I implement auth?", IntentQuestion},
{"why_question", "why does this fail?", IntentError}, // "fail" triggers error first
{"what_question", "what is the purpose of this function?", IntentQuestion},
{"question_mark", "the handler for auth?", IntentQuestion},
{"explain", "explain the architecture", IntentQuestion},
// Error intents
{"error_word", "authentication error in login", IntentError},
{"bug_word", "bug in user registration", IntentError},
{"fix_word", "fix the memory leak", IntentError},
{"not_working", "login not working", IntentError},
{"crash", "application crash on startup", IntentError},
// Implementation intents
{"implement", "implement user authentication", IntentImplementation},
{"add_feature", "add new endpoint for users", IntentImplementation},
{"create", "create a handler for uploads", IntentImplementation},
{"function", "function to validate input", IntentImplementation},
// Architecture intents
{"architecture", "architecture of the system", IntentArchitecture},
{"design", "design pattern for observers", IntentArchitecture},
{"component", "component structure", IntentArchitecture},
{"flow", "data flow in the pipeline", IntentArchitecture},
// General intents
{"general", "user authentication", IntentGeneral},
{"empty", "", IntentGeneral},
{"simple", "database", IntentGeneral},
}
for _, tt := range tests {
s.Run(tt.name, func() {
result := s.expander.DetectIntent(tt.query)
s.Equal(tt.expected, result, "Query: %s", tt.query)
})
}
}
// TestExpand tests basic query expansion.
func (s *ExpanderSuite) TestExpand() {
ctx := context.Background()
cfg := DefaultConfig()
cfg.EnableVocabularyExpansion = false // Disable for unit test
tests := []struct {
name string
query string
minExpansions int
hasOriginal bool
expectedIntent QueryIntent
}{
{"question", "how do I implement auth", 1, true, IntentQuestion},
{"error", "fix the bug in login", 1, true, IntentError},
{"implementation", "implement user handler", 1, true, IntentImplementation},
{"architecture", "architecture design", 1, true, IntentArchitecture},
{"general", "database connection", 1, true, IntentGeneral},
{"empty", "", 0, false, IntentGeneral},
}
for _, tt := range tests {
s.Run(tt.name, func() {
expansions := s.expander.Expand(ctx, tt.query, cfg)
if tt.minExpansions == 0 {
s.Empty(expansions)
return
}
s.GreaterOrEqual(len(expansions), tt.minExpansions)
if tt.hasOriginal {
// First expansion should be the original
s.Equal(tt.query, expansions[0].Query)
s.Equal(1.0, expansions[0].Weight)
s.Equal("original", expansions[0].Source)
}
})
}
}
// TestExpandWithConfig tests expansion with custom config.
func (s *ExpanderSuite) TestExpandWithConfig() {
ctx := context.Background()
cfg := Config{
MaxExpansions: 2,
MinSimilarity: 0.7,
EnableVocabularyExpansion: false,
}
expansions := s.expander.Expand(ctx, "how to implement authentication", cfg)
s.LessOrEqual(len(expansions), cfg.MaxExpansions)
}
// TestExpandDeduplication tests that duplicates are removed.
func (s *ExpanderSuite) TestExpandDeduplication() {
ctx := context.Background()
cfg := DefaultConfig()
cfg.EnableVocabularyExpansion = false
// Query that might generate duplicate expansions
query := "how to fix authentication"
expansions := s.expander.Expand(ctx, query, cfg)
// Check for duplicates
seen := make(map[string]bool)
for _, exp := range expansions {
normalized := exp.Query
s.False(seen[normalized], "Duplicate expansion found: %s", exp.Query)
seen[normalized] = true
}
}
// TestExtractKeyTerms tests key term extraction.
func TestExtractKeyTerms(t *testing.T) {
tests := []struct {
name string
query string
expected []string
}{
{
name: "simple",
query: "user authentication handler",
expected: []string{"user", "authentication", "handler"},
},
{
name: "with_stop_words",
query: "how to implement the user login",
expected: []string{"implement", "user", "login"},
},
{
name: "with_punctuation",
query: "fix the bug, please!",
expected: []string{"fix", "bug", "please"},
},
{
name: "empty",
query: "",
expected: nil,
},
{
name: "only_stop_words",
query: "the a an is are",
expected: nil,
},
{
name: "short_words_filtered",
query: "a b c auth",
expected: []string{"auth"},
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
result := extractKeyTerms(tt.query)
if tt.expected == nil {
assert.Empty(t, result)
} else {
assert.Equal(t, tt.expected, result)
}
})
}
}
// TestMakeDeclarative tests question to declarative conversion.
func TestMakeDeclarative(t *testing.T) {
tests := []struct {
name string
query string
expected string
}{
{
name: "how_do_i",
query: "how do I implement auth?",
expected: "implement auth",
},
{
name: "how_to",
query: "how to fix the bug",
expected: "fix the bug",
},
{
name: "what_is",
query: "what is the purpose of this?",
expected: "the purpose of this",
},
{
name: "why_does",
query: "why does this fail?",
expected: "this fail",
},
{
name: "already_declarative",
query: "user authentication",
expected: "user authentication",
},
{
name: "question_mark_only",
query: "authentication?",
expected: "authentication",
},
{
name: "case_insensitive",
query: "How To Fix Auth?",
expected: "Fix Auth",
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
result := makeDeclarative(tt.query)
assert.Equal(t, tt.expected, result)
})
}
}
// TestDeduplicateExpansions tests deduplication.
func TestDeduplicateExpansions(t *testing.T) {
expansions := []ExpandedQuery{
{Query: "auth handler", Weight: 1.0},
{Query: "AUTH HANDLER", Weight: 0.8}, // Duplicate (case insensitive)
{Query: "auth handler ", Weight: 0.7}, // Duplicate (whitespace)
{Query: "user auth", Weight: 0.6},
}
result := deduplicateExpansions(expansions)
assert.Len(t, result, 2) // "auth handler" and "user auth"
assert.Equal(t, "auth handler", result[0].Query)
assert.Equal(t, 1.0, result[0].Weight) // First one preserved
assert.Equal(t, "user auth", result[1].Query)
}
// TestCosineSimilarity tests cosine similarity calculation.
func TestCosineSimilarity(t *testing.T) {
tests := []struct {
name string
a []float32
b []float32
expected float64
delta float64
}{
{
name: "identical_vectors",
a: []float32{1, 0, 0},
b: []float32{1, 0, 0},
expected: 1.0,
delta: 0.001,
},
{
name: "orthogonal_vectors",
a: []float32{1, 0, 0},
b: []float32{0, 1, 0},
expected: 0.0,
delta: 0.001,
},
{
name: "opposite_vectors",
a: []float32{1, 0, 0},
b: []float32{-1, 0, 0},
expected: -1.0,
delta: 0.001,
},
{
name: "similar_vectors",
a: []float32{1, 1, 0},
b: []float32{1, 0, 0},
expected: 0.707,
delta: 0.01,
},
{
name: "empty_vectors",
a: []float32{},
b: []float32{},
expected: 0.0,
delta: 0.001,
},
{
name: "different_lengths",
a: []float32{1, 0},
b: []float32{1, 0, 0},
expected: 0.0,
delta: 0.001,
},
{
name: "zero_vector",
a: []float32{0, 0, 0},
b: []float32{1, 1, 1},
expected: 0.0,
delta: 0.001,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
result := cosineSimilarity(tt.a, tt.b)
assert.InDelta(t, tt.expected, result, tt.delta)
})
}
}
// TestDefaultConfig tests default configuration.
func TestDefaultConfig(t *testing.T) {
cfg := DefaultConfig()
assert.Equal(t, 4, cfg.MaxExpansions)
assert.Equal(t, 0.5, cfg.MinSimilarity)
assert.True(t, cfg.EnableVocabularyExpansion)
}
// TestExpandedQueryStruct tests ExpandedQuery struct.
func TestExpandedQueryStruct(t *testing.T) {
eq := ExpandedQuery{
Query: "test query",
Weight: 0.85,
Source: "vocabulary:auth",
Intent: IntentQuestion,
}
assert.Equal(t, "test query", eq.Query)
assert.Equal(t, 0.85, eq.Weight)
assert.Equal(t, "vocabulary:auth", eq.Source)
assert.Equal(t, IntentQuestion, eq.Intent)
}
// TestVocabEntry tests VocabEntry struct.
func TestVocabEntry(t *testing.T) {
ve := VocabEntry{
Term: "authentication",
Weight: 0.9,
Source: "concept",
}
assert.Equal(t, "authentication", ve.Term)
assert.Equal(t, 0.9, ve.Weight)
assert.Equal(t, "concept", ve.Source)
}
// TestIntentConstants tests intent constant values.
func TestIntentConstants(t *testing.T) {
assert.Equal(t, QueryIntent("question"), IntentQuestion)
assert.Equal(t, QueryIntent("error"), IntentError)
assert.Equal(t, QueryIntent("implementation"), IntentImplementation)
assert.Equal(t, QueryIntent("architecture"), IntentArchitecture)
assert.Equal(t, QueryIntent("general"), IntentGeneral)
}
// TestTruncate tests the truncate helper.
func TestTruncate(t *testing.T) {
tests := []struct {
name string
input string
maxLen int
expected string
}{
{"short", "hello", 10, "hello"},
{"exact", "hello", 5, "hello"},
{"long", "hello world", 5, "hello..."},
{"empty", "", 10, ""},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
result := truncate(tt.input, tt.maxLen)
assert.Equal(t, tt.expected, result)
})
}
}
// TestSqrt tests the sqrt helper.
func TestSqrt(t *testing.T) {
tests := []struct {
input float64
expected float64
delta float64
}{
{4.0, 2.0, 0.001},
{9.0, 3.0, 0.001},
{16.0, 4.0, 0.001},
{2.0, 1.414, 0.01},
{0.0, 0.0, 0.001},
{-1.0, 0.0, 0.001},
}
for _, tt := range tests {
t.Run("", func(t *testing.T) {
result := sqrt(tt.input)
assert.InDelta(t, tt.expected, result, tt.delta)
})
}
}
// TestExpandByIntentError tests error intent expansion.
func (s *ExpanderSuite) TestExpandByIntentError() {
expansions := s.expander.expandByIntent("fix authentication bug", IntentError)
s.NotEmpty(expansions)
// Should have solution-oriented expansion
hasSolution := false
for _, exp := range expansions {
if exp.Source == "intent:solution" {
hasSolution = true
break
}
}
s.True(hasSolution)
}
// TestExpandByIntentQuestion tests question intent expansion.
func (s *ExpanderSuite) TestExpandByIntentQuestion() {
expansions := s.expander.expandByIntent("how do I implement auth", IntentQuestion)
s.NotEmpty(expansions)
// Should have declarative expansion
hasDeclarative := false
for _, exp := range expansions {
if exp.Source == "intent:declarative" {
hasDeclarative = true
break
}
}
s.True(hasDeclarative)
}
// TestExpandByIntentImplementation tests implementation intent expansion.
func (s *ExpanderSuite) TestExpandByIntentImplementation() {
expansions := s.expander.expandByIntent("implement user handler", IntentImplementation)
s.NotEmpty(expansions)
// Should have how expansion
hasHow := false
for _, exp := range expansions {
if exp.Source == "intent:how" {
hasHow = true
break
}
}
s.True(hasHow)
}
// TestExpandByIntentArchitecture tests architecture intent expansion.
func (s *ExpanderSuite) TestExpandByIntentArchitecture() {
expansions := s.expander.expandByIntent("system architecture design", IntentArchitecture)
s.NotEmpty(expansions)
// Should have design expansion
hasDesign := false
for _, exp := range expansions {
if exp.Source == "intent:design" {
hasDesign = true
break
}
}
s.True(hasDesign)
}
// TestExpandByIntentGeneral tests general intent returns no expansions.
func (s *ExpanderSuite) TestExpandByIntentGeneral() {
expansions := s.expander.expandByIntent("database", IntentGeneral)
s.Empty(expansions) // General intent doesn't add intent-based expansions
}
// TestEmptyVocabulary tests expansion with empty vocabulary.
func (s *ExpanderSuite) TestEmptyVocabulary() {
ctx := context.Background()
expansions := s.expander.expandByVocabulary(ctx, "test query", 0.5)
s.Empty(expansions)
}
// TestIntentPatternsExist tests that all intents have patterns.
func (s *ExpanderSuite) TestIntentPatternsExist() {
s.NotEmpty(s.expander.intentPatterns[IntentQuestion])
s.NotEmpty(s.expander.intentPatterns[IntentError])
s.NotEmpty(s.expander.intentPatterns[IntentImplementation])
s.NotEmpty(s.expander.intentPatterns[IntentArchitecture])
}
+29 -15
View File
@@ -35,20 +35,21 @@ func NewManager(
// SearchParams contains parameters for unified search. // SearchParams contains parameters for unified search.
type SearchParams struct { type SearchParams struct {
Query string Query string
Type string // "observations", "sessions", "prompts", or empty for all Type string // "observations", "sessions", "prompts", or empty for all
Project string Project string
ObsType string // Observation type filter ObsType string // Observation type filter
Concepts string Concepts string
Files string Files string
DateStart int64 DateStart int64
DateEnd int64 DateEnd int64
OrderBy string // "relevance", "date_desc", "date_asc" OrderBy string // "relevance", "date_desc", "date_asc"
Limit int Limit int
Offset int Offset int
Format string // "index" or "full" Format string // "index" or "full"
Scope string // "project", "global", or empty for project+global Scope string // "project", "global", or empty for project+global
IncludeGlobal bool // If true, include global observations along with project-scoped IncludeGlobal bool // If true, include global observations along with project-scoped
ExcludeSuperseded bool // If true, exclude observations that have been superseded
} }
// SearchResult represents a unified search result. // SearchResult represents a unified search result.
@@ -126,6 +127,10 @@ func (m *Manager) vectorSearch(ctx context.Context, params SearchParams) (*Unifi
obs, err := m.observationStore.GetObservationsByIDs(ctx, obsIDs, params.OrderBy, 0) obs, err := m.observationStore.GetObservationsByIDs(ctx, obsIDs, params.OrderBy, 0)
if err == nil { if err == nil {
for _, o := range obs { for _, o := range obs {
// Skip superseded observations when requested
if params.ExcludeSuperseded && o.IsSuperseded {
continue
}
results = append(results, m.observationToResult(o, params.Format)) results = append(results, m.observationToResult(o, params.Format))
} }
} }
@@ -167,7 +172,16 @@ func (m *Manager) filterSearch(ctx context.Context, params SearchParams) (*Unifi
// Search observations // Search observations
if params.Type == "" || params.Type == "observations" { if params.Type == "" || params.Type == "observations" {
obs, err := m.observationStore.GetRecentObservations(ctx, params.Project, params.Limit) var obs []*models.Observation
var err error
// Use active observations (excluding superseded) when requested
if params.ExcludeSuperseded {
obs, err = m.observationStore.GetActiveObservations(ctx, params.Project, params.Limit)
} else {
obs, err = m.observationStore.GetRecentObservations(ctx, params.Project, params.Limit)
}
if err == nil { if err == nil {
for _, o := range obs { for _, o := range obs {
results = append(results, m.observationToResult(o, params.Format)) results = append(results, m.observationToResult(o, params.Format))
+154 -4
View File
@@ -60,12 +60,15 @@ func (c *Client) AddDocuments(ctx context.Context, docs []Document) error {
return fmt.Errorf("generate embeddings: %w", err) return fmt.Errorf("generate embeddings: %w", err)
} }
// Insert into vectors table // Insert into vectors table with model version tracking
const insertQuery = ` const insertQuery = `
INSERT OR REPLACE INTO vectors (doc_id, embedding, sqlite_id, doc_type, field_type, project, scope) INSERT OR REPLACE INTO vectors (doc_id, embedding, sqlite_id, doc_type, field_type, project, scope, model_version)
VALUES (?, ?, ?, ?, ?, ?, ?) VALUES (?, ?, ?, ?, ?, ?, ?, ?)
` `
// Get current model version for tracking
modelVersion := c.embedSvc.Version()
tx, err := c.db.BeginTx(ctx, nil) tx, err := c.db.BeginTx(ctx, nil)
if err != nil { if err != nil {
return fmt.Errorf("begin transaction: %w", err) return fmt.Errorf("begin transaction: %w", err)
@@ -104,6 +107,7 @@ func (c *Client) AddDocuments(ctx context.Context, docs []Document) error {
fieldType, fieldType,
project, project,
scope, scope,
modelVersion,
) )
if err != nil { if err != nil {
return fmt.Errorf("insert document %s: %w", doc.ID, err) return fmt.Errorf("insert document %s: %w", doc.ID, err)
@@ -114,7 +118,7 @@ func (c *Client) AddDocuments(ctx context.Context, docs []Document) error {
return fmt.Errorf("commit transaction: %w", err) return fmt.Errorf("commit transaction: %w", err)
} }
log.Debug().Int("count", len(docs)).Msg("Added documents to sqlite-vec") log.Debug().Int("count", len(docs)).Str("model", modelVersion).Msg("Added documents to sqlite-vec")
return nil return nil
} }
@@ -212,6 +216,7 @@ func (c *Client) Query(ctx context.Context, query string, limit int, where map[s
return nil, fmt.Errorf("scan row: %w", err) return nil, fmt.Errorf("scan row: %w", err)
} }
r.Similarity = DistanceToSimilarity(r.Distance)
r.Metadata = map[string]any{ r.Metadata = map[string]any{
"sqlite_id": float64(sqliteID), // Keep as float64 for compatibility "sqlite_id": float64(sqliteID), // Keep as float64 for compatibility
"doc_type": docType.String, "doc_type": docType.String,
@@ -252,3 +257,148 @@ func truncateString(s string, maxLen int) string {
} }
return s[:maxLen] + "..." return s[:maxLen] + "..."
} }
// Count returns the total number of vectors in the store.
func (c *Client) Count(ctx context.Context) (int64, error) {
c.mu.Lock()
defer c.mu.Unlock()
var count int64
err := c.db.QueryRowContext(ctx, "SELECT COUNT(*) FROM vectors").Scan(&count)
if err != nil {
return 0, fmt.Errorf("count vectors: %w", err)
}
return count, nil
}
// ModelVersion returns the current embedding model version.
func (c *Client) ModelVersion() string {
return c.embedSvc.Version()
}
// NeedsRebuild checks if vectors need to be rebuilt due to model version change.
// Returns true if:
// - The vectors table is empty
// - Any vectors have a different model_version than the current model
func (c *Client) NeedsRebuild(ctx context.Context) (bool, string) {
c.mu.Lock()
defer c.mu.Unlock()
currentModel := c.embedSvc.Version()
// Check total count
var totalCount int64
err := c.db.QueryRowContext(ctx, "SELECT COUNT(*) FROM vectors").Scan(&totalCount)
if err != nil {
log.Warn().Err(err).Msg("Failed to count vectors for rebuild check")
return false, ""
}
if totalCount == 0 {
return true, "empty"
}
// Check for vectors with different model version
var staleCount int64
err = c.db.QueryRowContext(ctx,
"SELECT COUNT(*) FROM vectors WHERE model_version != ? OR model_version IS NULL",
currentModel,
).Scan(&staleCount)
if err != nil {
log.Warn().Err(err).Msg("Failed to count stale vectors")
return false, ""
}
if staleCount > 0 {
return true, fmt.Sprintf("model_mismatch:%d", staleCount)
}
return false, ""
}
// StaleVectorInfo contains information about a vector that needs rebuilding.
type StaleVectorInfo struct {
DocID string
SQLiteID int64
DocType string
FieldType string
Project string
Scope string
}
// GetStaleVectors returns doc_ids of vectors with mismatched or null model versions.
// This enables granular rebuild - only re-embedding documents that need updating.
func (c *Client) GetStaleVectors(ctx context.Context) ([]StaleVectorInfo, error) {
c.mu.Lock()
defer c.mu.Unlock()
currentModel := c.embedSvc.Version()
query := `
SELECT doc_id, sqlite_id, doc_type, field_type, project, scope
FROM vectors
WHERE model_version != ? OR model_version IS NULL
`
rows, err := c.db.QueryContext(ctx, query, currentModel)
if err != nil {
return nil, fmt.Errorf("query stale vectors: %w", err)
}
defer rows.Close()
var results []StaleVectorInfo
for rows.Next() {
var info StaleVectorInfo
var sqliteID sql.NullInt64
var docType, fieldType, project, scope sql.NullString
if err := rows.Scan(&info.DocID, &sqliteID, &docType, &fieldType, &project, &scope); err != nil {
return nil, fmt.Errorf("scan row: %w", err)
}
info.SQLiteID = sqliteID.Int64
info.DocType = docType.String
info.FieldType = fieldType.String
info.Project = project.String
info.Scope = scope.String
results = append(results, info)
}
if err = rows.Err(); err != nil {
return nil, fmt.Errorf("iterate rows: %w", err)
}
return results, nil
}
// DeleteVectorsByDocIDs removes vectors by their doc_ids.
// Used for granular rebuild - delete stale vectors before re-adding.
func (c *Client) DeleteVectorsByDocIDs(ctx context.Context, docIDs []string) error {
if len(docIDs) == 0 {
return nil
}
c.mu.Lock()
defer c.mu.Unlock()
// Build placeholder string
placeholders := make([]string, len(docIDs))
args := make([]interface{}, len(docIDs))
for i, id := range docIDs {
placeholders[i] = "?"
args[i] = id
}
// #nosec G201 -- Placeholders are "?" strings, actual values are parameterized via args
query := fmt.Sprintf("DELETE FROM vectors WHERE doc_id IN (%s)",
strings.Join(placeholders, ","))
_, err := c.db.ExecContext(ctx, query, args...)
if err != nil {
return fmt.Errorf("delete vectors by doc_ids: %w", err)
}
log.Debug().Int("count", len(docIDs)).Msg("Deleted stale vectors by doc_id")
return nil
}
+2 -1
View File
@@ -38,7 +38,8 @@ func testDB(t *testing.T) (*sql.DB, func()) {
doc_type TEXT, doc_type TEXT,
field_type TEXT, field_type TEXT,
project TEXT, project TEXT,
scope TEXT scope TEXT,
model_version TEXT
) )
`) `)
require.NoError(t, err) require.NoError(t, err)
+26 -3
View File
@@ -19,9 +19,32 @@ type Document struct {
// QueryResult represents a search result from vector search. // QueryResult represents a search result from vector search.
type QueryResult struct { type QueryResult struct {
ID string ID string
Distance float64 Distance float64
Metadata map[string]any Similarity float64 // 1.0 = identical, 0.0 = opposite (derived from distance)
Metadata map[string]any
}
// DistanceToSimilarity converts sqlite-vec cosine distance to similarity score.
// Cosine distance: 0 = identical, 2 = opposite
// Similarity: 1.0 = identical, 0.0 = opposite
func DistanceToSimilarity(distance float64) float64 {
return 1.0 - (distance / 2.0)
}
// FilterByThreshold filters results to only include those above the similarity threshold.
// If maxResults > 0, also caps the number of results.
func FilterByThreshold(results []QueryResult, threshold float64, maxResults int) []QueryResult {
var filtered []QueryResult
for _, r := range results {
if r.Similarity >= threshold {
filtered = append(filtered, r)
if maxResults > 0 && len(filtered) >= maxResults {
break
}
}
}
return filtered
} }
// ExtractedIDs contains SQLite IDs extracted from query results, grouped by document type. // ExtractedIDs contains SQLite IDs extracted from query results, grouped by document type.
+98
View File
@@ -240,3 +240,101 @@ func (s *Sync) DeleteUserPrompts(ctx context.Context, promptIDs []int64) error {
return nil return nil
} }
// SyncPattern syncs a single pattern to the vector store.
func (s *Sync) SyncPattern(ctx context.Context, pattern *models.Pattern) error {
docs := s.formatPatternDocs(pattern)
if len(docs) == 0 {
return nil
}
if err := s.client.AddDocuments(ctx, docs); err != nil {
return fmt.Errorf("add pattern docs: %w", err)
}
log.Debug().
Int64("patternId", pattern.ID).
Int("docCount", len(docs)).
Msg("Synced pattern to sqlite-vec")
return nil
}
// formatPatternDocs formats a pattern into vector documents.
func (s *Sync) formatPatternDocs(pattern *models.Pattern) []Document {
docs := make([]Document, 0, 3)
baseMetadata := map[string]any{
"sqlite_id": pattern.ID,
"doc_type": "pattern",
"pattern_type": string(pattern.Type),
"status": string(pattern.Status),
"scope": "global", // Patterns are always global
"frequency": pattern.Frequency,
"confidence": pattern.Confidence,
"created_at_epoch": pattern.CreatedAtEpoch,
}
if len(pattern.Signature) > 0 {
baseMetadata["signature"] = joinStrings(pattern.Signature, ",")
}
if len(pattern.Projects) > 0 {
baseMetadata["projects"] = joinStrings(pattern.Projects, ",")
}
// Pattern name as document
if pattern.Name != "" {
docs = append(docs, Document{
ID: fmt.Sprintf("pattern_%d_name", pattern.ID),
Content: pattern.Name,
Metadata: copyMetadata(baseMetadata, "field_type", "name"),
})
}
// Pattern description as document
if pattern.Description.Valid && pattern.Description.String != "" {
docs = append(docs, Document{
ID: fmt.Sprintf("pattern_%d_description", pattern.ID),
Content: pattern.Description.String,
Metadata: copyMetadata(baseMetadata, "field_type", "description"),
})
}
// Pattern recommendation as document
if pattern.Recommendation.Valid && pattern.Recommendation.String != "" {
docs = append(docs, Document{
ID: fmt.Sprintf("pattern_%d_recommendation", pattern.ID),
Content: pattern.Recommendation.String,
Metadata: copyMetadata(baseMetadata, "field_type", "recommendation"),
})
}
return docs
}
// DeletePatterns removes pattern documents from the vector store.
func (s *Sync) DeletePatterns(ctx context.Context, patternIDs []int64) error {
if len(patternIDs) == 0 {
return nil
}
// Generate all possible document IDs for these patterns
// Pattern: pattern_{id}_name, pattern_{id}_description, pattern_{id}_recommendation
ids := make([]string, 0, len(patternIDs)*3)
for _, patternID := range patternIDs {
ids = append(ids, fmt.Sprintf("pattern_%d_name", patternID))
ids = append(ids, fmt.Sprintf("pattern_%d_description", patternID))
ids = append(ids, fmt.Sprintf("pattern_%d_recommendation", patternID))
}
if err := s.client.DeleteDocuments(ctx, ids); err != nil {
return fmt.Errorf("delete pattern docs: %w", err)
}
log.Debug().
Int("patternCount", len(patternIDs)).
Msg("Deleted patterns from sqlite-vec")
return nil
}
+238 -4
View File
@@ -4,13 +4,18 @@ package worker
import ( import (
"context" "context"
"encoding/json" "encoding/json"
"fmt"
"net/http" "net/http"
"sort"
"strconv" "strconv"
"time" "time"
"github.com/go-chi/chi/v5" "github.com/go-chi/chi/v5"
"github.com/lukaszraczylo/claude-mnemonic/internal/db/sqlite" "github.com/lukaszraczylo/claude-mnemonic/internal/db/sqlite"
"github.com/lukaszraczylo/claude-mnemonic/internal/embedding"
"github.com/lukaszraczylo/claude-mnemonic/internal/privacy" "github.com/lukaszraczylo/claude-mnemonic/internal/privacy"
"github.com/lukaszraczylo/claude-mnemonic/internal/reranking"
"github.com/lukaszraczylo/claude-mnemonic/internal/search/expansion"
"github.com/lukaszraczylo/claude-mnemonic/internal/vector/sqlitevec" "github.com/lukaszraczylo/claude-mnemonic/internal/vector/sqlitevec"
"github.com/lukaszraczylo/claude-mnemonic/internal/worker/sdk" "github.com/lukaszraczylo/claude-mnemonic/internal/worker/sdk"
"github.com/lukaszraczylo/claude-mnemonic/internal/worker/session" "github.com/lukaszraczylo/claude-mnemonic/internal/worker/session"
@@ -641,6 +646,18 @@ func (s *Service) handleGetTypes(w http.ResponseWriter, r *http.Request) {
}) })
} }
// handleGetModels returns available embedding models.
func (s *Service) handleGetModels(w http.ResponseWriter, _ *http.Request) {
models := embedding.ListModels()
defaultModel := embedding.GetDefaultModel()
writeJSON(w, map[string]interface{}{
"models": models,
"default": defaultModel,
"current": s.embedSvc.Version(),
})
}
// handleGetStats returns worker statistics. // handleGetStats returns worker statistics.
func (s *Service) handleGetStats(w http.ResponseWriter, r *http.Request) { func (s *Service) handleGetStats(w http.ResponseWriter, r *http.Request) {
project := r.URL.Query().Get("project") project := r.URL.Query().Get("project")
@@ -658,6 +675,22 @@ func (s *Service) handleGetStats(w http.ResponseWriter, r *http.Request) {
"ready": s.ready.Load(), "ready": s.ready.Load(),
} }
// Add embedding model info
if s.embedSvc != nil {
response["embeddingModel"] = map[string]interface{}{
"name": s.embedSvc.Name(),
"version": s.embedSvc.Version(),
"dimensions": s.embedSvc.Dimensions(),
}
}
// Add vector count
if s.vectorClient != nil {
if count, err := s.vectorClient.Count(r.Context()); err == nil {
response["vectorCount"] = count
}
}
// Include project-specific observation count if project is specified // Include project-specific observation count if project is specified
if project != "" { if project != "" {
count, err := s.observationStore.GetObservationCount(r.Context(), project) count, err := s.observationStore.GetObservationCount(r.Context(), project)
@@ -715,15 +748,69 @@ func (s *Service) handleSearchByPrompt(w http.ResponseWriter, r *http.Request) {
var observations []*models.Observation var observations []*models.Observation
var err error var err error
var usedVector bool var usedVector bool
similarityScores := make(map[int64]float64) // Track similarity per observation
// Get threshold settings from config
threshold := s.config.ContextRelevanceThreshold
maxResults := s.config.ContextMaxPromptResults
// Generate expanded queries if query expander is available
var expandedQueries []expansion.ExpandedQuery
var detectedIntent string
if s.queryExpander != nil {
cfg := expansion.DefaultConfig()
cfg.EnableVocabularyExpansion = false // Vocabulary expansion is optional
expandedQueries = s.queryExpander.Expand(r.Context(), query, cfg)
if len(expandedQueries) > 0 {
detectedIntent = string(expandedQueries[0].Intent)
}
}
if len(expandedQueries) == 0 {
// Fallback to just the original query
expandedQueries = []expansion.ExpandedQuery{
{Query: query, Weight: 1.0, Source: "original"},
}
}
// Try vector search first if available // Try vector search first if available
if s.vectorClient != nil && s.vectorClient.IsConnected() { if s.vectorClient != nil && s.vectorClient.IsConnected() {
where := sqlitevec.BuildWhereFilter(sqlitevec.DocTypeObservation, "") where := sqlitevec.BuildWhereFilter(sqlitevec.DocTypeObservation, "")
vectorResults, vecErr := s.vectorClient.Query(r.Context(), query, limit*2, where) // Search with each expanded query and merge results
if vecErr == nil && len(vectorResults) > 0 { allVectorResults := make([]sqlitevec.QueryResult, 0)
queryWeights := make(map[string]float64) // Track weights for score merging
for _, eq := range expandedQueries {
vectorResults, vecErr := s.vectorClient.Query(r.Context(), eq.Query, limit*2, where)
if vecErr == nil && len(vectorResults) > 0 {
// Apply weight to similarity scores before merging
for i := range vectorResults {
vectorResults[i].Similarity *= eq.Weight
}
allVectorResults = append(allVectorResults, vectorResults...)
queryWeights[eq.Query] = eq.Weight
}
}
if len(allVectorResults) > 0 {
// Filter by relevance threshold before extracting IDs
// Use a slightly lower threshold for expanded queries
effectiveThreshold := threshold * 0.9 // Allow slightly lower scores for expanded queries
filteredResults := sqlitevec.FilterByThreshold(allVectorResults, effectiveThreshold, 0)
// Build similarity map for filtered results (keeping highest weighted score per observation)
for _, vr := range filteredResults {
if sqliteID, ok := vr.Metadata["sqlite_id"].(float64); ok {
id := int64(sqliteID)
// Keep the highest score for each observation
if existing, exists := similarityScores[id]; !exists || vr.Similarity > existing {
similarityScores[id] = vr.Similarity
}
}
}
// Extract observation IDs with project/scope filtering using shared helper // Extract observation IDs with project/scope filtering using shared helper
obsIDs := sqlitevec.ExtractObservationIDs(vectorResults, project) obsIDs := sqlitevec.ExtractObservationIDs(filteredResults, project)
if len(obsIDs) > 0 { if len(obsIDs) > 0 {
// Fetch full observations from SQLite // Fetch full observations from SQLite
@@ -773,23 +860,132 @@ func (s *Service) handleSearchByPrompt(w http.ResponseWriter, r *http.Request) {
freshObservations = append(freshObservations, obs) freshObservations = append(freshObservations, obs)
} }
// Apply cross-encoder reranking if available
var reranked bool
if s.reranker != nil && len(freshObservations) > 0 && usedVector {
// Build candidates from observations with their bi-encoder scores
candidates := make([]reranking.Candidate, len(freshObservations))
for i, obs := range freshObservations {
content := obs.Title.String
if obs.Narrative.Valid && obs.Narrative.String != "" {
content = content + " " + obs.Narrative.String
}
candidates[i] = reranking.Candidate{
ID: fmt.Sprintf("%d", obs.ID),
Content: content,
Score: similarityScores[obs.ID],
Metadata: map[string]any{"obs_idx": i},
}
}
// Rerank using cross-encoder - use pure mode or combined scores
var rerankResults []reranking.RerankResult
var rerankErr error
if s.config.RerankingPureMode {
rerankResults, rerankErr = s.reranker.RerankByScore(query, candidates, s.config.RerankingResults)
} else {
rerankResults, rerankErr = s.reranker.Rerank(query, candidates, s.config.RerankingResults)
}
if rerankErr != nil {
log.Warn().Err(rerankErr).Msg("Cross-encoder reranking failed, using original order")
} else if len(rerankResults) > 0 {
// Update similarity scores with reranked scores
for _, rr := range rerankResults {
if id, err := strconv.ParseInt(rr.ID, 10, 64); err == nil {
similarityScores[id] = rr.CombinedScore
}
}
// Reorder observations based on rerank results
reorderedObs := make([]*models.Observation, 0, len(rerankResults))
obsMap := make(map[int64]*models.Observation)
for _, obs := range freshObservations {
obsMap[obs.ID] = obs
}
for _, rr := range rerankResults {
if id, err := strconv.ParseInt(rr.ID, 10, 64); err == nil {
if obs, ok := obsMap[id]; ok {
reorderedObs = append(reorderedObs, obs)
}
}
}
freshObservations = reorderedObs
reranked = true
log.Debug().
Int("candidates", len(candidates)).
Int("returned", len(rerankResults)).
Msg("Cross-encoder reranking complete")
}
}
// Cluster similar observations to remove duplicates // Cluster similar observations to remove duplicates
clusteredObservations := clusterObservations(freshObservations, 0.4) clusteredObservations := clusterObservations(freshObservations, 0.4)
// Sort by similarity score (highest first) if we have scores and didn't rerank
if len(similarityScores) > 0 && len(clusteredObservations) > 0 && !reranked {
sort.Slice(clusteredObservations, func(i, j int) bool {
scoreI := similarityScores[clusteredObservations[i].ID]
scoreJ := similarityScores[clusteredObservations[j].ID]
return scoreI > scoreJ
})
}
// Apply max results cap if configured
if maxResults > 0 && len(clusteredObservations) > maxResults {
clusteredObservations = clusteredObservations[:maxResults]
}
// Record retrieval stats (no verification done, so verified=0, deleted=0) // Record retrieval stats (no verification done, so verified=0, deleted=0)
s.recordRetrievalStats(project, int64(len(clusteredObservations)), 0, 0, true) s.recordRetrievalStats(project, int64(len(clusteredObservations)), 0, 0, true)
// Increment retrieval counts for scoring (async, non-blocking)
if len(clusteredObservations) > 0 {
ids := make([]int64, len(clusteredObservations))
for i, obs := range clusteredObservations {
ids[i] = obs.ID
}
s.incrementRetrievalCounts(ids)
}
log.Info(). log.Info().
Str("project", project). Str("project", project).
Str("query", query). Str("query", query).
Str("intent", detectedIntent).
Int("expansions", len(expandedQueries)).
Int("found", len(clusteredObservations)). Int("found", len(clusteredObservations)).
Int("stale_excluded", staleCount). Int("stale_excluded", staleCount).
Float64("threshold", threshold).
Msg("Prompt-based observation search") Msg("Prompt-based observation search")
// Build response with similarity scores
obsWithScores := make([]map[string]interface{}, len(clusteredObservations))
for i, obs := range clusteredObservations {
obsMap := obs.ToMap()
if score, ok := similarityScores[obs.ID]; ok {
obsMap["similarity"] = score
}
obsWithScores[i] = obsMap
}
// Build expansion info for response
expansionInfo := make([]map[string]interface{}, len(expandedQueries))
for i, eq := range expandedQueries {
expansionInfo[i] = map[string]interface{}{
"query": eq.Query,
"weight": eq.Weight,
"source": eq.Source,
}
}
writeJSON(w, map[string]interface{}{ writeJSON(w, map[string]interface{}{
"project": project, "project": project,
"query": query, "query": query,
"observations": clusteredObservations, "intent": detectedIntent,
"expansions": expansionInfo,
"observations": obsWithScores,
"threshold": threshold,
"max_results": maxResults,
}) })
} }
@@ -857,6 +1053,15 @@ func (s *Service) handleContextInject(w http.ResponseWriter, r *http.Request) {
// Record retrieval stats (no verification done) // Record retrieval stats (no verification done)
s.recordRetrievalStats(project, int64(len(clusteredObservations)), 0, 0, false) s.recordRetrievalStats(project, int64(len(clusteredObservations)), 0, 0, false)
// Increment retrieval counts for scoring (async, non-blocking)
if len(clusteredObservations) > 0 {
ids := make([]int64, len(clusteredObservations))
for i, obs := range clusteredObservations {
ids[i] = obs.ID
}
s.incrementRetrievalCounts(ids)
}
log.Info(). log.Info().
Str("project", project). Str("project", project).
Int("total", len(observations)). Int("total", len(observations)).
@@ -1015,6 +1220,35 @@ func (s *Service) handleSelfCheck(w http.ResponseWriter, r *http.Request) {
} }
components = append(components, sseStatus) components = append(components, sseStatus)
// Check Cross-Encoder Reranker
rerankerStatus := ComponentHealth{Name: "Cross-Encoder Reranker", Status: "healthy"}
if !s.config.RerankingEnabled {
rerankerStatus.Status = "degraded"
rerankerStatus.Message = "Disabled in config"
if overall == "healthy" {
overall = "degraded"
}
} else if s.reranker == nil {
rerankerStatus.Status = "degraded"
rerankerStatus.Message = "Not initialized"
if overall == "healthy" {
overall = "degraded"
}
} else {
// Verify reranker is functional using Score
_, normalizedScore, err := s.reranker.Score("test query", "test document")
if err != nil {
rerankerStatus.Status = "unhealthy"
rerankerStatus.Message = fmt.Sprintf("Score check failed: %v", err)
if overall == "healthy" {
overall = "degraded"
}
} else {
rerankerStatus.Message = fmt.Sprintf("Score check passed (%.4f)", normalizedScore)
}
}
components = append(components, rerankerStatus)
// Calculate uptime // Calculate uptime
uptime := time.Since(s.startTime).Round(time.Second).String() uptime := time.Since(s.startTime).Round(time.Second).String()
+292
View File
@@ -0,0 +1,292 @@
// Package worker provides the main worker service for claude-mnemonic.
package worker
import (
"encoding/json"
"net/http"
"strconv"
"github.com/go-chi/chi/v5"
"github.com/lukaszraczylo/claude-mnemonic/pkg/models"
)
// DefaultPatternsLimit is the default number of patterns to return.
const DefaultPatternsLimit = 100
// handleGetPatterns returns all active patterns, optionally filtered by type or project.
func (s *Service) handleGetPatterns(w http.ResponseWriter, r *http.Request) {
s.initMu.RLock()
store := s.patternStore
s.initMu.RUnlock()
if store == nil {
http.Error(w, "pattern store not initialized", http.StatusServiceUnavailable)
return
}
// Parse query parameters
limit := DefaultPatternsLimit
if l := r.URL.Query().Get("limit"); l != "" {
if parsed, err := strconv.Atoi(l); err == nil && parsed > 0 {
limit = parsed
}
}
patternType := r.URL.Query().Get("type")
project := r.URL.Query().Get("project")
var patterns []*models.Pattern
var err error
if patternType != "" {
// Filter by type
patterns, err = store.GetPatternsByType(r.Context(), models.PatternType(patternType), limit)
} else if project != "" {
// Filter by project
patterns, err = store.GetPatternsByProject(r.Context(), project, limit)
} else {
// Get all active patterns
patterns, err = store.GetActivePatterns(r.Context(), limit)
}
if err != nil {
http.Error(w, err.Error(), http.StatusInternalServerError)
return
}
writeJSON(w, patterns)
}
// handleGetPatternStats returns aggregate statistics about patterns.
func (s *Service) handleGetPatternStats(w http.ResponseWriter, r *http.Request) {
s.initMu.RLock()
store := s.patternStore
s.initMu.RUnlock()
if store == nil {
http.Error(w, "pattern store not initialized", http.StatusServiceUnavailable)
return
}
stats, err := store.GetPatternStats(r.Context())
if err != nil {
http.Error(w, err.Error(), http.StatusInternalServerError)
return
}
writeJSON(w, stats)
}
// handleGetPatternByID returns a single pattern by ID.
func (s *Service) handleGetPatternByID(w http.ResponseWriter, r *http.Request) {
s.initMu.RLock()
store := s.patternStore
s.initMu.RUnlock()
if store == nil {
http.Error(w, "pattern store not initialized", http.StatusServiceUnavailable)
return
}
idStr := chi.URLParam(r, "id")
id, err := strconv.ParseInt(idStr, 10, 64)
if err != nil {
http.Error(w, "invalid pattern ID", http.StatusBadRequest)
return
}
pattern, err := store.GetPatternByID(r.Context(), id)
if err != nil {
http.Error(w, err.Error(), http.StatusInternalServerError)
return
}
if pattern == nil {
http.Error(w, "pattern not found", http.StatusNotFound)
return
}
writeJSON(w, pattern)
}
// handleGetPatternInsight returns a formatted insight string for a pattern.
func (s *Service) handleGetPatternInsight(w http.ResponseWriter, r *http.Request) {
s.initMu.RLock()
detector := s.patternDetector
s.initMu.RUnlock()
if detector == nil {
http.Error(w, "pattern detector not initialized", http.StatusServiceUnavailable)
return
}
idStr := chi.URLParam(r, "id")
id, err := strconv.ParseInt(idStr, 10, 64)
if err != nil {
http.Error(w, "invalid pattern ID", http.StatusBadRequest)
return
}
insight, err := detector.GetPatternInsight(r.Context(), id)
if err != nil {
http.Error(w, err.Error(), http.StatusInternalServerError)
return
}
writeJSON(w, map[string]string{"insight": insight})
}
// handleDeletePattern deletes a pattern by ID.
func (s *Service) handleDeletePattern(w http.ResponseWriter, r *http.Request) {
s.initMu.RLock()
store := s.patternStore
s.initMu.RUnlock()
if store == nil {
http.Error(w, "pattern store not initialized", http.StatusServiceUnavailable)
return
}
idStr := chi.URLParam(r, "id")
id, err := strconv.ParseInt(idStr, 10, 64)
if err != nil {
http.Error(w, "invalid pattern ID", http.StatusBadRequest)
return
}
if err := store.DeletePattern(r.Context(), id); err != nil {
http.Error(w, err.Error(), http.StatusInternalServerError)
return
}
writeJSON(w, map[string]string{"status": "deleted"})
}
// handleDeprecatePattern marks a pattern as deprecated.
func (s *Service) handleDeprecatePattern(w http.ResponseWriter, r *http.Request) {
s.initMu.RLock()
store := s.patternStore
s.initMu.RUnlock()
if store == nil {
http.Error(w, "pattern store not initialized", http.StatusServiceUnavailable)
return
}
idStr := chi.URLParam(r, "id")
id, err := strconv.ParseInt(idStr, 10, 64)
if err != nil {
http.Error(w, "invalid pattern ID", http.StatusBadRequest)
return
}
if err := store.MarkPatternDeprecated(r.Context(), id); err != nil {
http.Error(w, err.Error(), http.StatusInternalServerError)
return
}
writeJSON(w, map[string]string{"status": "deprecated"})
}
// MergePatternsRequest is the request body for merging patterns.
type MergePatternsRequest struct {
SourceID int64 `json:"source_id"`
TargetID int64 `json:"target_id"`
}
// handleSearchPatterns performs full-text search on patterns.
func (s *Service) handleSearchPatterns(w http.ResponseWriter, r *http.Request) {
s.initMu.RLock()
store := s.patternStore
s.initMu.RUnlock()
if store == nil {
http.Error(w, "pattern store not initialized", http.StatusServiceUnavailable)
return
}
query := r.URL.Query().Get("q")
if query == "" {
http.Error(w, "query parameter 'q' is required", http.StatusBadRequest)
return
}
limit := DefaultPatternsLimit
if l := r.URL.Query().Get("limit"); l != "" {
if parsed, err := strconv.Atoi(l); err == nil && parsed > 0 {
limit = parsed
}
}
patterns, err := store.SearchPatternsFTS(r.Context(), query, limit)
if err != nil {
http.Error(w, err.Error(), http.StatusInternalServerError)
return
}
writeJSON(w, patterns)
}
// handleGetPatternByName returns a pattern by its name.
func (s *Service) handleGetPatternByName(w http.ResponseWriter, r *http.Request) {
s.initMu.RLock()
store := s.patternStore
s.initMu.RUnlock()
if store == nil {
http.Error(w, "pattern store not initialized", http.StatusServiceUnavailable)
return
}
name := r.URL.Query().Get("name")
if name == "" {
http.Error(w, "query parameter 'name' is required", http.StatusBadRequest)
return
}
pattern, err := store.GetPatternByName(r.Context(), name)
if err != nil {
http.Error(w, err.Error(), http.StatusInternalServerError)
return
}
if pattern == nil {
http.Error(w, "pattern not found", http.StatusNotFound)
return
}
writeJSON(w, pattern)
}
// handleMergePatterns merges a source pattern into a target pattern.
func (s *Service) handleMergePatterns(w http.ResponseWriter, r *http.Request) {
s.initMu.RLock()
store := s.patternStore
s.initMu.RUnlock()
if store == nil {
http.Error(w, "pattern store not initialized", http.StatusServiceUnavailable)
return
}
var req MergePatternsRequest
if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
http.Error(w, err.Error(), http.StatusBadRequest)
return
}
if req.SourceID == 0 || req.TargetID == 0 {
http.Error(w, "source_id and target_id are required", http.StatusBadRequest)
return
}
if req.SourceID == req.TargetID {
http.Error(w, "source_id and target_id cannot be the same", http.StatusBadRequest)
return
}
if err := store.MergePatterns(r.Context(), req.SourceID, req.TargetID); err != nil {
http.Error(w, err.Error(), http.StatusInternalServerError)
return
}
writeJSON(w, map[string]string{"status": "merged"})
}
+174
View File
@@ -0,0 +1,174 @@
// Package worker provides the main worker service for claude-mnemonic.
package worker
import (
"net/http"
"strconv"
"github.com/go-chi/chi/v5"
"github.com/lukaszraczylo/claude-mnemonic/pkg/models"
)
// DefaultRelationsLimit is the default number of relations to return.
const DefaultRelationsLimit = 50
// handleGetRelations returns relations for an observation.
func (s *Service) handleGetRelations(w http.ResponseWriter, r *http.Request) {
idStr := chi.URLParam(r, "id")
id, err := strconv.ParseInt(idStr, 10, 64)
if err != nil {
http.Error(w, "invalid observation id", http.StatusBadRequest)
return
}
relations, err := s.relationStore.GetRelationsWithDetails(r.Context(), id)
if err != nil {
http.Error(w, err.Error(), http.StatusInternalServerError)
return
}
if relations == nil {
relations = []*models.RelationWithDetails{}
}
writeJSON(w, relations)
}
// handleGetRelationGraph returns the relation graph for an observation.
func (s *Service) handleGetRelationGraph(w http.ResponseWriter, r *http.Request) {
idStr := chi.URLParam(r, "id")
id, err := strconv.ParseInt(idStr, 10, 64)
if err != nil {
http.Error(w, "invalid observation id", http.StatusBadRequest)
return
}
// Get depth parameter (default 2)
depth := 2
if depthStr := r.URL.Query().Get("depth"); depthStr != "" {
if d, err := strconv.Atoi(depthStr); err == nil && d > 0 && d <= 5 {
depth = d
}
}
graph, err := s.relationStore.GetRelationGraph(r.Context(), id, depth)
if err != nil {
http.Error(w, err.Error(), http.StatusInternalServerError)
return
}
writeJSON(w, graph)
}
// handleGetRelatedObservations returns observations related to a given one.
func (s *Service) handleGetRelatedObservations(w http.ResponseWriter, r *http.Request) {
idStr := chi.URLParam(r, "id")
id, err := strconv.ParseInt(idStr, 10, 64)
if err != nil {
http.Error(w, "invalid observation id", http.StatusBadRequest)
return
}
// Get minimum confidence parameter (default 0.4)
minConfidence := 0.4
if confStr := r.URL.Query().Get("min_confidence"); confStr != "" {
if c, err := strconv.ParseFloat(confStr, 64); err == nil && c >= 0 && c <= 1 {
minConfidence = c
}
}
// Get related observation IDs
relatedIDs, err := s.relationStore.GetRelatedObservationIDs(r.Context(), id, minConfidence)
if err != nil {
http.Error(w, err.Error(), http.StatusInternalServerError)
return
}
if len(relatedIDs) == 0 {
writeJSON(w, []*models.Observation{})
return
}
// Fetch full observations
observations, err := s.observationStore.GetObservationsByIDs(r.Context(), relatedIDs, "importance", 50)
if err != nil {
http.Error(w, err.Error(), http.StatusInternalServerError)
return
}
if observations == nil {
observations = []*models.Observation{}
}
writeJSON(w, observations)
}
// handleGetRelationsByType returns all relations of a specific type.
func (s *Service) handleGetRelationsByType(w http.ResponseWriter, r *http.Request) {
relType := chi.URLParam(r, "type")
// Validate relation type
validType := false
for _, t := range models.AllRelationTypes {
if string(t) == relType {
validType = true
break
}
}
if !validType {
http.Error(w, "invalid relation type", http.StatusBadRequest)
return
}
limit := DefaultRelationsLimit
if limitStr := r.URL.Query().Get("limit"); limitStr != "" {
if l, err := strconv.Atoi(limitStr); err == nil && l > 0 && l <= 100 {
limit = l
}
}
relations, err := s.relationStore.GetRelationsByType(r.Context(), models.RelationType(relType), limit)
if err != nil {
http.Error(w, err.Error(), http.StatusInternalServerError)
return
}
if relations == nil {
relations = []*models.ObservationRelation{}
}
writeJSON(w, relations)
}
// handleGetRelationStats returns statistics about relations.
func (s *Service) handleGetRelationStats(w http.ResponseWriter, r *http.Request) {
// Get total relation count
totalCount, err := s.relationStore.GetTotalRelationCount(r.Context())
if err != nil {
http.Error(w, err.Error(), http.StatusInternalServerError)
return
}
// Get high confidence relations count
highConfRelations, err := s.relationStore.GetHighConfidenceRelations(r.Context(), 0.7, 1000)
if err != nil {
http.Error(w, err.Error(), http.StatusInternalServerError)
return
}
// Count by relation type
typeCounts := make(map[string]int)
for _, t := range models.AllRelationTypes {
relations, err := s.relationStore.GetRelationsByType(r.Context(), t, 1000)
if err == nil {
typeCounts[string(t)] = len(relations)
}
}
writeJSON(w, map[string]interface{}{
"total_count": totalCount,
"high_confidence": len(highConfRelations),
"by_type": typeCounts,
"min_confidence_used": 0.4,
})
}
+354
View File
@@ -0,0 +1,354 @@
// Package worker provides the main worker service for claude-mnemonic.
package worker
import (
"context"
"encoding/json"
"net/http"
"strconv"
"time"
"github.com/go-chi/chi/v5"
"github.com/lukaszraczylo/claude-mnemonic/pkg/models"
)
// FeedbackRequest represents a user feedback submission.
type FeedbackRequest struct {
Feedback int `json:"feedback"` // -1 (thumbs down), 0 (neutral), 1 (thumbs up)
}
// handleObservationFeedback handles user feedback (thumbs up/down) for an observation.
// POST /api/observations/{id}/feedback
func (s *Service) handleObservationFeedback(w http.ResponseWriter, r *http.Request) {
// Parse observation ID
idStr := chi.URLParam(r, "id")
id, err := strconv.ParseInt(idStr, 10, 64)
if err != nil {
http.Error(w, "invalid observation id", http.StatusBadRequest)
return
}
// Parse request body
var req FeedbackRequest
if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
http.Error(w, "invalid request body", http.StatusBadRequest)
return
}
// Validate feedback value
if req.Feedback < -1 || req.Feedback > 1 {
http.Error(w, "feedback must be -1, 0, or 1", http.StatusBadRequest)
return
}
// Get required components
s.initMu.RLock()
observationStore := s.observationStore
scoreCalculator := s.scoreCalculator
s.initMu.RUnlock()
if observationStore == nil {
http.Error(w, "service not ready", http.StatusServiceUnavailable)
return
}
// Update feedback in database
if err := observationStore.UpdateObservationFeedback(r.Context(), id, req.Feedback); err != nil {
http.Error(w, "failed to update feedback", http.StatusInternalServerError)
return
}
// Recalculate score immediately if calculator is available
var newScore float64
if scoreCalculator != nil {
obs, err := observationStore.GetObservationByID(r.Context(), id)
if err == nil && obs != nil {
obs.UserFeedback = req.Feedback // Apply the new feedback
newScore = scoreCalculator.Calculate(obs, time.Now())
if err := observationStore.UpdateImportanceScore(r.Context(), id, newScore); err != nil {
// Log but don't fail - feedback was recorded
// Score will be updated on next recalculation cycle
}
}
}
// Broadcast update via SSE
s.sseBroadcaster.Broadcast(map[string]interface{}{
"type": "observation_feedback",
"id": id,
"feedback": req.Feedback,
"score": newScore,
})
writeJSON(w, map[string]interface{}{
"status": "ok",
"id": id,
"feedback": req.Feedback,
"score": newScore,
})
}
// handleGetScoringStats returns scoring statistics and configuration.
// GET /api/scoring/stats
func (s *Service) handleGetScoringStats(w http.ResponseWriter, r *http.Request) {
project := r.URL.Query().Get("project")
s.initMu.RLock()
observationStore := s.observationStore
recalculator := s.recalculator
s.initMu.RUnlock()
if observationStore == nil {
http.Error(w, "service not ready", http.StatusServiceUnavailable)
return
}
// Get feedback statistics
feedbackStats, err := observationStore.GetObservationFeedbackStats(r.Context(), project)
if err != nil {
http.Error(w, "failed to get feedback stats", http.StatusInternalServerError)
return
}
response := map[string]interface{}{
"feedback": feedbackStats,
}
// Add recalculator stats if available
if recalculator != nil {
response["recalculator"] = recalculator.GetStats()
}
writeJSON(w, response)
}
// handleGetTopObservations returns the highest-scoring observations.
// GET /api/observations/top
func (s *Service) handleGetTopObservations(w http.ResponseWriter, r *http.Request) {
limit := parseIntParam(r, "limit", 10)
project := r.URL.Query().Get("project")
s.initMu.RLock()
observationStore := s.observationStore
s.initMu.RUnlock()
if observationStore == nil {
http.Error(w, "service not ready", http.StatusServiceUnavailable)
return
}
observations, err := observationStore.GetTopScoringObservations(r.Context(), project, limit)
if err != nil {
http.Error(w, "failed to get top observations", http.StatusInternalServerError)
return
}
if observations == nil {
observations = []*models.Observation{}
}
writeJSON(w, observations)
}
// handleGetMostRetrieved returns the most frequently retrieved observations.
// GET /api/observations/most-retrieved
func (s *Service) handleGetMostRetrieved(w http.ResponseWriter, r *http.Request) {
limit := parseIntParam(r, "limit", 10)
project := r.URL.Query().Get("project")
s.initMu.RLock()
observationStore := s.observationStore
s.initMu.RUnlock()
if observationStore == nil {
http.Error(w, "service not ready", http.StatusServiceUnavailable)
return
}
observations, err := observationStore.GetMostRetrievedObservations(r.Context(), project, limit)
if err != nil {
http.Error(w, "failed to get most retrieved observations", http.StatusInternalServerError)
return
}
if observations == nil {
observations = []*models.Observation{}
}
writeJSON(w, observations)
}
// handleExplainScore returns a breakdown of how an observation's score was calculated.
// GET /api/observations/{id}/score
func (s *Service) handleExplainScore(w http.ResponseWriter, r *http.Request) {
// Parse observation ID
idStr := chi.URLParam(r, "id")
id, err := strconv.ParseInt(idStr, 10, 64)
if err != nil {
http.Error(w, "invalid observation id", http.StatusBadRequest)
return
}
s.initMu.RLock()
observationStore := s.observationStore
scoreCalculator := s.scoreCalculator
s.initMu.RUnlock()
if observationStore == nil || scoreCalculator == nil {
http.Error(w, "service not ready", http.StatusServiceUnavailable)
return
}
// Get observation
obs, err := observationStore.GetObservationByID(r.Context(), id)
if err != nil {
http.Error(w, "failed to get observation", http.StatusInternalServerError)
return
}
if obs == nil {
http.Error(w, "observation not found", http.StatusNotFound)
return
}
// Calculate score components
components := scoreCalculator.CalculateComponents(obs, time.Now())
writeJSON(w, map[string]interface{}{
"id": id,
"components": components,
"config": scoreCalculator.GetConfig(),
})
}
// handleUpdateConceptWeight updates a concept weight.
// PUT /api/scoring/concepts/{concept}
func (s *Service) handleUpdateConceptWeight(w http.ResponseWriter, r *http.Request) {
concept := chi.URLParam(r, "concept")
if concept == "" {
http.Error(w, "concept is required", http.StatusBadRequest)
return
}
var req struct {
Weight float64 `json:"weight"`
}
if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
http.Error(w, "invalid request body", http.StatusBadRequest)
return
}
// Validate weight
if req.Weight < 0 || req.Weight > 1 {
http.Error(w, "weight must be between 0 and 1", http.StatusBadRequest)
return
}
s.initMu.RLock()
observationStore := s.observationStore
recalculator := s.recalculator
s.initMu.RUnlock()
if observationStore == nil {
http.Error(w, "service not ready", http.StatusServiceUnavailable)
return
}
// Update in database
if err := observationStore.UpdateConceptWeight(r.Context(), concept, req.Weight); err != nil {
http.Error(w, "failed to update concept weight", http.StatusInternalServerError)
return
}
// Refresh concept weights in recalculator
if recalculator != nil {
if err := recalculator.RefreshConceptWeights(r.Context()); err != nil {
// Log but don't fail - weight was saved
}
}
writeJSON(w, map[string]interface{}{
"status": "ok",
"concept": concept,
"weight": req.Weight,
})
}
// handleGetConceptWeights returns all concept weights.
// GET /api/scoring/concepts
func (s *Service) handleGetConceptWeights(w http.ResponseWriter, r *http.Request) {
s.initMu.RLock()
observationStore := s.observationStore
s.initMu.RUnlock()
if observationStore == nil {
http.Error(w, "service not ready", http.StatusServiceUnavailable)
return
}
weights, err := observationStore.GetConceptWeights(r.Context())
if err != nil {
http.Error(w, "failed to get concept weights", http.StatusInternalServerError)
return
}
writeJSON(w, weights)
}
// handleTriggerRecalculation triggers an immediate score recalculation.
// POST /api/scoring/recalculate
func (s *Service) handleTriggerRecalculation(w http.ResponseWriter, r *http.Request) {
s.initMu.RLock()
recalculator := s.recalculator
s.initMu.RUnlock()
if recalculator == nil {
http.Error(w, "recalculator not available", http.StatusServiceUnavailable)
return
}
// Run recalculation in background
go func() {
if err := recalculator.RecalculateNow(r.Context()); err != nil {
// Log error but don't block response
}
}()
writeJSON(w, map[string]string{"status": "recalculation triggered"})
}
// parseIntParam parses an integer query parameter with a default value.
func parseIntParam(r *http.Request, name string, defaultVal int) int {
if val := r.URL.Query().Get(name); val != "" {
if parsed, err := strconv.Atoi(val); err == nil && parsed > 0 {
return parsed
}
}
return defaultVal
}
// incrementRetrievalCounts increments retrieval counts for observations.
// Called after search results are returned to track popularity.
func (s *Service) incrementRetrievalCounts(ids []int64) {
if len(ids) == 0 {
return
}
s.initMu.RLock()
store := s.observationStore
s.initMu.RUnlock()
if store == nil {
return
}
// Increment in background to not block response
go func() {
// Create a new context with timeout for the background operation
ctx, cancel := context.WithTimeout(context.Background(), 3*time.Second)
defer cancel()
if err := store.IncrementRetrievalCount(ctx, ids); err != nil {
// Log but don't fail - this is a background operation
}
}()
}
+490 -1
View File
@@ -15,6 +15,10 @@ import (
"github.com/lukaszraczylo/claude-mnemonic/internal/config" "github.com/lukaszraczylo/claude-mnemonic/internal/config"
"github.com/lukaszraczylo/claude-mnemonic/internal/db/sqlite" "github.com/lukaszraczylo/claude-mnemonic/internal/db/sqlite"
"github.com/lukaszraczylo/claude-mnemonic/internal/embedding" "github.com/lukaszraczylo/claude-mnemonic/internal/embedding"
"github.com/lukaszraczylo/claude-mnemonic/internal/pattern"
"github.com/lukaszraczylo/claude-mnemonic/internal/reranking"
"github.com/lukaszraczylo/claude-mnemonic/internal/scoring"
"github.com/lukaszraczylo/claude-mnemonic/internal/search/expansion"
"github.com/lukaszraczylo/claude-mnemonic/internal/update" "github.com/lukaszraczylo/claude-mnemonic/internal/update"
"github.com/lukaszraczylo/claude-mnemonic/internal/vector/sqlitevec" "github.com/lukaszraczylo/claude-mnemonic/internal/vector/sqlitevec"
"github.com/lukaszraczylo/claude-mnemonic/internal/watcher" "github.com/lukaszraczylo/claude-mnemonic/internal/watcher"
@@ -64,6 +68,12 @@ type Service struct {
observationStore *sqlite.ObservationStore observationStore *sqlite.ObservationStore
summaryStore *sqlite.SummaryStore summaryStore *sqlite.SummaryStore
promptStore *sqlite.PromptStore promptStore *sqlite.PromptStore
conflictStore *sqlite.ConflictStore
patternStore *sqlite.PatternStore
relationStore *sqlite.RelationStore
// Pattern detection
patternDetector *pattern.Detector
// Domain services // Domain services
sessionManager *session.Manager sessionManager *session.Manager
@@ -75,6 +85,16 @@ type Service struct {
vectorClient *sqlitevec.Client vectorClient *sqlitevec.Client
vectorSync *sqlitevec.Sync vectorSync *sqlitevec.Sync
// Cross-encoder reranking (for improved search relevance)
reranker *reranking.Service
// Query expansion (for improved search recall)
queryExpander *expansion.Expander
// Importance scoring
scoreCalculator *scoring.Calculator
recalculator *scoring.Recalculator
// HTTP server // HTTP server
router *chi.Mux router *chi.Mux
server *http.Server server *http.Server
@@ -177,6 +197,15 @@ func (s *Service) initializeAsync() {
observationStore := sqlite.NewObservationStore(store) observationStore := sqlite.NewObservationStore(store)
summaryStore := sqlite.NewSummaryStore(store) summaryStore := sqlite.NewSummaryStore(store)
promptStore := sqlite.NewPromptStore(store) promptStore := sqlite.NewPromptStore(store)
conflictStore := sqlite.NewConflictStore(store)
patternStore := sqlite.NewPatternStore(store)
relationStore := sqlite.NewRelationStore(store)
// Enable conflict detection by linking stores
observationStore.SetConflictStore(conflictStore)
// Enable relation detection by linking stores
observationStore.SetRelationStore(relationStore)
// Create session manager // Create session manager
sessionManager := session.NewManager(sessionStore) sessionManager := session.NewManager(sessionStore)
@@ -186,6 +215,8 @@ func (s *Service) initializeAsync() {
var vectorClient *sqlitevec.Client var vectorClient *sqlitevec.Client
var vectorSync *sqlitevec.Sync var vectorSync *sqlitevec.Sync
var reranker *reranking.Service
emb, err := embedding.NewService() emb, err := embedding.NewService()
if err != nil { if err != nil {
log.Warn().Err(err).Msg("Embedding service creation failed - vector search disabled") log.Warn().Err(err).Msg("Embedding service creation failed - vector search disabled")
@@ -200,8 +231,32 @@ func (s *Service) initializeAsync() {
} else { } else {
vectorClient = client vectorClient = client
vectorSync = sqlitevec.NewSync(client) vectorSync = sqlitevec.NewSync(client)
log.Info().Msg("sqlite-vec vector search enabled") log.Info().
Str("model", embedSvc.Version()).
Msg("sqlite-vec vector search enabled")
} }
// Create cross-encoder reranking service if enabled
if s.config.RerankingEnabled {
rerankCfg := reranking.DefaultConfig()
if s.config.RerankingAlpha > 0 && s.config.RerankingAlpha <= 1 {
rerankCfg.Alpha = s.config.RerankingAlpha
}
ranker, err := reranking.NewService(rerankCfg)
if err != nil {
log.Warn().Err(err).Msg("Cross-encoder reranking service creation failed - reranking disabled")
} else {
reranker = ranker
log.Info().
Float64("alpha", rerankCfg.Alpha).
Msg("Cross-encoder reranking enabled")
}
}
// Create query expander for improved search recall
s.queryExpander = expansion.NewExpander(embedSvc)
log.Info().Msg("Query expansion enabled")
} }
// Create SDK processor (optional - will be nil if Claude CLI not available) // Create SDK processor (optional - will be nil if Claude CLI not available)
@@ -225,11 +280,38 @@ func (s *Service) initializeAsync() {
s.observationStore = observationStore s.observationStore = observationStore
s.summaryStore = summaryStore s.summaryStore = summaryStore
s.promptStore = promptStore s.promptStore = promptStore
s.conflictStore = conflictStore
s.patternStore = patternStore
s.relationStore = relationStore
s.sessionManager = sessionManager s.sessionManager = sessionManager
s.processor = processor s.processor = processor
s.embedSvc = embedSvc s.embedSvc = embedSvc
s.vectorClient = vectorClient s.vectorClient = vectorClient
s.vectorSync = vectorSync s.vectorSync = vectorSync
s.reranker = reranker
s.initMu.Unlock()
// Initialize pattern detector
patternDetector := pattern.NewDetector(patternStore, observationStore, pattern.DefaultConfig())
// Set pattern sync callback if vector sync is available
if vectorSync != nil {
patternDetector.SetSyncFunc(func(p *models.Pattern) {
if err := vectorSync.SyncPattern(s.ctx, p); err != nil {
log.Warn().Err(err).Int64("id", p.ID).Msg("Failed to sync pattern to sqlite-vec")
}
})
// Set cleanup callback for pattern deletions
patternStore.SetCleanupFunc(func(ctx context.Context, deletedIDs []int64) {
if err := vectorSync.DeletePatterns(ctx, deletedIDs); err != nil {
log.Warn().Err(err).Ints64("ids", deletedIDs).Msg("Failed to delete patterns from sqlite-vec")
}
})
}
s.initMu.Lock()
s.patternDetector = patternDetector
s.initMu.Unlock() s.initMu.Unlock()
// Set vector sync callbacks on processor if both are available // Set vector sync callbacks on processor if both are available
@@ -238,6 +320,22 @@ func (s *Service) initializeAsync() {
if err := vectorSync.SyncObservation(s.ctx, obs); err != nil { if err := vectorSync.SyncObservation(s.ctx, obs); err != nil {
log.Warn().Err(err).Int64("id", obs.ID).Msg("Failed to sync observation to sqlite-vec") log.Warn().Err(err).Int64("id", obs.ID).Msg("Failed to sync observation to sqlite-vec")
} }
// Trigger pattern detection for the new observation
if patternDetector != nil {
go func(observation *models.Observation) {
detectCtx, cancel := context.WithTimeout(context.Background(), 10*time.Second)
defer cancel()
if result, err := patternDetector.AnalyzeObservation(detectCtx, observation); err != nil {
log.Warn().Err(err).Int64("obs_id", observation.ID).Msg("Pattern detection failed")
} else if result.MatchedPattern != nil {
log.Debug().
Int64("pattern_id", result.MatchedPattern.ID).
Str("pattern_name", result.MatchedPattern.Name).
Bool("is_new", result.IsNewPattern).
Msg("Pattern matched for observation")
}
}(obs)
}
}) })
processor.SetSyncSummaryFunc(func(summary *models.SessionSummary) { processor.SetSyncSummaryFunc(func(summary *models.SessionSummary) {
if err := vectorSync.SyncSummary(s.ctx, summary); err != nil { if err := vectorSync.SyncSummary(s.ctx, summary); err != nil {
@@ -282,6 +380,37 @@ func (s *Service) initializeAsync() {
}) })
}) })
// Initialize importance scoring system
scoringConfig := models.DefaultScoringConfig()
// Load concept weights from database if available
if weights, err := observationStore.GetConceptWeights(s.ctx); err == nil && len(weights) > 0 {
scoringConfig.ConceptWeights = weights
log.Info().Int("count", len(weights)).Msg("Loaded concept weights from database")
}
scoreCalculator := scoring.NewCalculator(scoringConfig)
recalculator := scoring.NewRecalculator(observationStore, scoreCalculator, log.Logger)
s.initMu.Lock()
s.scoreCalculator = scoreCalculator
s.recalculator = recalculator
s.initMu.Unlock()
// Start background recalculator
s.wg.Add(1)
go func() {
defer s.wg.Done()
recalculator.Start(s.ctx)
}()
log.Info().Msg("Importance scoring system initialized")
// Start pattern detector background analysis
if patternDetector != nil {
patternDetector.Start()
log.Info().Msg("Pattern recognition engine started")
}
// Mark as ready // Mark as ready
s.ready.Store(true) s.ready.Store(true)
log.Info().Msg("Async initialization complete - service ready") log.Info().Msg("Async initialization complete - service ready")
@@ -294,6 +423,27 @@ func (s *Service) initializeAsync() {
// Start file watchers for auto-recreation on deletion // Start file watchers for auto-recreation on deletion
s.startWatchers() s.startWatchers()
// Check if vectors need rebuilding (empty or model version mismatch) and trigger background rebuild
if vectorClient != nil && vectorSync != nil {
needsRebuild, reason := vectorClient.NeedsRebuild(s.ctx)
if needsRebuild {
log.Info().
Str("reason", reason).
Str("model", vectorClient.ModelVersion()).
Msg("Vector rebuild required")
if reason == "empty" {
// Full rebuild - vectors table is empty
s.wg.Add(1)
go s.rebuildAllVectors(observationStore, summaryStore, promptStore, vectorSync)
} else {
// Granular rebuild - only rebuild vectors with mismatched model versions
s.wg.Add(1)
go s.rebuildStaleVectors(observationStore, summaryStore, promptStore, vectorClient, vectorSync)
}
}
}
} }
// startWatchers initializes and starts file watchers for database and config. // startWatchers initializes and starts file watchers for database and config.
@@ -384,6 +534,15 @@ func (s *Service) reinitializeDatabase() {
observationStore := sqlite.NewObservationStore(store) observationStore := sqlite.NewObservationStore(store)
summaryStore := sqlite.NewSummaryStore(store) summaryStore := sqlite.NewSummaryStore(store)
promptStore := sqlite.NewPromptStore(store) promptStore := sqlite.NewPromptStore(store)
conflictStore := sqlite.NewConflictStore(store)
patternStore := sqlite.NewPatternStore(store)
relationStore := sqlite.NewRelationStore(store)
// Enable conflict detection by linking stores
observationStore.SetConflictStore(conflictStore)
// Enable relation detection by linking stores
observationStore.SetRelationStore(relationStore)
// Create new session manager // Create new session manager
sessionManager := session.NewManager(sessionStore) sessionManager := session.NewManager(sessionStore)
@@ -393,6 +552,8 @@ func (s *Service) reinitializeDatabase() {
var vectorClient *sqlitevec.Client var vectorClient *sqlitevec.Client
var vectorSync *sqlitevec.Sync var vectorSync *sqlitevec.Sync
var reranker *reranking.Service
emb, err := embedding.NewService() emb, err := embedding.NewService()
if err != nil { if err != nil {
log.Warn().Err(err).Msg("Embedding service creation failed after reinit") log.Warn().Err(err).Msg("Embedding service creation failed after reinit")
@@ -408,6 +569,34 @@ func (s *Service) reinitializeDatabase() {
vectorSync = sqlitevec.NewSync(client) vectorSync = sqlitevec.NewSync(client)
log.Info().Msg("sqlite-vec reconnected after reinit") log.Info().Msg("sqlite-vec reconnected after reinit")
} }
// Recreate cross-encoder reranking service if enabled
if s.config.RerankingEnabled {
rerankCfg := reranking.DefaultConfig()
if s.config.RerankingAlpha > 0 && s.config.RerankingAlpha <= 1 {
rerankCfg.Alpha = s.config.RerankingAlpha
}
ranker, err := reranking.NewService(rerankCfg)
if err != nil {
log.Warn().Err(err).Msg("Cross-encoder reranking service creation failed after reinit")
} else {
reranker = ranker
log.Info().Msg("Cross-encoder reranking reconnected after reinit")
}
}
// Recreate query expander
s.queryExpander = expansion.NewExpander(embedSvc)
log.Info().Msg("Query expansion reconnected after reinit")
}
// Close old reranker if exists
s.initMu.RLock()
oldReranker := s.reranker
s.initMu.RUnlock()
if oldReranker != nil {
_ = oldReranker.Close()
} }
// Recreate SDK processor with new stores // Recreate SDK processor with new stores
@@ -422,6 +611,30 @@ func (s *Service) reinitializeDatabase() {
}) })
} }
// Stop old pattern detector if it exists
if s.patternDetector != nil {
s.patternDetector.Stop()
}
// Create new pattern detector
patternDetector := pattern.NewDetector(patternStore, observationStore, pattern.DefaultConfig())
// Set pattern sync callback if vector sync is available
if vectorSync != nil {
patternDetector.SetSyncFunc(func(p *models.Pattern) {
if err := vectorSync.SyncPattern(s.ctx, p); err != nil {
log.Warn().Err(err).Int64("id", p.ID).Msg("Failed to sync pattern to sqlite-vec")
}
})
// Set cleanup callback for pattern deletions
patternStore.SetCleanupFunc(func(ctx context.Context, deletedIDs []int64) {
if err := vectorSync.DeletePatterns(ctx, deletedIDs); err != nil {
log.Warn().Err(err).Ints64("ids", deletedIDs).Msg("Failed to delete patterns from sqlite-vec")
}
})
}
// Atomically swap all components // Atomically swap all components
s.initMu.Lock() s.initMu.Lock()
s.store = store s.store = store
@@ -429,20 +642,44 @@ func (s *Service) reinitializeDatabase() {
s.observationStore = observationStore s.observationStore = observationStore
s.summaryStore = summaryStore s.summaryStore = summaryStore
s.promptStore = promptStore s.promptStore = promptStore
s.conflictStore = conflictStore
s.patternStore = patternStore
s.relationStore = relationStore
s.patternDetector = patternDetector
s.sessionManager = sessionManager s.sessionManager = sessionManager
s.processor = processor s.processor = processor
s.embedSvc = embedSvc s.embedSvc = embedSvc
s.vectorClient = vectorClient s.vectorClient = vectorClient
s.vectorSync = vectorSync s.vectorSync = vectorSync
s.reranker = reranker
s.initError = nil s.initError = nil
s.initMu.Unlock() s.initMu.Unlock()
// Start pattern detector
patternDetector.Start()
// Set vector sync callbacks on processor if both are available // Set vector sync callbacks on processor if both are available
if processor != nil && vectorSync != nil { if processor != nil && vectorSync != nil {
processor.SetSyncObservationFunc(func(obs *models.Observation) { processor.SetSyncObservationFunc(func(obs *models.Observation) {
if err := vectorSync.SyncObservation(s.ctx, obs); err != nil { if err := vectorSync.SyncObservation(s.ctx, obs); err != nil {
log.Warn().Err(err).Int64("id", obs.ID).Msg("Failed to sync observation to sqlite-vec") log.Warn().Err(err).Int64("id", obs.ID).Msg("Failed to sync observation to sqlite-vec")
} }
// Trigger pattern detection for the new observation
if patternDetector != nil {
go func(observation *models.Observation) {
detectCtx, cancel := context.WithTimeout(context.Background(), 10*time.Second)
defer cancel()
if result, err := patternDetector.AnalyzeObservation(detectCtx, observation); err != nil {
log.Warn().Err(err).Int64("obs_id", observation.ID).Msg("Pattern detection failed")
} else if result.MatchedPattern != nil {
log.Debug().
Int64("pattern_id", result.MatchedPattern.ID).
Str("pattern_name", result.MatchedPattern.Name).
Bool("is_new", result.IsNewPattern).
Msg("Pattern matched for observation")
}
}(obs)
}
}) })
processor.SetSyncSummaryFunc(func(summary *models.SessionSummary) { processor.SetSyncSummaryFunc(func(summary *models.SessionSummary) {
if err := vectorSync.SyncSummary(s.ctx, summary); err != nil { if err := vectorSync.SyncSummary(s.ctx, summary); err != nil {
@@ -565,6 +802,210 @@ func (s *Service) processStaleQueue() {
} }
} }
// rebuildAllVectors rebuilds all vectors from observations, summaries, and prompts.
// Called when the vectors table is empty (e.g., after migration 20 drops all vectors).
func (s *Service) rebuildAllVectors(
observationStore *sqlite.ObservationStore,
summaryStore *sqlite.SummaryStore,
promptStore *sqlite.PromptStore,
vectorSync *sqlitevec.Sync,
) {
defer s.wg.Done()
log.Info().Msg("Starting full vector rebuild...")
start := time.Now()
var totalSynced int
var syncErrors int
// Rebuild observations
observations, err := observationStore.GetAllObservations(s.ctx)
if err != nil {
log.Error().Err(err).Msg("Failed to fetch observations for vector rebuild")
} else {
for _, obs := range observations {
if err := vectorSync.SyncObservation(s.ctx, obs); err != nil {
log.Warn().Err(err).Int64("id", obs.ID).Msg("Failed to sync observation during rebuild")
syncErrors++
} else {
totalSynced++
}
}
log.Info().Int("count", len(observations)).Msg("Rebuilt observation vectors")
}
// Rebuild summaries
summaries, err := summaryStore.GetAllSummaries(s.ctx)
if err != nil {
log.Error().Err(err).Msg("Failed to fetch summaries for vector rebuild")
} else {
for _, summary := range summaries {
if err := vectorSync.SyncSummary(s.ctx, summary); err != nil {
log.Warn().Err(err).Int64("id", summary.ID).Msg("Failed to sync summary during rebuild")
syncErrors++
} else {
totalSynced++
}
}
log.Info().Int("count", len(summaries)).Msg("Rebuilt summary vectors")
}
// Rebuild user prompts
prompts, err := promptStore.GetAllPrompts(s.ctx)
if err != nil {
log.Error().Err(err).Msg("Failed to fetch prompts for vector rebuild")
} else {
for _, prompt := range prompts {
if err := vectorSync.SyncUserPrompt(s.ctx, prompt); err != nil {
log.Warn().Err(err).Int64("id", prompt.ID).Msg("Failed to sync prompt during rebuild")
syncErrors++
} else {
totalSynced++
}
}
log.Info().Int("count", len(prompts)).Msg("Rebuilt prompt vectors")
}
elapsed := time.Since(start)
log.Info().
Int("total_synced", totalSynced).
Int("errors", syncErrors).
Dur("elapsed", elapsed).
Msg("Full vector rebuild complete")
}
// rebuildStaleVectors rebuilds only vectors with mismatched or unknown model versions.
// This is more efficient than rebuilding all vectors when only some need updating.
func (s *Service) rebuildStaleVectors(
observationStore *sqlite.ObservationStore,
summaryStore *sqlite.SummaryStore,
promptStore *sqlite.PromptStore,
vectorClient *sqlitevec.Client,
vectorSync *sqlitevec.Sync,
) {
defer s.wg.Done()
log.Info().Msg("Starting granular vector rebuild for stale vectors...")
start := time.Now()
// Get all stale vectors
staleVectors, err := vectorClient.GetStaleVectors(s.ctx)
if err != nil {
log.Error().Err(err).Msg("Failed to get stale vectors")
return
}
if len(staleVectors) == 0 {
log.Info().Msg("No stale vectors found")
return
}
log.Info().Int("stale_count", len(staleVectors)).Msg("Found stale vectors to rebuild")
// Group stale vectors by doc_type and sqlite_id for efficient lookup
staleObsIDs := make(map[int64]bool)
staleSummaryIDs := make(map[int64]bool)
stalePromptIDs := make(map[int64]bool)
staleDocIDs := make([]string, 0, len(staleVectors))
for _, sv := range staleVectors {
staleDocIDs = append(staleDocIDs, sv.DocID)
switch sv.DocType {
case "observation":
staleObsIDs[sv.SQLiteID] = true
case "summary":
staleSummaryIDs[sv.SQLiteID] = true
case "prompt":
stalePromptIDs[sv.SQLiteID] = true
}
}
// Delete stale vectors before re-syncing
if err := vectorClient.DeleteVectorsByDocIDs(s.ctx, staleDocIDs); err != nil {
log.Error().Err(err).Msg("Failed to delete stale vectors")
return
}
var totalSynced int
var syncErrors int
// Rebuild stale observations
if len(staleObsIDs) > 0 {
ids := make([]int64, 0, len(staleObsIDs))
for id := range staleObsIDs {
ids = append(ids, id)
}
observations, err := observationStore.GetObservationsByIDs(s.ctx, ids, "date_desc", 0)
if err != nil {
log.Error().Err(err).Msg("Failed to fetch observations for rebuild")
} else {
for _, obs := range observations {
if err := vectorSync.SyncObservation(s.ctx, obs); err != nil {
log.Warn().Err(err).Int64("id", obs.ID).Msg("Failed to sync observation during rebuild")
syncErrors++
} else {
totalSynced++
}
}
log.Info().Int("count", len(observations)).Msg("Rebuilt stale observation vectors")
}
}
// Rebuild stale summaries
if len(staleSummaryIDs) > 0 {
ids := make([]int64, 0, len(staleSummaryIDs))
for id := range staleSummaryIDs {
ids = append(ids, id)
}
summaries, err := summaryStore.GetSummariesByIDs(s.ctx, ids, "date_desc", 0)
if err != nil {
log.Error().Err(err).Msg("Failed to fetch summaries for rebuild")
} else {
for _, summary := range summaries {
if err := vectorSync.SyncSummary(s.ctx, summary); err != nil {
log.Warn().Err(err).Int64("id", summary.ID).Msg("Failed to sync summary during rebuild")
syncErrors++
} else {
totalSynced++
}
}
log.Info().Int("count", len(summaries)).Msg("Rebuilt stale summary vectors")
}
}
// Rebuild stale prompts
if len(stalePromptIDs) > 0 {
ids := make([]int64, 0, len(stalePromptIDs))
for id := range stalePromptIDs {
ids = append(ids, id)
}
prompts, err := promptStore.GetPromptsByIDs(s.ctx, ids, "date_desc", 0)
if err != nil {
log.Error().Err(err).Msg("Failed to fetch prompts for rebuild")
} else {
for _, prompt := range prompts {
if err := vectorSync.SyncUserPrompt(s.ctx, prompt); err != nil {
log.Warn().Err(err).Int64("id", prompt.ID).Msg("Failed to sync prompt during rebuild")
syncErrors++
} else {
totalSynced++
}
}
log.Info().Int("count", len(prompts)).Msg("Rebuilt stale prompt vectors")
}
}
elapsed := time.Since(start)
log.Info().
Int("total_synced", totalSynced).
Int("errors", syncErrors).
Dur("elapsed", elapsed).
Msg("Granular vector rebuild complete")
}
// verifyStaleObservation verifies a single stale observation in the background. // verifyStaleObservation verifies a single stale observation in the background.
func (s *Service) verifyStaleObservation(req staleVerifyRequest) { func (s *Service) verifyStaleObservation(req staleVerifyRequest) {
// Wait for service to be ready // Wait for service to be ready
@@ -667,11 +1108,42 @@ func (s *Service) setupRoutes() {
r.Get("/api/stats", s.handleGetStats) r.Get("/api/stats", s.handleGetStats)
r.Get("/api/stats/retrieval", s.handleGetRetrievalStats) r.Get("/api/stats/retrieval", s.handleGetRetrievalStats)
r.Get("/api/types", s.handleGetTypes) r.Get("/api/types", s.handleGetTypes)
r.Get("/api/models", s.handleGetModels)
// Observation scoring and feedback routes
r.Post("/api/observations/{id}/feedback", s.handleObservationFeedback)
r.Get("/api/observations/{id}/score", s.handleExplainScore)
r.Get("/api/observations/top", s.handleGetTopObservations)
r.Get("/api/observations/most-retrieved", s.handleGetMostRetrieved)
// Scoring configuration routes
r.Get("/api/scoring/stats", s.handleGetScoringStats)
r.Get("/api/scoring/concepts", s.handleGetConceptWeights)
r.Put("/api/scoring/concepts/{concept}", s.handleUpdateConceptWeight)
r.Post("/api/scoring/recalculate", s.handleTriggerRecalculation)
// Context injection // Context injection
r.Get("/api/context/count", s.handleContextCount) r.Get("/api/context/count", s.handleContextCount)
r.Get("/api/context/inject", s.handleContextInject) r.Get("/api/context/inject", s.handleContextInject)
r.Get("/api/context/search", s.handleSearchByPrompt) r.Get("/api/context/search", s.handleSearchByPrompt)
// Pattern routes
r.Get("/api/patterns", s.handleGetPatterns)
r.Get("/api/patterns/stats", s.handleGetPatternStats)
r.Get("/api/patterns/search", s.handleSearchPatterns)
r.Get("/api/patterns/by-name", s.handleGetPatternByName)
r.Get("/api/patterns/{id}", s.handleGetPatternByID)
r.Get("/api/patterns/{id}/insight", s.handleGetPatternInsight)
r.Delete("/api/patterns/{id}", s.handleDeletePattern)
r.Post("/api/patterns/{id}/deprecate", s.handleDeprecatePattern)
r.Post("/api/patterns/merge", s.handleMergePatterns)
// Relation routes (knowledge graph)
r.Get("/api/relations/stats", s.handleGetRelationStats)
r.Get("/api/relations/type/{type}", s.handleGetRelationsByType)
r.Get("/api/observations/{id}/relations", s.handleGetRelations)
r.Get("/api/observations/{id}/graph", s.handleGetRelationGraph)
r.Get("/api/observations/{id}/related", s.handleGetRelatedObservations)
}) })
} }
@@ -894,6 +1366,16 @@ func (s *Service) Shutdown(ctx context.Context) error {
_ = s.configWatcher.Stop() _ = s.configWatcher.Stop()
} }
// Stop background recalculator
if s.recalculator != nil {
s.recalculator.Stop()
}
// Stop pattern detector
if s.patternDetector != nil {
s.patternDetector.Stop()
}
// Shutdown all sessions // Shutdown all sessions
s.sessionManager.ShutdownAll(ctx) s.sessionManager.ShutdownAll(ctx)
@@ -904,6 +1386,13 @@ func (s *Service) Shutdown(ctx context.Context) error {
} }
} }
// Close reranking service
if s.reranker != nil {
if err := s.reranker.Close(); err != nil {
log.Error().Err(err).Msg("Reranking service close error")
}
}
// Close embedding service // Close embedding service
if s.embedSvc != nil { if s.embedSvc != nil {
if err := s.embedSvc.Close(); err != nil { if err := s.embedSvc.Close(); err != nil {
+258
View File
@@ -0,0 +1,258 @@
// Package models contains domain models for claude-mnemonic.
package models
import (
"regexp"
"strings"
"time"
)
// ConflictType represents the type of conflict between observations.
type ConflictType string
const (
// ConflictSuperseded means newer observation supersedes older one (same topic, updated info).
ConflictSuperseded ConflictType = "superseded"
// ConflictContradicts means observations contain contradictory information.
ConflictContradicts ConflictType = "contradicts"
// ConflictOutdatedPattern means an outdated pattern/practice was identified.
ConflictOutdatedPattern ConflictType = "outdated_pattern"
)
// ConflictResolution indicates which observation to prefer.
type ConflictResolution string
const (
// ResolutionPreferNewer means prefer the newer observation.
ResolutionPreferNewer ConflictResolution = "prefer_newer"
// ResolutionPreferOlder means prefer the older observation (rare).
ResolutionPreferOlder ConflictResolution = "prefer_older"
// ResolutionManual means manual review is needed.
ResolutionManual ConflictResolution = "manual"
)
// ObservationConflict tracks conflicting observations.
type ObservationConflict struct {
ID int64 `db:"id" json:"id"`
NewerObsID int64 `db:"newer_obs_id" json:"newer_obs_id"`
OlderObsID int64 `db:"older_obs_id" json:"older_obs_id"`
ConflictType ConflictType `db:"conflict_type" json:"conflict_type"`
Resolution ConflictResolution `db:"resolution" json:"resolution"`
Reason string `db:"reason" json:"reason"`
DetectedAt string `db:"detected_at" json:"detected_at"`
DetectedAtEpoch int64 `db:"detected_at_epoch" json:"detected_at_epoch"`
Resolved bool `db:"resolved" json:"resolved"`
ResolvedAt *string `db:"resolved_at" json:"resolved_at,omitempty"`
}
// ConflictDetectionResult contains the result of conflict detection.
type ConflictDetectionResult struct {
HasConflict bool
Type ConflictType
Resolution ConflictResolution
Reason string
OlderObsIDs []int64 // IDs of observations that conflict with the new one
}
// NewObservationConflict creates a new conflict record.
func NewObservationConflict(newerID, olderID int64, conflictType ConflictType, resolution ConflictResolution, reason string) *ObservationConflict {
now := time.Now()
return &ObservationConflict{
NewerObsID: newerID,
OlderObsID: olderID,
ConflictType: conflictType,
Resolution: resolution,
Reason: reason,
DetectedAt: now.Format(time.RFC3339),
DetectedAtEpoch: now.UnixMilli(),
Resolved: false,
}
}
// CorrectionPatterns contains regex patterns that indicate explicit corrections.
var CorrectionPatterns = []*regexp.Regexp{
regexp.MustCompile(`(?i)\bactually[,\s]+that\s+was\s+wrong\b`),
regexp.MustCompile(`(?i)\bactually[,\s]+that's\s+(wrong|incorrect|not\s+right)\b`),
regexp.MustCompile(`(?i)\bpreviously\s+(said|mentioned|noted)\s+.*\s+but\b`),
regexp.MustCompile(`(?i)\bcorrection:\s*`),
regexp.MustCompile(`(?i)\bignore\s+(the\s+)?(previous|earlier)\b`),
regexp.MustCompile(`(?i)\bdisregard\s+(the\s+)?(previous|earlier)\b`),
regexp.MustCompile(`(?i)\bwas\s+(wrong|incorrect|mistaken)\b`),
regexp.MustCompile(`(?i)\bturns\s+out\s+.*(wrong|incorrect|not\s+the\s+case)\b`),
regexp.MustCompile(`(?i)\b(supersedes|replaces|overrides)\s+(the\s+)?(previous|earlier|old)\b`),
regexp.MustCompile(`(?i)\b(don't|do\s+not)\s+use\s+.*\s+anymore\b`),
regexp.MustCompile(`(?i)\bno\s+longer\s+(valid|applicable|correct|recommended)\b`),
regexp.MustCompile(`(?i)\bdeprecated\s+(approach|method|pattern|way)\b`),
regexp.MustCompile(`(?i)\bshould\s+have\s+(been|used)\b.*instead\b`),
regexp.MustCompile(`(?i)\bbetter\s+(approach|way|method|solution)\s+is\b`),
}
// OpposingChangePatterns detects add/remove conflicts.
var OpposingChangePatterns = map[string]string{
"add": "remove",
"added": "removed",
"create": "delete",
"created": "deleted",
"enable": "disable",
"enabled": "disabled",
"include": "exclude",
"allow": "deny",
"permit": "block",
}
// DetectExplicitCorrection checks if text contains explicit correction language.
func DetectExplicitCorrection(text string) (bool, string) {
for _, pattern := range CorrectionPatterns {
if match := pattern.FindString(text); match != "" {
return true, "Explicit correction detected: " + match
}
}
return false, ""
}
// DetectOpposingFileChanges checks if two observations have opposing changes on the same file.
func DetectOpposingFileChanges(newer, older *Observation) (bool, string) {
// Check for overlapping modified files
newerFiles := make(map[string]bool)
for _, f := range newer.FilesModified {
newerFiles[f] = true
}
var overlappingFiles []string
for _, f := range older.FilesModified {
if newerFiles[f] {
overlappingFiles = append(overlappingFiles, f)
}
}
if len(overlappingFiles) == 0 {
return false, ""
}
// Check for opposing action words in titles/narratives
newerText := strings.ToLower(newer.Title.String + " " + newer.Narrative.String)
olderText := strings.ToLower(older.Title.String + " " + older.Narrative.String)
for action, opposite := range OpposingChangePatterns {
if (strings.Contains(newerText, action) && strings.Contains(olderText, opposite)) ||
(strings.Contains(newerText, opposite) && strings.Contains(olderText, action)) {
return true, "Opposing changes on files: " + strings.Join(overlappingFiles, ", ")
}
}
return false, ""
}
// DetectConceptTagMismatch checks if observations have same concepts but different recommendations.
func DetectConceptTagMismatch(newer, older *Observation) (bool, string) {
// Find overlapping concepts
newerConcepts := make(map[string]bool)
for _, c := range newer.Concepts {
newerConcepts[c] = true
}
var overlapping []string
for _, c := range older.Concepts {
if newerConcepts[c] {
overlapping = append(overlapping, c)
}
}
if len(overlapping) == 0 {
return false, ""
}
// Check if same file was modified and concepts overlap
// This suggests the newer observation may update the approach
newerFiles := make(map[string]bool)
for _, f := range newer.FilesModified {
newerFiles[f] = true
}
for _, f := range older.FilesModified {
if newerFiles[f] {
// Same file modified with same concepts - likely an update
return true, "Same concepts (" + strings.Join(overlapping, ", ") + ") with overlapping file changes"
}
}
return false, ""
}
// DetectConflict performs comprehensive conflict detection between a new observation
// and an existing one. Returns detection result.
func DetectConflict(newer, older *Observation) *ConflictDetectionResult {
result := &ConflictDetectionResult{
HasConflict: false,
}
// 1. Check for explicit correction language in newer observation
if newer.Narrative.Valid {
if isCorrection, reason := DetectExplicitCorrection(newer.Narrative.String); isCorrection {
result.HasConflict = true
result.Type = ConflictContradicts
result.Resolution = ResolutionPreferNewer
result.Reason = reason
result.OlderObsIDs = append(result.OlderObsIDs, older.ID)
return result
}
}
// Check title as well
if newer.Title.Valid {
if isCorrection, reason := DetectExplicitCorrection(newer.Title.String); isCorrection {
result.HasConflict = true
result.Type = ConflictContradicts
result.Resolution = ResolutionPreferNewer
result.Reason = reason
result.OlderObsIDs = append(result.OlderObsIDs, older.ID)
return result
}
}
// 2. Check for opposing file changes
if isOpposing, reason := DetectOpposingFileChanges(newer, older); isOpposing {
result.HasConflict = true
result.Type = ConflictSuperseded
result.Resolution = ResolutionPreferNewer
result.Reason = reason
result.OlderObsIDs = append(result.OlderObsIDs, older.ID)
return result
}
// 3. Check for concept tag mismatches with same files
if isMismatch, reason := DetectConceptTagMismatch(newer, older); isMismatch {
result.HasConflict = true
result.Type = ConflictSuperseded
result.Resolution = ResolutionPreferNewer
result.Reason = reason
result.OlderObsIDs = append(result.OlderObsIDs, older.ID)
return result
}
return result
}
// DetectConflictsWithExisting checks a new observation against a list of existing observations.
// Returns all detected conflicts.
func DetectConflictsWithExisting(newer *Observation, existing []*Observation) []*ConflictDetectionResult {
var results []*ConflictDetectionResult
for _, older := range existing {
// Skip self-comparison
if older.ID == newer.ID {
continue
}
// Only compare within same project (or both global)
if newer.Project != older.Project && newer.Scope != ScopeGlobal && older.Scope != ScopeGlobal {
continue
}
result := DetectConflict(newer, older)
if result.HasConflict {
results = append(results, result)
}
}
return results
}
+421
View File
@@ -0,0 +1,421 @@
// Package models contains domain models for claude-mnemonic.
package models
import (
"database/sql"
"testing"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/suite"
)
// ConflictSuite is a test suite for conflict detection operations.
type ConflictSuite struct {
suite.Suite
}
func TestConflictSuite(t *testing.T) {
suite.Run(t, new(ConflictSuite))
}
// TestConflictTypeConstants tests conflict type constants.
func (s *ConflictSuite) TestConflictTypeConstants() {
s.Equal(ConflictType("superseded"), ConflictSuperseded)
s.Equal(ConflictType("contradicts"), ConflictContradicts)
s.Equal(ConflictType("outdated_pattern"), ConflictOutdatedPattern)
}
// TestResolutionConstants tests resolution constants.
func (s *ConflictSuite) TestResolutionConstants() {
s.Equal(ConflictResolution("prefer_newer"), ResolutionPreferNewer)
s.Equal(ConflictResolution("prefer_older"), ResolutionPreferOlder)
s.Equal(ConflictResolution("manual"), ResolutionManual)
}
// TestNewObservationConflict tests conflict creation.
func (s *ConflictSuite) TestNewObservationConflict() {
conflict := NewObservationConflict(2, 1, ConflictSuperseded, ResolutionPreferNewer, "Test reason")
s.Equal(int64(2), conflict.NewerObsID)
s.Equal(int64(1), conflict.OlderObsID)
s.Equal(ConflictSuperseded, conflict.ConflictType)
s.Equal(ResolutionPreferNewer, conflict.Resolution)
s.Equal("Test reason", conflict.Reason)
s.False(conflict.Resolved)
s.NotEmpty(conflict.DetectedAt)
s.Greater(conflict.DetectedAtEpoch, int64(0))
}
// TestDetectExplicitCorrection_TableDriven tests explicit correction detection.
func (s *ConflictSuite) TestDetectExplicitCorrection_TableDriven() {
tests := []struct {
name string
text string
expectMatch bool
expectPattern string
}{
{
name: "actually that was wrong",
text: "Actually, that was wrong - we should use a different approach",
expectMatch: true,
expectPattern: "actually, that was wrong",
},
{
name: "correction prefix",
text: "Correction: the previous implementation had a bug",
expectMatch: true,
expectPattern: "correction:",
},
{
name: "ignore previous",
text: "Please ignore the previous recommendation",
expectMatch: true,
expectPattern: "ignore",
},
{
name: "disregard earlier",
text: "Disregard the earlier suggestion, it was flawed",
expectMatch: true,
expectPattern: "disregard",
},
{
name: "was wrong",
text: "The original approach was wrong",
expectMatch: true,
expectPattern: "was wrong",
},
{
name: "no longer valid",
text: "This method is no longer valid after the refactor",
expectMatch: true,
expectPattern: "no longer valid",
},
{
name: "deprecated approach",
text: "This is a deprecated approach that should not be used",
expectMatch: true,
expectPattern: "deprecated approach",
},
{
name: "better approach is",
text: "A better approach is to use the new API",
expectMatch: true,
expectPattern: "better approach is",
},
{
name: "normal text - no correction",
text: "This is a normal observation about the code",
expectMatch: false,
},
{
name: "empty text",
text: "",
expectMatch: false,
},
}
for _, tt := range tests {
s.Run(tt.name, func() {
found, reason := DetectExplicitCorrection(tt.text)
s.Equal(tt.expectMatch, found)
if tt.expectMatch {
s.Contains(reason, "Explicit correction detected")
}
})
}
}
// TestDetectOpposingFileChanges_TableDriven tests opposing file change detection.
func (s *ConflictSuite) TestDetectOpposingFileChanges_TableDriven() {
tests := []struct {
name string
newerObs *Observation
olderObs *Observation
expectConflict bool
}{
{
name: "add then remove - conflict",
newerObs: &Observation{
Title: sql.NullString{String: "Remove authentication middleware", Valid: true},
Narrative: sql.NullString{String: "Removed the auth middleware from handlers", Valid: true},
FilesModified: []string{"middleware.go", "handler.go"},
},
olderObs: &Observation{
Title: sql.NullString{String: "Add authentication middleware", Valid: true},
Narrative: sql.NullString{String: "Added auth middleware to secure endpoints", Valid: true},
FilesModified: []string{"middleware.go", "handler.go"},
},
expectConflict: true,
},
{
name: "enable then disable - conflict",
newerObs: &Observation{
Title: sql.NullString{String: "Disable caching feature", Valid: true},
Narrative: sql.NullString{String: "Disabled caching due to issues", Valid: true},
FilesModified: []string{"cache.go"},
},
olderObs: &Observation{
Title: sql.NullString{String: "Enable caching feature", Valid: true},
Narrative: sql.NullString{String: "Enabled caching for performance", Valid: true},
FilesModified: []string{"cache.go"},
},
expectConflict: true,
},
{
name: "different files - no conflict",
newerObs: &Observation{
Title: sql.NullString{String: "Remove old code", Valid: true},
Narrative: sql.NullString{String: "Removed deprecated functions", Valid: true},
FilesModified: []string{"old.go"},
},
olderObs: &Observation{
Title: sql.NullString{String: "Add new feature", Valid: true},
Narrative: sql.NullString{String: "Added new functions", Valid: true},
FilesModified: []string{"new.go"},
},
expectConflict: false,
},
{
name: "same files but no opposing keywords - no conflict",
newerObs: &Observation{
Title: sql.NullString{String: "Update handler logic", Valid: true},
Narrative: sql.NullString{String: "Updated the handler implementation", Valid: true},
FilesModified: []string{"handler.go"},
},
olderObs: &Observation{
Title: sql.NullString{String: "Fix handler bug", Valid: true},
Narrative: sql.NullString{String: "Fixed a bug in handler", Valid: true},
FilesModified: []string{"handler.go"},
},
expectConflict: false,
},
}
for _, tt := range tests {
s.Run(tt.name, func() {
found, _ := DetectOpposingFileChanges(tt.newerObs, tt.olderObs)
s.Equal(tt.expectConflict, found)
})
}
}
// TestDetectConceptTagMismatch_TableDriven tests concept tag mismatch detection.
func (s *ConflictSuite) TestDetectConceptTagMismatch_TableDriven() {
tests := []struct {
name string
newerObs *Observation
olderObs *Observation
expectConflict bool
}{
{
name: "same concepts and files - conflict",
newerObs: &Observation{
Concepts: []string{"security", "authentication"},
FilesModified: []string{"auth.go"},
},
olderObs: &Observation{
Concepts: []string{"security", "authentication"},
FilesModified: []string{"auth.go"},
},
expectConflict: true,
},
{
name: "overlapping concepts different files - no conflict",
newerObs: &Observation{
Concepts: []string{"security"},
FilesModified: []string{"new_auth.go"},
},
olderObs: &Observation{
Concepts: []string{"security"},
FilesModified: []string{"old_auth.go"},
},
expectConflict: false,
},
{
name: "different concepts same files - no conflict",
newerObs: &Observation{
Concepts: []string{"performance"},
FilesModified: []string{"handler.go"},
},
olderObs: &Observation{
Concepts: []string{"testing"},
FilesModified: []string{"handler.go"},
},
expectConflict: false,
},
{
name: "no overlapping concepts or files - no conflict",
newerObs: &Observation{
Concepts: []string{"security"},
FilesModified: []string{"auth.go"},
},
olderObs: &Observation{
Concepts: []string{"testing"},
FilesModified: []string{"test.go"},
},
expectConflict: false,
},
}
for _, tt := range tests {
s.Run(tt.name, func() {
found, _ := DetectConceptTagMismatch(tt.newerObs, tt.olderObs)
s.Equal(tt.expectConflict, found)
})
}
}
// TestDetectConflict tests comprehensive conflict detection.
func (s *ConflictSuite) TestDetectConflict() {
// Test explicit correction takes precedence
newer := &Observation{
ID: 2,
Project: "test",
Narrative: sql.NullString{String: "Actually, that was wrong. We should use a different approach.", Valid: true},
}
older := &Observation{
ID: 1,
Project: "test",
}
result := DetectConflict(newer, older)
s.True(result.HasConflict)
s.Equal(ConflictContradicts, result.Type)
s.Equal(ResolutionPreferNewer, result.Resolution)
s.Contains(result.Reason, "Explicit correction")
}
// TestDetectConflictsWithExisting tests conflict detection against multiple observations.
func (s *ConflictSuite) TestDetectConflictsWithExisting() {
newer := &Observation{
ID: 3,
Project: "test",
Narrative: sql.NullString{String: "Actually, that was wrong", Valid: true},
Concepts: []string{"security"},
FilesModified: []string{"auth.go"},
}
existing := []*Observation{
{
ID: 1,
Project: "test",
Concepts: []string{"security"},
FilesModified: []string{"auth.go"},
},
{
ID: 2,
Project: "test",
Concepts: []string{"testing"},
FilesModified: []string{"test.go"},
},
{
ID: 3, // Same as newer - should be skipped
Project: "test",
},
}
results := DetectConflictsWithExisting(newer, existing)
// Should find conflicts with obs 1 (concepts + files overlap + correction language)
// but not with obs 2 (different concepts/files) or obs 3 (same ID)
s.GreaterOrEqual(len(results), 1)
// At least one result should reference obs 1
foundObs1 := false
for _, r := range results {
for _, id := range r.OlderObsIDs {
if id == 1 {
foundObs1 = true
break
}
}
}
s.True(foundObs1)
}
// TestDetectConflictsWithExisting_DifferentProjects tests that different projects don't conflict.
func (s *ConflictSuite) TestDetectConflictsWithExisting_DifferentProjects() {
newer := &Observation{
ID: 2,
Project: "project-a",
Scope: ScopeProject,
Narrative: sql.NullString{String: "Actually, that was wrong", Valid: true},
}
existing := []*Observation{
{
ID: 1,
Project: "project-b",
Scope: ScopeProject,
},
}
results := DetectConflictsWithExisting(newer, existing)
s.Empty(results) // Different projects should not conflict
}
// TestDetectConflictsWithExisting_GlobalScope tests that global observations can conflict across projects.
func (s *ConflictSuite) TestDetectConflictsWithExisting_GlobalScope() {
newer := &Observation{
ID: 2,
Project: "project-a",
Scope: ScopeGlobal,
Narrative: sql.NullString{String: "Actually, that was wrong", Valid: true},
Concepts: []string{"security"},
FilesModified: []string{"auth.go"},
}
existing := []*Observation{
{
ID: 1,
Project: "project-b",
Scope: ScopeGlobal, // Global scope allows cross-project conflict detection
Concepts: []string{"security"},
FilesModified: []string{"auth.go"},
},
}
results := DetectConflictsWithExisting(newer, existing)
s.GreaterOrEqual(len(results), 1) // Global scope allows conflict detection
}
// TestCorrectionPatterns_Compiled ensures all patterns compile correctly.
func TestCorrectionPatterns_Compiled(t *testing.T) {
// This test verifies that all correction patterns are valid regexps
// If any pattern fails to compile, the package won't load
assert.NotEmpty(t, CorrectionPatterns)
for i, pattern := range CorrectionPatterns {
assert.NotNil(t, pattern, "Pattern %d should not be nil", i)
}
}
// TestOpposingChangePatterns tests the opposing change pattern map.
func TestOpposingChangePatterns(t *testing.T) {
assert.NotEmpty(t, OpposingChangePatterns)
assert.Equal(t, "remove", OpposingChangePatterns["add"])
assert.Equal(t, "removed", OpposingChangePatterns["added"])
assert.Equal(t, "delete", OpposingChangePatterns["create"])
assert.Equal(t, "disable", OpposingChangePatterns["enable"])
}
// TestObservationConflict_Fields tests field access.
func TestObservationConflict_Fields(t *testing.T) {
conflict := &ObservationConflict{
ID: 1,
NewerObsID: 10,
OlderObsID: 5,
ConflictType: ConflictSuperseded,
Resolution: ResolutionPreferNewer,
Reason: "Test reason",
DetectedAt: "2024-01-01T00:00:00Z",
DetectedAtEpoch: 1704067200000,
Resolved: false,
}
assert.Equal(t, int64(1), conflict.ID)
assert.Equal(t, int64(10), conflict.NewerObsID)
assert.Equal(t, int64(5), conflict.OlderObsID)
assert.Equal(t, ConflictSuperseded, conflict.ConflictType)
assert.Equal(t, ResolutionPreferNewer, conflict.Resolution)
assert.False(t, conflict.Resolved)
}
+51
View File
@@ -139,6 +139,16 @@ type Observation struct {
CreatedAt string `db:"created_at" json:"created_at"` CreatedAt string `db:"created_at" json:"created_at"`
CreatedAtEpoch int64 `db:"created_at_epoch" json:"created_at_epoch"` CreatedAtEpoch int64 `db:"created_at_epoch" json:"created_at_epoch"`
IsStale bool `db:"-" json:"is_stale,omitempty"` IsStale bool `db:"-" json:"is_stale,omitempty"`
// Importance scoring fields
ImportanceScore float64 `db:"importance_score" json:"importance_score"`
UserFeedback int `db:"user_feedback" json:"user_feedback"`
RetrievalCount int `db:"retrieval_count" json:"retrieval_count"`
LastRetrievedAt sql.NullInt64 `db:"last_retrieved_at_epoch" json:"last_retrieved_at_epoch,omitempty"`
ScoreUpdatedAt sql.NullInt64 `db:"score_updated_at_epoch" json:"score_updated_at_epoch,omitempty"`
// Conflict detection fields
IsSuperseded bool `db:"is_superseded" json:"is_superseded,omitempty"`
} }
// ParsedObservation represents an observation parsed from SDK response XML. // ParsedObservation represents an observation parsed from SDK response XML.
@@ -205,6 +215,16 @@ type ObservationJSON struct {
CreatedAt string `json:"created_at"` CreatedAt string `json:"created_at"`
CreatedAtEpoch int64 `json:"created_at_epoch"` CreatedAtEpoch int64 `json:"created_at_epoch"`
IsStale bool `json:"is_stale,omitempty"` IsStale bool `json:"is_stale,omitempty"`
// Importance scoring fields
ImportanceScore float64 `json:"importance_score"`
UserFeedback int `json:"user_feedback"`
RetrievalCount int `json:"retrieval_count"`
LastRetrievedAt int64 `json:"last_retrieved_at_epoch,omitempty"`
ScoreUpdatedAt int64 `json:"score_updated_at_epoch,omitempty"`
// Conflict detection fields
IsSuperseded bool `json:"is_superseded,omitempty"`
} }
// MarshalJSON implements json.Marshaler for Observation. // MarshalJSON implements json.Marshaler for Observation.
@@ -225,6 +245,12 @@ func (o *Observation) MarshalJSON() ([]byte, error) {
CreatedAt: o.CreatedAt, CreatedAt: o.CreatedAt,
CreatedAtEpoch: o.CreatedAtEpoch, CreatedAtEpoch: o.CreatedAtEpoch,
IsStale: o.IsStale, IsStale: o.IsStale,
// Importance scoring fields
ImportanceScore: o.ImportanceScore,
UserFeedback: o.UserFeedback,
RetrievalCount: o.RetrievalCount,
// Conflict detection fields
IsSuperseded: o.IsSuperseded,
} }
if o.Title.Valid { if o.Title.Valid {
j.Title = o.Title.String j.Title = o.Title.String
@@ -238,6 +264,12 @@ func (o *Observation) MarshalJSON() ([]byte, error) {
if o.PromptNumber.Valid { if o.PromptNumber.Valid {
j.PromptNumber = o.PromptNumber.Int64 j.PromptNumber = o.PromptNumber.Int64
} }
if o.LastRetrievedAt.Valid {
j.LastRetrievedAt = o.LastRetrievedAt.Int64
}
if o.ScoreUpdatedAt.Valid {
j.ScoreUpdatedAt = o.ScoreUpdatedAt.Int64
}
return json.Marshal(j) return json.Marshal(j)
} }
@@ -268,9 +300,28 @@ func NewObservation(sdkSessionID, project string, parsed *ParsedObservation, pro
DiscoveryTokens: discoveryTokens, DiscoveryTokens: discoveryTokens,
CreatedAt: now.Format(time.RFC3339), CreatedAt: now.Format(time.RFC3339),
CreatedAtEpoch: now.UnixMilli(), CreatedAtEpoch: now.UnixMilli(),
// Importance scoring: new observations start with score 1.0
ImportanceScore: 1.0,
UserFeedback: 0,
RetrievalCount: 0,
} }
} }
// ToMap converts the observation to a map for JSON response building.
// This allows adding extra fields like similarity scores.
func (o *Observation) ToMap() map[string]interface{} {
// Marshal to JSON then unmarshal to map (uses MarshalJSON for proper conversion)
data, err := json.Marshal(o)
if err != nil {
return map[string]interface{}{"id": o.ID, "error": err.Error()}
}
var result map[string]interface{}
if err := json.Unmarshal(data, &result); err != nil {
return map[string]interface{}{"id": o.ID, "error": err.Error()}
}
return result
}
// CheckStaleness checks if an observation is stale based on current file mtimes. // CheckStaleness checks if an observation is stale based on current file mtimes.
// Returns true if any tracked file has been modified since the observation was created. // Returns true if any tracked file has been modified since the observation was created.
func (o *Observation) CheckStaleness(currentMtimes map[string]int64) bool { func (o *Observation) CheckStaleness(currentMtimes map[string]int64) bool {
+398
View File
@@ -0,0 +1,398 @@
// Package models contains domain models for claude-mnemonic.
package models
import (
"database/sql"
"database/sql/driver"
"encoding/json"
"time"
)
// PatternType represents the category of detected pattern.
type PatternType string
const (
// PatternTypeBug represents recurring bug patterns (e.g., "nil handling oversight").
PatternTypeBug PatternType = "bug"
// PatternTypeRefactor represents recurring refactoring approaches (e.g., "interface extraction").
PatternTypeRefactor PatternType = "refactor"
// PatternTypeArchitecture represents consistent architectural patterns.
PatternTypeArchitecture PatternType = "architecture"
// PatternTypeAntiPattern represents identified anti-patterns to avoid.
PatternTypeAntiPattern PatternType = "anti-pattern"
// PatternTypeBestPractice represents best practices that work consistently.
PatternTypeBestPractice PatternType = "best-practice"
)
// PatternStatus represents the lifecycle status of a pattern.
type PatternStatus string
const (
// PatternStatusActive means the pattern is actively being tracked and can be referenced.
PatternStatusActive PatternStatus = "active"
// PatternStatusDeprecated means the pattern has been superseded or is no longer relevant.
PatternStatusDeprecated PatternStatus = "deprecated"
// PatternStatusMerged means this pattern was merged into another pattern.
PatternStatusMerged PatternStatus = "merged"
)
// Pattern represents a recurring pattern detected across observations.
// This enables Claude to reference historical insights: "I've encountered this pattern 12 times."
type Pattern struct {
ID int64 `db:"id" json:"id"`
Name string `db:"name" json:"name"` // e.g., "State Management Anti-Pattern"
Type PatternType `db:"type" json:"type"` // bug, refactor, architecture, etc.
Description sql.NullString `db:"description" json:"description"` // Detailed description
Signature JSONStringArray `db:"signature" json:"signature"` // Keyword clusters for detection
Recommendation sql.NullString `db:"recommendation" json:"recommendation"` // What works for this pattern
Frequency int `db:"frequency" json:"frequency"` // How many times encountered
Projects JSONStringArray `db:"projects" json:"projects"` // Projects where this pattern was seen
ObservationIDs JSONInt64Array `db:"observation_ids" json:"observation_ids"` // Source observation IDs
Status PatternStatus `db:"status" json:"status"` // active, deprecated, merged
MergedIntoID sql.NullInt64 `db:"merged_into_id" json:"merged_into_id,omitempty"`
Confidence float64 `db:"confidence" json:"confidence"` // Detection confidence (0.0-1.0)
LastSeenAt string `db:"last_seen_at" json:"last_seen_at"` // Last time pattern was detected
LastSeenEpoch int64 `db:"last_seen_at_epoch" json:"last_seen_at_epoch"`
CreatedAt string `db:"created_at" json:"created_at"`
CreatedAtEpoch int64 `db:"created_at_epoch" json:"created_at_epoch"`
}
// JSONInt64Array is a custom type for handling JSON int64 arrays in SQLite.
type JSONInt64Array []int64
// Scan implements sql.Scanner for JSONInt64Array.
func (j *JSONInt64Array) Scan(src interface{}) error {
if src == nil {
*j = nil
return nil
}
var data []byte
switch v := src.(type) {
case string:
data = []byte(v)
case []byte:
data = v
default:
return nil
}
if len(data) == 0 {
*j = nil
return nil
}
return json.Unmarshal(data, j)
}
// Value implements driver.Valuer for JSONInt64Array.
func (j JSONInt64Array) Value() (driver.Value, error) {
if j == nil {
return nil, nil
}
return json.Marshal(j)
}
// PatternJSON is a JSON-friendly representation of Pattern.
type PatternJSON struct {
ID int64 `json:"id"`
Name string `json:"name"`
Type PatternType `json:"type"`
Description string `json:"description,omitempty"`
Signature []string `json:"signature,omitempty"`
Recommendation string `json:"recommendation,omitempty"`
Frequency int `json:"frequency"`
Projects []string `json:"projects,omitempty"`
ObservationIDs []int64 `json:"observation_ids,omitempty"`
Status PatternStatus `json:"status"`
MergedIntoID int64 `json:"merged_into_id,omitempty"`
Confidence float64 `json:"confidence"`
LastSeenAt string `json:"last_seen_at"`
LastSeenEpoch int64 `json:"last_seen_at_epoch"`
CreatedAt string `json:"created_at"`
CreatedAtEpoch int64 `json:"created_at_epoch"`
}
// MarshalJSON implements json.Marshaler for Pattern.
func (p *Pattern) MarshalJSON() ([]byte, error) {
j := PatternJSON{
ID: p.ID,
Name: p.Name,
Type: p.Type,
Signature: p.Signature,
Frequency: p.Frequency,
Projects: p.Projects,
ObservationIDs: p.ObservationIDs,
Status: p.Status,
Confidence: p.Confidence,
LastSeenAt: p.LastSeenAt,
LastSeenEpoch: p.LastSeenEpoch,
CreatedAt: p.CreatedAt,
CreatedAtEpoch: p.CreatedAtEpoch,
}
if p.Description.Valid {
j.Description = p.Description.String
}
if p.Recommendation.Valid {
j.Recommendation = p.Recommendation.String
}
if p.MergedIntoID.Valid {
j.MergedIntoID = p.MergedIntoID.Int64
}
return json.Marshal(j)
}
// NewPattern creates a new pattern from detected data.
func NewPattern(name string, patternType PatternType, description string, signature []string, project string, observationID int64) *Pattern {
now := time.Now()
return &Pattern{
Name: name,
Type: patternType,
Description: sql.NullString{String: description, Valid: description != ""},
Signature: signature,
Frequency: 1,
Projects: []string{project},
ObservationIDs: []int64{observationID},
Status: PatternStatusActive,
Confidence: 0.5, // Initial confidence
LastSeenAt: now.Format(time.RFC3339),
LastSeenEpoch: now.UnixMilli(),
CreatedAt: now.Format(time.RFC3339),
CreatedAtEpoch: now.UnixMilli(),
}
}
// AddOccurrence records a new occurrence of this pattern.
func (p *Pattern) AddOccurrence(project string, observationID int64) {
p.Frequency++
// Add project if not already tracked
found := false
for _, proj := range p.Projects {
if proj == project {
found = true
break
}
}
if !found {
p.Projects = append(p.Projects, project)
}
// Add observation ID if not already tracked
for _, id := range p.ObservationIDs {
if id == observationID {
return
}
}
p.ObservationIDs = append(p.ObservationIDs, observationID)
// Update confidence based on frequency and cross-project occurrence
p.updateConfidence()
// Update last seen timestamp
now := time.Now()
p.LastSeenAt = now.Format(time.RFC3339)
p.LastSeenEpoch = now.UnixMilli()
}
// updateConfidence adjusts confidence based on frequency and cross-project validation.
func (p *Pattern) updateConfidence() {
// Base confidence from frequency (logarithmic scaling)
freqConfidence := 0.3 + (0.4 * (float64(min(p.Frequency, 10)) / 10.0))
// Cross-project bonus: patterns seen across multiple projects are more reliable
projectBonus := 0.0
if len(p.Projects) >= 2 {
projectBonus = 0.1
}
if len(p.Projects) >= 5 {
projectBonus = 0.2
}
p.Confidence = min(1.0, freqConfidence+projectBonus)
}
// PatternMatch represents a match between an observation and a potential pattern.
type PatternMatch struct {
PatternID int64 `json:"pattern_id"`
Score float64 `json:"score"` // Match score (0.0-1.0)
MatchedOn string `json:"matched_on"` // What triggered the match (concept, keyword, type, etc.)
IsNew bool `json:"is_new"` // Whether this would create a new pattern
SuggestedName string `json:"suggested_name,omitempty"`
}
// PatternSignatureKeywords are common keywords used in pattern detection.
var PatternSignatureKeywords = map[PatternType][]string{
PatternTypeBug: {
"nil", "null", "undefined", "panic", "crash", "error handling",
"race condition", "deadlock", "memory leak", "overflow",
"off-by-one", "boundary", "timeout", "concurrency",
},
PatternTypeRefactor: {
"extract", "inline", "rename", "move", "split", "merge",
"interface", "abstraction", "decouple", "simplify",
"consolidate", "modularize", "encapsulate",
},
PatternTypeArchitecture: {
"layer", "service", "repository", "controller", "handler",
"middleware", "dependency injection", "factory", "singleton",
"observer", "strategy", "adapter", "facade", "builder",
},
PatternTypeAntiPattern: {
"god class", "spaghetti", "copy paste", "magic number",
"hardcoded", "circular dependency", "premature optimization",
"over-engineering", "feature envy", "data clump",
},
PatternTypeBestPractice: {
"test", "validation", "logging", "monitoring", "documentation",
"error handling", "retry", "timeout", "circuit breaker",
"graceful shutdown", "health check", "metrics",
},
}
// DetectPatternType analyzes concepts and content to determine pattern type.
func DetectPatternType(concepts []string, title, narrative string) PatternType {
// Check concepts first
for _, concept := range concepts {
switch concept {
case "anti-pattern":
return PatternTypeAntiPattern
case "best-practice":
return PatternTypeBestPractice
case "architecture":
return PatternTypeArchitecture
case "refactor":
return PatternTypeRefactor
}
}
// Check for bug-related patterns in content
content := title + " " + narrative
for _, keyword := range PatternSignatureKeywords[PatternTypeBug] {
if containsIgnoreCase(content, keyword) {
return PatternTypeBug
}
}
// Default to refactor for other patterns
return PatternTypeRefactor
}
// containsIgnoreCase checks if text contains substr (case-insensitive).
func containsIgnoreCase(text, substr string) bool {
textLower := toLower(text)
substrLower := toLower(substr)
return contains(textLower, substrLower)
}
// Simple implementations to avoid strings package dependency in this file
func toLower(s string) string {
b := make([]byte, len(s))
for i := 0; i < len(s); i++ {
c := s[i]
if 'A' <= c && c <= 'Z' {
c += 'a' - 'A'
}
b[i] = c
}
return string(b)
}
func contains(s, substr string) bool {
if len(substr) > len(s) {
return false
}
for i := 0; i <= len(s)-len(substr); i++ {
if s[i:i+len(substr)] == substr {
return true
}
}
return false
}
// ExtractSignature creates a signature from observation content.
func ExtractSignature(concepts []string, title, narrative string) []string {
var signature []string
// Add all concepts
signature = append(signature, concepts...)
// Extract key terms from title (simple word extraction)
for _, word := range splitWords(title) {
if len(word) > 3 && isSignificantWord(word) {
signature = append(signature, toLower(word))
}
}
return uniqueStrings(signature)
}
// splitWords is a simple word splitter.
func splitWords(s string) []string {
var words []string
word := ""
for _, r := range s {
if r == ' ' || r == '-' || r == '_' || r == '.' || r == ',' {
if word != "" {
words = append(words, word)
word = ""
}
} else {
word += string(r)
}
}
if word != "" {
words = append(words, word)
}
return words
}
// isSignificantWord filters out common stop words.
func isSignificantWord(word string) bool {
stopWords := map[string]bool{
"the": true, "and": true, "for": true, "with": true, "that": true,
"this": true, "from": true, "have": true, "not": true, "are": true,
"was": true, "but": true, "all": true, "can": true, "had": true,
"were": true, "been": true, "will": true, "when": true, "what": true,
}
return !stopWords[toLower(word)]
}
// uniqueStrings returns a slice with duplicate strings removed.
func uniqueStrings(s []string) []string {
seen := make(map[string]bool)
var result []string
for _, v := range s {
if !seen[v] {
seen[v] = true
result = append(result, v)
}
}
return result
}
// CalculateMatchScore computes similarity between two signatures.
func CalculateMatchScore(sig1, sig2 []string) float64 {
if len(sig1) == 0 || len(sig2) == 0 {
return 0.0
}
set1 := make(map[string]bool)
for _, s := range sig1 {
set1[toLower(s)] = true
}
matches := 0
for _, s := range sig2 {
if set1[toLower(s)] {
matches++
}
}
// Jaccard similarity
unionSize := len(sig1) + len(sig2) - matches
if unionSize == 0 {
return 0.0
}
return float64(matches) / float64(unionSize)
}
+357
View File
@@ -0,0 +1,357 @@
package models
import (
"database/sql"
"encoding/json"
"testing"
"time"
)
func TestNewPattern(t *testing.T) {
pattern := NewPattern(
"Test Pattern",
PatternTypeBug,
"A test pattern description",
[]string{"nil", "error", "handling"},
"test-project",
123,
)
if pattern.Name != "Test Pattern" {
t.Errorf("Expected name 'Test Pattern', got '%s'", pattern.Name)
}
if pattern.Type != PatternTypeBug {
t.Errorf("Expected type PatternTypeBug, got '%s'", pattern.Type)
}
if !pattern.Description.Valid || pattern.Description.String != "A test pattern description" {
t.Errorf("Description not set correctly")
}
if len(pattern.Signature) != 3 {
t.Errorf("Expected 3 signature elements, got %d", len(pattern.Signature))
}
if pattern.Frequency != 1 {
t.Errorf("Expected frequency 1, got %d", pattern.Frequency)
}
if len(pattern.Projects) != 1 || pattern.Projects[0] != "test-project" {
t.Errorf("Projects not set correctly")
}
if len(pattern.ObservationIDs) != 1 || pattern.ObservationIDs[0] != 123 {
t.Errorf("ObservationIDs not set correctly")
}
if pattern.Status != PatternStatusActive {
t.Errorf("Expected status Active, got '%s'", pattern.Status)
}
if pattern.Confidence != 0.5 {
t.Errorf("Expected initial confidence 0.5, got %f", pattern.Confidence)
}
}
func TestPattern_AddOccurrence(t *testing.T) {
pattern := NewPattern("Test", PatternTypeBug, "desc", []string{"test"}, "project1", 1)
// Add same project occurrence
pattern.AddOccurrence("project1", 2)
if pattern.Frequency != 2 {
t.Errorf("Expected frequency 2, got %d", pattern.Frequency)
}
if len(pattern.Projects) != 1 {
t.Errorf("Expected 1 project (no duplicates), got %d", len(pattern.Projects))
}
// Add different project occurrence
pattern.AddOccurrence("project2", 3)
if pattern.Frequency != 3 {
t.Errorf("Expected frequency 3, got %d", pattern.Frequency)
}
if len(pattern.Projects) != 2 {
t.Errorf("Expected 2 projects, got %d", len(pattern.Projects))
}
// Add duplicate observation ID - should not duplicate
pattern.AddOccurrence("project2", 3)
if len(pattern.ObservationIDs) != 3 {
t.Errorf("Expected 3 observation IDs (no duplicate), got %d", len(pattern.ObservationIDs))
}
// Check confidence increased
if pattern.Confidence <= 0.5 {
t.Errorf("Expected confidence to increase above 0.5, got %f", pattern.Confidence)
}
}
func TestPattern_ConfidenceCalculation(t *testing.T) {
tests := []struct {
name string
frequency int
projectCount int
minConfidence float64
maxConfidence float64
}{
{"low_frequency", 2, 1, 0.3, 0.5},
{"high_frequency", 10, 1, 0.6, 0.8},
{"multi_project", 3, 3, 0.4, 0.7},
{"high_freq_multi_proj", 10, 5, 0.7, 1.0},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
pattern := NewPattern("Test", PatternTypeBug, "", []string{}, "proj1", 1)
// Simulate occurrences
for i := 1; i < tt.frequency; i++ {
projIdx := i % tt.projectCount
if projIdx == 0 {
projIdx = 1
}
pattern.AddOccurrence("proj"+string(rune('0'+projIdx)), int64(i+1))
}
if pattern.Confidence < tt.minConfidence || pattern.Confidence > tt.maxConfidence {
t.Errorf("Expected confidence between %f and %f, got %f",
tt.minConfidence, tt.maxConfidence, pattern.Confidence)
}
})
}
}
func TestPatternType_Detection(t *testing.T) {
tests := []struct {
concepts []string
title string
narrative string
expected PatternType
}{
{[]string{"anti-pattern"}, "", "", PatternTypeAntiPattern},
{[]string{"best-practice"}, "", "", PatternTypeBestPractice},
{[]string{"architecture"}, "", "", PatternTypeArchitecture},
{[]string{"refactor"}, "", "", PatternTypeRefactor},
{[]string{}, "nil pointer bug", "", PatternTypeBug},
{[]string{}, "Deadlock in concurrent code", "", PatternTypeBug},
{[]string{}, "Extract interface", "", PatternTypeRefactor},
}
for _, tt := range tests {
t.Run(tt.title+"_"+tt.expected.String(), func(t *testing.T) {
result := DetectPatternType(tt.concepts, tt.title, tt.narrative)
if result != tt.expected {
t.Errorf("Expected %s, got %s", tt.expected, result)
}
})
}
}
func (pt PatternType) String() string {
return string(pt)
}
func TestExtractSignature(t *testing.T) {
concepts := []string{"error-handling", "security"}
title := "Nil Pointer Validation Pattern"
narrative := "Always validate before dereferencing"
signature := ExtractSignature(concepts, title, narrative)
// Should contain concepts
found := false
for _, s := range signature {
if s == "error-handling" {
found = true
break
}
}
if !found {
t.Errorf("Expected signature to contain concepts, got %v", signature)
}
// Should contain significant words from title
found = false
for _, s := range signature {
if s == "validation" || s == "pattern" || s == "pointer" {
found = true
break
}
}
if !found {
t.Errorf("Expected signature to contain title keywords, got %v", signature)
}
}
func TestCalculateMatchScore(t *testing.T) {
tests := []struct {
name string
sig1 []string
sig2 []string
minScore float64
maxScore float64
}{
{"identical", []string{"a", "b", "c"}, []string{"a", "b", "c"}, 1.0, 1.0},
{"partial", []string{"a", "b", "c"}, []string{"a", "b", "d"}, 0.4, 0.6},
{"no_match", []string{"a", "b", "c"}, []string{"x", "y", "z"}, 0.0, 0.0},
{"empty", []string{}, []string{"a", "b"}, 0.0, 0.0},
{"subset", []string{"a", "b"}, []string{"a", "b", "c", "d"}, 0.4, 0.6},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
score := CalculateMatchScore(tt.sig1, tt.sig2)
if score < tt.minScore || score > tt.maxScore {
t.Errorf("Expected score between %f and %f, got %f",
tt.minScore, tt.maxScore, score)
}
})
}
}
func TestPattern_MarshalJSON(t *testing.T) {
pattern := &Pattern{
ID: 1,
Name: "Test Pattern",
Type: PatternTypeBug,
Description: sql.NullString{String: "A description", Valid: true},
Signature: []string{"a", "b"},
Recommendation: sql.NullString{String: "Do this", Valid: true},
Frequency: 5,
Projects: []string{"proj1", "proj2"},
ObservationIDs: []int64{1, 2, 3},
Status: PatternStatusActive,
MergedIntoID: sql.NullInt64{Int64: 0, Valid: false},
Confidence: 0.8,
LastSeenAt: time.Now().Format(time.RFC3339),
LastSeenEpoch: time.Now().UnixMilli(),
CreatedAt: time.Now().Format(time.RFC3339),
CreatedAtEpoch: time.Now().UnixMilli(),
}
data, err := json.Marshal(pattern)
if err != nil {
t.Fatalf("Failed to marshal pattern: %v", err)
}
var result PatternJSON
if err := json.Unmarshal(data, &result); err != nil {
t.Fatalf("Failed to unmarshal pattern: %v", err)
}
if result.Name != pattern.Name {
t.Errorf("Expected name %s, got %s", pattern.Name, result.Name)
}
if result.Description != pattern.Description.String {
t.Errorf("Expected description %s, got %s", pattern.Description.String, result.Description)
}
if result.Frequency != pattern.Frequency {
t.Errorf("Expected frequency %d, got %d", pattern.Frequency, result.Frequency)
}
if result.MergedIntoID != 0 {
t.Errorf("Expected merged_into_id 0 for invalid NullInt64, got %d", result.MergedIntoID)
}
}
func TestJSONInt64Array_Scan(t *testing.T) {
tests := []struct {
name string
input interface{}
expected JSONInt64Array
wantErr bool
}{
{"string_array", "[1, 2, 3]", JSONInt64Array{1, 2, 3}, false},
{"bytes_array", []byte("[4, 5, 6]"), JSONInt64Array{4, 5, 6}, false},
{"nil", nil, nil, false},
{"empty_string", "", nil, false},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
var arr JSONInt64Array
err := arr.Scan(tt.input)
if (err != nil) != tt.wantErr {
t.Errorf("Scan() error = %v, wantErr %v", err, tt.wantErr)
return
}
if len(arr) != len(tt.expected) {
t.Errorf("Expected length %d, got %d", len(tt.expected), len(arr))
}
})
}
}
func TestJSONInt64Array_Value(t *testing.T) {
arr := JSONInt64Array{1, 2, 3}
val, err := arr.Value()
if err != nil {
t.Fatalf("Value() error = %v", err)
}
bytes, ok := val.([]byte)
if !ok {
t.Fatalf("Expected []byte, got %T", val)
}
var result []int64
if err := json.Unmarshal(bytes, &result); err != nil {
t.Fatalf("Failed to unmarshal: %v", err)
}
if len(result) != 3 || result[0] != 1 || result[1] != 2 || result[2] != 3 {
t.Errorf("Expected [1, 2, 3], got %v", result)
}
}
func TestPatternSignatureKeywords(t *testing.T) {
// Verify keywords exist for each type
types := []PatternType{
PatternTypeBug,
PatternTypeRefactor,
PatternTypeArchitecture,
PatternTypeAntiPattern,
PatternTypeBestPractice,
}
for _, pt := range types {
keywords := PatternSignatureKeywords[pt]
if len(keywords) == 0 {
t.Errorf("No keywords defined for pattern type %s", pt)
}
}
}
func TestUniqueStrings(t *testing.T) {
tests := []struct {
input []string
expected int
}{
{[]string{"a", "b", "c"}, 3},
{[]string{"a", "a", "b"}, 2},
{[]string{"a", "a", "a"}, 1},
{[]string{}, 0},
}
for _, tt := range tests {
result := uniqueStrings(tt.input)
if len(result) != tt.expected {
t.Errorf("uniqueStrings(%v) = %v (len=%d), expected len=%d",
tt.input, result, len(result), tt.expected)
}
}
}
func TestContainsIgnoreCase(t *testing.T) {
tests := []struct {
text string
substr string
expected bool
}{
{"Hello World", "hello", true},
{"Hello World", "WORLD", true},
{"Hello World", "xyz", false},
{"", "a", false},
{"a", "", true},
}
for _, tt := range tests {
result := containsIgnoreCase(tt.text, tt.substr)
if result != tt.expected {
t.Errorf("containsIgnoreCase(%q, %q) = %v, expected %v",
tt.text, tt.substr, result, tt.expected)
}
}
}
+489
View File
@@ -0,0 +1,489 @@
// Package models contains domain models for claude-mnemonic.
package models
import (
"strings"
"time"
)
// RelationType represents the type of relationship between observations.
type RelationType string
const (
// RelationCauses means source observation caused target observation.
// Example: "This architectural decision caused this bug"
RelationCauses RelationType = "causes"
// RelationFixes means source observation fixes target observation.
// Example: "This bugfix addresses that discovered issue"
RelationFixes RelationType = "fixes"
// RelationSupersedes means source observation supersedes target observation.
// Example: "This new approach replaces the old workaround"
RelationSupersedes RelationType = "supersedes"
// RelationDependsOn means source observation depends on target observation.
// Example: "This feature relies on that architectural decision"
RelationDependsOn RelationType = "depends_on"
// RelationRelatesTo means observations are related but no causal relationship.
// Example: "Both deal with authentication"
RelationRelatesTo RelationType = "relates_to"
// RelationEvolvesFrom means source observation evolved from target observation.
// Example: "This refined pattern evolved from that initial discovery"
RelationEvolvesFrom RelationType = "evolves_from"
)
// AllRelationTypes is the list of all valid relation types.
var AllRelationTypes = []RelationType{
RelationCauses,
RelationFixes,
RelationSupersedes,
RelationDependsOn,
RelationRelatesTo,
RelationEvolvesFrom,
}
// RelationDetectionSource indicates how a relationship was detected.
type RelationDetectionSource string
const (
// DetectionSourceFileOverlap means relationship was detected via shared file references.
DetectionSourceFileOverlap RelationDetectionSource = "file_overlap"
// DetectionSourceEmbeddingSimilarity means relationship was detected via vector similarity.
DetectionSourceEmbeddingSimilarity RelationDetectionSource = "embedding_similarity"
// DetectionSourceTemporalProximity means relationship was detected via close timestamps.
DetectionSourceTemporalProximity RelationDetectionSource = "temporal_proximity"
// DetectionSourceNarrativeMention means relationship was detected via explicit mentions.
DetectionSourceNarrativeMention RelationDetectionSource = "narrative_mention"
// DetectionSourceConceptOverlap means relationship was detected via shared concepts.
DetectionSourceConceptOverlap RelationDetectionSource = "concept_overlap"
// DetectionSourceTypeProgression means relationship was detected via type progression pattern.
DetectionSourceTypeProgression RelationDetectionSource = "type_progression"
)
// ObservationRelation represents a directed relationship between two observations.
type ObservationRelation struct {
ID int64 `db:"id" json:"id"`
SourceID int64 `db:"source_id" json:"source_id"`
TargetID int64 `db:"target_id" json:"target_id"`
RelationType RelationType `db:"relation_type" json:"relation_type"`
Confidence float64 `db:"confidence" json:"confidence"`
DetectionSource RelationDetectionSource `db:"detection_source" json:"detection_source"`
Reason string `db:"reason" json:"reason,omitempty"`
CreatedAt string `db:"created_at" json:"created_at"`
CreatedAtEpoch int64 `db:"created_at_epoch" json:"created_at_epoch"`
}
// NewObservationRelation creates a new observation relation.
func NewObservationRelation(sourceID, targetID int64, relType RelationType, confidence float64, source RelationDetectionSource, reason string) *ObservationRelation {
now := time.Now()
return &ObservationRelation{
SourceID: sourceID,
TargetID: targetID,
RelationType: relType,
Confidence: confidence,
DetectionSource: source,
Reason: reason,
CreatedAt: now.Format(time.RFC3339),
CreatedAtEpoch: now.UnixMilli(),
}
}
// RelationDetectionResult contains the result of relation detection.
type RelationDetectionResult struct {
SourceID int64
TargetID int64
RelationType RelationType
Confidence float64
DetectionSource RelationDetectionSource
Reason string
}
// DetectFileOverlapRelation checks if observations share file references and determines relationship type.
func DetectFileOverlapRelation(newer, older *Observation) *RelationDetectionResult {
// Check for overlapping modified files
newerModified := make(map[string]bool)
for _, f := range newer.FilesModified {
newerModified[f] = true
}
olderModified := make(map[string]bool)
for _, f := range older.FilesModified {
olderModified[f] = true
}
// Files modified by both
var sharedModified []string
for f := range newerModified {
if olderModified[f] {
sharedModified = append(sharedModified, f)
}
}
// Files that newer reads which older modified
var newerReadsOlderModified []string
for _, f := range newer.FilesRead {
if olderModified[f] {
newerReadsOlderModified = append(newerReadsOlderModified, f)
}
}
// Calculate overlap score
overlap := len(sharedModified) + len(newerReadsOlderModified)
if overlap == 0 {
return nil
}
// Determine relationship type based on observation types and file overlap
relType := RelationRelatesTo
confidence := 0.5 + float64(overlap)*0.1 // Base 0.5, +0.1 per overlapping file
// Type-based relationship inference
switch {
case newer.Type == ObsTypeBugfix && (older.Type == ObsTypeDecision || older.Type == ObsTypeFeature):
relType = RelationFixes
confidence += 0.2
case newer.Type == ObsTypeRefactor && older.Type == ObsTypeDiscovery:
relType = RelationEvolvesFrom
confidence += 0.15
case newer.Type == older.Type && len(sharedModified) > 0:
relType = RelationSupersedes
confidence += 0.1
case newer.Type == ObsTypeFeature && older.Type == ObsTypeDecision:
relType = RelationDependsOn
confidence += 0.15
}
if confidence > 1.0 {
confidence = 1.0
}
reason := buildFileOverlapReason(sharedModified, newerReadsOlderModified)
return &RelationDetectionResult{
SourceID: newer.ID,
TargetID: older.ID,
RelationType: relType,
Confidence: confidence,
DetectionSource: DetectionSourceFileOverlap,
Reason: reason,
}
}
// buildFileOverlapReason creates a human-readable reason for file overlap relation.
func buildFileOverlapReason(shared, readsModified []string) string {
parts := []string{}
if len(shared) > 0 {
parts = append(parts, "both modified: "+strings.Join(truncateList(shared, 3), ", "))
}
if len(readsModified) > 0 {
parts = append(parts, "reads files modified by older: "+strings.Join(truncateList(readsModified, 3), ", "))
}
return strings.Join(parts, "; ")
}
// DetectConceptOverlapRelation checks if observations share concepts.
func DetectConceptOverlapRelation(newer, older *Observation) *RelationDetectionResult {
newerConcepts := make(map[string]bool)
for _, c := range newer.Concepts {
newerConcepts[c] = true
}
var shared []string
for _, c := range older.Concepts {
if newerConcepts[c] {
shared = append(shared, c)
}
}
if len(shared) == 0 {
return nil
}
// Calculate confidence based on overlap ratio
totalUniqueConcepts := len(newerConcepts)
for _, c := range older.Concepts {
if !newerConcepts[c] {
totalUniqueConcepts++
}
}
overlapRatio := float64(len(shared)) / float64(totalUniqueConcepts)
confidence := 0.3 + overlapRatio*0.5 // Base 0.3, scale with overlap
// Boost for important concepts
for _, c := range shared {
if isHighValueConcept(c) {
confidence += 0.1
}
}
if confidence > 1.0 {
confidence = 1.0
}
return &RelationDetectionResult{
SourceID: newer.ID,
TargetID: older.ID,
RelationType: RelationRelatesTo,
Confidence: confidence,
DetectionSource: DetectionSourceConceptOverlap,
Reason: "shared concepts: " + strings.Join(truncateList(shared, 5), ", "),
}
}
// isHighValueConcept returns true for concepts that strongly indicate relationships.
func isHighValueConcept(concept string) bool {
highValue := map[string]bool{
"security": true,
"architecture": true,
"gotcha": true,
"anti-pattern": true,
"best-practice": true,
"error-handling": true,
}
return highValue[concept]
}
// DetectTypeProgressionRelation checks for natural type progressions.
// Example: discovery -> decision -> feature -> bugfix
func DetectTypeProgressionRelation(newer, older *Observation) *RelationDetectionResult {
// Define natural type progressions
progressions := map[ObservationType][]ObservationType{
ObsTypeBugfix: {ObsTypeDiscovery, ObsTypeFeature, ObsTypeDecision},
ObsTypeFeature: {ObsTypeDiscovery, ObsTypeDecision},
ObsTypeRefactor: {ObsTypeDiscovery, ObsTypeFeature, ObsTypeBugfix},
ObsTypeDecision: {ObsTypeDiscovery},
ObsTypeChange: {ObsTypeDiscovery, ObsTypeDecision},
}
validPredecessors, ok := progressions[newer.Type]
if !ok {
return nil
}
isValidProgression := false
for _, pred := range validPredecessors {
if older.Type == pred {
isValidProgression = true
break
}
}
if !isValidProgression {
return nil
}
// Determine relationship type based on progression
var relType RelationType
var confidence float64 = 0.4
switch {
case newer.Type == ObsTypeBugfix && older.Type == ObsTypeDiscovery:
relType = RelationFixes
confidence = 0.6
case newer.Type == ObsTypeBugfix && older.Type == ObsTypeFeature:
relType = RelationFixes
confidence = 0.5
case newer.Type == ObsTypeFeature && older.Type == ObsTypeDecision:
relType = RelationDependsOn
confidence = 0.6
case newer.Type == ObsTypeRefactor:
relType = RelationEvolvesFrom
confidence = 0.5
default:
relType = RelationRelatesTo
}
return &RelationDetectionResult{
SourceID: newer.ID,
TargetID: older.ID,
RelationType: relType,
Confidence: confidence,
DetectionSource: DetectionSourceTypeProgression,
Reason: string(older.Type) + " -> " + string(newer.Type) + " progression",
}
}
// DetectTemporalProximityRelation checks if observations are temporally close (same session).
func DetectTemporalProximityRelation(newer, older *Observation) *RelationDetectionResult {
// Only relate observations from the same session
if newer.SDKSessionID != older.SDKSessionID {
return nil
}
// Check temporal proximity (within 5 minutes)
timeDiffMs := newer.CreatedAtEpoch - older.CreatedAtEpoch
if timeDiffMs < 0 {
timeDiffMs = -timeDiffMs
}
fiveMinutesMs := int64(5 * 60 * 1000)
if timeDiffMs > fiveMinutesMs {
return nil
}
// Calculate confidence based on temporal proximity
// Closer = higher confidence
proximityRatio := 1.0 - (float64(timeDiffMs) / float64(fiveMinutesMs))
confidence := 0.3 + proximityRatio*0.4
return &RelationDetectionResult{
SourceID: newer.ID,
TargetID: older.ID,
RelationType: RelationRelatesTo,
Confidence: confidence,
DetectionSource: DetectionSourceTemporalProximity,
Reason: "same session, close timestamps",
}
}
// NarrativeMentionPatterns are patterns that indicate explicit relationships in narratives.
var NarrativeMentionPatterns = []struct {
Pattern string
RelationType RelationType
ConfBoost float64
}{
{" caused ", RelationCauses, 0.3},
{" causes ", RelationCauses, 0.3},
{" because of ", RelationCauses, 0.25},
{" due to ", RelationCauses, 0.2},
{" fixes ", RelationFixes, 0.3},
{" fixed ", RelationFixes, 0.3},
{" resolves ", RelationFixes, 0.3},
{" addresses ", RelationFixes, 0.25},
{" replaces ", RelationSupersedes, 0.3},
{" supersedes ", RelationSupersedes, 0.35},
{" instead of ", RelationSupersedes, 0.25},
{" depends on ", RelationDependsOn, 0.3},
{" requires ", RelationDependsOn, 0.25},
{" builds on ", RelationDependsOn, 0.25},
{" based on ", RelationDependsOn, 0.2},
{" related to ", RelationRelatesTo, 0.2},
{" similar to ", RelationRelatesTo, 0.2},
{" evolved from ", RelationEvolvesFrom, 0.3},
{" improved from ", RelationEvolvesFrom, 0.25},
{" refined from ", RelationEvolvesFrom, 0.25},
}
// DetectNarrativeMentionRelation checks if newer observation's narrative mentions relationship.
func DetectNarrativeMentionRelation(newer, older *Observation) *RelationDetectionResult {
if !newer.Narrative.Valid || newer.Narrative.String == "" {
return nil
}
narrative := strings.ToLower(newer.Narrative.String)
// Check for patterns
for _, p := range NarrativeMentionPatterns {
if strings.Contains(narrative, p.Pattern) {
// Found a pattern - this is a potential relationship
confidence := 0.4 + p.ConfBoost
if confidence > 1.0 {
confidence = 1.0
}
return &RelationDetectionResult{
SourceID: newer.ID,
TargetID: older.ID,
RelationType: p.RelationType,
Confidence: confidence,
DetectionSource: DetectionSourceNarrativeMention,
Reason: "narrative contains '" + strings.TrimSpace(p.Pattern) + "' language",
}
}
}
return nil
}
// DetectRelationsWithExisting checks a new observation against existing ones and returns detected relations.
// This is the main entry point for relation detection.
func DetectRelationsWithExisting(newer *Observation, existing []*Observation, minConfidence float64) []*RelationDetectionResult {
var results []*RelationDetectionResult
seen := make(map[int64]bool)
for _, older := range existing {
// Skip self
if older.ID == newer.ID {
continue
}
// Skip if already superseded
if older.IsSuperseded {
continue
}
// Only compare within same project (or both global)
if newer.Project != older.Project && newer.Scope != ScopeGlobal && older.Scope != ScopeGlobal {
continue
}
// Run all detection methods and keep highest confidence result per target
var bestResult *RelationDetectionResult
// 1. File overlap detection
if result := DetectFileOverlapRelation(newer, older); result != nil && result.Confidence >= minConfidence {
if bestResult == nil || result.Confidence > bestResult.Confidence {
bestResult = result
}
}
// 2. Concept overlap detection
if result := DetectConceptOverlapRelation(newer, older); result != nil && result.Confidence >= minConfidence {
if bestResult == nil || result.Confidence > bestResult.Confidence {
bestResult = result
}
}
// 3. Type progression detection
if result := DetectTypeProgressionRelation(newer, older); result != nil && result.Confidence >= minConfidence {
if bestResult == nil || result.Confidence > bestResult.Confidence {
bestResult = result
}
}
// 4. Temporal proximity detection
if result := DetectTemporalProximityRelation(newer, older); result != nil && result.Confidence >= minConfidence {
// Only use temporal proximity if no better detection found
if bestResult == nil {
bestResult = result
}
}
// 5. Narrative mention detection (can upgrade relation type)
if result := DetectNarrativeMentionRelation(newer, older); result != nil && result.Confidence >= minConfidence {
if bestResult == nil || result.Confidence > bestResult.Confidence {
bestResult = result
}
}
// Add best result if found and not already seen
if bestResult != nil && !seen[older.ID] {
results = append(results, bestResult)
seen[older.ID] = true
}
}
return results
}
// truncateList truncates a list to maxLen items.
func truncateList(items []string, maxLen int) []string {
if len(items) <= maxLen {
return items
}
result := items[:maxLen]
return append(result, "...")
}
// RelationWithDetails contains a relation with its observation details.
type RelationWithDetails struct {
Relation *ObservationRelation `json:"relation"`
SourceTitle string `json:"source_title"`
TargetTitle string `json:"target_title"`
SourceType ObservationType `json:"source_type"`
TargetType ObservationType `json:"target_type"`
}
// RelationGraph represents a graph of related observations.
type RelationGraph struct {
CenterID int64 `json:"center_id"`
Relations []*RelationWithDetails `json:"relations"`
}
+473
View File
@@ -0,0 +1,473 @@
// Package models contains domain models for claude-mnemonic.
package models
import (
"database/sql"
"testing"
)
func TestDetectFileOverlapRelation(t *testing.T) {
tests := []struct {
name string
newer *Observation
older *Observation
wantRelation bool
wantRelType RelationType
wantMinConfid float64
}{
{
name: "no file overlap",
newer: &Observation{
ID: 1,
FilesModified: []string{"file1.go", "file2.go"},
},
older: &Observation{
ID: 2,
FilesModified: []string{"file3.go", "file4.go"},
},
wantRelation: false,
},
{
name: "shared modified files",
newer: &Observation{
ID: 1,
Type: ObsTypeRefactor,
FilesModified: []string{"shared.go", "file2.go"},
},
older: &Observation{
ID: 2,
Type: ObsTypeRefactor,
FilesModified: []string{"shared.go", "file4.go"},
},
wantRelation: true,
wantRelType: RelationSupersedes,
wantMinConfid: 0.5,
},
{
name: "bugfix on feature file",
newer: &Observation{
ID: 1,
Type: ObsTypeBugfix,
FilesModified: []string{"feature.go"},
},
older: &Observation{
ID: 2,
Type: ObsTypeFeature,
FilesModified: []string{"feature.go"},
},
wantRelation: true,
wantRelType: RelationFixes,
wantMinConfid: 0.6,
},
{
name: "newer reads older modified",
newer: &Observation{
ID: 1,
Type: ObsTypeChange,
FilesRead: []string{"dep.go"},
FilesModified: []string{"caller.go"},
},
older: &Observation{
ID: 2,
Type: ObsTypeDecision,
FilesModified: []string{"dep.go"},
},
wantRelation: true,
wantMinConfid: 0.5,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
result := DetectFileOverlapRelation(tt.newer, tt.older)
if tt.wantRelation {
if result == nil {
t.Fatal("expected relation, got nil")
}
if tt.wantRelType != "" && result.RelationType != tt.wantRelType {
t.Errorf("relation type = %v, want %v", result.RelationType, tt.wantRelType)
}
if result.Confidence < tt.wantMinConfid {
t.Errorf("confidence = %v, want at least %v", result.Confidence, tt.wantMinConfid)
}
if result.DetectionSource != DetectionSourceFileOverlap {
t.Errorf("source = %v, want %v", result.DetectionSource, DetectionSourceFileOverlap)
}
} else {
if result != nil {
t.Errorf("expected no relation, got %+v", result)
}
}
})
}
}
func TestDetectConceptOverlapRelation(t *testing.T) {
tests := []struct {
name string
newer *Observation
older *Observation
wantRelation bool
wantMinConfid float64
}{
{
name: "no concept overlap",
newer: &Observation{
ID: 1,
Concepts: []string{"auth", "api"},
},
older: &Observation{
ID: 2,
Concepts: []string{"database", "caching"},
},
wantRelation: false,
},
{
name: "shared concepts",
newer: &Observation{
ID: 1,
Concepts: []string{"security", "auth"},
},
older: &Observation{
ID: 2,
Concepts: []string{"security", "validation"},
},
wantRelation: true,
wantMinConfid: 0.4, // security is a high-value concept
},
{
name: "multiple shared concepts",
newer: &Observation{
ID: 1,
Concepts: []string{"auth", "api", "validation"},
},
older: &Observation{
ID: 2,
Concepts: []string{"auth", "api", "database"},
},
wantRelation: true,
wantMinConfid: 0.5,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
result := DetectConceptOverlapRelation(tt.newer, tt.older)
if tt.wantRelation {
if result == nil {
t.Fatal("expected relation, got nil")
}
if result.Confidence < tt.wantMinConfid {
t.Errorf("confidence = %v, want at least %v", result.Confidence, tt.wantMinConfid)
}
if result.DetectionSource != DetectionSourceConceptOverlap {
t.Errorf("source = %v, want %v", result.DetectionSource, DetectionSourceConceptOverlap)
}
} else {
if result != nil {
t.Errorf("expected no relation, got %+v", result)
}
}
})
}
}
func TestDetectTypeProgressionRelation(t *testing.T) {
tests := []struct {
name string
newerType ObservationType
olderType ObservationType
wantRelation bool
wantRelType RelationType
}{
{
name: "bugfix fixes discovery",
newerType: ObsTypeBugfix,
olderType: ObsTypeDiscovery,
wantRelation: true,
wantRelType: RelationFixes,
},
{
name: "bugfix fixes feature",
newerType: ObsTypeBugfix,
olderType: ObsTypeFeature,
wantRelation: true,
wantRelType: RelationFixes,
},
{
name: "feature depends on decision",
newerType: ObsTypeFeature,
olderType: ObsTypeDecision,
wantRelation: true,
wantRelType: RelationDependsOn,
},
{
name: "refactor evolves from discovery",
newerType: ObsTypeRefactor,
olderType: ObsTypeDiscovery,
wantRelation: true,
wantRelType: RelationEvolvesFrom,
},
{
name: "no progression discovery to bugfix",
newerType: ObsTypeDiscovery,
olderType: ObsTypeBugfix,
wantRelation: false,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
newer := &Observation{ID: 1, Type: tt.newerType}
older := &Observation{ID: 2, Type: tt.olderType}
result := DetectTypeProgressionRelation(newer, older)
if tt.wantRelation {
if result == nil {
t.Fatal("expected relation, got nil")
}
if result.RelationType != tt.wantRelType {
t.Errorf("relation type = %v, want %v", result.RelationType, tt.wantRelType)
}
if result.DetectionSource != DetectionSourceTypeProgression {
t.Errorf("source = %v, want %v", result.DetectionSource, DetectionSourceTypeProgression)
}
} else {
if result != nil {
t.Errorf("expected no relation, got %+v", result)
}
}
})
}
}
func TestDetectTemporalProximityRelation(t *testing.T) {
baseTime := int64(1700000000000) // some base epoch ms
tests := []struct {
name string
newerSession string
olderSession string
newerTime int64
olderTime int64
wantRelation bool
}{
{
name: "same session close time",
newerSession: "session-1",
olderSession: "session-1",
newerTime: baseTime + 60000, // 1 minute later
olderTime: baseTime,
wantRelation: true,
},
{
name: "same session far apart",
newerSession: "session-1",
olderSession: "session-1",
newerTime: baseTime + 600000, // 10 minutes later
olderTime: baseTime,
wantRelation: false, // > 5 minutes
},
{
name: "different sessions close time",
newerSession: "session-1",
olderSession: "session-2",
newerTime: baseTime + 30000,
olderTime: baseTime,
wantRelation: false, // different sessions
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
newer := &Observation{
ID: 1,
SDKSessionID: tt.newerSession,
CreatedAtEpoch: tt.newerTime,
}
older := &Observation{
ID: 2,
SDKSessionID: tt.olderSession,
CreatedAtEpoch: tt.olderTime,
}
result := DetectTemporalProximityRelation(newer, older)
if tt.wantRelation {
if result == nil {
t.Fatal("expected relation, got nil")
}
if result.DetectionSource != DetectionSourceTemporalProximity {
t.Errorf("source = %v, want %v", result.DetectionSource, DetectionSourceTemporalProximity)
}
} else {
if result != nil {
t.Errorf("expected no relation, got %+v", result)
}
}
})
}
}
func TestDetectNarrativeMentionRelation(t *testing.T) {
tests := []struct {
name string
narrative string
wantRelation bool
wantRelType RelationType
}{
{
name: "fixes language",
narrative: "This change fixes the issue with authentication",
wantRelation: true,
wantRelType: RelationFixes,
},
{
name: "causes language",
narrative: "This decision caused unexpected side effects",
wantRelation: true,
wantRelType: RelationCauses,
},
{
name: "supersedes language",
narrative: "This approach supersedes the previous workaround",
wantRelation: true,
wantRelType: RelationSupersedes,
},
{
name: "depends on language",
narrative: "This feature depends on the authentication module",
wantRelation: true,
wantRelType: RelationDependsOn,
},
{
name: "no relationship language",
narrative: "Added new feature for user management",
wantRelation: false,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
newer := &Observation{
ID: 1,
Narrative: sql.NullString{String: tt.narrative, Valid: true},
}
older := &Observation{ID: 2}
result := DetectNarrativeMentionRelation(newer, older)
if tt.wantRelation {
if result == nil {
t.Fatal("expected relation, got nil")
}
if result.RelationType != tt.wantRelType {
t.Errorf("relation type = %v, want %v", result.RelationType, tt.wantRelType)
}
if result.DetectionSource != DetectionSourceNarrativeMention {
t.Errorf("source = %v, want %v", result.DetectionSource, DetectionSourceNarrativeMention)
}
} else {
if result != nil {
t.Errorf("expected no relation, got %+v", result)
}
}
})
}
}
func TestDetectRelationsWithExisting(t *testing.T) {
newer := &Observation{
ID: 1,
SDKSessionID: "session-1",
Project: "test-project",
Type: ObsTypeBugfix,
FilesModified: []string{"auth.go"},
Concepts: []string{"security", "auth"},
Narrative: sql.NullString{String: "Fixed security issue in auth module", Valid: true},
}
existing := []*Observation{
{
ID: 2,
SDKSessionID: "session-1",
Project: "test-project",
Type: ObsTypeDiscovery,
FilesModified: []string{"auth.go"},
Concepts: []string{"security"},
},
{
ID: 3,
SDKSessionID: "session-2",
Project: "test-project",
Type: ObsTypeFeature,
FilesModified: []string{"other.go"},
Concepts: []string{"api"},
},
{
ID: 4,
SDKSessionID: "session-1",
Project: "other-project", // different project
Type: ObsTypeDiscovery,
},
}
results := DetectRelationsWithExisting(newer, existing, 0.4)
// Should find relation with observation 2 (file overlap + concept overlap + type progression)
// Should not find relation with observation 3 (no overlap)
// Should not find relation with observation 4 (different project)
if len(results) == 0 {
t.Fatal("expected at least one relation")
}
// Check that we found relation with observation 2
foundObs2 := false
for _, r := range results {
if r.TargetID == 2 {
foundObs2 = true
// Should be high confidence due to multiple signals
if r.Confidence < 0.5 {
t.Errorf("expected higher confidence for obs 2, got %v", r.Confidence)
}
}
// Should not find relation with obs 4
if r.TargetID == 4 {
t.Error("should not find relation with different project")
}
}
if !foundObs2 {
t.Error("expected to find relation with observation 2")
}
}
func TestNewObservationRelation(t *testing.T) {
rel := NewObservationRelation(1, 2, RelationFixes, 0.8, DetectionSourceFileOverlap, "test reason")
if rel.SourceID != 1 {
t.Errorf("SourceID = %v, want 1", rel.SourceID)
}
if rel.TargetID != 2 {
t.Errorf("TargetID = %v, want 2", rel.TargetID)
}
if rel.RelationType != RelationFixes {
t.Errorf("RelationType = %v, want %v", rel.RelationType, RelationFixes)
}
if rel.Confidence != 0.8 {
t.Errorf("Confidence = %v, want 0.8", rel.Confidence)
}
if rel.DetectionSource != DetectionSourceFileOverlap {
t.Errorf("DetectionSource = %v, want %v", rel.DetectionSource, DetectionSourceFileOverlap)
}
if rel.Reason != "test reason" {
t.Errorf("Reason = %v, want 'test reason'", rel.Reason)
}
if rel.CreatedAt == "" {
t.Error("CreatedAt should be set")
}
if rel.CreatedAtEpoch == 0 {
t.Error("CreatedAtEpoch should be set")
}
}
+112
View File
@@ -0,0 +1,112 @@
// Package models contains domain models for claude-mnemonic.
package models
// ConceptWeight represents a configurable weight for a concept.
type ConceptWeight struct {
Concept string `db:"concept" json:"concept"`
Weight float64 `db:"weight" json:"weight"`
UpdatedAt string `db:"updated_at" json:"updated_at"`
}
// UserFeedbackType represents the type of user feedback.
type UserFeedbackType int
const (
// FeedbackNegative represents a thumbs down.
FeedbackNegative UserFeedbackType = -1
// FeedbackNeutral represents no feedback.
FeedbackNeutral UserFeedbackType = 0
// FeedbackPositive represents a thumbs up.
FeedbackPositive UserFeedbackType = 1
)
// DefaultConceptWeights contains the default weights for concepts.
// Higher weights indicate more important concepts.
var DefaultConceptWeights = map[string]float64{
// Critical concepts (0.25-0.30)
"security": 0.30, // Security issues are most critical
// High importance (0.20-0.25)
"gotcha": 0.25, // Gotchas prevent future mistakes
"best-practice": 0.20, // Best practices guide development
"anti-pattern": 0.20, // Anti-patterns prevent bad code
// Medium importance (0.10-0.15)
"architecture": 0.15, // Architectural decisions have lasting impact
"performance": 0.15, // Performance patterns matter for scale
"error-handling": 0.15, // Error handling prevents failures
"pattern": 0.10, // General patterns are useful
"testing": 0.10, // Testing knowledge helps quality
"debugging": 0.10, // Debugging tips save time
"problem-solution": 0.10, // Problem-solution pairs are actionable
"trade-off": 0.10, // Trade-offs inform decisions
// Lower importance (0.05)
"workflow": 0.05, // Workflow optimizations are nice-to-have
"tooling": 0.05, // Tooling preferences are subjective
"how-it-works": 0.05, // Understanding is foundational but less urgent
"why-it-exists": 0.05, // Context is helpful but less actionable
"what-changed": 0.05, // Changes are informational
}
// TypeBaseScores contains the base importance multipliers for each observation type.
// These are multiplied with the core score to weight different observation types.
var TypeBaseScores = map[ObservationType]float64{
ObsTypeBugfix: 1.3, // Bugfixes are valuable - prevent regressions
ObsTypeFeature: 1.2, // New features expand capabilities
ObsTypeDiscovery: 1.1, // Discoveries inform future work
ObsTypeDecision: 1.1, // Architectural decisions guide development
ObsTypeRefactor: 1.0, // Refactoring is neutral
ObsTypeChange: 0.9, // Minor changes are slightly less important
}
// ScoringConfig contains all scoring weights and parameters.
type ScoringConfig struct {
// RecencyHalfLifeDays is the number of days for the importance score to halve.
// With 7 days, a 7-day old observation has 50% of a new observation's recency score.
RecencyHalfLifeDays float64 `json:"recency_half_life_days"`
// FeedbackWeight scales the user feedback contribution to final score.
// With 0.30, a thumbs up adds 0.30 to the score, thumbs down subtracts 0.30.
FeedbackWeight float64 `json:"feedback_weight"`
// ConceptWeight scales the concept boost contribution.
// The sum of matching concept weights is multiplied by this.
ConceptWeight float64 `json:"concept_weight"`
// RetrievalWeight scales the retrieval boost contribution.
// Popular observations get a logarithmic bonus.
RetrievalWeight float64 `json:"retrieval_weight"`
// ConceptWeights maps concept names to their importance weights.
ConceptWeights map[string]float64 `json:"concept_weights"`
// MinScore is the minimum allowed importance score.
// Prevents observations from completely disappearing.
MinScore float64 `json:"min_score"`
}
// DefaultScoringConfig returns the default scoring configuration.
func DefaultScoringConfig() *ScoringConfig {
conceptWeights := make(map[string]float64, len(DefaultConceptWeights))
for k, v := range DefaultConceptWeights {
conceptWeights[k] = v
}
return &ScoringConfig{
RecencyHalfLifeDays: 7.0, // Score halves every 7 days
FeedbackWeight: 0.30, // Feedback has moderate impact
ConceptWeight: 0.20, // Concept weights have smaller impact
RetrievalWeight: 0.15, // Retrieval has smallest impact
ConceptWeights: conceptWeights,
MinScore: 0.01, // Never completely disappear
}
}
// TypeBaseScore returns the base weight for an observation type.
func TypeBaseScore(t ObservationType) float64 {
if score, ok := TypeBaseScores[t]; ok {
return score
}
return 1.0 // Default for unknown types
}
+121
View File
@@ -0,0 +1,121 @@
#!/bin/bash
# Download BGE-small-en-v1.5 model for embedding
# Usage: ./download-bge-model.sh [--force]
# Use --force to re-download even if files exist
set -e
MODEL_NAME="bge-small-en-v1.5"
MODEL_REPO="BAAI/bge-small-en-v1.5"
ASSETS_DIR="internal/embedding/assets"
VERSION_FILE="${ASSETS_DIR}/.model_version"
FORCE_DOWNLOAD=false
# Check for --force flag
for arg in "$@"; do
if [ "$arg" = "--force" ]; then
FORCE_DOWNLOAD=true
fi
done
# Temporary directory for downloads
TEMP_DIR=$(mktemp -d)
trap "rm -rf ${TEMP_DIR}" EXIT
# Check if model already exists
model_exists() {
[ -f "${ASSETS_DIR}/model.onnx" ] && [ -f "${ASSETS_DIR}/tokenizer.json" ]
}
# Get installed version
get_installed_version() {
if [ -f "$VERSION_FILE" ]; then
cat "$VERSION_FILE"
else
echo ""
fi
}
# Write version file
write_version_file() {
echo "${MODEL_NAME}" > "$VERSION_FILE"
}
download_model() {
echo "Downloading ${MODEL_NAME} from Hugging Face..."
# Create assets directory
mkdir -p "${ASSETS_DIR}"
# Download ONNX model
# BGE models have ONNX exports available in the repo
echo "Downloading ONNX model..."
curl -fsSL \
"https://huggingface.co/${MODEL_REPO}/resolve/main/onnx/model.onnx" \
-o "${TEMP_DIR}/model.onnx"
# Download tokenizer.json
echo "Downloading tokenizer..."
curl -fsSL \
"https://huggingface.co/${MODEL_REPO}/resolve/main/tokenizer.json" \
-o "${TEMP_DIR}/tokenizer.json"
# Verify files exist and have content
if [ ! -s "${TEMP_DIR}/model.onnx" ]; then
echo "Error: Failed to download model.onnx or file is empty"
exit 1
fi
if [ ! -s "${TEMP_DIR}/tokenizer.json" ]; then
echo "Error: Failed to download tokenizer.json or file is empty"
exit 1
fi
# Move to assets directory (backup old files first)
if [ -f "${ASSETS_DIR}/model.onnx" ]; then
mv "${ASSETS_DIR}/model.onnx" "${ASSETS_DIR}/model.onnx.bak"
fi
if [ -f "${ASSETS_DIR}/tokenizer.json" ]; then
mv "${ASSETS_DIR}/tokenizer.json" "${ASSETS_DIR}/tokenizer.json.bak"
fi
mv "${TEMP_DIR}/model.onnx" "${ASSETS_DIR}/model.onnx"
mv "${TEMP_DIR}/tokenizer.json" "${ASSETS_DIR}/tokenizer.json"
# Remove backups on success
rm -f "${ASSETS_DIR}/model.onnx.bak" "${ASSETS_DIR}/tokenizer.json.bak"
# Write version file
write_version_file
echo "Model size: $(du -h "${ASSETS_DIR}/model.onnx" | cut -f1)"
echo "Tokenizer size: $(du -h "${ASSETS_DIR}/tokenizer.json" | cut -f1)"
}
echo "BGE Model Downloader - ${MODEL_NAME}"
echo "=================================="
need_download=false
reason=""
if [ "$FORCE_DOWNLOAD" = true ]; then
need_download=true
reason="forced"
elif ! model_exists; then
need_download=true
reason="not found"
elif [ "$(get_installed_version)" != "${MODEL_NAME}" ]; then
need_download=true
reason="version mismatch (installed: $(get_installed_version), required: ${MODEL_NAME})"
fi
if [ "$need_download" = true ]; then
if [ -n "$reason" ] && [ "$reason" != "not found" ]; then
echo "Re-downloading: ${reason}"
fi
download_model
echo "Done! ${MODEL_NAME} installed successfully."
else
echo "Model ${MODEL_NAME} already exists, skipping download."
echo "Use --force to re-download."
fi
+108 -2
View File
@@ -1,13 +1,15 @@
{ {
"name": "claude-mnemonic-dashboard", "name": "claude-mnemonic-dashboard",
"version": "v0.6.33-3-gf38ce5c-dirty", "version": "0ddacaa-dirty",
"lockfileVersion": 3, "lockfileVersion": 3,
"requires": true, "requires": true,
"packages": { "packages": {
"": { "": {
"name": "claude-mnemonic-dashboard", "name": "claude-mnemonic-dashboard",
"version": "v0.6.33-3-gf38ce5c-dirty", "version": "0ddacaa-dirty",
"dependencies": { "dependencies": {
"vis-data": "^7.1.9",
"vis-network": "^9.1.9",
"vue": "^3.5.13" "vue": "^3.5.13"
}, },
"devDependencies": { "devDependencies": {
@@ -81,6 +83,19 @@
"node": ">=6.9.0" "node": ">=6.9.0"
} }
}, },
"node_modules/@egjs/hammerjs": {
"version": "2.0.17",
"resolved": "https://registry.npmjs.org/@egjs/hammerjs/-/hammerjs-2.0.17.tgz",
"integrity": "sha512-XQsZgjm2EcVUiZQf11UBJQfmZeEmOW8DpI1gsFeln6w0ae0ii4dMQEQ0kjl6DspdWX1aGY1/loyXnP0JS06e/A==",
"license": "MIT",
"peer": true,
"dependencies": {
"@types/hammerjs": "^2.0.36"
},
"engines": {
"node": ">=0.8.0"
}
},
"node_modules/@esbuild/aix-ppc64": { "node_modules/@esbuild/aix-ppc64": {
"version": "0.25.12", "version": "0.25.12",
"resolved": "https://registry.npmjs.org/@esbuild/aix-ppc64/-/aix-ppc64-0.25.12.tgz", "resolved": "https://registry.npmjs.org/@esbuild/aix-ppc64/-/aix-ppc64-0.25.12.tgz",
@@ -924,6 +939,13 @@
"dev": true, "dev": true,
"license": "MIT" "license": "MIT"
}, },
"node_modules/@types/hammerjs": {
"version": "2.0.46",
"resolved": "https://registry.npmjs.org/@types/hammerjs/-/hammerjs-2.0.46.tgz",
"integrity": "sha512-ynRvcq6wvqexJ9brDMS4BnBLzmr0e14d6ZJTEShTBWKymQiHwlAyGu0ZPEFI2Fh1U53F7tN9ufClWM5KvqkKOw==",
"license": "MIT",
"peer": true
},
"node_modules/@types/node": { "node_modules/@types/node": {
"version": "22.19.2", "version": "22.19.2",
"resolved": "https://registry.npmjs.org/@types/node/-/node-22.19.2.tgz", "resolved": "https://registry.npmjs.org/@types/node/-/node-22.19.2.tgz",
@@ -1352,6 +1374,19 @@
"node": ">= 6" "node": ">= 6"
} }
}, },
"node_modules/component-emitter": {
"version": "2.0.0",
"resolved": "https://registry.npmjs.org/component-emitter/-/component-emitter-2.0.0.tgz",
"integrity": "sha512-4m5s3Me2xxlVKG9PkZpQqHQR7bgpnN7joDMJ4yvVkVXngjoITG76IaZmzmywSeRTeTpc6N6r3H3+KyUurV8OYw==",
"license": "MIT",
"peer": true,
"engines": {
"node": ">=18"
},
"funding": {
"url": "https://github.com/sponsors/sindresorhus"
}
},
"node_modules/cssesc": { "node_modules/cssesc": {
"version": "3.0.0", "version": "3.0.0",
"resolved": "https://registry.npmjs.org/cssesc/-/cssesc-3.0.0.tgz", "resolved": "https://registry.npmjs.org/cssesc/-/cssesc-3.0.0.tgz",
@@ -1669,6 +1704,13 @@
"jiti": "bin/jiti.js" "jiti": "bin/jiti.js"
} }
}, },
"node_modules/keycharm": {
"version": "0.4.0",
"resolved": "https://registry.npmjs.org/keycharm/-/keycharm-0.4.0.tgz",
"integrity": "sha512-TyQTtsabOVv3MeOpR92sIKk/br9wxS+zGj4BG7CR8YbK4jM3tyIBaF0zhzeBUMx36/Q/iQLOKKOT+3jOQtemRQ==",
"license": "(Apache-2.0 OR MIT)",
"peer": true
},
"node_modules/lilconfig": { "node_modules/lilconfig": {
"version": "3.1.3", "version": "3.1.3",
"resolved": "https://registry.npmjs.org/lilconfig/-/lilconfig-3.1.3.tgz", "resolved": "https://registry.npmjs.org/lilconfig/-/lilconfig-3.1.3.tgz",
@@ -2412,6 +2454,70 @@
"dev": true, "dev": true,
"license": "MIT" "license": "MIT"
}, },
"node_modules/uuid": {
"version": "11.1.0",
"resolved": "https://registry.npmjs.org/uuid/-/uuid-11.1.0.tgz",
"integrity": "sha512-0/A9rDy9P7cJ+8w1c9WD9V//9Wj15Ce2MPz8Ri6032usz+NfePxx5AcN3bN+r6ZL6jEo066/yNYB3tn4pQEx+A==",
"funding": [
"https://github.com/sponsors/broofa",
"https://github.com/sponsors/ctavan"
],
"license": "MIT",
"peer": true,
"bin": {
"uuid": "dist/esm/bin/uuid"
}
},
"node_modules/vis-data": {
"version": "7.1.10",
"resolved": "https://registry.npmjs.org/vis-data/-/vis-data-7.1.10.tgz",
"integrity": "sha512-23juM9tdCaHTX5vyIQ7XBzsfZU0Hny+gSTwniLrfFcmw9DOm7pi3+h9iEBsoZMp5rX6KNqWwc1MF0fkAmWVuoQ==",
"license": "(Apache-2.0 OR MIT)",
"funding": {
"type": "opencollective",
"url": "https://opencollective.com/visjs"
},
"peerDependencies": {
"uuid": "^3.4.0 || ^7.0.0 || ^8.0.0 || ^9.0.0 || ^10.0.0 || ^11.0.0",
"vis-util": "^5.0.1"
}
},
"node_modules/vis-network": {
"version": "9.1.13",
"resolved": "https://registry.npmjs.org/vis-network/-/vis-network-9.1.13.tgz",
"integrity": "sha512-HLeHd5KZS92qzO1kC59qMh1/FWAZxMUEwUWBwDMoj6RKj/Ajkrgy/heEYo0Zc8SZNQ2J+u6omvK2+a28GX1QuQ==",
"license": "(Apache-2.0 OR MIT)",
"funding": {
"type": "opencollective",
"url": "https://opencollective.com/visjs"
},
"peerDependencies": {
"@egjs/hammerjs": "^2.0.0",
"component-emitter": "^1.3.0 || ^2.0.0",
"keycharm": "^0.2.0 || ^0.3.0 || ^0.4.0",
"uuid": "^3.4.0 || ^7.0.0 || ^8.0.0 || ^9.0.0 || ^10.0.0 || ^11.0.0",
"vis-data": "^6.3.0 || ^7.0.0",
"vis-util": "^5.0.1"
}
},
"node_modules/vis-util": {
"version": "5.0.7",
"resolved": "https://registry.npmjs.org/vis-util/-/vis-util-5.0.7.tgz",
"integrity": "sha512-E3L03G3+trvc/X4LXvBfih3YIHcKS2WrP0XTdZefr6W6Qi/2nNCqZfe4JFfJU6DcQLm6Gxqj2Pfl+02859oL5A==",
"license": "(Apache-2.0 OR MIT)",
"peer": true,
"engines": {
"node": ">=8"
},
"funding": {
"type": "opencollective",
"url": "https://opencollective.com/visjs"
},
"peerDependencies": {
"@egjs/hammerjs": "^2.0.0",
"component-emitter": "^1.3.0 || ^2.0.0"
}
},
"node_modules/vite": { "node_modules/vite": {
"version": "6.4.1", "version": "6.4.1",
"resolved": "https://registry.npmjs.org/vite/-/vite-6.4.1.tgz", "resolved": "https://registry.npmjs.org/vite/-/vite-6.4.1.tgz",
+3 -1
View File
@@ -1,6 +1,6 @@
{ {
"name": "claude-mnemonic-dashboard", "name": "claude-mnemonic-dashboard",
"version": "v0.6.33-3-gf38ce5c-dirty", "version": "0ddacaa-dirty",
"private": true, "private": true,
"type": "module", "type": "module",
"scripts": { "scripts": {
@@ -10,6 +10,8 @@
"type-check": "vue-tsc --noEmit" "type-check": "vue-tsc --noEmit"
}, },
"dependencies": { "dependencies": {
"vis-data": "^7.1.9",
"vis-network": "^9.1.9",
"vue": "^3.5.13" "vue": "^3.5.13"
}, },
"devDependencies": { "devDependencies": {
+2
View File
@@ -10,6 +10,8 @@
"type-check": "vue-tsc --noEmit" "type-check": "vue-tsc --noEmit"
}, },
"dependencies": { "dependencies": {
"vis-data": "^7.1.9",
"vis-network": "^9.1.9",
"vue": "^3.5.13" "vue": "^3.5.13"
}, },
"devDependencies": { "devDependencies": {
+1 -1
View File
@@ -30,7 +30,7 @@ const {
// Pass currentProject ref to useStats for project-specific retrieval stats // Pass currentProject ref to useStats for project-specific retrieval stats
const { stats } = useStats(currentProject) const { stats } = useStats(currentProject)
// Note: Timeline refresh is handled by useTimeline's SSE watcher // Note: Feedback is handled directly in ObservationCard component
</script> </script>
<template> <template>
+243 -2
View File
@@ -1,17 +1,66 @@
<script setup lang="ts"> <script setup lang="ts">
import type { ObservationFeedItem } from '@/types' import type { ObservationFeedItem, RelationWithDetails } from '@/types'
import { TYPE_CONFIG, CONCEPT_CONFIG } from '@/types/observation' import { TYPE_CONFIG, CONCEPT_CONFIG } from '@/types/observation'
import { RELATION_TYPE_CONFIG, DETECTION_SOURCE_CONFIG } from '@/types/relation'
import { formatRelativeTime } from '@/utils/formatters' import { formatRelativeTime } from '@/utils/formatters'
import { fetchObservationRelations } from '@/utils/api'
import Card from './Card.vue' import Card from './Card.vue'
import IconBox from './IconBox.vue' import IconBox from './IconBox.vue'
import Badge from './Badge.vue' import Badge from './Badge.vue'
import { computed } from 'vue' import RelationGraph from './RelationGraph.vue'
import { computed, ref, onMounted } from 'vue'
const props = defineProps<{ const props = defineProps<{
observation: ObservationFeedItem observation: ObservationFeedItem
highlight?: boolean highlight?: boolean
showFeedback?: boolean
}>() }>()
const emit = defineEmits<{
navigateToObservation: [id: number]
}>()
// Local feedback and score state (optimistic updates)
const localFeedback = ref<number | null>(null)
const localScore = ref<number | null>(null)
const isSubmitting = ref(false)
const currentFeedback = computed(() =>
localFeedback.value !== null ? localFeedback.value : (props.observation.user_feedback || 0)
)
const currentScore = computed(() =>
localScore.value !== null ? localScore.value : (props.observation.importance_score || 1)
)
const submitFeedback = async (value: number) => {
if (isSubmitting.value) return
// Toggle off if clicking same button
const newValue = currentFeedback.value === value ? 0 : value
localFeedback.value = newValue
isSubmitting.value = true
try {
const response = await fetch(`/api/observations/${props.observation.id}/feedback`, {
method: 'POST',
headers: { 'Content-Type': 'application/json' },
body: JSON.stringify({ feedback: newValue })
})
if (response.ok) {
const data = await response.json()
if (data.score !== undefined) {
localScore.value = data.score
}
}
} catch (error) {
console.error('Error submitting feedback:', error)
} finally {
isSubmitting.value = false
}
}
const config = computed(() => TYPE_CONFIG[props.observation.type] || TYPE_CONFIG.change) const config = computed(() => TYPE_CONFIG[props.observation.type] || TYPE_CONFIG.change)
const concepts = computed(() => { const concepts = computed(() => {
@@ -40,6 +89,60 @@ const filesModified = computed(() => {
const hasFiles = computed(() => filesRead.value.length > 0 || filesModified.value.length > 0) const hasFiles = computed(() => filesRead.value.length > 0 || filesModified.value.length > 0)
// Relations state
const relations = ref<RelationWithDetails[]>([])
const relationsLoading = ref(false)
const relationsExpanded = ref(false)
const showGraph = ref(false)
const hasRelations = computed(() => relations.value.length > 0)
const relationCount = computed(() => relations.value.length)
// Load relations on mount
const loadRelations = async () => {
relationsLoading.value = true
try {
relations.value = await fetchObservationRelations(props.observation.id)
} catch (err) {
console.error('Failed to load relations:', err)
} finally {
relationsLoading.value = false
}
}
onMounted(() => {
loadRelations()
})
// Toggle relations expansion
const toggleRelations = () => {
relationsExpanded.value = !relationsExpanded.value
}
// Open graph modal
const openGraph = (e: Event) => {
e.stopPropagation()
showGraph.value = true
}
// Handle navigation from graph
const handleNavigateTo = (id: number) => {
showGraph.value = false
emit('navigateToObservation', id)
}
// Get relation display info (whether we're source or target)
const getRelationDisplay = (rel: RelationWithDetails) => {
const isSource = rel.relation.source_id === props.observation.id
return {
type: rel.relation.relation_type,
otherTitle: isSource ? rel.target_title : rel.source_title,
otherId: isSource ? rel.relation.target_id : rel.relation.source_id,
direction: isSource ? 'outgoing' : 'incoming',
confidence: rel.relation.confidence
}
}
// Split path into project root and relative path for styling // Split path into project root and relative path for styling
// e.g., /Users/foo/project/src/file.go { root: 'project', path: 'src/file.go' } // e.g., /Users/foo/project/src/file.go { root: 'project', path: 'src/file.go' }
const splitPath = (path: string, components = 3) => { const splitPath = (path: string, components = 3) => {
@@ -140,7 +243,145 @@ const splitPath = (path: string, components = 3) => {
</div> </div>
</div> </div>
</div> </div>
<!-- Relations -->
<div v-if="hasRelations || relationsLoading" class="mt-3 pt-3 border-t border-slate-700/50">
<!-- Header with count and graph button -->
<div class="flex items-center justify-between">
<button
@click="toggleRelations"
class="flex items-center gap-2 text-xs text-slate-400 hover:text-slate-300 transition-colors"
:disabled="relationsLoading"
>
<i class="fas fa-diagram-project text-cyan-500/70" />
<span v-if="relationsLoading" class="text-slate-500">
<i class="fas fa-circle-notch fa-spin mr-1" />
Loading relations...
</span>
<span v-else>
{{ relationCount }} related observation{{ relationCount !== 1 ? 's' : '' }}
</span>
<i
v-if="!relationsLoading && hasRelations"
class="fas text-[10px] transition-transform"
:class="relationsExpanded ? 'fa-chevron-up' : 'fa-chevron-down'"
/>
</button>
<!-- View Graph button -->
<button
v-if="hasRelations"
@click="openGraph"
class="flex items-center gap-1.5 px-2 py-1 text-xs text-cyan-400 hover:text-cyan-300 bg-cyan-500/10 hover:bg-cyan-500/20 rounded transition-colors"
title="View knowledge graph"
>
<i class="fas fa-project-diagram" />
<span>View Graph</span>
</button>
</div>
<!-- Expanded relations list -->
<div
v-if="relationsExpanded && hasRelations"
class="mt-2 space-y-1.5"
>
<div
v-for="rel in relations"
:key="rel.relation.id"
class="flex items-center gap-2 text-xs p-1.5 rounded bg-slate-800/30 hover:bg-slate-800/50 transition-colors group"
>
<!-- Relation type icon -->
<i
class="fas w-4 text-center"
:class="[
RELATION_TYPE_CONFIG[getRelationDisplay(rel).type]?.icon || 'fa-link',
RELATION_TYPE_CONFIG[getRelationDisplay(rel).type]?.colorClass || 'text-slate-400'
]"
:title="RELATION_TYPE_CONFIG[getRelationDisplay(rel).type]?.label"
/>
<!-- Direction arrow -->
<i
class="fas text-[10px] text-slate-600"
:class="getRelationDisplay(rel).direction === 'outgoing' ? 'fa-arrow-right' : 'fa-arrow-left'"
/>
<!-- Related observation title -->
<span
class="flex-1 truncate text-slate-300 cursor-pointer hover:text-amber-300 transition-colors"
:title="getRelationDisplay(rel).otherTitle"
@click="emit('navigateToObservation', getRelationDisplay(rel).otherId)"
>
{{ getRelationDisplay(rel).otherTitle || 'Untitled' }}
</span>
<!-- Confidence -->
<span
class="text-[10px] text-slate-500 font-mono"
:title="`${Math.round(getRelationDisplay(rel).confidence * 100)}% confidence`"
>
{{ Math.round(getRelationDisplay(rel).confidence * 100) }}%
</span>
<!-- Detection source icon -->
<i
class="fas text-[10px] text-slate-600"
:class="DETECTION_SOURCE_CONFIG[rel.relation.detection_source]?.icon || 'fa-question'"
:title="DETECTION_SOURCE_CONFIG[rel.relation.detection_source]?.label"
/>
</div>
</div>
</div>
</div>
<!-- Feedback buttons (right side) -->
<div v-if="showFeedback" class="flex flex-col items-center gap-1 ml-2 flex-shrink-0">
<button
@click="submitFeedback(1)"
:disabled="isSubmitting"
:class="[
'p-1.5 rounded-lg transition-all duration-200',
currentFeedback === 1
? 'bg-green-500/30 text-green-300 shadow-green-500/20 shadow-sm'
: 'text-slate-500 hover:text-green-400 hover:bg-green-500/10'
]"
title="Helpful"
>
<i class="fas fa-thumbs-up text-sm" />
</button>
<span
class="text-[10px] font-mono px-1.5 py-0.5 rounded bg-slate-800/50 text-slate-400 flex items-center gap-1 transition-all duration-300"
:class="{ 'text-green-400': localScore !== null && localScore > (observation.importance_score || 1), 'text-red-400': localScore !== null && localScore < (observation.importance_score || 1) }"
:title="`Importance Score: ${currentScore.toFixed(3)}\nRetrieval Count: ${observation.retrieval_count || 0}`"
>
<i class="fas fa-scale-balanced text-amber-500/60" />
{{ currentScore.toFixed(2) }}
</span>
<button
@click="submitFeedback(-1)"
:disabled="isSubmitting"
:class="[
'p-1.5 rounded-lg transition-all duration-200',
currentFeedback === -1
? 'bg-red-500/30 text-red-300 shadow-red-500/20 shadow-sm'
: 'text-slate-500 hover:text-red-400 hover:bg-red-500/10'
]"
title="Not helpful"
>
<i class="fas fa-thumbs-down text-sm" />
</button>
</div> </div>
</div> </div>
<!-- Relation Graph Modal -->
<RelationGraph
:observation-id="observation.id"
:observation-title="observation.title || 'Untitled'"
:show="showGraph"
@close="showGraph = false"
@navigate-to="handleNavigateTo"
/>
</Card> </Card>
</template> </template>
+409
View File
@@ -0,0 +1,409 @@
<script setup lang="ts">
import { ref, onMounted, onUnmounted, watch, computed } from 'vue'
import { Network, type Data, type Options } from 'vis-network'
import type { RelationGraph, RelationWithDetails } from '@/types'
import { RELATION_TYPE_CONFIG, DETECTION_SOURCE_CONFIG } from '@/types/relation'
import { fetchObservationGraph } from '@/utils/api'
const props = defineProps<{
observationId: number
observationTitle: string
show: boolean
}>()
const emit = defineEmits<{
close: []
navigateTo: [id: number]
}>()
const graphContainer = ref<HTMLElement | null>(null)
const loading = ref(true)
const error = ref<string | null>(null)
const graphData = ref<RelationGraph | null>(null)
const selectedRelation = ref<RelationWithDetails | null>(null)
const depth = ref(2)
let network: Network | null = null
// Node colors based on observation type
const getNodeColor = (type: string) => {
const colors: Record<string, { background: string; border: string; highlight: { background: string; border: string } }> = {
bugfix: { background: '#ef4444', border: '#dc2626', highlight: { background: '#f87171', border: '#ef4444' } },
feature: { background: '#a855f7', border: '#9333ea', highlight: { background: '#c084fc', border: '#a855f7' } },
refactor: { background: '#3b82f6', border: '#2563eb', highlight: { background: '#60a5fa', border: '#3b82f6' } },
discovery: { background: '#06b6d4', border: '#0891b2', highlight: { background: '#22d3ee', border: '#06b6d4' } },
decision: { background: '#eab308', border: '#ca8a04', highlight: { background: '#facc15', border: '#eab308' } },
change: { background: '#64748b', border: '#475569', highlight: { background: '#94a3b8', border: '#64748b' } },
}
return colors[type] || colors.change
}
// Edge colors based on relation type
const getEdgeColor = (type: string) => {
const colors: Record<string, string> = {
causes: '#f97316',
fixes: '#22c55e',
supersedes: '#a855f7',
depends_on: '#3b82f6',
relates_to: '#64748b',
evolves_from: '#06b6d4',
}
return colors[type] || '#64748b'
}
const loadGraph = async () => {
loading.value = true
error.value = null
try {
graphData.value = await fetchObservationGraph(props.observationId, depth.value)
renderGraph()
} catch (err) {
error.value = err instanceof Error ? err.message : 'Failed to load graph'
} finally {
loading.value = false
}
}
const renderGraph = () => {
if (!graphContainer.value || !graphData.value) return
// Build nodes and edges from relations
const nodeMap = new Map<number, { id: number; label: string; type: string }>()
const edgesList: { from: number; to: number; label: string; color: string; arrows: string; relation: RelationWithDetails }[] = []
// Add center node
nodeMap.set(props.observationId, {
id: props.observationId,
label: truncateLabel(props.observationTitle),
type: 'center'
})
// Process relations
for (const rel of graphData.value.relations) {
// Add source node
if (!nodeMap.has(rel.relation.source_id)) {
nodeMap.set(rel.relation.source_id, {
id: rel.relation.source_id,
label: truncateLabel(rel.source_title),
type: rel.source_type
})
}
// Add target node
if (!nodeMap.has(rel.relation.target_id)) {
nodeMap.set(rel.relation.target_id, {
id: rel.relation.target_id,
label: truncateLabel(rel.target_title),
type: rel.target_type
})
}
// Add edge
edgesList.push({
from: rel.relation.source_id,
to: rel.relation.target_id,
label: rel.relation.relation_type.replace('_', ' '),
color: getEdgeColor(rel.relation.relation_type),
arrows: 'to',
relation: rel
})
}
// Create vis-network data using plain arrays (simpler type compatibility)
const nodes = Array.from(nodeMap.values()).map(node => ({
id: node.id,
label: node.label,
color: node.id === props.observationId
? { background: '#f59e0b', border: '#d97706', highlight: { background: '#fbbf24', border: '#f59e0b' } }
: getNodeColor(node.type),
font: { color: '#fff', size: 12 },
shape: 'box' as const,
borderWidth: node.id === props.observationId ? 3 : 2,
margin: { top: 10, right: 10, bottom: 10, left: 10 },
shadow: true
}))
const edges = edgesList.map((edge, index) => ({
id: index,
from: edge.from,
to: edge.to,
label: edge.label,
color: { color: edge.color, highlight: edge.color },
font: { color: '#94a3b8', size: 10, strokeWidth: 0 },
arrows: edge.arrows,
width: 2,
smooth: { enabled: true, type: 'curvedCW' as const, roundness: 0.2 }
}))
// Cleanup existing network
if (network) {
network.destroy()
}
// Create network data
const data: Data = { nodes, edges }
const options: Options = {
physics: {
enabled: true,
solver: 'forceAtlas2Based',
forceAtlas2Based: {
gravitationalConstant: -50,
centralGravity: 0.01,
springLength: 150,
springConstant: 0.08
},
stabilization: { iterations: 100 }
},
interaction: {
hover: true,
tooltipDelay: 200,
zoomView: true,
dragView: true
},
layout: {
improvedLayout: true
}
}
// Create network
network = new Network(graphContainer.value, data, options)
// Handle edge click to show details
network.on('selectEdge', (params: { edges: (string | number)[] }) => {
if (params.edges.length > 0) {
const edgeId = params.edges[0] as number
selectedRelation.value = edgesList[edgeId]?.relation || null
}
})
// Handle node double-click to navigate
network.on('doubleClick', (params: { nodes: (string | number)[] }) => {
if (params.nodes.length > 0) {
const nodeId = params.nodes[0] as number
if (nodeId !== props.observationId) {
emit('navigateTo', nodeId)
}
}
})
// Clear selection when clicking background
network.on('click', (params: { nodes: (string | number)[]; edges: (string | number)[] }) => {
if (params.nodes.length === 0 && params.edges.length === 0) {
selectedRelation.value = null
}
})
}
const truncateLabel = (label: string, maxLen = 30) => {
if (label.length <= maxLen) return label
return label.substring(0, maxLen - 3) + '...'
}
const relationCount = computed(() => graphData.value?.relations.length || 0)
const closeModal = () => {
emit('close')
}
// Watch for show prop changes
watch(() => props.show, (newVal) => {
if (newVal) {
loadGraph()
} else {
selectedRelation.value = null
if (network) {
network.destroy()
network = null
}
}
})
// Watch for depth changes
watch(depth, () => {
if (props.show) {
loadGraph()
}
})
onMounted(() => {
if (props.show) {
loadGraph()
}
})
onUnmounted(() => {
if (network) {
network.destroy()
network = null
}
})
</script>
<template>
<!-- Modal backdrop -->
<Teleport to="body">
<div
v-if="show"
class="fixed inset-0 z-50 flex items-center justify-center p-4"
@click.self="closeModal"
>
<!-- Backdrop -->
<div class="absolute inset-0 bg-black/70 backdrop-blur-sm" @click="closeModal" />
<!-- Modal content -->
<div class="relative bg-slate-900 border border-slate-700 rounded-xl shadow-2xl w-full max-w-5xl max-h-[90vh] flex flex-col overflow-hidden">
<!-- Header -->
<div class="flex items-center justify-between p-4 border-b border-slate-700">
<div class="flex items-center gap-3">
<div class="p-2 rounded-lg bg-amber-500/20">
<i class="fas fa-diagram-project text-amber-400" />
</div>
<div>
<h2 class="text-lg font-semibold text-amber-100">Knowledge Graph</h2>
<p class="text-sm text-slate-400">
{{ relationCount }} relation{{ relationCount !== 1 ? 's' : '' }} for "{{ truncateLabel(observationTitle, 50) }}"
</p>
</div>
</div>
<!-- Controls -->
<div class="flex items-center gap-4">
<!-- Depth selector -->
<div class="flex items-center gap-2">
<label class="text-xs text-slate-400">Depth:</label>
<select
v-model="depth"
class="bg-slate-800 border border-slate-600 rounded px-2 py-1 text-sm text-slate-200 focus:outline-none focus:border-amber-500"
>
<option :value="1">1</option>
<option :value="2">2</option>
<option :value="3">3</option>
</select>
</div>
<!-- Close button -->
<button
@click="closeModal"
class="p-2 rounded-lg text-slate-400 hover:text-slate-200 hover:bg-slate-800 transition-colors"
>
<i class="fas fa-times" />
</button>
</div>
</div>
<!-- Graph container -->
<div class="relative" style="height: 60vh; min-height: 500px;">
<!-- Loading state -->
<div v-if="loading" class="absolute inset-0 flex items-center justify-center bg-slate-900/50">
<div class="flex items-center gap-3 text-amber-400">
<i class="fas fa-circle-notch fa-spin text-xl" />
<span>Loading graph...</span>
</div>
</div>
<!-- Error state -->
<div v-else-if="error" class="absolute inset-0 flex items-center justify-center">
<div class="text-center">
<i class="fas fa-exclamation-triangle text-3xl text-red-400 mb-2" />
<p class="text-red-300">{{ error }}</p>
<button
@click="loadGraph"
class="mt-3 px-4 py-2 bg-slate-800 hover:bg-slate-700 rounded-lg text-sm text-slate-200 transition-colors"
>
Retry
</button>
</div>
</div>
<!-- Empty state -->
<div v-else-if="graphData && graphData.relations.length === 0" class="absolute inset-0 flex items-center justify-center">
<div class="text-center">
<i class="fas fa-diagram-project text-4xl text-slate-600 mb-3" />
<p class="text-slate-400">No relations found for this observation</p>
<p class="text-sm text-slate-500 mt-1">Relations are detected automatically when observations share files, concepts, or patterns</p>
</div>
</div>
<!-- Graph -->
<div ref="graphContainer" class="absolute inset-0" />
</div>
<!-- Relation details panel -->
<div v-if="selectedRelation" class="border-t border-slate-700 p-4 bg-slate-800/50">
<div class="flex items-start gap-4">
<!-- Relation type icon -->
<div
class="p-3 rounded-lg"
:class="RELATION_TYPE_CONFIG[selectedRelation.relation.relation_type]?.bgClass"
>
<i
class="fas"
:class="[
RELATION_TYPE_CONFIG[selectedRelation.relation.relation_type]?.icon,
RELATION_TYPE_CONFIG[selectedRelation.relation.relation_type]?.colorClass
]"
/>
</div>
<!-- Details -->
<div class="flex-1 min-w-0">
<div class="flex items-center gap-2 mb-1">
<span class="font-medium" :class="RELATION_TYPE_CONFIG[selectedRelation.relation.relation_type]?.colorClass">
{{ RELATION_TYPE_CONFIG[selectedRelation.relation.relation_type]?.label }}
</span>
<span class="text-xs text-slate-500">
({{ Math.round(selectedRelation.relation.confidence * 100) }}% confidence)
</span>
</div>
<div class="flex items-center gap-2 text-sm text-slate-300 mb-2">
<span class="font-mono text-amber-400">{{ truncateLabel(selectedRelation.source_title, 40) }}</span>
<i class="fas fa-arrow-right text-slate-500 text-xs" />
<span class="font-mono text-amber-400">{{ truncateLabel(selectedRelation.target_title, 40) }}</span>
</div>
<div class="flex items-center gap-4 text-xs text-slate-500">
<span class="flex items-center gap-1">
<i :class="['fas', DETECTION_SOURCE_CONFIG[selectedRelation.relation.detection_source]?.icon]" />
{{ DETECTION_SOURCE_CONFIG[selectedRelation.relation.detection_source]?.label }}
</span>
<span v-if="selectedRelation.relation.reason">
{{ selectedRelation.relation.reason }}
</span>
</div>
</div>
</div>
</div>
<!-- Legend -->
<div class="border-t border-slate-700 p-3 bg-slate-800/30">
<div class="flex flex-wrap items-center justify-center gap-4 text-xs text-slate-400">
<span class="font-medium text-slate-300">Legend:</span>
<span class="flex items-center gap-1">
<span class="w-3 h-3 rounded bg-amber-500" /> Center
</span>
<span class="flex items-center gap-1">
<span class="w-3 h-3 rounded bg-red-500" /> Bugfix
</span>
<span class="flex items-center gap-1">
<span class="w-3 h-3 rounded bg-purple-500" /> Feature
</span>
<span class="flex items-center gap-1">
<span class="w-3 h-3 rounded bg-blue-500" /> Refactor
</span>
<span class="flex items-center gap-1">
<span class="w-3 h-3 rounded bg-cyan-500" /> Discovery
</span>
<span class="flex items-center gap-1">
<span class="w-3 h-3 rounded bg-yellow-500" /> Decision
</span>
<span class="text-slate-500">|</span>
<span>Double-click node to navigate</span>
</div>
</div>
</div>
</div>
</Teleport>
</template>
+1
View File
@@ -31,6 +31,7 @@ defineProps<{
v-if="item.itemType === 'observation'" v-if="item.itemType === 'observation'"
:observation="item" :observation="item"
:highlight="index === 0" :highlight="index === 0"
:show-feedback="true"
/> />
<PromptCard <PromptCard
v-else-if="item.itemType === 'prompt'" v-else-if="item.itemType === 'prompt'"
+1
View File
@@ -2,3 +2,4 @@ export * from './observation'
export * from './prompt' export * from './prompt'
export * from './summary' export * from './summary'
export * from './api' export * from './api'
export * from './relation'
+6
View File
@@ -30,6 +30,12 @@ export interface Observation {
created_at: string created_at: string
created_at_epoch: number created_at_epoch: number
is_stale?: boolean is_stale?: boolean
// Importance scoring fields
importance_score: number
user_feedback: number // -1 (thumbs down), 0 (neutral), 1 (thumbs up)
retrieval_count: number
last_retrieved_at_epoch?: number
score_updated_at_epoch?: number
} }
export const OBSERVATION_TYPES: ObservationType[] = ['bugfix', 'feature', 'refactor', 'discovery', 'decision', 'change'] export const OBSERVATION_TYPES: ObservationType[] = ['bugfix', 'feature', 'refactor', 'discovery', 'decision', 'change']
+90
View File
@@ -0,0 +1,90 @@
export type RelationType = 'causes' | 'fixes' | 'supersedes' | 'depends_on' | 'relates_to' | 'evolves_from'
export type DetectionSource = 'file_overlap' | 'embedding_similarity' | 'temporal_proximity' | 'narrative_mention' | 'concept_overlap' | 'type_progression'
export interface ObservationRelation {
id: number
source_id: number
target_id: number
relation_type: RelationType
confidence: number
detection_source: DetectionSource
reason: string
created_at: string
created_at_epoch: number
}
export interface RelationWithDetails {
relation: ObservationRelation
source_title: string
target_title: string
source_type: string
target_type: string
}
export interface RelationGraph {
center_id: number
relations: RelationWithDetails[]
}
export interface RelationStats {
total_count: number
high_confidence: number
by_type: Record<RelationType, number>
min_confidence_used: number
}
// Configuration for relation type display
export const RELATION_TYPE_CONFIG: Record<RelationType, { icon: string; label: string; colorClass: string; bgClass: string; description: string }> = {
causes: {
icon: 'fa-arrow-right',
label: 'Causes',
colorClass: 'text-orange-300',
bgClass: 'bg-orange-500/20',
description: 'This observation caused the related issue'
},
fixes: {
icon: 'fa-wrench',
label: 'Fixes',
colorClass: 'text-green-300',
bgClass: 'bg-green-500/20',
description: 'This observation fixes the related issue'
},
supersedes: {
icon: 'fa-layer-group',
label: 'Supersedes',
colorClass: 'text-purple-300',
bgClass: 'bg-purple-500/20',
description: 'This observation replaces the older one'
},
depends_on: {
icon: 'fa-link',
label: 'Depends On',
colorClass: 'text-blue-300',
bgClass: 'bg-blue-500/20',
description: 'This observation depends on the related one'
},
relates_to: {
icon: 'fa-arrows-left-right',
label: 'Related',
colorClass: 'text-slate-300',
bgClass: 'bg-slate-500/20',
description: 'These observations are related'
},
evolves_from: {
icon: 'fa-code-branch',
label: 'Evolves From',
colorClass: 'text-cyan-300',
bgClass: 'bg-cyan-500/20',
description: 'This observation evolved from the related one'
}
}
// Configuration for detection source display
export const DETECTION_SOURCE_CONFIG: Record<DetectionSource, { icon: string; label: string }> = {
file_overlap: { icon: 'fa-file-code', label: 'Shared files' },
embedding_similarity: { icon: 'fa-brain', label: 'Semantic similarity' },
temporal_proximity: { icon: 'fa-clock', label: 'Close in time' },
narrative_mention: { icon: 'fa-quote-left', label: 'Mentioned in text' },
concept_overlap: { icon: 'fa-tags', label: 'Shared concepts' },
type_progression: { icon: 'fa-diagram-next', label: 'Natural progression' }
}
+18 -1
View File
@@ -1,4 +1,4 @@
import type { Observation, UserPrompt, SessionSummary, Stats, FeedItem, ObservationFeedItem, PromptFeedItem, SummaryFeedItem } from '@/types' import type { Observation, UserPrompt, SessionSummary, Stats, FeedItem, ObservationFeedItem, PromptFeedItem, SummaryFeedItem, RelationWithDetails, RelationGraph, RelationStats } from '@/types'
const API_BASE = '/api' const API_BASE = '/api'
const DEFAULT_TIMEOUT = 10000 // 10 seconds const DEFAULT_TIMEOUT = 10000 // 10 seconds
@@ -147,3 +147,20 @@ export function combineTimeline(
return [...obsItems, ...promptItems, ...summaryItems] return [...obsItems, ...promptItems, ...summaryItems]
.sort((a, b) => b.timestamp.getTime() - a.timestamp.getTime()) .sort((a, b) => b.timestamp.getTime() - a.timestamp.getTime())
} }
// Relation API functions
export async function fetchObservationRelations(observationId: number, signal?: AbortSignal): Promise<RelationWithDetails[]> {
return fetchWithRetry<RelationWithDetails[]>(`${API_BASE}/observations/${observationId}/relations`, { signal })
}
export async function fetchObservationGraph(observationId: number, depth: number = 2, signal?: AbortSignal): Promise<RelationGraph> {
return fetchWithRetry<RelationGraph>(`${API_BASE}/observations/${observationId}/graph?depth=${depth}`, { signal })
}
export async function fetchRelatedObservations(observationId: number, minConfidence: number = 0.4, signal?: AbortSignal): Promise<Observation[]> {
return fetchWithRetry<Observation[]>(`${API_BASE}/observations/${observationId}/related?min_confidence=${minConfidence}`, { signal })
}
export async function fetchRelationStats(signal?: AbortSignal): Promise<RelationStats> {
return fetchWithRetry<RelationStats>(`${API_BASE}/relations/stats`, { signal })
}
+1 -1
View File
@@ -1 +1 @@
{"root":["./src/main.ts","./src/vite-env.d.ts","./src/components/index.ts","./src/composables/index.ts","./src/composables/usehealth.ts","./src/composables/usesse.ts","./src/composables/usestats.ts","./src/composables/usetimeline.ts","./src/composables/usetypes.ts","./src/composables/useupdate.ts","./src/types/api.ts","./src/types/index.ts","./src/types/observation.ts","./src/types/prompt.ts","./src/types/summary.ts","./src/utils/api.ts","./src/utils/formatters.ts","./src/app.vue","./src/components/badge.vue","./src/components/card.vue","./src/components/filtertabs.vue","./src/components/header.vue","./src/components/iconbox.vue","./src/components/observationcard.vue","./src/components/projectfilter.vue","./src/components/promptcard.vue","./src/components/sidebar.vue","./src/components/statscards.vue","./src/components/summarycard.vue","./src/components/timeline.vue"],"version":"5.7.3"} {"root":["./src/main.ts","./src/vite-env.d.ts","./src/components/index.ts","./src/composables/index.ts","./src/composables/usehealth.ts","./src/composables/usesse.ts","./src/composables/usestats.ts","./src/composables/usetimeline.ts","./src/composables/usetypes.ts","./src/composables/useupdate.ts","./src/types/api.ts","./src/types/index.ts","./src/types/observation.ts","./src/types/prompt.ts","./src/types/relation.ts","./src/types/summary.ts","./src/utils/api.ts","./src/utils/formatters.ts","./src/app.vue","./src/components/badge.vue","./src/components/card.vue","./src/components/filtertabs.vue","./src/components/header.vue","./src/components/iconbox.vue","./src/components/observationcard.vue","./src/components/projectfilter.vue","./src/components/promptcard.vue","./src/components/relationgraph.vue","./src/components/sidebar.vue","./src/components/statscards.vue","./src/components/summarycard.vue","./src/components/timeline.vue"],"version":"5.7.3"}