mirror of
https://github.com/lukaszraczylo/claude-mnemonic.git
synced 2026-06-05 23:03:55 +00:00
Release dec 2025 (#15)
* Resolves issue #13 - Switched model to bge-small-en-v1.5 - Added lazy re-embedding - Added model version tracking per vector - Added conversion of vectors to the new model * Add lfs support to the workflow. * Implements importance scoring with decay + voting #6 * Resolves issue #5 by marking observations as superseeded and scheduled for deletion * Implement pattern detection #7 * Improve injections and observations accuracy - Session start: Recent observations for project context (recency-based) - User prompt: Semantically relevant observations (similarity-based with threshold) * Added two stage retrieval with bi and cross encoder #8 * Implement query expansion and reformulation #9 * Knowledge graph and relationships ( resolves #4 ) - File Overlap Detection: Detects relationships when observations modify/read the same files - Concept Overlap Detection: Detects relationships based on shared semantic concepts - Type Progression Detection: Infers relationships from natural observation type progressions (e.g., discovery → bugfix = "fixes") - Temporal Proximity Detection: Detects relationships between observations in the same session within 5 minutes - Narrative Mention Detection: Detects explicit relationship language in narratives (e.g., "fixes", "depends on", "supersedes") * Add visualisation of the relations to the dashboard. * fixup! Add visualisation of the relations to the dashboard. * Update documentation with new settings and screenshots.
This commit is contained in:
@@ -20,4 +20,5 @@ jobs:
|
||||
uses: lukaszraczylo/shared-actions/.github/workflows/go-pr.yaml@main
|
||||
with:
|
||||
go-version: ">=1.24"
|
||||
lfs: true
|
||||
secrets: inherit
|
||||
|
||||
@@ -12,7 +12,21 @@ Claude Code forgets everything when your session ends. Claude Mnemonic fixes tha
|
||||
|
||||
It captures what Claude learns during your coding sessions - bug fixes, architecture decisions, patterns that work - and brings that knowledge back in future conversations. No more re-explaining your codebase.
|
||||
|
||||
## What's New in v0.6
|
||||

|
||||
|
||||
## What's New in v0.7
|
||||
|
||||
- **Two-Stage Retrieval** - Cross-encoder reranking for dramatically improved search relevance
|
||||
- **Knowledge Graph** - Automatic relationship detection between observations with visual graph in dashboard
|
||||

|
||||
- **Pattern Detection** - Identifies recurring patterns across sessions and projects
|
||||
- **Importance Scoring** - Time decay and voting system to surface the most valuable memories
|
||||
- **Query Expansion** - Reformulates searches to find semantically related content
|
||||
- **Conflict Detection** - Identifies and resolves contradictory observations
|
||||
- **Observation Lifecycle** - Memories can be superseded when better information arrives
|
||||
|
||||
<details>
|
||||
<summary>Previous: v0.6</summary>
|
||||
|
||||
- **Auto-Updates** - Automatically stays up-to-date with the latest version
|
||||
- **Slash Command: `/restart`** - Restart the worker directly from Claude Code
|
||||
@@ -20,6 +34,7 @@ It captures what Claude learns during your coding sessions - bug fixes, architec
|
||||
- **Async Queue Processing** - Non-blocking observation capture for faster sessions
|
||||
- **Smarter Storage** - Filters out system/agent summaries to keep knowledge relevant
|
||||
- **Improved Reliability** - Better handling of connectivity issues and dead connections
|
||||
</details>
|
||||
|
||||
## Requirements
|
||||
|
||||
@@ -110,16 +125,39 @@ Config file: `~/.claude-mnemonic/settings.json`
|
||||
{
|
||||
"CLAUDE_MNEMONIC_WORKER_PORT": 37777,
|
||||
"CLAUDE_MNEMONIC_CONTEXT_OBSERVATIONS": 100,
|
||||
"CLAUDE_MNEMONIC_CONTEXT_FULL_COUNT": 25
|
||||
"CLAUDE_MNEMONIC_CONTEXT_FULL_COUNT": 25,
|
||||
"CLAUDE_MNEMONIC_RERANKING_ENABLED": true
|
||||
}
|
||||
```
|
||||
|
||||
### Core Settings
|
||||
|
||||
| Variable | Default | What it does |
|
||||
|----------|---------|--------------|
|
||||
| `WORKER_PORT` | `37777` | Dashboard & API port |
|
||||
| `CONTEXT_OBSERVATIONS` | `100` | Max memories per session |
|
||||
| `CONTEXT_FULL_COUNT` | `25` | Full detail memories (rest are condensed to title only) |
|
||||
| `CONTEXT_FULL_COUNT` | `25` | Full detail memories (rest are condensed) |
|
||||
| `CONTEXT_SESSION_COUNT` | `10` | Recent sessions to reference |
|
||||
| `CONTEXT_RELEVANCE_THRESHOLD` | `0.3` | Minimum similarity score (0.0-1.0) for inclusion |
|
||||
| `CONTEXT_MAX_PROMPT_RESULTS` | `10` | Max results per prompt search |
|
||||
|
||||
### Reranking Settings (Two-Stage Retrieval)
|
||||
|
||||
| Variable | Default | What it does |
|
||||
|----------|---------|--------------|
|
||||
| `RERANKING_ENABLED` | `true` | Enable cross-encoder reranking |
|
||||
| `RERANKING_CANDIDATES` | `100` | Candidates to retrieve before reranking |
|
||||
| `RERANKING_RESULTS` | `10` | Final results after reranking |
|
||||
| `RERANKING_ALPHA` | `0.7` | Score blend: alpha×rerank + (1-alpha)×original |
|
||||
| `RERANKING_PURE_MODE` | `false` | Use pure cross-encoder scores only |
|
||||
|
||||
### Embedding Settings
|
||||
|
||||
| Variable | Default | What it does |
|
||||
|----------|---------|--------------|
|
||||
| `EMBEDDING_MODEL` | `bge-v1.5` | Embedding model for semantic search |
|
||||
|
||||
All variables are prefixed with `CLAUDE_MNEMONIC_` in the config file.
|
||||
|
||||
## Project vs Global scope
|
||||
|
||||
@@ -202,10 +240,11 @@ curl -sSL https://raw.githubusercontent.com/lukaszraczylo/claude-mnemonic/main/s
|
||||
|
||||
- **SQLite + FTS5** - Full-text search for exact matches
|
||||
- **sqlite-vec** - Vector database embedded in SQLite
|
||||
- **all-MiniLM-L6-v2** - Local embedding model (384 dimensions) via ONNX Runtime
|
||||
- **Two-Stage Retrieval** - Bi-encoder (embedding) + cross-encoder (reranking) for high accuracy
|
||||
- **Local Models** - all-MiniLM-L6-v2 for embeddings, BGE reranker for relevance scoring
|
||||
- **Go** - Single binary, no external dependencies
|
||||
|
||||
Everything runs locally. No Python. No external vector database. No API calls for embeddings.
|
||||
Everything runs locally. No Python. No external vector database. No API calls.
|
||||
|
||||
## Platform support
|
||||
|
||||
|
||||
@@ -61,12 +61,8 @@ func main() {
|
||||
|
||||
searchResult, _ := hooks.GET(port, searchURL)
|
||||
if observations, ok := searchResult["observations"].([]interface{}); ok && len(observations) > 0 {
|
||||
// Limit to top 5 most relevant observations
|
||||
maxObs := 5
|
||||
if len(observations) < maxObs {
|
||||
maxObs = len(observations)
|
||||
}
|
||||
observations = observations[:maxObs]
|
||||
// Results are already filtered by relevance threshold and capped by max_results
|
||||
// from the server-side config (ContextRelevanceThreshold, ContextMaxPromptResults)
|
||||
observationCount = len(observations)
|
||||
|
||||
// Build context from search results
|
||||
|
||||
Binary file not shown.
|
After Width: | Height: | Size: 252 KiB |
Binary file not shown.
|
After Width: | Height: | Size: 123 KiB |
+37
-7
@@ -29,6 +29,21 @@
|
||||
subtitle="Stop re-explaining your codebase. Claude Mnemonic captures bug fixes, architecture decisions, and coding patterns - then brings them back exactly when you need them."
|
||||
/>
|
||||
|
||||
<!-- Dashboard Preview -->
|
||||
<section class="py-12 lg:py-16 px-4 sm:px-6 relative">
|
||||
<div class="max-w-6xl mx-auto">
|
||||
<div class="relative rounded-xl overflow-hidden border border-slate-700/50 shadow-2xl shadow-amber-500/5">
|
||||
<div class="absolute inset-0 bg-gradient-to-t from-slate-950 via-transparent to-transparent pointer-events-none z-10"></div>
|
||||
<img
|
||||
src="/claude-mnemonic.jpg"
|
||||
alt="Claude Mnemonic Dashboard"
|
||||
class="w-full h-auto"
|
||||
/>
|
||||
</div>
|
||||
<p class="text-center text-slate-500 text-sm mt-4">The dashboard at localhost:37777 - browse, search, and manage your memories</p>
|
||||
</div>
|
||||
</section>
|
||||
|
||||
<!-- Problem Section -->
|
||||
<section class="py-20 lg:py-28 px-4 sm:px-6 relative">
|
||||
<div class="max-w-6xl mx-auto grid lg:grid-cols-2 gap-8 lg:gap-16 items-center">
|
||||
@@ -85,6 +100,21 @@
|
||||
</div>
|
||||
</section>
|
||||
|
||||
<!-- Knowledge Graph Preview -->
|
||||
<section class="py-16 lg:py-20 px-4 sm:px-6 relative">
|
||||
<div class="max-w-5xl mx-auto">
|
||||
<SectionHeader title="See how knowledge connects" subtitle="The knowledge graph reveals relationships between your memories automatically" />
|
||||
<div class="relative rounded-xl overflow-hidden border border-slate-700/50 shadow-2xl shadow-purple-500/5">
|
||||
<img
|
||||
src="/observation-relation-graph.jpg"
|
||||
alt="Knowledge Graph Visualization"
|
||||
class="w-full h-auto"
|
||||
/>
|
||||
</div>
|
||||
<p class="text-center text-slate-500 text-sm mt-4">Click any observation to explore its relationships - see what causes what, what fixes what, and how concepts evolve</p>
|
||||
</div>
|
||||
</section>
|
||||
|
||||
<!-- Before/After Section -->
|
||||
<section class="py-20 lg:py-28 px-4 sm:px-6">
|
||||
<div class="max-w-6xl mx-auto">
|
||||
@@ -288,8 +318,8 @@
|
||||
<p class="text-slate-400 text-xs sm:text-sm">Embedded vector database. No external services required.</p>
|
||||
</div>
|
||||
<div class="glass rounded-2xl p-6 sm:p-8 hover:border-amber-500/30 transition-colors">
|
||||
<div class="text-3xl sm:text-4xl font-bold text-amber-500 mb-2">MiniLM</div>
|
||||
<p class="text-slate-400 text-xs sm:text-sm">Local embeddings via ONNX. "Fix auth" finds "JWT issue" automatically.</p>
|
||||
<div class="text-3xl sm:text-4xl font-bold text-amber-500 mb-2">BGE</div>
|
||||
<p class="text-slate-400 text-xs sm:text-sm">Two-stage retrieval: bi-encoder embeddings + cross-encoder reranking for high accuracy.</p>
|
||||
</div>
|
||||
</div>
|
||||
</div>
|
||||
@@ -386,10 +416,10 @@ const activeTab = ref('macos')
|
||||
|
||||
const features = [
|
||||
{ icon: 'fas fa-brain', title: 'Learns as you work', description: 'Every bug fix, every architecture decision, every "aha moment" - captured automatically without breaking your flow.' },
|
||||
{ icon: 'fas fa-search', title: 'Two-stage retrieval', description: 'Cross-encoder reranking delivers highly relevant results. Finds what you need even with vague queries like "that auth thing".' },
|
||||
{ icon: 'fas fa-project-diagram', title: 'Knowledge graph', description: 'Automatically discovers relationships between memories. See how concepts connect in the visual graph dashboard.' },
|
||||
{ icon: 'fas fa-folder-tree', title: 'Project-aware context', description: 'Your React knowledge stays in React projects. Your Go patterns stay in Go projects. No context pollution.' },
|
||||
{ icon: 'fas fa-globe', title: 'Shared best practices', description: 'Security patterns, performance tips, and universal learnings automatically available across all your projects.' },
|
||||
{ icon: 'fas fa-search', title: 'Finds what matters', description: 'Semantic search finds relevant memories even when you don\'t remember the exact words. "That auth thing" just works.' },
|
||||
{ icon: 'fas fa-chart-line', title: 'Live statusline', description: 'Real-time metrics right in Claude Code: memories served, searches performed, and project memory count at a glance.' },
|
||||
{ icon: 'fas fa-chart-line', title: 'Smart scoring', description: 'Importance decay, pattern detection, and conflict resolution ensure the most valuable memories surface first.' },
|
||||
{ icon: 'fas fa-lock', title: '100% private', description: 'Your code context never leaves your machine. No telemetry. No cloud sync. Your memories are yours.' },
|
||||
]
|
||||
|
||||
@@ -415,8 +445,8 @@ const installCommands = {
|
||||
const configOptions = [
|
||||
{ name: 'CLAUDE_MNEMONIC_WORKER_PORT', description: 'HTTP port for the worker service (default: 37777)', icon: 'fas fa-network-wired' },
|
||||
{ name: 'CLAUDE_MNEMONIC_CONTEXT_OBSERVATIONS', description: 'Maximum observations injected per session (default: 100)', icon: 'fas fa-layer-group' },
|
||||
{ name: 'CLAUDE_MNEMONIC_CONTEXT_FULL_COUNT', description: 'Observations with full narrative detail, rest are condensed (default: 25)', icon: 'fas fa-expand' },
|
||||
{ name: 'CLAUDE_MNEMONIC_MODEL', description: 'Model for processing observations (default: haiku)', icon: 'fas fa-microchip' },
|
||||
{ name: 'CLAUDE_MNEMONIC_RERANKING_ENABLED', description: 'Enable cross-encoder reranking for improved search relevance (default: true)', icon: 'fas fa-sort-amount-down' },
|
||||
{ name: 'CLAUDE_MNEMONIC_CONTEXT_RELEVANCE_THRESHOLD', description: 'Minimum similarity score for inclusion, 0.0-1.0 (default: 0.3)', icon: 'fas fa-filter' },
|
||||
]
|
||||
|
||||
const requiredDeps = [
|
||||
|
||||
+74
-22
@@ -48,16 +48,29 @@ type Config struct {
|
||||
Model string `json:"model"`
|
||||
ClaudeCodePath string `json:"claude_code_path"`
|
||||
|
||||
// Embedding settings
|
||||
EmbeddingModel string `json:"embedding_model"` // e.g., "bge-v1.5"
|
||||
|
||||
// Reranking settings (cross-encoder)
|
||||
RerankingEnabled bool `json:"reranking_enabled"` // Enable cross-encoder reranking
|
||||
RerankingCandidates int `json:"reranking_candidates"` // Number of candidates to retrieve before reranking (default 100)
|
||||
RerankingResults int `json:"reranking_results"` // Number of results to return after reranking (default 10)
|
||||
RerankingAlpha float64 `json:"reranking_alpha"` // Weight for combining scores: alpha*rerank + (1-alpha)*original (default 0.7)
|
||||
RerankingMinImprovement float64 `json:"reranking_min_improvement"` // Minimum rank improvement to trigger reranking (default 0, always rerank)
|
||||
RerankingPureMode bool `json:"reranking_pure_mode"` // Use pure cross-encoder scores without combining with bi-encoder (default false)
|
||||
|
||||
// Context injection settings
|
||||
ContextObservations int `json:"context_observations"`
|
||||
ContextFullCount int `json:"context_full_count"`
|
||||
ContextSessionCount int `json:"context_session_count"`
|
||||
ContextShowReadTokens bool `json:"context_show_read_tokens"`
|
||||
ContextShowWorkTokens bool `json:"context_show_work_tokens"`
|
||||
ContextFullField string `json:"context_full_field"`
|
||||
ContextShowLastSummary bool `json:"context_show_last_summary"`
|
||||
ContextObsTypes []string `json:"context_obs_types"`
|
||||
ContextObsConcepts []string `json:"context_obs_concepts"`
|
||||
ContextObservations int `json:"context_observations"`
|
||||
ContextFullCount int `json:"context_full_count"`
|
||||
ContextSessionCount int `json:"context_session_count"`
|
||||
ContextShowReadTokens bool `json:"context_show_read_tokens"`
|
||||
ContextShowWorkTokens bool `json:"context_show_work_tokens"`
|
||||
ContextFullField string `json:"context_full_field"`
|
||||
ContextShowLastSummary bool `json:"context_show_last_summary"`
|
||||
ContextObsTypes []string `json:"context_obs_types"`
|
||||
ContextObsConcepts []string `json:"context_obs_concepts"`
|
||||
ContextRelevanceThreshold float64 `json:"context_relevance_threshold"` // 0.0-1.0, minimum similarity for inclusion
|
||||
ContextMaxPromptResults int `json:"context_max_prompt_results"` // Max results per prompt (0 = threshold only)
|
||||
}
|
||||
|
||||
var (
|
||||
@@ -119,22 +132,33 @@ func EnsureAll() error {
|
||||
return nil
|
||||
}
|
||||
|
||||
// DefaultEmbeddingModel is the default embedding model to use.
|
||||
const DefaultEmbeddingModel = "bge-v1.5"
|
||||
|
||||
// Default returns a Config with default values.
|
||||
func Default() *Config {
|
||||
return &Config{
|
||||
WorkerPort: DefaultWorkerPort,
|
||||
DBPath: DBPath(),
|
||||
MaxConns: 4,
|
||||
Model: DefaultModel,
|
||||
ContextObservations: 100,
|
||||
ContextFullCount: 25,
|
||||
ContextSessionCount: 10,
|
||||
ContextShowReadTokens: true,
|
||||
ContextShowWorkTokens: true,
|
||||
ContextFullField: "narrative",
|
||||
ContextShowLastSummary: true,
|
||||
ContextObsTypes: DefaultObservationTypes,
|
||||
ContextObsConcepts: DefaultObservationConcepts,
|
||||
WorkerPort: DefaultWorkerPort,
|
||||
DBPath: DBPath(),
|
||||
MaxConns: 4,
|
||||
Model: DefaultModel,
|
||||
EmbeddingModel: DefaultEmbeddingModel,
|
||||
RerankingEnabled: true, // Enable by default for improved relevance
|
||||
RerankingCandidates: 100, // Retrieve top 100 candidates
|
||||
RerankingResults: 10, // Return top 10 after reranking
|
||||
RerankingAlpha: 0.7, // Favor cross-encoder score
|
||||
RerankingMinImprovement: 0, // Always apply reranking
|
||||
ContextObservations: 100,
|
||||
ContextFullCount: 25,
|
||||
ContextSessionCount: 10,
|
||||
ContextShowReadTokens: true,
|
||||
ContextShowWorkTokens: true,
|
||||
ContextFullField: "narrative",
|
||||
ContextShowLastSummary: true,
|
||||
ContextObsTypes: DefaultObservationTypes,
|
||||
ContextObsConcepts: DefaultObservationConcepts,
|
||||
ContextRelevanceThreshold: 0.3, // Minimum 30% similarity to include
|
||||
ContextMaxPromptResults: 10, // Cap at 10 results max (0 = no cap, threshold only)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -166,6 +190,28 @@ func Load() (*Config, error) {
|
||||
if v, ok := settings["CLAUDE_CODE_PATH"].(string); ok {
|
||||
cfg.ClaudeCodePath = v
|
||||
}
|
||||
if v, ok := settings["CLAUDE_MNEMONIC_EMBEDDING_MODEL"].(string); ok && v != "" {
|
||||
cfg.EmbeddingModel = v
|
||||
}
|
||||
// Reranking settings
|
||||
if v, ok := settings["CLAUDE_MNEMONIC_RERANKING_ENABLED"].(bool); ok {
|
||||
cfg.RerankingEnabled = v
|
||||
}
|
||||
if v, ok := settings["CLAUDE_MNEMONIC_RERANKING_CANDIDATES"].(float64); ok && v > 0 {
|
||||
cfg.RerankingCandidates = int(v)
|
||||
}
|
||||
if v, ok := settings["CLAUDE_MNEMONIC_RERANKING_RESULTS"].(float64); ok && v > 0 {
|
||||
cfg.RerankingResults = int(v)
|
||||
}
|
||||
if v, ok := settings["CLAUDE_MNEMONIC_RERANKING_ALPHA"].(float64); ok && v >= 0 && v <= 1 {
|
||||
cfg.RerankingAlpha = v
|
||||
}
|
||||
if v, ok := settings["CLAUDE_MNEMONIC_RERANKING_MIN_IMPROVEMENT"].(float64); ok && v >= 0 {
|
||||
cfg.RerankingMinImprovement = v
|
||||
}
|
||||
if v, ok := settings["CLAUDE_MNEMONIC_RERANKING_PURE_MODE"].(bool); ok {
|
||||
cfg.RerankingPureMode = v
|
||||
}
|
||||
if v, ok := settings["CLAUDE_MNEMONIC_CONTEXT_OBSERVATIONS"].(float64); ok {
|
||||
cfg.ContextObservations = int(v)
|
||||
}
|
||||
@@ -181,6 +227,12 @@ func Load() (*Config, error) {
|
||||
if v, ok := settings["CLAUDE_MNEMONIC_CONTEXT_OBS_CONCEPTS"].(string); ok && v != "" {
|
||||
cfg.ContextObsConcepts = splitTrim(v)
|
||||
}
|
||||
if v, ok := settings["CLAUDE_MNEMONIC_CONTEXT_RELEVANCE_THRESHOLD"].(float64); ok && v >= 0 && v <= 1 {
|
||||
cfg.ContextRelevanceThreshold = v
|
||||
}
|
||||
if v, ok := settings["CLAUDE_MNEMONIC_CONTEXT_MAX_PROMPT_RESULTS"].(float64); ok && v >= 0 {
|
||||
cfg.ContextMaxPromptResults = int(v)
|
||||
}
|
||||
|
||||
return cfg, nil
|
||||
}
|
||||
|
||||
@@ -0,0 +1,276 @@
|
||||
// Package sqlite provides SQLite database operations for claude-mnemonic.
|
||||
package sqlite
|
||||
|
||||
import (
|
||||
"context"
|
||||
"time"
|
||||
|
||||
"github.com/lukaszraczylo/claude-mnemonic/pkg/models"
|
||||
)
|
||||
|
||||
// SupersededRetentionDays is the number of days to keep superseded observations before deletion.
|
||||
const SupersededRetentionDays = 3
|
||||
|
||||
// ConflictStore provides conflict-related database operations.
|
||||
type ConflictStore struct {
|
||||
store *Store
|
||||
}
|
||||
|
||||
// NewConflictStore creates a new conflict store.
|
||||
func NewConflictStore(store *Store) *ConflictStore {
|
||||
return &ConflictStore{store: store}
|
||||
}
|
||||
|
||||
// StoreConflict stores a new observation conflict.
|
||||
func (s *ConflictStore) StoreConflict(ctx context.Context, conflict *models.ObservationConflict) (int64, error) {
|
||||
const query = `
|
||||
INSERT INTO observation_conflicts
|
||||
(newer_obs_id, older_obs_id, conflict_type, resolution, reason, detected_at, detected_at_epoch, resolved, resolved_at)
|
||||
VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?)
|
||||
`
|
||||
|
||||
result, err := s.store.ExecContext(ctx, query,
|
||||
conflict.NewerObsID, conflict.OlderObsID,
|
||||
string(conflict.ConflictType), string(conflict.Resolution),
|
||||
conflict.Reason, conflict.DetectedAt, conflict.DetectedAtEpoch,
|
||||
conflict.Resolved, conflict.ResolvedAt,
|
||||
)
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
|
||||
return result.LastInsertId()
|
||||
}
|
||||
|
||||
// MarkObservationSuperseded marks an observation as superseded.
|
||||
func (s *ConflictStore) MarkObservationSuperseded(ctx context.Context, obsID int64) error {
|
||||
const query = `UPDATE observations SET is_superseded = 1 WHERE id = ?`
|
||||
_, err := s.store.ExecContext(ctx, query, obsID)
|
||||
return err
|
||||
}
|
||||
|
||||
// MarkObservationsSuperseded marks multiple observations as superseded.
|
||||
func (s *ConflictStore) MarkObservationsSuperseded(ctx context.Context, obsIDs []int64) error {
|
||||
if len(obsIDs) == 0 {
|
||||
return nil
|
||||
}
|
||||
|
||||
query := `UPDATE observations SET is_superseded = 1 WHERE id IN (?` + repeatPlaceholders(len(obsIDs)-1) + `)` // #nosec G202 -- uses parameterized placeholders
|
||||
args := int64SliceToInterface(obsIDs)
|
||||
_, err := s.store.db.ExecContext(ctx, query, args...)
|
||||
return err
|
||||
}
|
||||
|
||||
// GetConflictsByObservationID retrieves all conflicts involving an observation.
|
||||
func (s *ConflictStore) GetConflictsByObservationID(ctx context.Context, obsID int64) ([]*models.ObservationConflict, error) {
|
||||
const query = `
|
||||
SELECT id, newer_obs_id, older_obs_id, conflict_type, resolution, reason,
|
||||
detected_at, detected_at_epoch, resolved, resolved_at
|
||||
FROM observation_conflicts
|
||||
WHERE newer_obs_id = ? OR older_obs_id = ?
|
||||
ORDER BY detected_at_epoch DESC
|
||||
`
|
||||
|
||||
rows, err := s.store.QueryContext(ctx, query, obsID, obsID)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer rows.Close()
|
||||
|
||||
return s.scanConflictRows(rows)
|
||||
}
|
||||
|
||||
// GetUnresolvedConflicts retrieves all unresolved conflicts.
|
||||
func (s *ConflictStore) GetUnresolvedConflicts(ctx context.Context, limit int) ([]*models.ObservationConflict, error) {
|
||||
const query = `
|
||||
SELECT id, newer_obs_id, older_obs_id, conflict_type, resolution, reason,
|
||||
detected_at, detected_at_epoch, resolved, resolved_at
|
||||
FROM observation_conflicts
|
||||
WHERE resolved = 0
|
||||
ORDER BY detected_at_epoch DESC
|
||||
LIMIT ?
|
||||
`
|
||||
|
||||
rows, err := s.store.QueryContext(ctx, query, limit)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer rows.Close()
|
||||
|
||||
return s.scanConflictRows(rows)
|
||||
}
|
||||
|
||||
// GetSupersededObservationIDs returns IDs of all observations that have been superseded.
|
||||
func (s *ConflictStore) GetSupersededObservationIDs(ctx context.Context, project string) ([]int64, error) {
|
||||
const query = `
|
||||
SELECT DISTINCT older_obs_id
|
||||
FROM observation_conflicts oc
|
||||
JOIN observations o ON o.id = oc.older_obs_id
|
||||
WHERE oc.resolution = 'prefer_newer'
|
||||
AND (o.project = ? OR o.scope = 'global')
|
||||
`
|
||||
|
||||
rows, err := s.store.QueryContext(ctx, query, project)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer rows.Close()
|
||||
|
||||
var ids []int64
|
||||
for rows.Next() {
|
||||
var id int64
|
||||
if err := rows.Scan(&id); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
ids = append(ids, id)
|
||||
}
|
||||
return ids, rows.Err()
|
||||
}
|
||||
|
||||
// ResolveConflict marks a conflict as resolved.
|
||||
func (s *ConflictStore) ResolveConflict(ctx context.Context, conflictID int64, resolution models.ConflictResolution) error {
|
||||
now := time.Now().Format(time.RFC3339)
|
||||
const query = `
|
||||
UPDATE observation_conflicts
|
||||
SET resolved = 1, resolved_at = ?, resolution = ?
|
||||
WHERE id = ?
|
||||
`
|
||||
_, err := s.store.ExecContext(ctx, query, now, string(resolution), conflictID)
|
||||
return err
|
||||
}
|
||||
|
||||
// DeleteConflictsByObservationID deletes all conflicts involving an observation.
|
||||
// Called when an observation is deleted.
|
||||
func (s *ConflictStore) DeleteConflictsByObservationID(ctx context.Context, obsID int64) error {
|
||||
const query = `DELETE FROM observation_conflicts WHERE newer_obs_id = ? OR older_obs_id = ?`
|
||||
_, err := s.store.ExecContext(ctx, query, obsID, obsID)
|
||||
return err
|
||||
}
|
||||
|
||||
// ConflictWithDetails contains a conflict with its observation details.
|
||||
type ConflictWithDetails struct {
|
||||
Conflict *models.ObservationConflict
|
||||
NewerObsTitle string
|
||||
OlderObsTitle string
|
||||
}
|
||||
|
||||
// CleanupSupersededObservations deletes observations that have been superseded for longer than
|
||||
// SupersededRetentionDays. Returns the IDs of deleted observations for downstream cleanup (e.g., vector DB).
|
||||
func (s *ConflictStore) CleanupSupersededObservations(ctx context.Context, project string) ([]int64, error) {
|
||||
// Calculate cutoff time (3 days ago in milliseconds)
|
||||
cutoffEpoch := time.Now().AddDate(0, 0, -SupersededRetentionDays).UnixMilli()
|
||||
|
||||
// First, find the IDs that will be deleted
|
||||
// We delete observations that:
|
||||
// 1. Are marked as superseded
|
||||
// 2. Have a conflict record where they are the older observation
|
||||
// 3. The conflict was detected more than 3 days ago
|
||||
const selectQuery = `
|
||||
SELECT DISTINCT o.id FROM observations o
|
||||
JOIN observation_conflicts oc ON o.id = oc.older_obs_id
|
||||
WHERE o.is_superseded = 1
|
||||
AND o.project = ?
|
||||
AND oc.detected_at_epoch < ?
|
||||
`
|
||||
|
||||
rows, err := s.store.QueryContext(ctx, selectQuery, project, cutoffEpoch)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer rows.Close()
|
||||
|
||||
var toDelete []int64
|
||||
for rows.Next() {
|
||||
var id int64
|
||||
if err := rows.Scan(&id); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
toDelete = append(toDelete, id)
|
||||
}
|
||||
if err := rows.Err(); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if len(toDelete) == 0 {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
// Delete the conflict records first (due to foreign key constraints)
|
||||
for _, obsID := range toDelete {
|
||||
if err := s.DeleteConflictsByObservationID(ctx, obsID); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
|
||||
// Delete the observations
|
||||
deleteQuery := `DELETE FROM observations WHERE id IN (?` + repeatPlaceholders(len(toDelete)-1) + `)` // #nosec G202 -- uses parameterized placeholders
|
||||
args := int64SliceToInterface(toDelete)
|
||||
_, err = s.store.db.ExecContext(ctx, deleteQuery, args...)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return toDelete, nil
|
||||
}
|
||||
|
||||
// GetConflictsWithDetails retrieves all conflicts with observation titles for display.
|
||||
func (s *ConflictStore) GetConflictsWithDetails(ctx context.Context, project string, limit int) ([]*ConflictWithDetails, error) {
|
||||
const query = `
|
||||
SELECT oc.id, oc.newer_obs_id, oc.older_obs_id, oc.conflict_type, oc.resolution, oc.reason,
|
||||
oc.detected_at, oc.detected_at_epoch, oc.resolved, oc.resolved_at,
|
||||
COALESCE(newer.title, '') as newer_title,
|
||||
COALESCE(older.title, '') as older_title
|
||||
FROM observation_conflicts oc
|
||||
JOIN observations newer ON newer.id = oc.newer_obs_id
|
||||
JOIN observations older ON older.id = oc.older_obs_id
|
||||
WHERE newer.project = ? OR older.project = ?
|
||||
ORDER BY oc.detected_at_epoch DESC
|
||||
LIMIT ?
|
||||
`
|
||||
|
||||
rows, err := s.store.QueryContext(ctx, query, project, project, limit)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer rows.Close()
|
||||
|
||||
var results []*ConflictWithDetails
|
||||
for rows.Next() {
|
||||
var c models.ObservationConflict
|
||||
var cwd ConflictWithDetails
|
||||
if err := rows.Scan(
|
||||
&c.ID, &c.NewerObsID, &c.OlderObsID,
|
||||
&c.ConflictType, &c.Resolution, &c.Reason,
|
||||
&c.DetectedAt, &c.DetectedAtEpoch,
|
||||
&c.Resolved, &c.ResolvedAt,
|
||||
&cwd.NewerObsTitle, &cwd.OlderObsTitle,
|
||||
); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
cwd.Conflict = &c
|
||||
results = append(results, &cwd)
|
||||
}
|
||||
return results, rows.Err()
|
||||
}
|
||||
|
||||
// scanConflictRows scans multiple conflicts from rows.
|
||||
func (s *ConflictStore) scanConflictRows(rows interface {
|
||||
Next() bool
|
||||
Scan(...interface{}) error
|
||||
Err() error
|
||||
}) ([]*models.ObservationConflict, error) {
|
||||
var conflicts []*models.ObservationConflict
|
||||
for rows.Next() {
|
||||
var c models.ObservationConflict
|
||||
if err := rows.Scan(
|
||||
&c.ID, &c.NewerObsID, &c.OlderObsID,
|
||||
&c.ConflictType, &c.Resolution, &c.Reason,
|
||||
&c.DetectedAt, &c.DetectedAtEpoch,
|
||||
&c.Resolved, &c.ResolvedAt,
|
||||
); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
conflicts = append(conflicts, &c)
|
||||
}
|
||||
return conflicts, rows.Err()
|
||||
}
|
||||
@@ -283,6 +283,213 @@ var Migrations = []Migration{
|
||||
ON user_prompts(claude_session_id, prompt_number);
|
||||
`,
|
||||
},
|
||||
{
|
||||
Version: 19,
|
||||
Name: "vectors_with_model_version",
|
||||
SQL: `
|
||||
-- Drop old vectors table (virtual tables cannot be altered)
|
||||
DROP TABLE IF EXISTS vectors;
|
||||
|
||||
-- Recreate vectors table with model_version column
|
||||
-- Uses bge-small-en-v1.5 embeddings (384 dimensions)
|
||||
CREATE VIRTUAL TABLE IF NOT EXISTS vectors USING vec0(
|
||||
doc_id TEXT PRIMARY KEY,
|
||||
embedding float[384],
|
||||
sqlite_id INTEGER,
|
||||
doc_type TEXT,
|
||||
field_type TEXT,
|
||||
project TEXT,
|
||||
scope TEXT,
|
||||
model_version TEXT
|
||||
);
|
||||
`,
|
||||
},
|
||||
{
|
||||
Version: 20,
|
||||
Name: "importance_scoring",
|
||||
SQL: `
|
||||
-- Importance scoring system for observations
|
||||
-- Implements multi-factor scoring: type weight, recency decay, user feedback, concept weights, retrieval boost
|
||||
|
||||
-- Cached importance score (recalculated periodically)
|
||||
ALTER TABLE observations ADD COLUMN importance_score REAL DEFAULT 1.0;
|
||||
|
||||
-- User feedback: -1 = thumbs down, 0 = neutral, 1 = thumbs up
|
||||
ALTER TABLE observations ADD COLUMN user_feedback INTEGER DEFAULT 0;
|
||||
|
||||
-- Retrieval tracking: how many times this observation was returned in searches
|
||||
ALTER TABLE observations ADD COLUMN retrieval_count INTEGER DEFAULT 0;
|
||||
|
||||
-- Last time this observation was retrieved (for analytics)
|
||||
ALTER TABLE observations ADD COLUMN last_retrieved_at_epoch INTEGER;
|
||||
|
||||
-- Timestamp of last score recalculation
|
||||
ALTER TABLE observations ADD COLUMN score_updated_at_epoch INTEGER;
|
||||
|
||||
-- Index for importance-based sorting (primary ordering strategy)
|
||||
CREATE INDEX IF NOT EXISTS idx_observations_importance
|
||||
ON observations(importance_score DESC, created_at_epoch DESC);
|
||||
|
||||
-- Index for finding observations needing score recalculation
|
||||
CREATE INDEX IF NOT EXISTS idx_observations_score_updated
|
||||
ON observations(score_updated_at_epoch);
|
||||
|
||||
-- Configurable concept weights table
|
||||
-- Allows runtime tuning of how much each concept contributes to importance
|
||||
CREATE TABLE IF NOT EXISTS concept_weights (
|
||||
concept TEXT PRIMARY KEY,
|
||||
weight REAL NOT NULL DEFAULT 0.1,
|
||||
updated_at TEXT NOT NULL
|
||||
);
|
||||
|
||||
-- Seed default concept weights (security highest, tooling lowest)
|
||||
INSERT OR IGNORE INTO concept_weights (concept, weight, updated_at) VALUES
|
||||
('security', 0.30, datetime('now')),
|
||||
('gotcha', 0.25, datetime('now')),
|
||||
('best-practice', 0.20, datetime('now')),
|
||||
('anti-pattern', 0.20, datetime('now')),
|
||||
('architecture', 0.15, datetime('now')),
|
||||
('performance', 0.15, datetime('now')),
|
||||
('error-handling', 0.15, datetime('now')),
|
||||
('pattern', 0.10, datetime('now')),
|
||||
('testing', 0.10, datetime('now')),
|
||||
('debugging', 0.10, datetime('now')),
|
||||
('workflow', 0.05, datetime('now')),
|
||||
('tooling', 0.05, datetime('now'));
|
||||
`,
|
||||
},
|
||||
{
|
||||
Version: 21,
|
||||
Name: "observation_conflicts",
|
||||
SQL: `
|
||||
-- Observation conflicts table for tracking contradictions and superseded observations
|
||||
-- Implements Issue #5: Contradiction & Obsolescence Detection
|
||||
CREATE TABLE IF NOT EXISTS observation_conflicts (
|
||||
id INTEGER PRIMARY KEY AUTOINCREMENT,
|
||||
newer_obs_id INTEGER NOT NULL,
|
||||
older_obs_id INTEGER NOT NULL,
|
||||
conflict_type TEXT NOT NULL CHECK(conflict_type IN ('superseded', 'contradicts', 'outdated_pattern')),
|
||||
resolution TEXT NOT NULL CHECK(resolution IN ('prefer_newer', 'prefer_older', 'manual')),
|
||||
reason TEXT,
|
||||
detected_at TEXT NOT NULL,
|
||||
detected_at_epoch INTEGER NOT NULL,
|
||||
resolved INTEGER DEFAULT 0,
|
||||
resolved_at TEXT,
|
||||
FOREIGN KEY(newer_obs_id) REFERENCES observations(id) ON DELETE CASCADE,
|
||||
FOREIGN KEY(older_obs_id) REFERENCES observations(id) ON DELETE CASCADE
|
||||
);
|
||||
|
||||
-- Index for looking up conflicts by observation ID
|
||||
CREATE INDEX IF NOT EXISTS idx_conflicts_newer ON observation_conflicts(newer_obs_id);
|
||||
CREATE INDEX IF NOT EXISTS idx_conflicts_older ON observation_conflicts(older_obs_id);
|
||||
CREATE INDEX IF NOT EXISTS idx_conflicts_unresolved ON observation_conflicts(resolved, detected_at_epoch DESC);
|
||||
|
||||
-- Add is_superseded column to observations for quick filtering
|
||||
-- Set to 1 when this observation has been superseded by a newer one
|
||||
ALTER TABLE observations ADD COLUMN is_superseded INTEGER DEFAULT 0;
|
||||
|
||||
-- Index for filtering out superseded observations in queries
|
||||
CREATE INDEX IF NOT EXISTS idx_observations_superseded ON observations(is_superseded, importance_score DESC);
|
||||
`,
|
||||
},
|
||||
{
|
||||
Version: 22,
|
||||
Name: "patterns_table",
|
||||
SQL: `
|
||||
-- Pattern Recognition Engine (Issue #7)
|
||||
-- Tracks recurring patterns detected across observations
|
||||
-- Enables Claude to reference historical insights: "I've encountered this pattern 12 times."
|
||||
CREATE TABLE IF NOT EXISTS patterns (
|
||||
id INTEGER PRIMARY KEY AUTOINCREMENT,
|
||||
name TEXT NOT NULL,
|
||||
type TEXT NOT NULL CHECK(type IN ('bug', 'refactor', 'architecture', 'anti-pattern', 'best-practice')),
|
||||
description TEXT,
|
||||
signature TEXT, -- JSON array of keywords/concepts for detection
|
||||
recommendation TEXT, -- What works for this pattern
|
||||
frequency INTEGER DEFAULT 1, -- How many times encountered
|
||||
projects TEXT, -- JSON array of projects where seen
|
||||
observation_ids TEXT, -- JSON array of source observation IDs
|
||||
status TEXT DEFAULT 'active' CHECK(status IN ('active', 'deprecated', 'merged')),
|
||||
merged_into_id INTEGER, -- If status is 'merged', which pattern it merged into
|
||||
confidence REAL DEFAULT 0.5, -- Detection confidence (0.0-1.0)
|
||||
last_seen_at TEXT NOT NULL,
|
||||
last_seen_at_epoch INTEGER NOT NULL,
|
||||
created_at TEXT NOT NULL,
|
||||
created_at_epoch INTEGER NOT NULL,
|
||||
FOREIGN KEY(merged_into_id) REFERENCES patterns(id) ON DELETE SET NULL
|
||||
);
|
||||
|
||||
-- Indexes for efficient pattern queries
|
||||
CREATE INDEX IF NOT EXISTS idx_patterns_type ON patterns(type);
|
||||
CREATE INDEX IF NOT EXISTS idx_patterns_status ON patterns(status);
|
||||
CREATE INDEX IF NOT EXISTS idx_patterns_frequency ON patterns(frequency DESC);
|
||||
CREATE INDEX IF NOT EXISTS idx_patterns_confidence ON patterns(confidence DESC);
|
||||
CREATE INDEX IF NOT EXISTS idx_patterns_last_seen ON patterns(last_seen_at_epoch DESC);
|
||||
|
||||
-- FTS5 virtual table for pattern search
|
||||
CREATE VIRTUAL TABLE IF NOT EXISTS patterns_fts USING fts5(
|
||||
name, description, recommendation,
|
||||
content='patterns',
|
||||
content_rowid='id'
|
||||
);
|
||||
|
||||
-- Triggers for FTS5 sync
|
||||
CREATE TRIGGER IF NOT EXISTS patterns_ai AFTER INSERT ON patterns BEGIN
|
||||
INSERT INTO patterns_fts(rowid, name, description, recommendation)
|
||||
VALUES (new.id, new.name, new.description, new.recommendation);
|
||||
END;
|
||||
|
||||
CREATE TRIGGER IF NOT EXISTS patterns_ad AFTER DELETE ON patterns BEGIN
|
||||
INSERT INTO patterns_fts(patterns_fts, rowid, name, description, recommendation)
|
||||
VALUES('delete', old.id, old.name, old.description, old.recommendation);
|
||||
END;
|
||||
|
||||
CREATE TRIGGER IF NOT EXISTS patterns_au AFTER UPDATE ON patterns BEGIN
|
||||
INSERT INTO patterns_fts(patterns_fts, rowid, name, description, recommendation)
|
||||
VALUES('delete', old.id, old.name, old.description, old.recommendation);
|
||||
INSERT INTO patterns_fts(rowid, name, description, recommendation)
|
||||
VALUES (new.id, new.name, new.description, new.recommendation);
|
||||
END;
|
||||
`,
|
||||
},
|
||||
{
|
||||
Version: 23,
|
||||
Name: "observation_relations",
|
||||
SQL: `
|
||||
-- Knowledge Graph: Observation Relations (Issue #4)
|
||||
-- Tracks explicit relationships between observations for knowledge graph navigation.
|
||||
-- Enables queries like "What caused this bug?" or "What depends on this decision?"
|
||||
CREATE TABLE IF NOT EXISTS observation_relations (
|
||||
id INTEGER PRIMARY KEY AUTOINCREMENT,
|
||||
source_id INTEGER NOT NULL,
|
||||
target_id INTEGER NOT NULL,
|
||||
relation_type TEXT NOT NULL CHECK(relation_type IN ('causes', 'fixes', 'supersedes', 'depends_on', 'relates_to', 'evolves_from')),
|
||||
confidence REAL NOT NULL DEFAULT 0.5,
|
||||
detection_source TEXT NOT NULL CHECK(detection_source IN ('file_overlap', 'embedding_similarity', 'temporal_proximity', 'narrative_mention', 'concept_overlap', 'type_progression')),
|
||||
reason TEXT,
|
||||
created_at TEXT NOT NULL,
|
||||
created_at_epoch INTEGER NOT NULL,
|
||||
FOREIGN KEY(source_id) REFERENCES observations(id) ON DELETE CASCADE,
|
||||
FOREIGN KEY(target_id) REFERENCES observations(id) ON DELETE CASCADE,
|
||||
UNIQUE(source_id, target_id, relation_type)
|
||||
);
|
||||
|
||||
-- Index for finding relations by source observation
|
||||
CREATE INDEX IF NOT EXISTS idx_relations_source ON observation_relations(source_id);
|
||||
|
||||
-- Index for finding relations by target observation
|
||||
CREATE INDEX IF NOT EXISTS idx_relations_target ON observation_relations(target_id);
|
||||
|
||||
-- Index for relation type queries
|
||||
CREATE INDEX IF NOT EXISTS idx_relations_type ON observation_relations(relation_type);
|
||||
|
||||
-- Index for confidence-based filtering
|
||||
CREATE INDEX IF NOT EXISTS idx_relations_confidence ON observation_relations(confidence DESC);
|
||||
|
||||
-- Index for finding all relations involving an observation (either direction)
|
||||
CREATE INDEX IF NOT EXISTS idx_relations_both ON observation_relations(source_id, target_id);
|
||||
`,
|
||||
},
|
||||
}
|
||||
|
||||
// MigrationManager handles database schema migrations.
|
||||
|
||||
@@ -11,14 +11,27 @@ import (
|
||||
"github.com/lukaszraczylo/claude-mnemonic/pkg/models"
|
||||
)
|
||||
|
||||
// observationColumns is the standard list of columns to select for observations.
|
||||
// This ensures consistency across all observation queries and includes importance scoring fields.
|
||||
const observationColumns = `id, sdk_session_id, project, COALESCE(scope, 'project') as scope, type,
|
||||
title, subtitle, facts, narrative, concepts, files_read, files_modified, file_mtimes,
|
||||
prompt_number, discovery_tokens, created_at, created_at_epoch,
|
||||
COALESCE(importance_score, 1.0) as importance_score,
|
||||
COALESCE(user_feedback, 0) as user_feedback,
|
||||
COALESCE(retrieval_count, 0) as retrieval_count,
|
||||
last_retrieved_at_epoch, score_updated_at_epoch,
|
||||
COALESCE(is_superseded, 0) as is_superseded`
|
||||
|
||||
// CleanupFunc is a callback for when observations are cleaned up.
|
||||
// Receives the IDs of deleted observations for downstream cleanup (e.g., vector DB).
|
||||
type CleanupFunc func(ctx context.Context, deletedIDs []int64)
|
||||
|
||||
// ObservationStore provides observation-related database operations.
|
||||
type ObservationStore struct {
|
||||
store *Store
|
||||
cleanupFunc CleanupFunc
|
||||
store *Store
|
||||
cleanupFunc CleanupFunc
|
||||
conflictStore *ConflictStore
|
||||
relationStore *RelationStore
|
||||
}
|
||||
|
||||
// NewObservationStore creates a new observation store.
|
||||
@@ -31,6 +44,16 @@ func (s *ObservationStore) SetCleanupFunc(fn CleanupFunc) {
|
||||
s.cleanupFunc = fn
|
||||
}
|
||||
|
||||
// SetConflictStore sets the conflict store for conflict detection.
|
||||
func (s *ObservationStore) SetConflictStore(conflictStore *ConflictStore) {
|
||||
s.conflictStore = conflictStore
|
||||
}
|
||||
|
||||
// SetRelationStore sets the relation store for relationship detection.
|
||||
func (s *ObservationStore) SetRelationStore(relationStore *RelationStore) {
|
||||
s.relationStore = relationStore
|
||||
}
|
||||
|
||||
// StoreObservation stores a new observation.
|
||||
func (s *ObservationStore) StoreObservation(ctx context.Context, sdkSessionID, project string, obs *models.ParsedObservation, promptNumber int, discoveryTokens int64) (int64, int64, error) {
|
||||
now := time.Now()
|
||||
@@ -86,9 +109,112 @@ func (s *ObservationStore) StoreObservation(ctx context.Context, sdkSessionID, p
|
||||
}(project)
|
||||
}
|
||||
|
||||
// Detect conflicts with existing observations (async to not block handler)
|
||||
if s.conflictStore != nil && project != "" {
|
||||
go func(newObsID int64, proj string, parsedObs *models.ParsedObservation) {
|
||||
conflictCtx, cancel := context.WithTimeout(context.Background(), 10*time.Second)
|
||||
defer cancel()
|
||||
s.detectAndStoreConflicts(conflictCtx, newObsID, proj, parsedObs)
|
||||
}(id, project, obs)
|
||||
}
|
||||
|
||||
// Detect relationships with existing observations (async to not block handler)
|
||||
if s.relationStore != nil && project != "" {
|
||||
go func(newObsID int64, proj string) {
|
||||
relationCtx, cancel := context.WithTimeout(context.Background(), 10*time.Second)
|
||||
defer cancel()
|
||||
s.detectAndStoreRelations(relationCtx, newObsID, proj)
|
||||
}(id, project)
|
||||
}
|
||||
|
||||
return id, nowEpoch, nil
|
||||
}
|
||||
|
||||
// detectAndStoreConflicts detects conflicts between a new observation and existing ones.
|
||||
func (s *ObservationStore) detectAndStoreConflicts(ctx context.Context, newObsID int64, project string, parsedObs *models.ParsedObservation) {
|
||||
// Fetch the newly stored observation
|
||||
newObs, err := s.GetObservationByID(ctx, newObsID)
|
||||
if err != nil || newObs == nil {
|
||||
return
|
||||
}
|
||||
|
||||
// Fetch recent observations from the same project to check for conflicts
|
||||
existing, err := s.GetRecentObservations(ctx, project, 50)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
|
||||
// Detect conflicts
|
||||
conflicts := models.DetectConflictsWithExisting(newObs, existing)
|
||||
|
||||
// Store conflicts and mark superseded observations
|
||||
var supersededIDs []int64
|
||||
for _, result := range conflicts {
|
||||
for _, olderID := range result.OlderObsIDs {
|
||||
conflict := models.NewObservationConflict(
|
||||
newObsID, olderID,
|
||||
result.Type, result.Resolution, result.Reason,
|
||||
)
|
||||
if _, err := s.conflictStore.StoreConflict(ctx, conflict); err != nil {
|
||||
continue
|
||||
}
|
||||
|
||||
// If resolution is to prefer newer, mark older as superseded
|
||||
if result.Resolution == models.ResolutionPreferNewer {
|
||||
supersededIDs = append(supersededIDs, olderID)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Mark superseded observations
|
||||
if len(supersededIDs) > 0 {
|
||||
_ = s.conflictStore.MarkObservationsSuperseded(ctx, supersededIDs)
|
||||
}
|
||||
|
||||
// Cleanup old superseded observations (older than 3 days)
|
||||
deletedIDs, _ := s.conflictStore.CleanupSupersededObservations(ctx, project)
|
||||
if len(deletedIDs) > 0 && s.cleanupFunc != nil {
|
||||
s.cleanupFunc(ctx, deletedIDs)
|
||||
}
|
||||
}
|
||||
|
||||
// MinRelationConfidence is the minimum confidence threshold for storing relations.
|
||||
const MinRelationConfidence = 0.4
|
||||
|
||||
// detectAndStoreRelations detects relationships between a new observation and existing ones.
|
||||
func (s *ObservationStore) detectAndStoreRelations(ctx context.Context, newObsID int64, project string) {
|
||||
// Fetch the newly stored observation
|
||||
newObs, err := s.GetObservationByID(ctx, newObsID)
|
||||
if err != nil || newObs == nil {
|
||||
return
|
||||
}
|
||||
|
||||
// Fetch recent observations from the same project to check for relations
|
||||
existing, err := s.GetRecentObservations(ctx, project, 50)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
|
||||
// Detect relationships using the models package detection logic
|
||||
results := models.DetectRelationsWithExisting(newObs, existing, MinRelationConfidence)
|
||||
if len(results) == 0 {
|
||||
return
|
||||
}
|
||||
|
||||
// Convert detection results to relation objects
|
||||
relations := make([]*models.ObservationRelation, len(results))
|
||||
for i, r := range results {
|
||||
relations[i] = models.NewObservationRelation(
|
||||
r.SourceID, r.TargetID,
|
||||
r.RelationType, r.Confidence,
|
||||
r.DetectionSource, r.Reason,
|
||||
)
|
||||
}
|
||||
|
||||
// Store all relations
|
||||
_ = s.relationStore.StoreRelations(ctx, relations)
|
||||
}
|
||||
|
||||
// ensureSessionExists creates a session if it doesn't exist.
|
||||
func (s *ObservationStore) ensureSessionExists(ctx context.Context, sdkSessionID, project string) error {
|
||||
return EnsureSessionExists(ctx, s.store, sdkSessionID, project)
|
||||
@@ -96,13 +222,7 @@ func (s *ObservationStore) ensureSessionExists(ctx context.Context, sdkSessionID
|
||||
|
||||
// GetObservationByID retrieves an observation by ID.
|
||||
func (s *ObservationStore) GetObservationByID(ctx context.Context, id int64) (*models.Observation, error) {
|
||||
const query = `
|
||||
SELECT id, sdk_session_id, project, COALESCE(scope, 'project') as scope, type, title, subtitle, facts, narrative,
|
||||
concepts, files_read, files_modified, file_mtimes, prompt_number, discovery_tokens,
|
||||
created_at, created_at_epoch
|
||||
FROM observations
|
||||
WHERE id = ?
|
||||
`
|
||||
query := `SELECT ` + observationColumns + ` FROM observations WHERE id = ?`
|
||||
|
||||
obs, err := scanObservation(s.store.QueryRowContext(ctx, query, id))
|
||||
if err == sql.ErrNoRows {
|
||||
@@ -112,6 +232,7 @@ func (s *ObservationStore) GetObservationByID(ctx context.Context, id int64) (*m
|
||||
}
|
||||
|
||||
// GetObservationsByIDs retrieves observations by a list of IDs.
|
||||
// Results are ordered by importance_score DESC by default, with created_at_epoch as secondary sort.
|
||||
func (s *ObservationStore) GetObservationsByIDs(ctx context.Context, ids []int64, orderBy string, limit int) ([]*models.Observation, error) {
|
||||
if len(ids) == 0 {
|
||||
return nil, nil
|
||||
@@ -119,18 +240,22 @@ func (s *ObservationStore) GetObservationsByIDs(ctx context.Context, ids []int64
|
||||
|
||||
// Build query with placeholders
|
||||
// #nosec G202 -- query uses parameterized placeholders, not user input
|
||||
query := `
|
||||
SELECT id, sdk_session_id, project, COALESCE(scope, 'project') as scope, type, title, subtitle, facts, narrative,
|
||||
concepts, files_read, files_modified, file_mtimes, prompt_number, discovery_tokens,
|
||||
created_at, created_at_epoch
|
||||
query := `SELECT ` + observationColumns + `
|
||||
FROM observations
|
||||
WHERE id IN (?` + repeatPlaceholders(len(ids)-1) + `)
|
||||
ORDER BY created_at_epoch `
|
||||
ORDER BY `
|
||||
|
||||
if orderBy == "date_asc" {
|
||||
query += "ASC"
|
||||
} else {
|
||||
query += "DESC"
|
||||
// Default to importance-based ordering
|
||||
switch orderBy {
|
||||
case "date_asc":
|
||||
query += "created_at_epoch ASC"
|
||||
case "date_desc":
|
||||
query += "created_at_epoch DESC"
|
||||
case "importance":
|
||||
query += "importance_score DESC, created_at_epoch DESC"
|
||||
default:
|
||||
// Default: importance first, then recency
|
||||
query += "COALESCE(importance_score, 1.0) DESC, created_at_epoch DESC"
|
||||
}
|
||||
|
||||
if limit > 0 {
|
||||
@@ -154,14 +279,56 @@ func (s *ObservationStore) GetObservationsByIDs(ctx context.Context, ids []int64
|
||||
|
||||
// GetRecentObservations retrieves recent observations for a project.
|
||||
// This includes project-scoped observations for the specified project AND global observations.
|
||||
// Results are ordered by importance_score DESC, then created_at_epoch DESC.
|
||||
func (s *ObservationStore) GetRecentObservations(ctx context.Context, project string, limit int) ([]*models.Observation, error) {
|
||||
const query = `
|
||||
SELECT id, sdk_session_id, project, COALESCE(scope, 'project') as scope, type, title, subtitle, facts, narrative,
|
||||
concepts, files_read, files_modified, file_mtimes, prompt_number, discovery_tokens,
|
||||
created_at, created_at_epoch
|
||||
query := `SELECT ` + observationColumns + `
|
||||
FROM observations
|
||||
WHERE (project = ? AND (scope IS NULL OR scope = 'project'))
|
||||
OR scope = 'global'
|
||||
ORDER BY COALESCE(importance_score, 1.0) DESC, created_at_epoch DESC
|
||||
LIMIT ?
|
||||
`
|
||||
|
||||
rows, err := s.store.QueryContext(ctx, query, project, limit)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer rows.Close()
|
||||
|
||||
return scanObservationRows(rows)
|
||||
}
|
||||
|
||||
// GetActiveObservations retrieves recent non-superseded observations for a project.
|
||||
// This excludes observations that have been marked as superseded by newer ones.
|
||||
// Use this for context injection where you want to avoid outdated/contradicted advice.
|
||||
// Results are ordered by importance_score DESC, then created_at_epoch DESC.
|
||||
func (s *ObservationStore) GetActiveObservations(ctx context.Context, project string, limit int) ([]*models.Observation, error) {
|
||||
query := `SELECT ` + observationColumns + `
|
||||
FROM observations
|
||||
WHERE ((project = ? AND (scope IS NULL OR scope = 'project'))
|
||||
OR scope = 'global')
|
||||
AND COALESCE(is_superseded, 0) = 0
|
||||
ORDER BY COALESCE(importance_score, 1.0) DESC, created_at_epoch DESC
|
||||
LIMIT ?
|
||||
`
|
||||
|
||||
rows, err := s.store.QueryContext(ctx, query, project, limit)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer rows.Close()
|
||||
|
||||
return scanObservationRows(rows)
|
||||
}
|
||||
|
||||
// GetSupersededObservations retrieves observations that have been superseded by newer ones.
|
||||
// Use this for verification/debugging to see which observations were marked as outdated.
|
||||
// Results are ordered by created_at_epoch DESC.
|
||||
func (s *ObservationStore) GetSupersededObservations(ctx context.Context, project string, limit int) ([]*models.Observation, error) {
|
||||
query := `SELECT ` + observationColumns + `
|
||||
FROM observations
|
||||
WHERE project = ?
|
||||
AND COALESCE(is_superseded, 0) = 1
|
||||
ORDER BY created_at_epoch DESC
|
||||
LIMIT ?
|
||||
`
|
||||
@@ -178,14 +345,12 @@ func (s *ObservationStore) GetRecentObservations(ctx context.Context, project st
|
||||
// GetObservationsByProjectStrict retrieves observations strictly for a specific project.
|
||||
// Unlike GetRecentObservations, this does NOT include global observations from other projects.
|
||||
// Use this for dashboard filtering where the user expects to see only that project's data.
|
||||
// Results are ordered by importance_score DESC, then created_at_epoch DESC.
|
||||
func (s *ObservationStore) GetObservationsByProjectStrict(ctx context.Context, project string, limit int) ([]*models.Observation, error) {
|
||||
const query = `
|
||||
SELECT id, sdk_session_id, project, COALESCE(scope, 'project') as scope, type, title, subtitle, facts, narrative,
|
||||
concepts, files_read, files_modified, file_mtimes, prompt_number, discovery_tokens,
|
||||
created_at, created_at_epoch
|
||||
query := `SELECT ` + observationColumns + `
|
||||
FROM observations
|
||||
WHERE project = ?
|
||||
ORDER BY created_at_epoch DESC
|
||||
ORDER BY COALESCE(importance_score, 1.0) DESC, created_at_epoch DESC
|
||||
LIMIT ?
|
||||
`
|
||||
|
||||
@@ -210,13 +375,11 @@ func (s *ObservationStore) GetObservationCount(ctx context.Context, project stri
|
||||
}
|
||||
|
||||
// GetAllRecentObservations retrieves recent observations across all projects.
|
||||
// Results are ordered by importance_score DESC, then created_at_epoch DESC.
|
||||
func (s *ObservationStore) GetAllRecentObservations(ctx context.Context, limit int) ([]*models.Observation, error) {
|
||||
const query = `
|
||||
SELECT id, sdk_session_id, project, COALESCE(scope, 'project') as scope, type, title, subtitle, facts, narrative,
|
||||
concepts, files_read, files_modified, file_mtimes, prompt_number, discovery_tokens,
|
||||
created_at, created_at_epoch
|
||||
query := `SELECT ` + observationColumns + `
|
||||
FROM observations
|
||||
ORDER BY created_at_epoch DESC
|
||||
ORDER BY COALESCE(importance_score, 1.0) DESC, created_at_epoch DESC
|
||||
LIMIT ?
|
||||
`
|
||||
|
||||
@@ -229,7 +392,24 @@ func (s *ObservationStore) GetAllRecentObservations(ctx context.Context, limit i
|
||||
return scanObservationRows(rows)
|
||||
}
|
||||
|
||||
// GetAllObservations retrieves all observations (for vector rebuild).
|
||||
func (s *ObservationStore) GetAllObservations(ctx context.Context) ([]*models.Observation, error) {
|
||||
query := `SELECT ` + observationColumns + `
|
||||
FROM observations
|
||||
ORDER BY id
|
||||
`
|
||||
|
||||
rows, err := s.store.QueryContext(ctx, query)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer rows.Close()
|
||||
|
||||
return scanObservationRows(rows)
|
||||
}
|
||||
|
||||
// SearchObservationsFTS performs full-text search on observations.
|
||||
// Results are ordered by FTS rank (relevance), then by importance_score.
|
||||
func (s *ObservationStore) SearchObservationsFTS(ctx context.Context, query, project string, limit int) ([]*models.Observation, error) {
|
||||
if limit <= 0 {
|
||||
limit = 10
|
||||
@@ -245,15 +425,21 @@ func (s *ObservationStore) SearchObservationsFTS(ctx context.Context, query, pro
|
||||
ftsTerms := strings.Join(keywords, " OR ")
|
||||
|
||||
// Use FTS5 to search title, subtitle, and narrative
|
||||
const ftsQuery = `
|
||||
// Include importance scoring columns and order by rank then importance
|
||||
ftsQuery := `
|
||||
SELECT o.id, o.sdk_session_id, o.project, COALESCE(o.scope, 'project') as scope, o.type,
|
||||
o.title, o.subtitle, o.facts, o.narrative, o.concepts, o.files_read, o.files_modified,
|
||||
o.file_mtimes, o.prompt_number, o.discovery_tokens, o.created_at, o.created_at_epoch
|
||||
o.file_mtimes, o.prompt_number, o.discovery_tokens, o.created_at, o.created_at_epoch,
|
||||
COALESCE(o.importance_score, 1.0) as importance_score,
|
||||
COALESCE(o.user_feedback, 0) as user_feedback,
|
||||
COALESCE(o.retrieval_count, 0) as retrieval_count,
|
||||
o.last_retrieved_at_epoch, o.score_updated_at_epoch,
|
||||
COALESCE(o.is_superseded, 0) as is_superseded
|
||||
FROM observations o
|
||||
JOIN observations_fts fts ON o.id = fts.rowid
|
||||
WHERE observations_fts MATCH ?
|
||||
AND (o.project = ? OR o.scope = 'global')
|
||||
ORDER BY rank
|
||||
ORDER BY rank, COALESCE(o.importance_score, 1.0) DESC
|
||||
LIMIT ?
|
||||
`
|
||||
|
||||
@@ -278,6 +464,7 @@ func (s *ObservationStore) SearchObservationsFTS(ctx context.Context, query, pro
|
||||
}
|
||||
|
||||
// searchObservationsLike performs fallback LIKE search on observations.
|
||||
// Results are ordered by importance_score DESC, then created_at_epoch DESC.
|
||||
func (s *ObservationStore) searchObservationsLike(ctx context.Context, keywords []string, project string, limit int) ([]*models.Observation, error) {
|
||||
if len(keywords) == 0 {
|
||||
return nil, nil
|
||||
@@ -294,14 +481,11 @@ func (s *ObservationStore) searchObservationsLike(ctx context.Context, keywords
|
||||
}
|
||||
|
||||
// #nosec G202 -- query uses parameterized placeholders, not user input
|
||||
query := `
|
||||
SELECT id, sdk_session_id, project, COALESCE(scope, 'project') as scope, type,
|
||||
title, subtitle, facts, narrative, concepts, files_read, files_modified,
|
||||
file_mtimes, prompt_number, discovery_tokens, created_at, created_at_epoch
|
||||
query := `SELECT ` + observationColumns + `
|
||||
FROM observations
|
||||
WHERE (` + strings.Join(conditions, " OR ") + `)
|
||||
AND (project = ? OR scope = 'global')
|
||||
ORDER BY created_at_epoch DESC
|
||||
ORDER BY COALESCE(importance_score, 1.0) DESC, created_at_epoch DESC
|
||||
LIMIT ?
|
||||
`
|
||||
args = append(args, project, limit)
|
||||
@@ -445,6 +629,11 @@ func scanObservation(scanner interface{ Scan(...interface{}) error }) (*models.O
|
||||
&obs.Concepts, &obs.FilesRead, &obs.FilesModified, &obs.FileMtimes,
|
||||
&obs.PromptNumber, &obs.DiscoveryTokens,
|
||||
&obs.CreatedAt, &obs.CreatedAtEpoch,
|
||||
// Importance scoring fields
|
||||
&obs.ImportanceScore, &obs.UserFeedback, &obs.RetrievalCount,
|
||||
&obs.LastRetrievedAt, &obs.ScoreUpdatedAt,
|
||||
// Conflict detection fields
|
||||
&obs.IsSuperseded,
|
||||
); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
@@ -0,0 +1,370 @@
|
||||
// Package sqlite provides SQLite database operations for claude-mnemonic.
|
||||
package sqlite
|
||||
|
||||
import (
|
||||
"context"
|
||||
"database/sql"
|
||||
"encoding/json"
|
||||
"time"
|
||||
|
||||
"github.com/lukaszraczylo/claude-mnemonic/pkg/models"
|
||||
)
|
||||
|
||||
// patternColumns is the standard list of columns to select for patterns.
|
||||
const patternColumns = `id, name, type, description, signature, recommendation,
|
||||
frequency, projects, observation_ids, status, merged_into_id, confidence,
|
||||
last_seen_at, last_seen_at_epoch, created_at, created_at_epoch`
|
||||
|
||||
// PatternCleanupFunc is a callback for when patterns are deleted.
|
||||
type PatternCleanupFunc func(ctx context.Context, deletedIDs []int64)
|
||||
|
||||
// PatternStore provides pattern-related database operations.
|
||||
type PatternStore struct {
|
||||
store *Store
|
||||
cleanupFunc PatternCleanupFunc
|
||||
}
|
||||
|
||||
// NewPatternStore creates a new pattern store.
|
||||
func NewPatternStore(store *Store) *PatternStore {
|
||||
return &PatternStore{store: store}
|
||||
}
|
||||
|
||||
// SetCleanupFunc sets the callback for when patterns are deleted.
|
||||
func (s *PatternStore) SetCleanupFunc(fn PatternCleanupFunc) {
|
||||
s.cleanupFunc = fn
|
||||
}
|
||||
|
||||
// StorePattern stores a new pattern.
|
||||
func (s *PatternStore) StorePattern(ctx context.Context, pattern *models.Pattern) (int64, error) {
|
||||
signatureJSON, _ := json.Marshal(pattern.Signature)
|
||||
projectsJSON, _ := json.Marshal(pattern.Projects)
|
||||
obsIDsJSON, _ := json.Marshal(pattern.ObservationIDs)
|
||||
|
||||
const query = `
|
||||
INSERT INTO patterns
|
||||
(name, type, description, signature, recommendation, frequency, projects,
|
||||
observation_ids, status, merged_into_id, confidence,
|
||||
last_seen_at, last_seen_at_epoch, created_at, created_at_epoch)
|
||||
VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)
|
||||
`
|
||||
|
||||
result, err := s.store.ExecContext(ctx, query,
|
||||
pattern.Name, string(pattern.Type),
|
||||
nullString(pattern.Description.String), string(signatureJSON),
|
||||
nullString(pattern.Recommendation.String),
|
||||
pattern.Frequency, string(projectsJSON), string(obsIDsJSON),
|
||||
string(pattern.Status), nullInt64(pattern.MergedIntoID),
|
||||
pattern.Confidence, pattern.LastSeenAt, pattern.LastSeenEpoch,
|
||||
pattern.CreatedAt, pattern.CreatedAtEpoch,
|
||||
)
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
|
||||
return result.LastInsertId()
|
||||
}
|
||||
|
||||
// UpdatePattern updates an existing pattern.
|
||||
func (s *PatternStore) UpdatePattern(ctx context.Context, pattern *models.Pattern) error {
|
||||
signatureJSON, _ := json.Marshal(pattern.Signature)
|
||||
projectsJSON, _ := json.Marshal(pattern.Projects)
|
||||
obsIDsJSON, _ := json.Marshal(pattern.ObservationIDs)
|
||||
|
||||
const query = `
|
||||
UPDATE patterns SET
|
||||
name = ?, type = ?, description = ?, signature = ?, recommendation = ?,
|
||||
frequency = ?, projects = ?, observation_ids = ?, status = ?,
|
||||
merged_into_id = ?, confidence = ?, last_seen_at = ?, last_seen_at_epoch = ?
|
||||
WHERE id = ?
|
||||
`
|
||||
|
||||
_, err := s.store.ExecContext(ctx, query,
|
||||
pattern.Name, string(pattern.Type),
|
||||
nullString(pattern.Description.String), string(signatureJSON),
|
||||
nullString(pattern.Recommendation.String),
|
||||
pattern.Frequency, string(projectsJSON), string(obsIDsJSON),
|
||||
string(pattern.Status), nullInt64(pattern.MergedIntoID),
|
||||
pattern.Confidence, pattern.LastSeenAt, pattern.LastSeenEpoch,
|
||||
pattern.ID,
|
||||
)
|
||||
return err
|
||||
}
|
||||
|
||||
// GetPatternByID retrieves a pattern by ID.
|
||||
func (s *PatternStore) GetPatternByID(ctx context.Context, id int64) (*models.Pattern, error) {
|
||||
query := `SELECT ` + patternColumns + ` FROM patterns WHERE id = ?`
|
||||
|
||||
row := s.store.QueryRowContext(ctx, query, id)
|
||||
return scanPattern(row)
|
||||
}
|
||||
|
||||
// GetPatternByName retrieves a pattern by name.
|
||||
func (s *PatternStore) GetPatternByName(ctx context.Context, name string) (*models.Pattern, error) {
|
||||
query := `SELECT ` + patternColumns + ` FROM patterns WHERE name = ? AND status = 'active'`
|
||||
|
||||
row := s.store.QueryRowContext(ctx, query, name)
|
||||
pattern, err := scanPattern(row)
|
||||
if err == sql.ErrNoRows {
|
||||
return nil, nil
|
||||
}
|
||||
return pattern, err
|
||||
}
|
||||
|
||||
// GetActivePatterns retrieves all active patterns.
|
||||
func (s *PatternStore) GetActivePatterns(ctx context.Context, limit int) ([]*models.Pattern, error) {
|
||||
query := `SELECT ` + patternColumns + `
|
||||
FROM patterns
|
||||
WHERE status = 'active'
|
||||
ORDER BY frequency DESC, confidence DESC
|
||||
LIMIT ?`
|
||||
|
||||
rows, err := s.store.QueryContext(ctx, query, limit)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer rows.Close()
|
||||
|
||||
return scanPatternRows(rows)
|
||||
}
|
||||
|
||||
// GetPatternsByType retrieves patterns of a specific type.
|
||||
func (s *PatternStore) GetPatternsByType(ctx context.Context, patternType models.PatternType, limit int) ([]*models.Pattern, error) {
|
||||
query := `SELECT ` + patternColumns + `
|
||||
FROM patterns
|
||||
WHERE type = ? AND status = 'active'
|
||||
ORDER BY frequency DESC, confidence DESC
|
||||
LIMIT ?`
|
||||
|
||||
rows, err := s.store.QueryContext(ctx, query, string(patternType), limit)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer rows.Close()
|
||||
|
||||
return scanPatternRows(rows)
|
||||
}
|
||||
|
||||
// GetPatternsByProject retrieves patterns that have been observed in a specific project.
|
||||
func (s *PatternStore) GetPatternsByProject(ctx context.Context, project string, limit int) ([]*models.Pattern, error) {
|
||||
// Use JSON path to search within the projects array
|
||||
query := `SELECT ` + patternColumns + `
|
||||
FROM patterns
|
||||
WHERE status = 'active'
|
||||
AND EXISTS (
|
||||
SELECT 1 FROM json_each(projects)
|
||||
WHERE json_each.value = ?
|
||||
)
|
||||
ORDER BY frequency DESC, confidence DESC
|
||||
LIMIT ?`
|
||||
|
||||
rows, err := s.store.QueryContext(ctx, query, project, limit)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer rows.Close()
|
||||
|
||||
return scanPatternRows(rows)
|
||||
}
|
||||
|
||||
// FindMatchingPatterns searches for patterns that match a given signature.
|
||||
func (s *PatternStore) FindMatchingPatterns(ctx context.Context, signature []string, minScore float64) ([]*models.Pattern, error) {
|
||||
// Get all active patterns and filter by signature match in Go
|
||||
// This is simpler than complex SQL for JSON array matching
|
||||
patterns, err := s.GetActivePatterns(ctx, 100)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
var matches []*models.Pattern
|
||||
for _, pattern := range patterns {
|
||||
score := models.CalculateMatchScore(signature, pattern.Signature)
|
||||
if score >= minScore {
|
||||
matches = append(matches, pattern)
|
||||
}
|
||||
}
|
||||
return matches, nil
|
||||
}
|
||||
|
||||
// MarkPatternDeprecated marks a pattern as deprecated.
|
||||
func (s *PatternStore) MarkPatternDeprecated(ctx context.Context, id int64) error {
|
||||
const query = `UPDATE patterns SET status = 'deprecated' WHERE id = ?`
|
||||
_, err := s.store.ExecContext(ctx, query, id)
|
||||
return err
|
||||
}
|
||||
|
||||
// MergePatterns merges a source pattern into a target pattern.
|
||||
func (s *PatternStore) MergePatterns(ctx context.Context, sourceID, targetID int64) error {
|
||||
// Get both patterns
|
||||
source, err := s.GetPatternByID(ctx, sourceID)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
target, err := s.GetPatternByID(ctx, targetID)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// Merge source into target
|
||||
target.Frequency += source.Frequency
|
||||
for _, proj := range source.Projects {
|
||||
found := false
|
||||
for _, existing := range target.Projects {
|
||||
if existing == proj {
|
||||
found = true
|
||||
break
|
||||
}
|
||||
}
|
||||
if !found {
|
||||
target.Projects = append(target.Projects, proj)
|
||||
}
|
||||
}
|
||||
for _, obsID := range source.ObservationIDs {
|
||||
found := false
|
||||
for _, existing := range target.ObservationIDs {
|
||||
if existing == obsID {
|
||||
found = true
|
||||
break
|
||||
}
|
||||
}
|
||||
if !found {
|
||||
target.ObservationIDs = append(target.ObservationIDs, obsID)
|
||||
}
|
||||
}
|
||||
|
||||
// Update target
|
||||
if err := s.UpdatePattern(ctx, target); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// Mark source as merged
|
||||
source.Status = models.PatternStatusMerged
|
||||
source.MergedIntoID = sql.NullInt64{Int64: targetID, Valid: true}
|
||||
return s.UpdatePattern(ctx, source)
|
||||
}
|
||||
|
||||
// DeletePattern deletes a pattern by ID.
|
||||
func (s *PatternStore) DeletePattern(ctx context.Context, id int64) error {
|
||||
const query = `DELETE FROM patterns WHERE id = ?`
|
||||
_, err := s.store.ExecContext(ctx, query, id)
|
||||
if err == nil && s.cleanupFunc != nil {
|
||||
s.cleanupFunc(ctx, []int64{id})
|
||||
}
|
||||
return err
|
||||
}
|
||||
|
||||
// SearchPatternsFTS performs full-text search on patterns.
|
||||
func (s *PatternStore) SearchPatternsFTS(ctx context.Context, searchQuery string, limit int) ([]*models.Pattern, error) {
|
||||
query := `SELECT p.` + patternColumns + `
|
||||
FROM patterns p
|
||||
JOIN patterns_fts fts ON p.id = fts.rowid
|
||||
WHERE patterns_fts MATCH ?
|
||||
AND p.status = 'active'
|
||||
ORDER BY rank
|
||||
LIMIT ?`
|
||||
|
||||
rows, err := s.store.QueryContext(ctx, query, searchQuery, limit)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer rows.Close()
|
||||
|
||||
return scanPatternRows(rows)
|
||||
}
|
||||
|
||||
// GetPatternStats returns statistics about patterns.
|
||||
func (s *PatternStore) GetPatternStats(ctx context.Context) (*PatternStats, error) {
|
||||
const query = `
|
||||
SELECT
|
||||
COUNT(*) as total,
|
||||
COUNT(CASE WHEN status = 'active' THEN 1 END) as active,
|
||||
COUNT(CASE WHEN status = 'deprecated' THEN 1 END) as deprecated,
|
||||
COUNT(CASE WHEN status = 'merged' THEN 1 END) as merged,
|
||||
COALESCE(SUM(frequency), 0) as total_occurrences,
|
||||
COALESCE(AVG(confidence), 0) as avg_confidence,
|
||||
COUNT(CASE WHEN type = 'bug' THEN 1 END) as bugs,
|
||||
COUNT(CASE WHEN type = 'refactor' THEN 1 END) as refactors,
|
||||
COUNT(CASE WHEN type = 'architecture' THEN 1 END) as architectures,
|
||||
COUNT(CASE WHEN type = 'anti-pattern' THEN 1 END) as anti_patterns,
|
||||
COUNT(CASE WHEN type = 'best-practice' THEN 1 END) as best_practices
|
||||
FROM patterns
|
||||
`
|
||||
|
||||
var stats PatternStats
|
||||
err := s.store.QueryRowContext(ctx, query).Scan(
|
||||
&stats.Total, &stats.Active, &stats.Deprecated, &stats.Merged,
|
||||
&stats.TotalOccurrences, &stats.AvgConfidence,
|
||||
&stats.Bugs, &stats.Refactors, &stats.Architectures,
|
||||
&stats.AntiPatterns, &stats.BestPractices,
|
||||
)
|
||||
return &stats, err
|
||||
}
|
||||
|
||||
// PatternStats contains aggregate statistics about patterns.
|
||||
type PatternStats struct {
|
||||
Total int `json:"total"`
|
||||
Active int `json:"active"`
|
||||
Deprecated int `json:"deprecated"`
|
||||
Merged int `json:"merged"`
|
||||
TotalOccurrences int `json:"total_occurrences"`
|
||||
AvgConfidence float64 `json:"avg_confidence"`
|
||||
Bugs int `json:"bugs"`
|
||||
Refactors int `json:"refactors"`
|
||||
Architectures int `json:"architectures"`
|
||||
AntiPatterns int `json:"anti_patterns"`
|
||||
BestPractices int `json:"best_practices"`
|
||||
}
|
||||
|
||||
// scanPattern scans a single pattern from a row scanner.
|
||||
func scanPattern(scanner interface{ Scan(...interface{}) error }) (*models.Pattern, error) {
|
||||
var pattern models.Pattern
|
||||
if err := scanner.Scan(
|
||||
&pattern.ID, &pattern.Name, &pattern.Type,
|
||||
&pattern.Description, &pattern.Signature, &pattern.Recommendation,
|
||||
&pattern.Frequency, &pattern.Projects, &pattern.ObservationIDs,
|
||||
&pattern.Status, &pattern.MergedIntoID, &pattern.Confidence,
|
||||
&pattern.LastSeenAt, &pattern.LastSeenEpoch,
|
||||
&pattern.CreatedAt, &pattern.CreatedAtEpoch,
|
||||
); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return &pattern, nil
|
||||
}
|
||||
|
||||
// scanPatternRows scans multiple patterns from rows.
|
||||
func scanPatternRows(rows *sql.Rows) ([]*models.Pattern, error) {
|
||||
var patterns []*models.Pattern
|
||||
for rows.Next() {
|
||||
pattern, err := scanPattern(rows)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
patterns = append(patterns, pattern)
|
||||
}
|
||||
return patterns, rows.Err()
|
||||
}
|
||||
|
||||
// nullInt64 converts sql.NullInt64 to the value needed for database insertion.
|
||||
func nullInt64(n sql.NullInt64) interface{} {
|
||||
if n.Valid {
|
||||
return n.Int64
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// IncrementPatternFrequency atomically increments a pattern's frequency and updates last_seen.
|
||||
func (s *PatternStore) IncrementPatternFrequency(ctx context.Context, id int64, project string, observationID int64) error {
|
||||
now := time.Now()
|
||||
|
||||
// Get current pattern
|
||||
pattern, err := s.GetPatternByID(ctx, id)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// Add occurrence
|
||||
pattern.AddOccurrence(project, observationID)
|
||||
pattern.LastSeenAt = now.Format(time.RFC3339)
|
||||
pattern.LastSeenEpoch = now.UnixMilli()
|
||||
|
||||
return s.UpdatePattern(ctx, pattern)
|
||||
}
|
||||
@@ -0,0 +1,507 @@
|
||||
package sqlite
|
||||
|
||||
import (
|
||||
"context"
|
||||
"database/sql"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/lukaszraczylo/claude-mnemonic/pkg/models"
|
||||
)
|
||||
|
||||
// setupPatternTestStore creates a test store with patterns table.
|
||||
func setupPatternTestStore(t *testing.T) *Store {
|
||||
t.Helper()
|
||||
db, _, cleanup := testDB(t)
|
||||
t.Cleanup(cleanup)
|
||||
createBaseTables(t, db)
|
||||
return newStoreFromDB(db)
|
||||
}
|
||||
|
||||
func TestPatternStore_StoreAndGet(t *testing.T) {
|
||||
store := setupPatternTestStore(t)
|
||||
|
||||
patternStore := NewPatternStore(store)
|
||||
ctx := context.Background()
|
||||
|
||||
// Create test pattern
|
||||
pattern := &models.Pattern{
|
||||
Name: "Test Pattern",
|
||||
Type: models.PatternTypeBug,
|
||||
Description: sql.NullString{String: "A test pattern", Valid: true},
|
||||
Signature: []string{"nil", "error"},
|
||||
Recommendation: sql.NullString{String: "Always check for nil", Valid: true},
|
||||
Frequency: 1,
|
||||
Projects: []string{"project1"},
|
||||
ObservationIDs: []int64{1, 2},
|
||||
Status: models.PatternStatusActive,
|
||||
Confidence: 0.5,
|
||||
LastSeenAt: time.Now().Format(time.RFC3339),
|
||||
LastSeenEpoch: time.Now().UnixMilli(),
|
||||
CreatedAt: time.Now().Format(time.RFC3339),
|
||||
CreatedAtEpoch: time.Now().UnixMilli(),
|
||||
}
|
||||
|
||||
// Store pattern
|
||||
id, err := patternStore.StorePattern(ctx, pattern)
|
||||
if err != nil {
|
||||
t.Fatalf("StorePattern() error = %v", err)
|
||||
}
|
||||
if id <= 0 {
|
||||
t.Errorf("Expected positive ID, got %d", id)
|
||||
}
|
||||
|
||||
// Get pattern by ID
|
||||
retrieved, err := patternStore.GetPatternByID(ctx, id)
|
||||
if err != nil {
|
||||
t.Fatalf("GetPatternByID() error = %v", err)
|
||||
}
|
||||
|
||||
if retrieved.Name != pattern.Name {
|
||||
t.Errorf("Expected name %s, got %s", pattern.Name, retrieved.Name)
|
||||
}
|
||||
if retrieved.Type != pattern.Type {
|
||||
t.Errorf("Expected type %s, got %s", pattern.Type, retrieved.Type)
|
||||
}
|
||||
if len(retrieved.Signature) != len(pattern.Signature) {
|
||||
t.Errorf("Expected %d signature elements, got %d",
|
||||
len(pattern.Signature), len(retrieved.Signature))
|
||||
}
|
||||
}
|
||||
|
||||
func TestPatternStore_GetByName(t *testing.T) {
|
||||
store := setupPatternTestStore(t)
|
||||
|
||||
patternStore := NewPatternStore(store)
|
||||
ctx := context.Background()
|
||||
|
||||
pattern := createTestPattern("Unique Name Pattern")
|
||||
_, err := patternStore.StorePattern(ctx, pattern)
|
||||
if err != nil {
|
||||
t.Fatalf("StorePattern() error = %v", err)
|
||||
}
|
||||
|
||||
// Get by name
|
||||
retrieved, err := patternStore.GetPatternByName(ctx, "Unique Name Pattern")
|
||||
if err != nil {
|
||||
t.Fatalf("GetPatternByName() error = %v", err)
|
||||
}
|
||||
if retrieved == nil {
|
||||
t.Fatal("Expected pattern, got nil")
|
||||
}
|
||||
if retrieved.Name != "Unique Name Pattern" {
|
||||
t.Errorf("Expected name 'Unique Name Pattern', got '%s'", retrieved.Name)
|
||||
}
|
||||
|
||||
// Get non-existent pattern
|
||||
nonExistent, err := patternStore.GetPatternByName(ctx, "Non Existent")
|
||||
if err != nil {
|
||||
t.Fatalf("GetPatternByName() error = %v", err)
|
||||
}
|
||||
if nonExistent != nil {
|
||||
t.Errorf("Expected nil for non-existent pattern, got %v", nonExistent)
|
||||
}
|
||||
}
|
||||
|
||||
func TestPatternStore_GetActivePatterns(t *testing.T) {
|
||||
store := setupPatternTestStore(t)
|
||||
|
||||
patternStore := NewPatternStore(store)
|
||||
ctx := context.Background()
|
||||
|
||||
// Create multiple patterns with different statuses
|
||||
active1 := createTestPattern("Active 1")
|
||||
active1.Frequency = 5
|
||||
active2 := createTestPattern("Active 2")
|
||||
active2.Frequency = 3
|
||||
deprecated := createTestPattern("Deprecated")
|
||||
deprecated.Status = models.PatternStatusDeprecated
|
||||
|
||||
patternStore.StorePattern(ctx, active1)
|
||||
patternStore.StorePattern(ctx, active2)
|
||||
patternStore.StorePattern(ctx, deprecated)
|
||||
|
||||
// Get active patterns
|
||||
patterns, err := patternStore.GetActivePatterns(ctx, 10)
|
||||
if err != nil {
|
||||
t.Fatalf("GetActivePatterns() error = %v", err)
|
||||
}
|
||||
|
||||
if len(patterns) != 2 {
|
||||
t.Errorf("Expected 2 active patterns, got %d", len(patterns))
|
||||
}
|
||||
|
||||
// Check order (should be by frequency descending)
|
||||
if len(patterns) >= 2 {
|
||||
if patterns[0].Frequency < patterns[1].Frequency {
|
||||
t.Errorf("Patterns not ordered by frequency descending")
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestPatternStore_GetPatternsByType(t *testing.T) {
|
||||
store := setupPatternTestStore(t)
|
||||
|
||||
patternStore := NewPatternStore(store)
|
||||
ctx := context.Background()
|
||||
|
||||
// Create patterns of different types
|
||||
bugPattern := createTestPattern("Bug Pattern")
|
||||
bugPattern.Type = models.PatternTypeBug
|
||||
|
||||
refactorPattern := createTestPattern("Refactor Pattern")
|
||||
refactorPattern.Type = models.PatternTypeRefactor
|
||||
|
||||
patternStore.StorePattern(ctx, bugPattern)
|
||||
patternStore.StorePattern(ctx, refactorPattern)
|
||||
|
||||
// Get by type
|
||||
bugs, err := patternStore.GetPatternsByType(ctx, models.PatternTypeBug, 10)
|
||||
if err != nil {
|
||||
t.Fatalf("GetPatternsByType() error = %v", err)
|
||||
}
|
||||
if len(bugs) != 1 {
|
||||
t.Errorf("Expected 1 bug pattern, got %d", len(bugs))
|
||||
}
|
||||
|
||||
refactors, err := patternStore.GetPatternsByType(ctx, models.PatternTypeRefactor, 10)
|
||||
if err != nil {
|
||||
t.Fatalf("GetPatternsByType() error = %v", err)
|
||||
}
|
||||
if len(refactors) != 1 {
|
||||
t.Errorf("Expected 1 refactor pattern, got %d", len(refactors))
|
||||
}
|
||||
}
|
||||
|
||||
func TestPatternStore_GetPatternsByProject(t *testing.T) {
|
||||
store := setupPatternTestStore(t)
|
||||
|
||||
patternStore := NewPatternStore(store)
|
||||
ctx := context.Background()
|
||||
|
||||
// Create patterns with different projects
|
||||
pattern1 := createTestPattern("Pattern 1")
|
||||
pattern1.Projects = []string{"project-a", "project-b"}
|
||||
|
||||
pattern2 := createTestPattern("Pattern 2")
|
||||
pattern2.Projects = []string{"project-b", "project-c"}
|
||||
|
||||
patternStore.StorePattern(ctx, pattern1)
|
||||
patternStore.StorePattern(ctx, pattern2)
|
||||
|
||||
// Get by project
|
||||
projectA, err := patternStore.GetPatternsByProject(ctx, "project-a", 10)
|
||||
if err != nil {
|
||||
t.Fatalf("GetPatternsByProject() error = %v", err)
|
||||
}
|
||||
if len(projectA) != 1 {
|
||||
t.Errorf("Expected 1 pattern for project-a, got %d", len(projectA))
|
||||
}
|
||||
|
||||
projectB, err := patternStore.GetPatternsByProject(ctx, "project-b", 10)
|
||||
if err != nil {
|
||||
t.Fatalf("GetPatternsByProject() error = %v", err)
|
||||
}
|
||||
if len(projectB) != 2 {
|
||||
t.Errorf("Expected 2 patterns for project-b, got %d", len(projectB))
|
||||
}
|
||||
}
|
||||
|
||||
func TestPatternStore_UpdatePattern(t *testing.T) {
|
||||
store := setupPatternTestStore(t)
|
||||
|
||||
patternStore := NewPatternStore(store)
|
||||
ctx := context.Background()
|
||||
|
||||
// Create and store pattern
|
||||
pattern := createTestPattern("Original Name")
|
||||
id, _ := patternStore.StorePattern(ctx, pattern)
|
||||
|
||||
// Update pattern
|
||||
pattern.ID = id
|
||||
pattern.Name = "Updated Name"
|
||||
pattern.Frequency = 10
|
||||
pattern.Confidence = 0.9
|
||||
|
||||
err := patternStore.UpdatePattern(ctx, pattern)
|
||||
if err != nil {
|
||||
t.Fatalf("UpdatePattern() error = %v", err)
|
||||
}
|
||||
|
||||
// Verify update
|
||||
updated, _ := patternStore.GetPatternByID(ctx, id)
|
||||
if updated.Name != "Updated Name" {
|
||||
t.Errorf("Expected name 'Updated Name', got '%s'", updated.Name)
|
||||
}
|
||||
if updated.Frequency != 10 {
|
||||
t.Errorf("Expected frequency 10, got %d", updated.Frequency)
|
||||
}
|
||||
if updated.Confidence != 0.9 {
|
||||
t.Errorf("Expected confidence 0.9, got %f", updated.Confidence)
|
||||
}
|
||||
}
|
||||
|
||||
func TestPatternStore_DeletePattern(t *testing.T) {
|
||||
store := setupPatternTestStore(t)
|
||||
|
||||
patternStore := NewPatternStore(store)
|
||||
ctx := context.Background()
|
||||
|
||||
// Create and store pattern
|
||||
pattern := createTestPattern("To Delete")
|
||||
id, _ := patternStore.StorePattern(ctx, pattern)
|
||||
|
||||
// Delete pattern
|
||||
err := patternStore.DeletePattern(ctx, id)
|
||||
if err != nil {
|
||||
t.Fatalf("DeletePattern() error = %v", err)
|
||||
}
|
||||
|
||||
// Verify deletion
|
||||
deleted, err := patternStore.GetPatternByID(ctx, id)
|
||||
if err != sql.ErrNoRows {
|
||||
t.Errorf("Expected ErrNoRows, got %v", err)
|
||||
}
|
||||
if deleted != nil {
|
||||
t.Errorf("Expected nil for deleted pattern")
|
||||
}
|
||||
}
|
||||
|
||||
func TestPatternStore_MarkPatternDeprecated(t *testing.T) {
|
||||
store := setupPatternTestStore(t)
|
||||
|
||||
patternStore := NewPatternStore(store)
|
||||
ctx := context.Background()
|
||||
|
||||
// Create and store pattern
|
||||
pattern := createTestPattern("To Deprecate")
|
||||
id, _ := patternStore.StorePattern(ctx, pattern)
|
||||
|
||||
// Mark as deprecated
|
||||
err := patternStore.MarkPatternDeprecated(ctx, id)
|
||||
if err != nil {
|
||||
t.Fatalf("MarkPatternDeprecated() error = %v", err)
|
||||
}
|
||||
|
||||
// Verify status
|
||||
deprecated, _ := patternStore.GetPatternByID(ctx, id)
|
||||
if deprecated.Status != models.PatternStatusDeprecated {
|
||||
t.Errorf("Expected status 'deprecated', got '%s'", deprecated.Status)
|
||||
}
|
||||
}
|
||||
|
||||
func TestPatternStore_MergePatterns(t *testing.T) {
|
||||
store := setupPatternTestStore(t)
|
||||
|
||||
patternStore := NewPatternStore(store)
|
||||
ctx := context.Background()
|
||||
|
||||
// Create source and target patterns
|
||||
source := createTestPattern("Source Pattern")
|
||||
source.Frequency = 3
|
||||
source.Projects = []string{"proj1", "proj2"}
|
||||
source.ObservationIDs = []int64{1, 2, 3}
|
||||
|
||||
target := createTestPattern("Target Pattern")
|
||||
target.Frequency = 2
|
||||
target.Projects = []string{"proj2", "proj3"}
|
||||
target.ObservationIDs = []int64{4, 5}
|
||||
|
||||
sourceID, _ := patternStore.StorePattern(ctx, source)
|
||||
targetID, _ := patternStore.StorePattern(ctx, target)
|
||||
|
||||
// Merge
|
||||
err := patternStore.MergePatterns(ctx, sourceID, targetID)
|
||||
if err != nil {
|
||||
t.Fatalf("MergePatterns() error = %v", err)
|
||||
}
|
||||
|
||||
// Verify source is marked as merged
|
||||
mergedSource, _ := patternStore.GetPatternByID(ctx, sourceID)
|
||||
if mergedSource.Status != models.PatternStatusMerged {
|
||||
t.Errorf("Expected source status 'merged', got '%s'", mergedSource.Status)
|
||||
}
|
||||
if !mergedSource.MergedIntoID.Valid || mergedSource.MergedIntoID.Int64 != targetID {
|
||||
t.Errorf("Expected source merged_into_id to be %d", targetID)
|
||||
}
|
||||
|
||||
// Verify target has combined data
|
||||
mergedTarget, _ := patternStore.GetPatternByID(ctx, targetID)
|
||||
expectedFrequency := 5 // 3 + 2
|
||||
if mergedTarget.Frequency != expectedFrequency {
|
||||
t.Errorf("Expected merged frequency %d, got %d", expectedFrequency, mergedTarget.Frequency)
|
||||
}
|
||||
// Should have 3 unique projects: proj1, proj2, proj3
|
||||
if len(mergedTarget.Projects) != 3 {
|
||||
t.Errorf("Expected 3 projects after merge, got %d", len(mergedTarget.Projects))
|
||||
}
|
||||
// Should have 5 observation IDs
|
||||
if len(mergedTarget.ObservationIDs) != 5 {
|
||||
t.Errorf("Expected 5 observation IDs after merge, got %d", len(mergedTarget.ObservationIDs))
|
||||
}
|
||||
}
|
||||
|
||||
func TestPatternStore_FindMatchingPatterns(t *testing.T) {
|
||||
store := setupPatternTestStore(t)
|
||||
|
||||
patternStore := NewPatternStore(store)
|
||||
ctx := context.Background()
|
||||
|
||||
// Create patterns with known signatures
|
||||
pattern1 := createTestPattern("Pattern 1")
|
||||
pattern1.Signature = []string{"nil", "error", "handling"}
|
||||
|
||||
pattern2 := createTestPattern("Pattern 2")
|
||||
pattern2.Signature = []string{"nil", "pointer", "check"}
|
||||
|
||||
pattern3 := createTestPattern("Pattern 3")
|
||||
pattern3.Signature = []string{"refactor", "extract", "method"}
|
||||
|
||||
patternStore.StorePattern(ctx, pattern1)
|
||||
patternStore.StorePattern(ctx, pattern2)
|
||||
patternStore.StorePattern(ctx, pattern3)
|
||||
|
||||
// Find patterns matching "nil" related signature
|
||||
matches, err := patternStore.FindMatchingPatterns(ctx, []string{"nil", "error"}, 0.3)
|
||||
if err != nil {
|
||||
t.Fatalf("FindMatchingPatterns() error = %v", err)
|
||||
}
|
||||
|
||||
if len(matches) < 1 {
|
||||
t.Errorf("Expected at least 1 match, got %d", len(matches))
|
||||
}
|
||||
|
||||
// Verify no match for unrelated signature
|
||||
noMatches, err := patternStore.FindMatchingPatterns(ctx, []string{"completely", "different"}, 0.5)
|
||||
if err != nil {
|
||||
t.Fatalf("FindMatchingPatterns() error = %v", err)
|
||||
}
|
||||
if len(noMatches) != 0 {
|
||||
t.Errorf("Expected 0 matches for unrelated signature, got %d", len(noMatches))
|
||||
}
|
||||
}
|
||||
|
||||
func TestPatternStore_IncrementPatternFrequency(t *testing.T) {
|
||||
store := setupPatternTestStore(t)
|
||||
|
||||
patternStore := NewPatternStore(store)
|
||||
ctx := context.Background()
|
||||
|
||||
// Create pattern
|
||||
pattern := createTestPattern("Frequency Test")
|
||||
pattern.Frequency = 1
|
||||
pattern.Projects = []string{"proj1"}
|
||||
pattern.ObservationIDs = []int64{1}
|
||||
|
||||
id, _ := patternStore.StorePattern(ctx, pattern)
|
||||
|
||||
// Increment frequency
|
||||
err := patternStore.IncrementPatternFrequency(ctx, id, "proj2", 2)
|
||||
if err != nil {
|
||||
t.Fatalf("IncrementPatternFrequency() error = %v", err)
|
||||
}
|
||||
|
||||
// Verify
|
||||
updated, _ := patternStore.GetPatternByID(ctx, id)
|
||||
if updated.Frequency != 2 {
|
||||
t.Errorf("Expected frequency 2, got %d", updated.Frequency)
|
||||
}
|
||||
if len(updated.Projects) != 2 {
|
||||
t.Errorf("Expected 2 projects, got %d", len(updated.Projects))
|
||||
}
|
||||
if len(updated.ObservationIDs) != 2 {
|
||||
t.Errorf("Expected 2 observation IDs, got %d", len(updated.ObservationIDs))
|
||||
}
|
||||
}
|
||||
|
||||
func TestPatternStore_GetPatternStats(t *testing.T) {
|
||||
store := setupPatternTestStore(t)
|
||||
|
||||
patternStore := NewPatternStore(store)
|
||||
ctx := context.Background()
|
||||
|
||||
// Create patterns with different types and statuses
|
||||
bug := createTestPattern("Bug")
|
||||
bug.Type = models.PatternTypeBug
|
||||
bug.Frequency = 5
|
||||
|
||||
refactor := createTestPattern("Refactor")
|
||||
refactor.Type = models.PatternTypeRefactor
|
||||
refactor.Frequency = 3
|
||||
|
||||
deprecated := createTestPattern("Deprecated")
|
||||
deprecated.Type = models.PatternTypeArchitecture
|
||||
deprecated.Status = models.PatternStatusDeprecated
|
||||
|
||||
patternStore.StorePattern(ctx, bug)
|
||||
patternStore.StorePattern(ctx, refactor)
|
||||
patternStore.StorePattern(ctx, deprecated)
|
||||
|
||||
// Get stats
|
||||
stats, err := patternStore.GetPatternStats(ctx)
|
||||
if err != nil {
|
||||
t.Fatalf("GetPatternStats() error = %v", err)
|
||||
}
|
||||
|
||||
if stats.Total != 3 {
|
||||
t.Errorf("Expected total 3, got %d", stats.Total)
|
||||
}
|
||||
if stats.Active != 2 {
|
||||
t.Errorf("Expected 2 active, got %d", stats.Active)
|
||||
}
|
||||
if stats.Deprecated != 1 {
|
||||
t.Errorf("Expected 1 deprecated, got %d", stats.Deprecated)
|
||||
}
|
||||
if stats.Bugs != 1 {
|
||||
t.Errorf("Expected 1 bug, got %d", stats.Bugs)
|
||||
}
|
||||
if stats.Refactors != 1 {
|
||||
t.Errorf("Expected 1 refactor, got %d", stats.Refactors)
|
||||
}
|
||||
if stats.TotalOccurrences != 9 { // 5 + 3 + 1
|
||||
t.Errorf("Expected 9 total occurrences, got %d", stats.TotalOccurrences)
|
||||
}
|
||||
}
|
||||
|
||||
func TestPatternStore_CleanupCallback(t *testing.T) {
|
||||
store := setupPatternTestStore(t)
|
||||
|
||||
patternStore := NewPatternStore(store)
|
||||
ctx := context.Background()
|
||||
|
||||
var deletedIDs []int64
|
||||
patternStore.SetCleanupFunc(func(ctx context.Context, ids []int64) {
|
||||
deletedIDs = ids
|
||||
})
|
||||
|
||||
// Create and delete pattern
|
||||
pattern := createTestPattern("Cleanup Test")
|
||||
id, _ := patternStore.StorePattern(ctx, pattern)
|
||||
|
||||
patternStore.DeletePattern(ctx, id)
|
||||
|
||||
if len(deletedIDs) != 1 || deletedIDs[0] != id {
|
||||
t.Errorf("Expected cleanup callback with ID %d, got %v", id, deletedIDs)
|
||||
}
|
||||
}
|
||||
|
||||
// Helper function to create a test pattern
|
||||
func createTestPattern(name string) *models.Pattern {
|
||||
now := time.Now()
|
||||
return &models.Pattern{
|
||||
Name: name,
|
||||
Type: models.PatternTypeBug,
|
||||
Description: sql.NullString{String: "Test description", Valid: true},
|
||||
Signature: []string{"test", "pattern"},
|
||||
Recommendation: sql.NullString{String: "Test recommendation", Valid: true},
|
||||
Frequency: 1,
|
||||
Projects: []string{"test-project"},
|
||||
ObservationIDs: []int64{1},
|
||||
Status: models.PatternStatusActive,
|
||||
Confidence: 0.5,
|
||||
LastSeenAt: now.Format(time.RFC3339),
|
||||
LastSeenEpoch: now.UnixMilli(),
|
||||
CreatedAt: now.Format(time.RFC3339),
|
||||
CreatedAtEpoch: now.UnixMilli(),
|
||||
}
|
||||
}
|
||||
@@ -199,6 +199,28 @@ func (s *PromptStore) GetAllRecentUserPrompts(ctx context.Context, limit int) ([
|
||||
return scanPromptWithSessionRows(rows)
|
||||
}
|
||||
|
||||
// GetAllPrompts retrieves all user prompts (for vector rebuild).
|
||||
func (s *PromptStore) GetAllPrompts(ctx context.Context) ([]*models.UserPromptWithSession, error) {
|
||||
const query = `
|
||||
SELECT up.id, up.claude_session_id, up.prompt_number, up.prompt_text,
|
||||
COALESCE(up.matched_observations, 0) as matched_observations,
|
||||
up.created_at, up.created_at_epoch,
|
||||
COALESCE(s.project, '') as project,
|
||||
COALESCE(s.sdk_session_id, '') as sdk_session_id
|
||||
FROM user_prompts up
|
||||
LEFT JOIN sdk_sessions s ON up.claude_session_id = s.claude_session_id
|
||||
ORDER BY up.id
|
||||
`
|
||||
|
||||
rows, err := s.store.QueryContext(ctx, query)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer rows.Close()
|
||||
|
||||
return scanPromptWithSessionRows(rows)
|
||||
}
|
||||
|
||||
// FindRecentPromptByText finds a prompt with the same text for a session within the last few seconds.
|
||||
// This is used to detect duplicate hook invocations.
|
||||
// Returns (promptID, promptNumber, found).
|
||||
|
||||
@@ -0,0 +1,377 @@
|
||||
// Package sqlite provides SQLite database operations for claude-mnemonic.
|
||||
package sqlite
|
||||
|
||||
import (
|
||||
"context"
|
||||
"database/sql"
|
||||
|
||||
"github.com/lukaszraczylo/claude-mnemonic/pkg/models"
|
||||
)
|
||||
|
||||
// RelationStore provides relation-related database operations.
|
||||
type RelationStore struct {
|
||||
store *Store
|
||||
}
|
||||
|
||||
// NewRelationStore creates a new relation store.
|
||||
func NewRelationStore(store *Store) *RelationStore {
|
||||
return &RelationStore{store: store}
|
||||
}
|
||||
|
||||
// StoreRelation stores a new observation relation.
|
||||
// Uses INSERT OR IGNORE to handle duplicate (source_id, target_id, relation_type) combinations.
|
||||
func (s *RelationStore) StoreRelation(ctx context.Context, relation *models.ObservationRelation) (int64, error) {
|
||||
const query = `
|
||||
INSERT OR IGNORE INTO observation_relations
|
||||
(source_id, target_id, relation_type, confidence, detection_source, reason, created_at, created_at_epoch)
|
||||
VALUES (?, ?, ?, ?, ?, ?, ?, ?)
|
||||
`
|
||||
|
||||
result, err := s.store.ExecContext(ctx, query,
|
||||
relation.SourceID, relation.TargetID,
|
||||
string(relation.RelationType), relation.Confidence,
|
||||
string(relation.DetectionSource), relation.Reason,
|
||||
relation.CreatedAt, relation.CreatedAtEpoch,
|
||||
)
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
|
||||
return result.LastInsertId()
|
||||
}
|
||||
|
||||
// StoreRelations stores multiple relations in a single transaction.
|
||||
func (s *RelationStore) StoreRelations(ctx context.Context, relations []*models.ObservationRelation) error {
|
||||
if len(relations) == 0 {
|
||||
return nil
|
||||
}
|
||||
|
||||
tx, err := s.store.db.BeginTx(ctx, nil)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
defer func() {
|
||||
if err != nil {
|
||||
_ = tx.Rollback()
|
||||
}
|
||||
}()
|
||||
|
||||
const query = `
|
||||
INSERT OR IGNORE INTO observation_relations
|
||||
(source_id, target_id, relation_type, confidence, detection_source, reason, created_at, created_at_epoch)
|
||||
VALUES (?, ?, ?, ?, ?, ?, ?, ?)
|
||||
`
|
||||
|
||||
stmt, err := tx.PrepareContext(ctx, query)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
defer stmt.Close()
|
||||
|
||||
for _, rel := range relations {
|
||||
_, err = stmt.ExecContext(ctx,
|
||||
rel.SourceID, rel.TargetID,
|
||||
string(rel.RelationType), rel.Confidence,
|
||||
string(rel.DetectionSource), rel.Reason,
|
||||
rel.CreatedAt, rel.CreatedAtEpoch,
|
||||
)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
return tx.Commit()
|
||||
}
|
||||
|
||||
// GetRelationsByObservationID retrieves all relations involving an observation (as source or target).
|
||||
func (s *RelationStore) GetRelationsByObservationID(ctx context.Context, obsID int64) ([]*models.ObservationRelation, error) {
|
||||
const query = `
|
||||
SELECT id, source_id, target_id, relation_type, confidence, detection_source, reason,
|
||||
created_at, created_at_epoch
|
||||
FROM observation_relations
|
||||
WHERE source_id = ? OR target_id = ?
|
||||
ORDER BY confidence DESC, created_at_epoch DESC
|
||||
`
|
||||
|
||||
rows, err := s.store.QueryContext(ctx, query, obsID, obsID)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer rows.Close()
|
||||
|
||||
return s.scanRelationRows(rows)
|
||||
}
|
||||
|
||||
// GetOutgoingRelations retrieves relations where the observation is the source.
|
||||
func (s *RelationStore) GetOutgoingRelations(ctx context.Context, obsID int64) ([]*models.ObservationRelation, error) {
|
||||
const query = `
|
||||
SELECT id, source_id, target_id, relation_type, confidence, detection_source, reason,
|
||||
created_at, created_at_epoch
|
||||
FROM observation_relations
|
||||
WHERE source_id = ?
|
||||
ORDER BY confidence DESC, created_at_epoch DESC
|
||||
`
|
||||
|
||||
rows, err := s.store.QueryContext(ctx, query, obsID)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer rows.Close()
|
||||
|
||||
return s.scanRelationRows(rows)
|
||||
}
|
||||
|
||||
// GetIncomingRelations retrieves relations where the observation is the target.
|
||||
func (s *RelationStore) GetIncomingRelations(ctx context.Context, obsID int64) ([]*models.ObservationRelation, error) {
|
||||
const query = `
|
||||
SELECT id, source_id, target_id, relation_type, confidence, detection_source, reason,
|
||||
created_at, created_at_epoch
|
||||
FROM observation_relations
|
||||
WHERE target_id = ?
|
||||
ORDER BY confidence DESC, created_at_epoch DESC
|
||||
`
|
||||
|
||||
rows, err := s.store.QueryContext(ctx, query, obsID)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer rows.Close()
|
||||
|
||||
return s.scanRelationRows(rows)
|
||||
}
|
||||
|
||||
// GetRelationsByType retrieves all relations of a specific type.
|
||||
func (s *RelationStore) GetRelationsByType(ctx context.Context, relationType models.RelationType, limit int) ([]*models.ObservationRelation, error) {
|
||||
const query = `
|
||||
SELECT id, source_id, target_id, relation_type, confidence, detection_source, reason,
|
||||
created_at, created_at_epoch
|
||||
FROM observation_relations
|
||||
WHERE relation_type = ?
|
||||
ORDER BY confidence DESC, created_at_epoch DESC
|
||||
LIMIT ?
|
||||
`
|
||||
|
||||
rows, err := s.store.QueryContext(ctx, query, string(relationType), limit)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer rows.Close()
|
||||
|
||||
return s.scanRelationRows(rows)
|
||||
}
|
||||
|
||||
// GetRelationsWithDetails retrieves relations with observation titles for display.
|
||||
func (s *RelationStore) GetRelationsWithDetails(ctx context.Context, obsID int64) ([]*models.RelationWithDetails, error) {
|
||||
const query = `
|
||||
SELECT r.id, r.source_id, r.target_id, r.relation_type, r.confidence, r.detection_source, r.reason,
|
||||
r.created_at, r.created_at_epoch,
|
||||
COALESCE(src.title, '') as source_title,
|
||||
COALESCE(tgt.title, '') as target_title,
|
||||
src.type as source_type,
|
||||
tgt.type as target_type
|
||||
FROM observation_relations r
|
||||
JOIN observations src ON src.id = r.source_id
|
||||
JOIN observations tgt ON tgt.id = r.target_id
|
||||
WHERE r.source_id = ? OR r.target_id = ?
|
||||
ORDER BY r.confidence DESC, r.created_at_epoch DESC
|
||||
`
|
||||
|
||||
rows, err := s.store.QueryContext(ctx, query, obsID, obsID)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer rows.Close()
|
||||
|
||||
var results []*models.RelationWithDetails
|
||||
for rows.Next() {
|
||||
var r models.ObservationRelation
|
||||
var rwd models.RelationWithDetails
|
||||
var reason sql.NullString
|
||||
if err := rows.Scan(
|
||||
&r.ID, &r.SourceID, &r.TargetID,
|
||||
&r.RelationType, &r.Confidence, &r.DetectionSource, &reason,
|
||||
&r.CreatedAt, &r.CreatedAtEpoch,
|
||||
&rwd.SourceTitle, &rwd.TargetTitle,
|
||||
&rwd.SourceType, &rwd.TargetType,
|
||||
); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if reason.Valid {
|
||||
r.Reason = reason.String
|
||||
}
|
||||
rwd.Relation = &r
|
||||
results = append(results, &rwd)
|
||||
}
|
||||
return results, rows.Err()
|
||||
}
|
||||
|
||||
// GetRelationGraph retrieves a relation graph centered on an observation.
|
||||
// This returns all observations within N hops from the center.
|
||||
func (s *RelationStore) GetRelationGraph(ctx context.Context, centerID int64, maxDepth int) (*models.RelationGraph, error) {
|
||||
// Get all relations involving the center observation
|
||||
relations, err := s.GetRelationsWithDetails(ctx, centerID)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
graph := &models.RelationGraph{
|
||||
CenterID: centerID,
|
||||
Relations: relations,
|
||||
}
|
||||
|
||||
// If depth > 1, recursively get relations for connected observations
|
||||
if maxDepth > 1 {
|
||||
visited := map[int64]bool{centerID: true}
|
||||
toVisit := make([]int64, 0)
|
||||
|
||||
// Collect IDs of directly connected observations
|
||||
for _, r := range relations {
|
||||
if !visited[r.Relation.SourceID] {
|
||||
toVisit = append(toVisit, r.Relation.SourceID)
|
||||
visited[r.Relation.SourceID] = true
|
||||
}
|
||||
if !visited[r.Relation.TargetID] {
|
||||
toVisit = append(toVisit, r.Relation.TargetID)
|
||||
visited[r.Relation.TargetID] = true
|
||||
}
|
||||
}
|
||||
|
||||
// Get relations for connected observations (depth - 1)
|
||||
for depth := 1; depth < maxDepth && len(toVisit) > 0; depth++ {
|
||||
nextLevel := make([]int64, 0)
|
||||
for _, obsID := range toVisit {
|
||||
moreRelations, err := s.GetRelationsWithDetails(ctx, obsID)
|
||||
if err != nil {
|
||||
continue
|
||||
}
|
||||
for _, r := range moreRelations {
|
||||
// Avoid duplicates
|
||||
exists := false
|
||||
for _, existing := range graph.Relations {
|
||||
if existing.Relation.ID == r.Relation.ID {
|
||||
exists = true
|
||||
break
|
||||
}
|
||||
}
|
||||
if !exists {
|
||||
graph.Relations = append(graph.Relations, r)
|
||||
}
|
||||
|
||||
// Queue next level
|
||||
if !visited[r.Relation.SourceID] {
|
||||
nextLevel = append(nextLevel, r.Relation.SourceID)
|
||||
visited[r.Relation.SourceID] = true
|
||||
}
|
||||
if !visited[r.Relation.TargetID] {
|
||||
nextLevel = append(nextLevel, r.Relation.TargetID)
|
||||
visited[r.Relation.TargetID] = true
|
||||
}
|
||||
}
|
||||
}
|
||||
toVisit = nextLevel
|
||||
}
|
||||
}
|
||||
|
||||
return graph, nil
|
||||
}
|
||||
|
||||
// DeleteRelationsByObservationID deletes all relations involving an observation.
|
||||
// Called when an observation is deleted.
|
||||
func (s *RelationStore) DeleteRelationsByObservationID(ctx context.Context, obsID int64) error {
|
||||
const query = `DELETE FROM observation_relations WHERE source_id = ? OR target_id = ?`
|
||||
_, err := s.store.ExecContext(ctx, query, obsID, obsID)
|
||||
return err
|
||||
}
|
||||
|
||||
// GetRelationCount returns the count of relations for an observation.
|
||||
func (s *RelationStore) GetRelationCount(ctx context.Context, obsID int64) (int, error) {
|
||||
const query = `
|
||||
SELECT COUNT(*) FROM observation_relations
|
||||
WHERE source_id = ? OR target_id = ?
|
||||
`
|
||||
var count int
|
||||
err := s.store.QueryRowContext(ctx, query, obsID, obsID).Scan(&count)
|
||||
return count, err
|
||||
}
|
||||
|
||||
// GetTotalRelationCount returns the total count of all relations.
|
||||
func (s *RelationStore) GetTotalRelationCount(ctx context.Context) (int, error) {
|
||||
const query = `SELECT COUNT(*) FROM observation_relations`
|
||||
var count int
|
||||
err := s.store.QueryRowContext(ctx, query).Scan(&count)
|
||||
return count, err
|
||||
}
|
||||
|
||||
// GetHighConfidenceRelations retrieves relations with confidence above threshold.
|
||||
func (s *RelationStore) GetHighConfidenceRelations(ctx context.Context, minConfidence float64, limit int) ([]*models.ObservationRelation, error) {
|
||||
const query = `
|
||||
SELECT id, source_id, target_id, relation_type, confidence, detection_source, reason,
|
||||
created_at, created_at_epoch
|
||||
FROM observation_relations
|
||||
WHERE confidence >= ?
|
||||
ORDER BY confidence DESC, created_at_epoch DESC
|
||||
LIMIT ?
|
||||
`
|
||||
|
||||
rows, err := s.store.QueryContext(ctx, query, minConfidence, limit)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer rows.Close()
|
||||
|
||||
return s.scanRelationRows(rows)
|
||||
}
|
||||
|
||||
// UpdateRelationConfidence updates the confidence of a relation.
|
||||
func (s *RelationStore) UpdateRelationConfidence(ctx context.Context, relationID int64, newConfidence float64) error {
|
||||
const query = `UPDATE observation_relations SET confidence = ? WHERE id = ?`
|
||||
_, err := s.store.ExecContext(ctx, query, newConfidence, relationID)
|
||||
return err
|
||||
}
|
||||
|
||||
// GetRelatedObservationIDs returns IDs of observations related to the given one.
|
||||
// This is useful for expanding search results.
|
||||
func (s *RelationStore) GetRelatedObservationIDs(ctx context.Context, obsID int64, minConfidence float64) ([]int64, error) {
|
||||
const query = `
|
||||
SELECT DISTINCT CASE WHEN source_id = ? THEN target_id ELSE source_id END as related_id
|
||||
FROM observation_relations
|
||||
WHERE (source_id = ? OR target_id = ?) AND confidence >= ?
|
||||
`
|
||||
|
||||
rows, err := s.store.QueryContext(ctx, query, obsID, obsID, obsID, minConfidence)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer rows.Close()
|
||||
|
||||
var ids []int64
|
||||
for rows.Next() {
|
||||
var id int64
|
||||
if err := rows.Scan(&id); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
ids = append(ids, id)
|
||||
}
|
||||
return ids, rows.Err()
|
||||
}
|
||||
|
||||
// scanRelationRows scans multiple relations from rows.
|
||||
func (s *RelationStore) scanRelationRows(rows *sql.Rows) ([]*models.ObservationRelation, error) {
|
||||
var relations []*models.ObservationRelation
|
||||
for rows.Next() {
|
||||
var r models.ObservationRelation
|
||||
var reason sql.NullString
|
||||
if err := rows.Scan(
|
||||
&r.ID, &r.SourceID, &r.TargetID,
|
||||
&r.RelationType, &r.Confidence, &r.DetectionSource, &reason,
|
||||
&r.CreatedAt, &r.CreatedAtEpoch,
|
||||
); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if reason.Valid {
|
||||
r.Reason = reason.String
|
||||
}
|
||||
relations = append(relations, &r)
|
||||
}
|
||||
return relations, rows.Err()
|
||||
}
|
||||
@@ -0,0 +1,324 @@
|
||||
// Package sqlite provides SQLite database operations for claude-mnemonic.
|
||||
package sqlite
|
||||
|
||||
import (
|
||||
"context"
|
||||
"database/sql"
|
||||
"time"
|
||||
|
||||
"github.com/lukaszraczylo/claude-mnemonic/pkg/models"
|
||||
)
|
||||
|
||||
// UpdateObservationFeedback updates the user feedback for an observation.
|
||||
// Feedback values: -1 (thumbs down), 0 (neutral), 1 (thumbs up).
|
||||
func (s *ObservationStore) UpdateObservationFeedback(ctx context.Context, id int64, feedback int) error {
|
||||
const query = `
|
||||
UPDATE observations
|
||||
SET user_feedback = ?, score_updated_at_epoch = ?
|
||||
WHERE id = ?
|
||||
`
|
||||
_, err := s.store.ExecContext(ctx, query, feedback, time.Now().UnixMilli(), id)
|
||||
return err
|
||||
}
|
||||
|
||||
// IncrementRetrievalCount increments the retrieval counter for the given observation IDs.
|
||||
// This is called when observations are returned in search results.
|
||||
func (s *ObservationStore) IncrementRetrievalCount(ctx context.Context, ids []int64) error {
|
||||
if len(ids) == 0 {
|
||||
return nil
|
||||
}
|
||||
|
||||
now := time.Now().UnixMilli()
|
||||
|
||||
// Build query with placeholders
|
||||
// #nosec G202 -- query uses parameterized placeholders, not user input
|
||||
query := `
|
||||
UPDATE observations
|
||||
SET retrieval_count = COALESCE(retrieval_count, 0) + 1,
|
||||
last_retrieved_at_epoch = ?
|
||||
WHERE id IN (?` + repeatPlaceholders(len(ids)-1) + `)
|
||||
`
|
||||
|
||||
args := make([]interface{}, 0, len(ids)+1)
|
||||
args = append(args, now)
|
||||
for _, id := range ids {
|
||||
args = append(args, id)
|
||||
}
|
||||
|
||||
_, err := s.store.db.ExecContext(ctx, query, args...)
|
||||
return err
|
||||
}
|
||||
|
||||
// UpdateImportanceScore updates the importance score for a single observation.
|
||||
func (s *ObservationStore) UpdateImportanceScore(ctx context.Context, id int64, score float64) error {
|
||||
const query = `
|
||||
UPDATE observations
|
||||
SET importance_score = ?, score_updated_at_epoch = ?
|
||||
WHERE id = ?
|
||||
`
|
||||
_, err := s.store.ExecContext(ctx, query, score, time.Now().UnixMilli(), id)
|
||||
return err
|
||||
}
|
||||
|
||||
// UpdateImportanceScores bulk updates importance scores for multiple observations.
|
||||
// This is more efficient than individual updates for batch recalculation.
|
||||
func (s *ObservationStore) UpdateImportanceScores(ctx context.Context, scores map[int64]float64) error {
|
||||
if len(scores) == 0 {
|
||||
return nil
|
||||
}
|
||||
|
||||
tx, err := s.store.db.BeginTx(ctx, nil)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
defer tx.Rollback()
|
||||
|
||||
now := time.Now().UnixMilli()
|
||||
stmt, err := tx.PrepareContext(ctx, `
|
||||
UPDATE observations
|
||||
SET importance_score = ?, score_updated_at_epoch = ?
|
||||
WHERE id = ?
|
||||
`)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
defer stmt.Close()
|
||||
|
||||
for id, score := range scores {
|
||||
if _, err := stmt.ExecContext(ctx, score, now, id); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
return tx.Commit()
|
||||
}
|
||||
|
||||
// GetObservationsNeedingScoreUpdate returns observations that need their importance score recalculated.
|
||||
// Returns observations where score_updated_at_epoch is NULL or older than the threshold.
|
||||
func (s *ObservationStore) GetObservationsNeedingScoreUpdate(ctx context.Context, threshold time.Duration, limit int) ([]*models.Observation, error) {
|
||||
cutoff := time.Now().Add(-threshold).UnixMilli()
|
||||
|
||||
query := `SELECT ` + observationColumns + `
|
||||
FROM observations
|
||||
WHERE score_updated_at_epoch IS NULL OR score_updated_at_epoch < ?
|
||||
ORDER BY created_at_epoch DESC
|
||||
LIMIT ?
|
||||
`
|
||||
|
||||
rows, err := s.store.QueryContext(ctx, query, cutoff, limit)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer rows.Close()
|
||||
|
||||
return scanObservationRows(rows)
|
||||
}
|
||||
|
||||
// GetConceptWeights returns all concept weights from the database.
|
||||
func (s *ObservationStore) GetConceptWeights(ctx context.Context) (map[string]float64, error) {
|
||||
const query = `SELECT concept, weight FROM concept_weights`
|
||||
|
||||
rows, err := s.store.QueryContext(ctx, query)
|
||||
if err != nil {
|
||||
// Table might not exist in older databases
|
||||
if err == sql.ErrNoRows {
|
||||
return models.DefaultConceptWeights, nil
|
||||
}
|
||||
return nil, err
|
||||
}
|
||||
defer rows.Close()
|
||||
|
||||
weights := make(map[string]float64)
|
||||
for rows.Next() {
|
||||
var concept string
|
||||
var weight float64
|
||||
if err := rows.Scan(&concept, &weight); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
weights[concept] = weight
|
||||
}
|
||||
|
||||
if err := rows.Err(); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// If no weights found, use defaults
|
||||
if len(weights) == 0 {
|
||||
return models.DefaultConceptWeights, nil
|
||||
}
|
||||
|
||||
return weights, nil
|
||||
}
|
||||
|
||||
// UpdateConceptWeight updates a single concept weight.
|
||||
func (s *ObservationStore) UpdateConceptWeight(ctx context.Context, concept string, weight float64) error {
|
||||
const query = `
|
||||
INSERT INTO concept_weights (concept, weight, updated_at)
|
||||
VALUES (?, ?, datetime('now'))
|
||||
ON CONFLICT(concept) DO UPDATE SET weight = excluded.weight, updated_at = excluded.updated_at
|
||||
`
|
||||
_, err := s.store.ExecContext(ctx, query, concept, weight)
|
||||
return err
|
||||
}
|
||||
|
||||
// UpdateConceptWeights bulk updates multiple concept weights.
|
||||
func (s *ObservationStore) UpdateConceptWeights(ctx context.Context, weights map[string]float64) error {
|
||||
if len(weights) == 0 {
|
||||
return nil
|
||||
}
|
||||
|
||||
tx, err := s.store.db.BeginTx(ctx, nil)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
defer tx.Rollback()
|
||||
|
||||
stmt, err := tx.PrepareContext(ctx, `
|
||||
INSERT INTO concept_weights (concept, weight, updated_at)
|
||||
VALUES (?, ?, datetime('now'))
|
||||
ON CONFLICT(concept) DO UPDATE SET weight = excluded.weight, updated_at = excluded.updated_at
|
||||
`)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
defer stmt.Close()
|
||||
|
||||
for concept, weight := range weights {
|
||||
if _, err := stmt.ExecContext(ctx, concept, weight); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
return tx.Commit()
|
||||
}
|
||||
|
||||
// GetObservationFeedbackStats returns statistics about user feedback.
|
||||
func (s *ObservationStore) GetObservationFeedbackStats(ctx context.Context, project string) (*FeedbackStats, error) {
|
||||
var query string
|
||||
var args []interface{}
|
||||
|
||||
if project == "" {
|
||||
query = `
|
||||
SELECT
|
||||
COUNT(*) as total,
|
||||
COALESCE(SUM(CASE WHEN user_feedback = 1 THEN 1 ELSE 0 END), 0) as positive,
|
||||
COALESCE(SUM(CASE WHEN user_feedback = -1 THEN 1 ELSE 0 END), 0) as negative,
|
||||
COALESCE(SUM(CASE WHEN user_feedback = 0 THEN 1 ELSE 0 END), 0) as neutral,
|
||||
COALESCE(AVG(COALESCE(importance_score, 1.0)), 0) as avg_score,
|
||||
COALESCE(AVG(COALESCE(retrieval_count, 0)), 0) as avg_retrieval
|
||||
FROM observations
|
||||
`
|
||||
} else {
|
||||
query = `
|
||||
SELECT
|
||||
COUNT(*) as total,
|
||||
COALESCE(SUM(CASE WHEN user_feedback = 1 THEN 1 ELSE 0 END), 0) as positive,
|
||||
COALESCE(SUM(CASE WHEN user_feedback = -1 THEN 1 ELSE 0 END), 0) as negative,
|
||||
COALESCE(SUM(CASE WHEN user_feedback = 0 THEN 1 ELSE 0 END), 0) as neutral,
|
||||
COALESCE(AVG(COALESCE(importance_score, 1.0)), 0) as avg_score,
|
||||
COALESCE(AVG(COALESCE(retrieval_count, 0)), 0) as avg_retrieval
|
||||
FROM observations
|
||||
WHERE project = ? OR scope = 'global'
|
||||
`
|
||||
args = append(args, project)
|
||||
}
|
||||
|
||||
var stats FeedbackStats
|
||||
err := s.store.QueryRowContext(ctx, query, args...).Scan(
|
||||
&stats.Total,
|
||||
&stats.Positive,
|
||||
&stats.Negative,
|
||||
&stats.Neutral,
|
||||
&stats.AvgScore,
|
||||
&stats.AvgRetrieval,
|
||||
)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return &stats, nil
|
||||
}
|
||||
|
||||
// FeedbackStats contains statistics about observation feedback and scoring.
|
||||
type FeedbackStats struct {
|
||||
Total int `json:"total"`
|
||||
Positive int `json:"positive"`
|
||||
Negative int `json:"negative"`
|
||||
Neutral int `json:"neutral"`
|
||||
AvgScore float64 `json:"avg_score"`
|
||||
AvgRetrieval float64 `json:"avg_retrieval"`
|
||||
}
|
||||
|
||||
// GetTopScoringObservations returns the highest-scoring observations.
|
||||
func (s *ObservationStore) GetTopScoringObservations(ctx context.Context, project string, limit int) ([]*models.Observation, error) {
|
||||
var query string
|
||||
var args []interface{}
|
||||
|
||||
if project == "" {
|
||||
query = `SELECT ` + observationColumns + `
|
||||
FROM observations
|
||||
ORDER BY COALESCE(importance_score, 1.0) DESC, created_at_epoch DESC
|
||||
LIMIT ?
|
||||
`
|
||||
args = append(args, limit)
|
||||
} else {
|
||||
query = `SELECT ` + observationColumns + `
|
||||
FROM observations
|
||||
WHERE project = ? OR scope = 'global'
|
||||
ORDER BY COALESCE(importance_score, 1.0) DESC, created_at_epoch DESC
|
||||
LIMIT ?
|
||||
`
|
||||
args = append(args, project, limit)
|
||||
}
|
||||
|
||||
rows, err := s.store.QueryContext(ctx, query, args...)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer rows.Close()
|
||||
|
||||
return scanObservationRows(rows)
|
||||
}
|
||||
|
||||
// GetMostRetrievedObservations returns the most frequently retrieved observations.
|
||||
func (s *ObservationStore) GetMostRetrievedObservations(ctx context.Context, project string, limit int) ([]*models.Observation, error) {
|
||||
var query string
|
||||
var args []interface{}
|
||||
|
||||
if project == "" {
|
||||
query = `SELECT ` + observationColumns + `
|
||||
FROM observations
|
||||
WHERE retrieval_count > 0
|
||||
ORDER BY retrieval_count DESC, created_at_epoch DESC
|
||||
LIMIT ?
|
||||
`
|
||||
args = append(args, limit)
|
||||
} else {
|
||||
query = `SELECT ` + observationColumns + `
|
||||
FROM observations
|
||||
WHERE (project = ? OR scope = 'global') AND retrieval_count > 0
|
||||
ORDER BY retrieval_count DESC, created_at_epoch DESC
|
||||
LIMIT ?
|
||||
`
|
||||
args = append(args, project, limit)
|
||||
}
|
||||
|
||||
rows, err := s.store.QueryContext(ctx, query, args...)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer rows.Close()
|
||||
|
||||
return scanObservationRows(rows)
|
||||
}
|
||||
|
||||
// ResetObservationScores resets all observation scores to their default values.
|
||||
// This is useful for testing or when changing the scoring algorithm.
|
||||
func (s *ObservationStore) ResetObservationScores(ctx context.Context) error {
|
||||
const query = `
|
||||
UPDATE observations
|
||||
SET importance_score = 1.0, score_updated_at_epoch = NULL
|
||||
`
|
||||
_, err := s.store.ExecContext(ctx, query)
|
||||
return err
|
||||
}
|
||||
@@ -0,0 +1,698 @@
|
||||
// Package sqlite provides SQLite database operations for claude-mnemonic.
|
||||
package sqlite
|
||||
|
||||
import (
|
||||
"context"
|
||||
"database/sql"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/lukaszraczylo/claude-mnemonic/pkg/models"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
"github.com/stretchr/testify/suite"
|
||||
)
|
||||
|
||||
// testScoringObservationStore creates an ObservationStore with scoring columns for testing.
|
||||
func testScoringObservationStore(t *testing.T) (*ObservationStore, *Store, func()) {
|
||||
t.Helper()
|
||||
|
||||
db, _, cleanup := testDB(t)
|
||||
createBaseTables(t, db)
|
||||
createConceptWeightsTable(t, db)
|
||||
|
||||
// Add importance index if not exists (columns already in createBaseTables)
|
||||
if _, err := db.Exec(`CREATE INDEX IF NOT EXISTS idx_observations_importance ON observations(importance_score DESC, created_at_epoch DESC)`); err != nil {
|
||||
t.Fatalf("create importance index: %v", err)
|
||||
}
|
||||
|
||||
store := newStoreFromDB(db)
|
||||
obsStore := NewObservationStore(store)
|
||||
|
||||
return obsStore, store, cleanup
|
||||
}
|
||||
|
||||
// createConceptWeightsTable creates the concept_weights table for testing.
|
||||
func createConceptWeightsTable(t *testing.T, db *sql.DB) {
|
||||
t.Helper()
|
||||
|
||||
_, err := db.Exec(`
|
||||
CREATE TABLE IF NOT EXISTS concept_weights (
|
||||
concept TEXT PRIMARY KEY,
|
||||
weight REAL NOT NULL DEFAULT 0.1,
|
||||
updated_at TEXT NOT NULL
|
||||
)
|
||||
`)
|
||||
if err != nil {
|
||||
t.Fatalf("create concept_weights: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
// ScoringStoreSuite is a test suite for scoring-related database operations.
|
||||
type ScoringStoreSuite struct {
|
||||
suite.Suite
|
||||
obsStore *ObservationStore
|
||||
store *Store
|
||||
cleanup func()
|
||||
ctx context.Context
|
||||
}
|
||||
|
||||
func (s *ScoringStoreSuite) SetupTest() {
|
||||
s.obsStore, s.store, s.cleanup = testScoringObservationStore(s.T())
|
||||
s.ctx = context.Background()
|
||||
}
|
||||
|
||||
func (s *ScoringStoreSuite) TearDownTest() {
|
||||
if s.cleanup != nil {
|
||||
s.cleanup()
|
||||
}
|
||||
}
|
||||
|
||||
func TestScoringStoreSuite(t *testing.T) {
|
||||
suite.Run(t, new(ScoringStoreSuite))
|
||||
}
|
||||
|
||||
// =============================================================================
|
||||
// FEEDBACK TESTS
|
||||
// =============================================================================
|
||||
|
||||
func (s *ScoringStoreSuite) TestUpdateObservationFeedback_Positive() {
|
||||
// Create observation
|
||||
obs := &models.ParsedObservation{
|
||||
Type: models.ObsTypeBugfix,
|
||||
Title: "Test feedback",
|
||||
}
|
||||
id, _, err := s.obsStore.StoreObservation(s.ctx, "session-1", "project-a", obs, 1, 100)
|
||||
s.NoError(err)
|
||||
|
||||
// Update feedback to positive
|
||||
err = s.obsStore.UpdateObservationFeedback(s.ctx, id, 1)
|
||||
s.NoError(err)
|
||||
|
||||
// Verify
|
||||
retrieved, err := s.obsStore.GetObservationByID(s.ctx, id)
|
||||
s.NoError(err)
|
||||
s.Equal(1, retrieved.UserFeedback)
|
||||
s.True(retrieved.ScoreUpdatedAt.Valid)
|
||||
}
|
||||
|
||||
func (s *ScoringStoreSuite) TestUpdateObservationFeedback_Negative() {
|
||||
obs := &models.ParsedObservation{
|
||||
Type: models.ObsTypeDiscovery,
|
||||
}
|
||||
id, _, err := s.obsStore.StoreObservation(s.ctx, "session-1", "project-a", obs, 1, 100)
|
||||
s.NoError(err)
|
||||
|
||||
err = s.obsStore.UpdateObservationFeedback(s.ctx, id, -1)
|
||||
s.NoError(err)
|
||||
|
||||
retrieved, err := s.obsStore.GetObservationByID(s.ctx, id)
|
||||
s.NoError(err)
|
||||
s.Equal(-1, retrieved.UserFeedback)
|
||||
}
|
||||
|
||||
func (s *ScoringStoreSuite) TestUpdateObservationFeedback_Neutral() {
|
||||
obs := &models.ParsedObservation{
|
||||
Type: models.ObsTypeChange,
|
||||
}
|
||||
id, _, err := s.obsStore.StoreObservation(s.ctx, "session-1", "project-a", obs, 1, 100)
|
||||
s.NoError(err)
|
||||
|
||||
// First set to positive
|
||||
err = s.obsStore.UpdateObservationFeedback(s.ctx, id, 1)
|
||||
s.NoError(err)
|
||||
|
||||
// Then reset to neutral
|
||||
err = s.obsStore.UpdateObservationFeedback(s.ctx, id, 0)
|
||||
s.NoError(err)
|
||||
|
||||
retrieved, err := s.obsStore.GetObservationByID(s.ctx, id)
|
||||
s.NoError(err)
|
||||
s.Equal(0, retrieved.UserFeedback)
|
||||
}
|
||||
|
||||
func (s *ScoringStoreSuite) TestUpdateObservationFeedback_NonExistent() {
|
||||
// Updating non-existent observation should not fail (just no rows affected)
|
||||
err := s.obsStore.UpdateObservationFeedback(s.ctx, 99999, 1)
|
||||
s.NoError(err)
|
||||
}
|
||||
|
||||
// =============================================================================
|
||||
// RETRIEVAL COUNT TESTS
|
||||
// =============================================================================
|
||||
|
||||
func (s *ScoringStoreSuite) TestIncrementRetrievalCount_Single() {
|
||||
obs := &models.ParsedObservation{
|
||||
Type: models.ObsTypeBugfix,
|
||||
}
|
||||
id, _, err := s.obsStore.StoreObservation(s.ctx, "session-1", "project-a", obs, 1, 100)
|
||||
s.NoError(err)
|
||||
|
||||
err = s.obsStore.IncrementRetrievalCount(s.ctx, []int64{id})
|
||||
s.NoError(err)
|
||||
|
||||
retrieved, err := s.obsStore.GetObservationByID(s.ctx, id)
|
||||
s.NoError(err)
|
||||
s.Equal(1, retrieved.RetrievalCount)
|
||||
s.True(retrieved.LastRetrievedAt.Valid)
|
||||
}
|
||||
|
||||
func (s *ScoringStoreSuite) TestIncrementRetrievalCount_Multiple() {
|
||||
var ids []int64
|
||||
for i := 0; i < 3; i++ {
|
||||
obs := &models.ParsedObservation{
|
||||
Type: models.ObsTypeDiscovery,
|
||||
}
|
||||
id, _, err := s.obsStore.StoreObservation(s.ctx, "session-1", "project-a", obs, i, 100)
|
||||
s.NoError(err)
|
||||
ids = append(ids, id)
|
||||
}
|
||||
|
||||
err := s.obsStore.IncrementRetrievalCount(s.ctx, ids)
|
||||
s.NoError(err)
|
||||
|
||||
for _, id := range ids {
|
||||
retrieved, err := s.obsStore.GetObservationByID(s.ctx, id)
|
||||
s.NoError(err)
|
||||
s.Equal(1, retrieved.RetrievalCount)
|
||||
}
|
||||
}
|
||||
|
||||
func (s *ScoringStoreSuite) TestIncrementRetrievalCount_Cumulative() {
|
||||
obs := &models.ParsedObservation{
|
||||
Type: models.ObsTypeBugfix,
|
||||
}
|
||||
id, _, err := s.obsStore.StoreObservation(s.ctx, "session-1", "project-a", obs, 1, 100)
|
||||
s.NoError(err)
|
||||
|
||||
// Increment multiple times
|
||||
for i := 0; i < 5; i++ {
|
||||
err = s.obsStore.IncrementRetrievalCount(s.ctx, []int64{id})
|
||||
s.NoError(err)
|
||||
}
|
||||
|
||||
retrieved, err := s.obsStore.GetObservationByID(s.ctx, id)
|
||||
s.NoError(err)
|
||||
s.Equal(5, retrieved.RetrievalCount)
|
||||
}
|
||||
|
||||
func (s *ScoringStoreSuite) TestIncrementRetrievalCount_Empty() {
|
||||
err := s.obsStore.IncrementRetrievalCount(s.ctx, []int64{})
|
||||
s.NoError(err)
|
||||
}
|
||||
|
||||
// =============================================================================
|
||||
// IMPORTANCE SCORE TESTS
|
||||
// =============================================================================
|
||||
|
||||
func (s *ScoringStoreSuite) TestUpdateImportanceScore_Single() {
|
||||
obs := &models.ParsedObservation{
|
||||
Type: models.ObsTypeBugfix,
|
||||
}
|
||||
id, _, err := s.obsStore.StoreObservation(s.ctx, "session-1", "project-a", obs, 1, 100)
|
||||
s.NoError(err)
|
||||
|
||||
err = s.obsStore.UpdateImportanceScore(s.ctx, id, 1.5)
|
||||
s.NoError(err)
|
||||
|
||||
retrieved, err := s.obsStore.GetObservationByID(s.ctx, id)
|
||||
s.NoError(err)
|
||||
s.InDelta(1.5, retrieved.ImportanceScore, 0.001)
|
||||
}
|
||||
|
||||
func (s *ScoringStoreSuite) TestUpdateImportanceScores_Batch() {
|
||||
var ids []int64
|
||||
for i := 0; i < 5; i++ {
|
||||
obs := &models.ParsedObservation{
|
||||
Type: models.ObsTypeDiscovery,
|
||||
}
|
||||
id, _, err := s.obsStore.StoreObservation(s.ctx, "session-1", "project-a", obs, i, 100)
|
||||
s.NoError(err)
|
||||
ids = append(ids, id)
|
||||
}
|
||||
|
||||
scores := map[int64]float64{
|
||||
ids[0]: 1.5,
|
||||
ids[1]: 0.8,
|
||||
ids[2]: 1.2,
|
||||
ids[3]: 0.5,
|
||||
ids[4]: 2.0,
|
||||
}
|
||||
|
||||
err := s.obsStore.UpdateImportanceScores(s.ctx, scores)
|
||||
s.NoError(err)
|
||||
|
||||
for id, expectedScore := range scores {
|
||||
retrieved, err := s.obsStore.GetObservationByID(s.ctx, id)
|
||||
s.NoError(err)
|
||||
s.InDelta(expectedScore, retrieved.ImportanceScore, 0.001)
|
||||
}
|
||||
}
|
||||
|
||||
func (s *ScoringStoreSuite) TestUpdateImportanceScores_Empty() {
|
||||
err := s.obsStore.UpdateImportanceScores(s.ctx, map[int64]float64{})
|
||||
s.NoError(err)
|
||||
}
|
||||
|
||||
// =============================================================================
|
||||
// OBSERVATIONS NEEDING SCORE UPDATE TESTS
|
||||
// =============================================================================
|
||||
|
||||
func (s *ScoringStoreSuite) TestGetObservationsNeedingScoreUpdate_NeverUpdated() {
|
||||
// Observations without score_updated_at_epoch should need update
|
||||
obs := &models.ParsedObservation{
|
||||
Type: models.ObsTypeBugfix,
|
||||
}
|
||||
id, _, err := s.obsStore.StoreObservation(s.ctx, "session-1", "project-a", obs, 1, 100)
|
||||
s.NoError(err)
|
||||
|
||||
observations, err := s.obsStore.GetObservationsNeedingScoreUpdate(s.ctx, 6*time.Hour, 100)
|
||||
s.NoError(err)
|
||||
s.Len(observations, 1)
|
||||
s.Equal(id, observations[0].ID)
|
||||
}
|
||||
|
||||
func (s *ScoringStoreSuite) TestGetObservationsNeedingScoreUpdate_RecentlyUpdated() {
|
||||
obs := &models.ParsedObservation{
|
||||
Type: models.ObsTypeBugfix,
|
||||
}
|
||||
id, _, err := s.obsStore.StoreObservation(s.ctx, "session-1", "project-a", obs, 1, 100)
|
||||
s.NoError(err)
|
||||
|
||||
// Update score (this sets score_updated_at_epoch)
|
||||
err = s.obsStore.UpdateImportanceScore(s.ctx, id, 1.5)
|
||||
s.NoError(err)
|
||||
|
||||
// Should not need update (just updated)
|
||||
observations, err := s.obsStore.GetObservationsNeedingScoreUpdate(s.ctx, 6*time.Hour, 100)
|
||||
s.NoError(err)
|
||||
s.Empty(observations)
|
||||
}
|
||||
|
||||
func (s *ScoringStoreSuite) TestGetObservationsNeedingScoreUpdate_Limit() {
|
||||
// Create 10 observations
|
||||
for i := 0; i < 10; i++ {
|
||||
obs := &models.ParsedObservation{
|
||||
Type: models.ObsTypeDiscovery,
|
||||
}
|
||||
_, _, err := s.obsStore.StoreObservation(s.ctx, "session-1", "project-a", obs, i, 100)
|
||||
s.NoError(err)
|
||||
}
|
||||
|
||||
// Request only 5
|
||||
observations, err := s.obsStore.GetObservationsNeedingScoreUpdate(s.ctx, 6*time.Hour, 5)
|
||||
s.NoError(err)
|
||||
s.Len(observations, 5)
|
||||
}
|
||||
|
||||
// =============================================================================
|
||||
// CONCEPT WEIGHTS TESTS
|
||||
// =============================================================================
|
||||
|
||||
func (s *ScoringStoreSuite) TestGetConceptWeights_Empty() {
|
||||
weights, err := s.obsStore.GetConceptWeights(s.ctx)
|
||||
s.NoError(err)
|
||||
s.Equal(models.DefaultConceptWeights, weights)
|
||||
}
|
||||
|
||||
func (s *ScoringStoreSuite) TestUpdateConceptWeight_NewConcept() {
|
||||
err := s.obsStore.UpdateConceptWeight(s.ctx, "new-concept", 0.42)
|
||||
s.NoError(err)
|
||||
|
||||
weights, err := s.obsStore.GetConceptWeights(s.ctx)
|
||||
s.NoError(err)
|
||||
s.Equal(0.42, weights["new-concept"])
|
||||
}
|
||||
|
||||
func (s *ScoringStoreSuite) TestUpdateConceptWeight_UpdateExisting() {
|
||||
// Insert first
|
||||
err := s.obsStore.UpdateConceptWeight(s.ctx, "test-concept", 0.1)
|
||||
s.NoError(err)
|
||||
|
||||
// Update
|
||||
err = s.obsStore.UpdateConceptWeight(s.ctx, "test-concept", 0.9)
|
||||
s.NoError(err)
|
||||
|
||||
weights, err := s.obsStore.GetConceptWeights(s.ctx)
|
||||
s.NoError(err)
|
||||
s.Equal(0.9, weights["test-concept"])
|
||||
}
|
||||
|
||||
func (s *ScoringStoreSuite) TestUpdateConceptWeights_Batch() {
|
||||
weightsToSet := map[string]float64{
|
||||
"security": 0.5,
|
||||
"performance": 0.3,
|
||||
"testing": 0.2,
|
||||
}
|
||||
|
||||
err := s.obsStore.UpdateConceptWeights(s.ctx, weightsToSet)
|
||||
s.NoError(err)
|
||||
|
||||
retrieved, err := s.obsStore.GetConceptWeights(s.ctx)
|
||||
s.NoError(err)
|
||||
|
||||
for concept, expected := range weightsToSet {
|
||||
s.Equal(expected, retrieved[concept])
|
||||
}
|
||||
}
|
||||
|
||||
func (s *ScoringStoreSuite) TestUpdateConceptWeights_Empty() {
|
||||
err := s.obsStore.UpdateConceptWeights(s.ctx, map[string]float64{})
|
||||
s.NoError(err)
|
||||
}
|
||||
|
||||
// =============================================================================
|
||||
// FEEDBACK STATS TESTS
|
||||
// =============================================================================
|
||||
|
||||
func (s *ScoringStoreSuite) TestGetObservationFeedbackStats_Empty() {
|
||||
stats, err := s.obsStore.GetObservationFeedbackStats(s.ctx, "")
|
||||
s.NoError(err)
|
||||
s.Equal(0, stats.Total)
|
||||
s.Equal(0, stats.Positive)
|
||||
s.Equal(0, stats.Negative)
|
||||
s.Equal(0, stats.Neutral)
|
||||
}
|
||||
|
||||
func (s *ScoringStoreSuite) TestGetObservationFeedbackStats_WithData() {
|
||||
// Create observations with different feedback
|
||||
feedbacks := []int{1, 1, 1, -1, -1, 0, 0, 0, 0, 0}
|
||||
for i, fb := range feedbacks {
|
||||
obs := &models.ParsedObservation{
|
||||
Type: models.ObsTypeDiscovery,
|
||||
}
|
||||
id, _, err := s.obsStore.StoreObservation(s.ctx, "session-1", "project-a", obs, i, 100)
|
||||
s.NoError(err)
|
||||
if fb != 0 {
|
||||
err = s.obsStore.UpdateObservationFeedback(s.ctx, id, fb)
|
||||
s.NoError(err)
|
||||
}
|
||||
}
|
||||
|
||||
stats, err := s.obsStore.GetObservationFeedbackStats(s.ctx, "")
|
||||
s.NoError(err)
|
||||
s.Equal(10, stats.Total)
|
||||
s.Equal(3, stats.Positive)
|
||||
s.Equal(2, stats.Negative)
|
||||
s.Equal(5, stats.Neutral)
|
||||
}
|
||||
|
||||
func (s *ScoringStoreSuite) TestGetObservationFeedbackStats_ByProject() {
|
||||
// Project A observations
|
||||
for i := 0; i < 5; i++ {
|
||||
obs := &models.ParsedObservation{
|
||||
Type: models.ObsTypeBugfix,
|
||||
}
|
||||
id, _, err := s.obsStore.StoreObservation(s.ctx, "session-1", "project-a", obs, i, 100)
|
||||
s.NoError(err)
|
||||
_ = s.obsStore.UpdateObservationFeedback(s.ctx, id, 1)
|
||||
}
|
||||
|
||||
// Project B observations
|
||||
for i := 0; i < 3; i++ {
|
||||
obs := &models.ParsedObservation{
|
||||
Type: models.ObsTypeFeature,
|
||||
}
|
||||
id, _, err := s.obsStore.StoreObservation(s.ctx, "session-1", "project-b", obs, i, 100)
|
||||
s.NoError(err)
|
||||
_ = s.obsStore.UpdateObservationFeedback(s.ctx, id, -1)
|
||||
}
|
||||
|
||||
// Check project A stats
|
||||
statsA, err := s.obsStore.GetObservationFeedbackStats(s.ctx, "project-a")
|
||||
s.NoError(err)
|
||||
s.Equal(5, statsA.Total)
|
||||
s.Equal(5, statsA.Positive)
|
||||
|
||||
// Check project B stats
|
||||
statsB, err := s.obsStore.GetObservationFeedbackStats(s.ctx, "project-b")
|
||||
s.NoError(err)
|
||||
s.Equal(3, statsB.Total)
|
||||
s.Equal(3, statsB.Negative)
|
||||
}
|
||||
|
||||
// =============================================================================
|
||||
// TOP SCORING OBSERVATIONS TESTS
|
||||
// =============================================================================
|
||||
|
||||
func (s *ScoringStoreSuite) TestGetTopScoringObservations() {
|
||||
// Create observations with different scores
|
||||
for i := 0; i < 5; i++ {
|
||||
obs := &models.ParsedObservation{
|
||||
Type: models.ObsTypeDiscovery,
|
||||
}
|
||||
id, _, err := s.obsStore.StoreObservation(s.ctx, "session-1", "project-a", obs, i, 100)
|
||||
s.NoError(err)
|
||||
// Set different scores
|
||||
err = s.obsStore.UpdateImportanceScore(s.ctx, id, float64(i+1)*0.5)
|
||||
s.NoError(err)
|
||||
}
|
||||
|
||||
// Get top 3
|
||||
top, err := s.obsStore.GetTopScoringObservations(s.ctx, "", 3)
|
||||
s.NoError(err)
|
||||
s.Len(top, 3)
|
||||
|
||||
// Verify ordered by score descending
|
||||
s.GreaterOrEqual(top[0].ImportanceScore, top[1].ImportanceScore)
|
||||
s.GreaterOrEqual(top[1].ImportanceScore, top[2].ImportanceScore)
|
||||
}
|
||||
|
||||
func (s *ScoringStoreSuite) TestGetTopScoringObservations_ByProject() {
|
||||
// Project A with high scores
|
||||
for i := 0; i < 3; i++ {
|
||||
obs := &models.ParsedObservation{
|
||||
Type: models.ObsTypeBugfix,
|
||||
}
|
||||
id, _, err := s.obsStore.StoreObservation(s.ctx, "session-1", "project-a", obs, i, 100)
|
||||
s.NoError(err)
|
||||
_ = s.obsStore.UpdateImportanceScore(s.ctx, id, 2.0)
|
||||
}
|
||||
|
||||
// Project B with low scores
|
||||
for i := 0; i < 3; i++ {
|
||||
obs := &models.ParsedObservation{
|
||||
Type: models.ObsTypeChange,
|
||||
}
|
||||
id, _, err := s.obsStore.StoreObservation(s.ctx, "session-1", "project-b", obs, i, 100)
|
||||
s.NoError(err)
|
||||
_ = s.obsStore.UpdateImportanceScore(s.ctx, id, 0.5)
|
||||
}
|
||||
|
||||
// Get top for project A
|
||||
topA, err := s.obsStore.GetTopScoringObservations(s.ctx, "project-a", 10)
|
||||
s.NoError(err)
|
||||
s.Len(topA, 3)
|
||||
for _, obs := range topA {
|
||||
s.Equal("project-a", obs.Project)
|
||||
}
|
||||
}
|
||||
|
||||
// =============================================================================
|
||||
// MOST RETRIEVED OBSERVATIONS TESTS
|
||||
// =============================================================================
|
||||
|
||||
func (s *ScoringStoreSuite) TestGetMostRetrievedObservations() {
|
||||
// Create observations with different retrieval counts
|
||||
var ids []int64
|
||||
for i := 0; i < 5; i++ {
|
||||
obs := &models.ParsedObservation{
|
||||
Type: models.ObsTypeDiscovery,
|
||||
}
|
||||
id, _, err := s.obsStore.StoreObservation(s.ctx, "session-1", "project-a", obs, i, 100)
|
||||
s.NoError(err)
|
||||
ids = append(ids, id)
|
||||
}
|
||||
|
||||
// Set different retrieval counts
|
||||
for i := 0; i < 10; i++ {
|
||||
_ = s.obsStore.IncrementRetrievalCount(s.ctx, []int64{ids[0]}) // 10 retrievals
|
||||
}
|
||||
for i := 0; i < 5; i++ {
|
||||
_ = s.obsStore.IncrementRetrievalCount(s.ctx, []int64{ids[1]}) // 5 retrievals
|
||||
}
|
||||
for i := 0; i < 3; i++ {
|
||||
_ = s.obsStore.IncrementRetrievalCount(s.ctx, []int64{ids[2]}) // 3 retrievals
|
||||
}
|
||||
// ids[3] and ids[4] have 0 retrievals
|
||||
|
||||
// Get top 3
|
||||
most, err := s.obsStore.GetMostRetrievedObservations(s.ctx, "", 3)
|
||||
s.NoError(err)
|
||||
s.Len(most, 3)
|
||||
|
||||
// Verify ordered by retrieval count descending
|
||||
s.Equal(10, most[0].RetrievalCount)
|
||||
s.Equal(5, most[1].RetrievalCount)
|
||||
s.Equal(3, most[2].RetrievalCount)
|
||||
}
|
||||
|
||||
func (s *ScoringStoreSuite) TestGetMostRetrievedObservations_NoRetrievals() {
|
||||
// Create observations without any retrievals
|
||||
obs := &models.ParsedObservation{
|
||||
Type: models.ObsTypeDiscovery,
|
||||
}
|
||||
_, _, err := s.obsStore.StoreObservation(s.ctx, "session-1", "project-a", obs, 1, 100)
|
||||
s.NoError(err)
|
||||
|
||||
most, err := s.obsStore.GetMostRetrievedObservations(s.ctx, "", 10)
|
||||
s.NoError(err)
|
||||
s.Empty(most) // No observations with retrieval_count > 0
|
||||
}
|
||||
|
||||
// =============================================================================
|
||||
// RESET OBSERVATION SCORES TESTS
|
||||
// =============================================================================
|
||||
|
||||
func (s *ScoringStoreSuite) TestResetObservationScores() {
|
||||
// Create observations with various scores
|
||||
for i := 0; i < 5; i++ {
|
||||
obs := &models.ParsedObservation{
|
||||
Type: models.ObsTypeDiscovery,
|
||||
}
|
||||
id, _, err := s.obsStore.StoreObservation(s.ctx, "session-1", "project-a", obs, i, 100)
|
||||
s.NoError(err)
|
||||
_ = s.obsStore.UpdateImportanceScore(s.ctx, id, float64(i+1))
|
||||
}
|
||||
|
||||
// Reset all scores
|
||||
err := s.obsStore.ResetObservationScores(s.ctx)
|
||||
s.NoError(err)
|
||||
|
||||
// Verify all scores are reset to 1.0
|
||||
observations, err := s.obsStore.GetAllRecentObservations(s.ctx, 100)
|
||||
s.NoError(err)
|
||||
for _, obs := range observations {
|
||||
s.InDelta(1.0, obs.ImportanceScore, 0.001)
|
||||
s.False(obs.ScoreUpdatedAt.Valid)
|
||||
}
|
||||
}
|
||||
|
||||
// =============================================================================
|
||||
// EDGE CASES
|
||||
// =============================================================================
|
||||
|
||||
func (s *ScoringStoreSuite) TestScoring_ZeroScore() {
|
||||
obs := &models.ParsedObservation{
|
||||
Type: models.ObsTypeBugfix,
|
||||
}
|
||||
id, _, err := s.obsStore.StoreObservation(s.ctx, "session-1", "project-a", obs, 1, 100)
|
||||
s.NoError(err)
|
||||
|
||||
// Set score to 0
|
||||
err = s.obsStore.UpdateImportanceScore(s.ctx, id, 0.0)
|
||||
s.NoError(err)
|
||||
|
||||
retrieved, err := s.obsStore.GetObservationByID(s.ctx, id)
|
||||
s.NoError(err)
|
||||
s.InDelta(0.0, retrieved.ImportanceScore, 0.001)
|
||||
}
|
||||
|
||||
func (s *ScoringStoreSuite) TestScoring_NegativeScore() {
|
||||
obs := &models.ParsedObservation{
|
||||
Type: models.ObsTypeBugfix,
|
||||
}
|
||||
id, _, err := s.obsStore.StoreObservation(s.ctx, "session-1", "project-a", obs, 1, 100)
|
||||
s.NoError(err)
|
||||
|
||||
// Set negative score (calculator shouldn't produce this, but test DB handling)
|
||||
err = s.obsStore.UpdateImportanceScore(s.ctx, id, -0.5)
|
||||
s.NoError(err)
|
||||
|
||||
retrieved, err := s.obsStore.GetObservationByID(s.ctx, id)
|
||||
s.NoError(err)
|
||||
s.InDelta(-0.5, retrieved.ImportanceScore, 0.001)
|
||||
}
|
||||
|
||||
func (s *ScoringStoreSuite) TestScoring_LargeScore() {
|
||||
obs := &models.ParsedObservation{
|
||||
Type: models.ObsTypeBugfix,
|
||||
}
|
||||
id, _, err := s.obsStore.StoreObservation(s.ctx, "session-1", "project-a", obs, 1, 100)
|
||||
s.NoError(err)
|
||||
|
||||
// Set very large score
|
||||
err = s.obsStore.UpdateImportanceScore(s.ctx, id, 999.999)
|
||||
s.NoError(err)
|
||||
|
||||
retrieved, err := s.obsStore.GetObservationByID(s.ctx, id)
|
||||
s.NoError(err)
|
||||
s.InDelta(999.999, retrieved.ImportanceScore, 0.001)
|
||||
}
|
||||
|
||||
func (s *ScoringStoreSuite) TestConceptWeight_ZeroWeight() {
|
||||
err := s.obsStore.UpdateConceptWeight(s.ctx, "zero-concept", 0.0)
|
||||
s.NoError(err)
|
||||
|
||||
weights, err := s.obsStore.GetConceptWeights(s.ctx)
|
||||
s.NoError(err)
|
||||
s.Equal(0.0, weights["zero-concept"])
|
||||
}
|
||||
|
||||
func (s *ScoringStoreSuite) TestConceptWeight_ExactBoundary() {
|
||||
err := s.obsStore.UpdateConceptWeight(s.ctx, "max-concept", 1.0)
|
||||
s.NoError(err)
|
||||
|
||||
weights, err := s.obsStore.GetConceptWeights(s.ctx)
|
||||
s.NoError(err)
|
||||
s.Equal(1.0, weights["max-concept"])
|
||||
}
|
||||
|
||||
// =============================================================================
|
||||
// STANDALONE TESTS
|
||||
// =============================================================================
|
||||
|
||||
func TestFeedbackStats_Structure(t *testing.T) {
|
||||
stats := FeedbackStats{
|
||||
Total: 100,
|
||||
Positive: 30,
|
||||
Negative: 10,
|
||||
Neutral: 60,
|
||||
AvgScore: 1.5,
|
||||
AvgRetrieval: 5.0,
|
||||
}
|
||||
|
||||
assert.Equal(t, 100, stats.Total)
|
||||
assert.Equal(t, 30, stats.Positive)
|
||||
assert.Equal(t, 10, stats.Negative)
|
||||
assert.Equal(t, 60, stats.Neutral)
|
||||
assert.Equal(t, 1.5, stats.AvgScore)
|
||||
assert.Equal(t, 5.0, stats.AvgRetrieval)
|
||||
}
|
||||
|
||||
func TestScoringStore_Integration(t *testing.T) {
|
||||
obsStore, _, cleanup := testScoringObservationStore(t)
|
||||
defer cleanup()
|
||||
|
||||
ctx := context.Background()
|
||||
|
||||
// Full integration test: store, feedback, retrieval, score update
|
||||
obs := &models.ParsedObservation{
|
||||
Type: models.ObsTypeBugfix,
|
||||
Title: "Integration test observation",
|
||||
Concepts: []string{"security"},
|
||||
}
|
||||
id, _, err := obsStore.StoreObservation(ctx, "session-int", "project-int", obs, 1, 100)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Add feedback
|
||||
err = obsStore.UpdateObservationFeedback(ctx, id, 1)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Increment retrieval
|
||||
err = obsStore.IncrementRetrievalCount(ctx, []int64{id})
|
||||
require.NoError(t, err)
|
||||
|
||||
// Update score
|
||||
err = obsStore.UpdateImportanceScore(ctx, id, 1.75)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Verify final state
|
||||
retrieved, err := obsStore.GetObservationByID(ctx, id)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, 1, retrieved.UserFeedback)
|
||||
assert.Equal(t, 1, retrieved.RetrievalCount)
|
||||
assert.InDelta(t, 1.75, retrieved.ImportanceScore, 0.001)
|
||||
assert.True(t, retrieved.ScoreUpdatedAt.Valid)
|
||||
assert.True(t, retrieved.LastRetrievedAt.Valid)
|
||||
}
|
||||
@@ -116,3 +116,21 @@ func (s *SummaryStore) GetAllRecentSummaries(ctx context.Context, limit int) ([]
|
||||
|
||||
return scanSummaryRows(rows)
|
||||
}
|
||||
|
||||
// GetAllSummaries retrieves all summaries (for vector rebuild).
|
||||
func (s *SummaryStore) GetAllSummaries(ctx context.Context) ([]*models.SessionSummary, error) {
|
||||
const query = `
|
||||
SELECT id, sdk_session_id, project, request, investigated, learned, completed,
|
||||
next_steps, notes, prompt_number, discovery_tokens, created_at, created_at_epoch
|
||||
FROM session_summaries
|
||||
ORDER BY id
|
||||
`
|
||||
|
||||
rows, err := s.store.QueryContext(ctx, query)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer rows.Close()
|
||||
|
||||
return scanSummaryRows(rows)
|
||||
}
|
||||
|
||||
@@ -103,6 +103,12 @@ func createBaseTables(t *testing.T, db *sql.DB) {
|
||||
discovery_tokens INTEGER DEFAULT 0,
|
||||
created_at TEXT NOT NULL,
|
||||
created_at_epoch INTEGER NOT NULL,
|
||||
importance_score REAL DEFAULT 1.0,
|
||||
user_feedback INTEGER DEFAULT 0,
|
||||
retrieval_count INTEGER DEFAULT 0,
|
||||
last_retrieved_at_epoch INTEGER,
|
||||
score_updated_at_epoch INTEGER,
|
||||
is_superseded INTEGER DEFAULT 0,
|
||||
FOREIGN KEY(sdk_session_id) REFERENCES sdk_sessions(sdk_session_id) ON DELETE CASCADE
|
||||
)
|
||||
`)
|
||||
@@ -110,6 +116,27 @@ func createBaseTables(t *testing.T, db *sql.DB) {
|
||||
t.Fatalf("create observations: %v", err)
|
||||
}
|
||||
|
||||
// Create observation_conflicts table for conflict detection
|
||||
_, err = db.Exec(`
|
||||
CREATE TABLE IF NOT EXISTS observation_conflicts (
|
||||
id INTEGER PRIMARY KEY AUTOINCREMENT,
|
||||
newer_obs_id INTEGER NOT NULL,
|
||||
older_obs_id INTEGER NOT NULL,
|
||||
conflict_type TEXT NOT NULL CHECK(conflict_type IN ('superseded', 'contradicts', 'outdated_pattern')),
|
||||
resolution TEXT NOT NULL CHECK(resolution IN ('prefer_newer', 'prefer_older', 'manual')),
|
||||
reason TEXT,
|
||||
detected_at TEXT NOT NULL,
|
||||
detected_at_epoch INTEGER NOT NULL,
|
||||
resolved INTEGER DEFAULT 0,
|
||||
resolved_at TEXT,
|
||||
FOREIGN KEY(newer_obs_id) REFERENCES observations(id) ON DELETE CASCADE,
|
||||
FOREIGN KEY(older_obs_id) REFERENCES observations(id) ON DELETE CASCADE
|
||||
)
|
||||
`)
|
||||
if err != nil {
|
||||
t.Fatalf("create observation_conflicts: %v", err)
|
||||
}
|
||||
|
||||
_, err = db.Exec(`
|
||||
CREATE TABLE IF NOT EXISTS session_summaries (
|
||||
id INTEGER PRIMARY KEY AUTOINCREMENT,
|
||||
@@ -150,6 +177,31 @@ func createBaseTables(t *testing.T, db *sql.DB) {
|
||||
t.Fatalf("create user_prompts: %v", err)
|
||||
}
|
||||
|
||||
_, err = db.Exec(`
|
||||
CREATE TABLE IF NOT EXISTS patterns (
|
||||
id INTEGER PRIMARY KEY AUTOINCREMENT,
|
||||
name TEXT NOT NULL,
|
||||
type TEXT NOT NULL CHECK(type IN ('bug', 'refactor', 'architecture', 'anti-pattern', 'best-practice')),
|
||||
description TEXT,
|
||||
signature TEXT,
|
||||
recommendation TEXT,
|
||||
frequency INTEGER DEFAULT 1,
|
||||
projects TEXT,
|
||||
observation_ids TEXT,
|
||||
status TEXT DEFAULT 'active' CHECK(status IN ('active', 'deprecated', 'merged')),
|
||||
merged_into_id INTEGER,
|
||||
confidence REAL DEFAULT 0.5,
|
||||
last_seen_at TEXT NOT NULL,
|
||||
last_seen_at_epoch INTEGER NOT NULL,
|
||||
created_at TEXT NOT NULL,
|
||||
created_at_epoch INTEGER NOT NULL,
|
||||
FOREIGN KEY(merged_into_id) REFERENCES patterns(id) ON DELETE SET NULL
|
||||
)
|
||||
`)
|
||||
if err != nil {
|
||||
t.Fatalf("create patterns: %v", err)
|
||||
}
|
||||
|
||||
indexes := []string{
|
||||
`CREATE INDEX IF NOT EXISTS idx_sdk_sessions_claude_id ON sdk_sessions(claude_session_id)`,
|
||||
`CREATE INDEX IF NOT EXISTS idx_sdk_sessions_sdk_id ON sdk_sessions(sdk_session_id)`,
|
||||
|
||||
@@ -0,0 +1 @@
|
||||
bge-small-en-v1.5
|
||||
Binary file not shown.
@@ -1,21 +1,7 @@
|
||||
{
|
||||
"version": "1.0",
|
||||
"truncation": {
|
||||
"direction": "Right",
|
||||
"max_length": 128,
|
||||
"strategy": "LongestFirst",
|
||||
"stride": 0
|
||||
},
|
||||
"padding": {
|
||||
"strategy": {
|
||||
"Fixed": 128
|
||||
},
|
||||
"direction": "Right",
|
||||
"pad_to_multiple_of": null,
|
||||
"pad_id": 0,
|
||||
"pad_type_id": 0,
|
||||
"pad_token": "[PAD]"
|
||||
},
|
||||
"truncation": null,
|
||||
"padding": null,
|
||||
"added_tokens": [
|
||||
{
|
||||
"id": 0,
|
||||
|
||||
@@ -0,0 +1,157 @@
|
||||
// Package embedding provides text embedding generation with swappable models.
|
||||
package embedding
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"sync"
|
||||
)
|
||||
|
||||
// PoolingStrategy defines how to pool token embeddings into sentence embeddings.
|
||||
type PoolingStrategy string
|
||||
|
||||
const (
|
||||
// PoolingNone means the model already outputs sentence embeddings directly.
|
||||
PoolingNone PoolingStrategy = "none"
|
||||
// PoolingMean averages all token embeddings (weighted by attention mask).
|
||||
PoolingMean PoolingStrategy = "mean"
|
||||
// PoolingCLS uses only the [CLS] token embedding.
|
||||
PoolingCLS PoolingStrategy = "cls"
|
||||
)
|
||||
|
||||
// ONNXConfig describes ONNX-specific model configuration.
|
||||
// This allows different models to specify their tensor names and pooling needs.
|
||||
type ONNXConfig struct {
|
||||
// InputNames are the ONNX input tensor names in order.
|
||||
InputNames []string
|
||||
// OutputNames are the ONNX output tensor names.
|
||||
OutputNames []string
|
||||
// Pooling specifies how to convert token embeddings to sentence embeddings.
|
||||
// If PoolingNone, the model outputs sentence embeddings directly.
|
||||
Pooling PoolingStrategy
|
||||
// HiddenSize is the embedding dimension (used for pooling calculations).
|
||||
HiddenSize int
|
||||
}
|
||||
|
||||
// EmbeddingModel represents a text embedding model.
|
||||
type EmbeddingModel interface {
|
||||
// Name returns the human-readable model name (e.g., "bge-small-en-v1.5").
|
||||
Name() string
|
||||
|
||||
// Version returns a short version string for storage (e.g., "bge-v1.5").
|
||||
Version() string
|
||||
|
||||
// Dimensions returns the embedding vector size.
|
||||
Dimensions() int
|
||||
|
||||
// Embed generates an embedding for a single text.
|
||||
Embed(text string) ([]float32, error)
|
||||
|
||||
// EmbedBatch generates embeddings for multiple texts.
|
||||
EmbedBatch(texts []string) ([][]float32, error)
|
||||
|
||||
// Close releases model resources.
|
||||
Close() error
|
||||
}
|
||||
|
||||
// ONNXConfigurer is an optional interface that models can implement
|
||||
// to expose their ONNX configuration for introspection.
|
||||
type ONNXConfigurer interface {
|
||||
// ONNXConfig returns the model's ONNX configuration.
|
||||
ONNXConfig() ONNXConfig
|
||||
}
|
||||
|
||||
// ModelMetadata describes an embedding model for UI/config.
|
||||
type ModelMetadata struct {
|
||||
Name string `json:"name"` // Human-readable name
|
||||
Version string `json:"version"` // Short ID for DB storage
|
||||
Dimensions int `json:"dimensions"` // Vector size
|
||||
Description string `json:"description"` // Brief description
|
||||
Default bool `json:"default"` // Is this the default model?
|
||||
}
|
||||
|
||||
// ModelFactory creates a new instance of an embedding model.
|
||||
type ModelFactory func() (EmbeddingModel, error)
|
||||
|
||||
// ModelRegistry provides model lookup by version.
|
||||
type ModelRegistry struct {
|
||||
mu sync.RWMutex
|
||||
models map[string]ModelFactory
|
||||
metadata map[string]ModelMetadata
|
||||
defaultModel string
|
||||
}
|
||||
|
||||
// NewModelRegistry creates a new model registry.
|
||||
func NewModelRegistry() *ModelRegistry {
|
||||
return &ModelRegistry{
|
||||
models: make(map[string]ModelFactory),
|
||||
metadata: make(map[string]ModelMetadata),
|
||||
}
|
||||
}
|
||||
|
||||
// Register adds a model factory to the registry.
|
||||
func (r *ModelRegistry) Register(meta ModelMetadata, factory ModelFactory) {
|
||||
r.mu.Lock()
|
||||
defer r.mu.Unlock()
|
||||
|
||||
r.models[meta.Version] = factory
|
||||
r.metadata[meta.Version] = meta
|
||||
|
||||
if meta.Default {
|
||||
r.defaultModel = meta.Version
|
||||
}
|
||||
}
|
||||
|
||||
// Get creates a new instance of the model with the given version.
|
||||
func (r *ModelRegistry) Get(version string) (EmbeddingModel, error) {
|
||||
r.mu.RLock()
|
||||
factory, ok := r.models[version]
|
||||
r.mu.RUnlock()
|
||||
|
||||
if !ok {
|
||||
return nil, fmt.Errorf("unknown model version: %s", version)
|
||||
}
|
||||
|
||||
return factory()
|
||||
}
|
||||
|
||||
// Default returns the default model version.
|
||||
func (r *ModelRegistry) Default() string {
|
||||
r.mu.RLock()
|
||||
defer r.mu.RUnlock()
|
||||
return r.defaultModel
|
||||
}
|
||||
|
||||
// List returns metadata for all registered models.
|
||||
func (r *ModelRegistry) List() []ModelMetadata {
|
||||
r.mu.RLock()
|
||||
defer r.mu.RUnlock()
|
||||
|
||||
result := make([]ModelMetadata, 0, len(r.metadata))
|
||||
for _, meta := range r.metadata {
|
||||
result = append(result, meta)
|
||||
}
|
||||
return result
|
||||
}
|
||||
|
||||
// DefaultRegistry is the global model registry with all available models.
|
||||
var DefaultRegistry = NewModelRegistry()
|
||||
|
||||
// RegisterModel adds a model to the default registry.
|
||||
func RegisterModel(meta ModelMetadata, factory ModelFactory) {
|
||||
DefaultRegistry.Register(meta, factory)
|
||||
}
|
||||
|
||||
// GetModel creates a model instance from the default registry.
|
||||
func GetModel(version string) (EmbeddingModel, error) {
|
||||
return DefaultRegistry.Get(version)
|
||||
}
|
||||
|
||||
// GetDefaultModel returns the default model version from the default registry.
|
||||
func GetDefaultModel() string {
|
||||
return DefaultRegistry.Default()
|
||||
}
|
||||
|
||||
// ListModels returns metadata for all models in the default registry.
|
||||
func ListModels() []ModelMetadata {
|
||||
return DefaultRegistry.List()
|
||||
}
|
||||
+277
-56
@@ -1,4 +1,4 @@
|
||||
// Package embedding provides text embedding generation using all-MiniLM-L6-v2.
|
||||
// Package embedding provides text embedding generation with swappable models.
|
||||
package embedding
|
||||
|
||||
import (
|
||||
@@ -15,19 +15,50 @@ import (
|
||||
ort "github.com/yalue/onnxruntime_go"
|
||||
)
|
||||
|
||||
// EmbeddingDim is the dimension of embeddings produced by all-MiniLM-L6-v2.
|
||||
// EmbeddingDim is the dimension of embeddings produced by the current model.
|
||||
// Both all-MiniLM-L6-v2 and bge-small-en-v1.5 produce 384-dimensional embeddings.
|
||||
const EmbeddingDim = 384
|
||||
|
||||
// Service provides thread-safe text embedding generation.
|
||||
type Service struct {
|
||||
// Model version constants
|
||||
const (
|
||||
// BGEModelVersion is the version string for bge-small-en-v1.5
|
||||
BGEModelVersion = "bge-v1.5"
|
||||
// BGEModelName is the human-readable name for bge-small-en-v1.5
|
||||
BGEModelName = "bge-small-en-v1.5"
|
||||
// DefaultModelVersion is the default model to use
|
||||
DefaultModelVersion = BGEModelVersion
|
||||
)
|
||||
|
||||
// MaxSequenceLength is the maximum token sequence length for the model.
|
||||
const MaxSequenceLength = 512
|
||||
|
||||
// bgeONNXConfig defines the ONNX configuration for BGE models.
|
||||
// BGE outputs last_hidden_state and requires mean pooling.
|
||||
var bgeONNXConfig = ONNXConfig{
|
||||
InputNames: []string{"input_ids", "attention_mask", "token_type_ids"},
|
||||
OutputNames: []string{"last_hidden_state"},
|
||||
Pooling: PoolingMean,
|
||||
HiddenSize: EmbeddingDim,
|
||||
}
|
||||
|
||||
// bgeModel is the ONNX-based embedding model implementation.
|
||||
// Currently supports bge-small-en-v1.5 (previously all-MiniLM-L6-v2).
|
||||
type bgeModel struct {
|
||||
tk *tokenizer.Tokenizer
|
||||
session *ort.DynamicAdvancedSession
|
||||
mu sync.Mutex
|
||||
libDir string // temp directory containing extracted libraries
|
||||
libDir string // temp directory containing extracted libraries
|
||||
config ONNXConfig // ONNX configuration for this model
|
||||
}
|
||||
|
||||
// NewService creates a new embedding service using bundled ONNX runtime and model.
|
||||
func NewService() (*Service, error) {
|
||||
// Compile-time check that bgeModel implements EmbeddingModel
|
||||
var _ EmbeddingModel = (*bgeModel)(nil)
|
||||
|
||||
// Compile-time check that bgeModel implements ONNXConfigurer
|
||||
var _ ONNXConfigurer = (*bgeModel)(nil)
|
||||
|
||||
// newBGEModel creates a new BGE embedding model using bundled ONNX runtime and model.
|
||||
func newBGEModel() (EmbeddingModel, error) {
|
||||
// Extract ONNX runtime library to temp directory
|
||||
libDir, err := extractONNXLibrary()
|
||||
if err != nil {
|
||||
@@ -49,22 +80,41 @@ func NewService() (*Service, error) {
|
||||
return nil, fmt.Errorf("load tokenizer: %w", err)
|
||||
}
|
||||
|
||||
// Create ONNX session with embedded model
|
||||
inputNames := []string{"input_ids", "attention_mask", "token_type_ids"}
|
||||
outputNames := []string{"sentence_embedding"}
|
||||
|
||||
session, err := ort.NewDynamicAdvancedSessionWithONNXData(modelData, inputNames, outputNames, nil)
|
||||
// Create ONNX session using model-specific configuration
|
||||
config := bgeONNXConfig
|
||||
session, err := ort.NewDynamicAdvancedSessionWithONNXData(modelData, config.InputNames, config.OutputNames, nil)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("create ONNX session: %w", err)
|
||||
}
|
||||
|
||||
return &Service{
|
||||
return &bgeModel{
|
||||
tk: tk,
|
||||
session: session,
|
||||
libDir: libDir,
|
||||
config: config,
|
||||
}, nil
|
||||
}
|
||||
|
||||
// ONNXConfig returns the model's ONNX configuration.
|
||||
func (m *bgeModel) ONNXConfig() ONNXConfig {
|
||||
return m.config
|
||||
}
|
||||
|
||||
// Name returns the human-readable model name.
|
||||
func (m *bgeModel) Name() string {
|
||||
return BGEModelName
|
||||
}
|
||||
|
||||
// Version returns the short version string for storage.
|
||||
func (m *bgeModel) Version() string {
|
||||
return BGEModelVersion
|
||||
}
|
||||
|
||||
// Dimensions returns the embedding vector size.
|
||||
func (m *bgeModel) Dimensions() int {
|
||||
return EmbeddingDim
|
||||
}
|
||||
|
||||
// extractONNXLibrary extracts the embedded ONNX runtime library to a temp directory.
|
||||
// Uses content hash to avoid re-extracting if already present.
|
||||
func extractONNXLibrary() (string, error) {
|
||||
@@ -107,15 +157,15 @@ func extractONNXLibrary() (string, error) {
|
||||
|
||||
// Embed generates an embedding for a single text.
|
||||
// Returns a 384-dimensional float32 vector.
|
||||
func (s *Service) Embed(text string) ([]float32, error) {
|
||||
s.mu.Lock()
|
||||
defer s.mu.Unlock()
|
||||
func (m *bgeModel) Embed(text string) ([]float32, error) {
|
||||
m.mu.Lock()
|
||||
defer m.mu.Unlock()
|
||||
|
||||
if text == "" {
|
||||
return make([]float32, EmbeddingDim), nil
|
||||
}
|
||||
|
||||
results, err := s.computeBatch([]string{text})
|
||||
results, err := m.computeBatch([]string{text})
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
@@ -127,13 +177,13 @@ func (s *Service) Embed(text string) ([]float32, error) {
|
||||
|
||||
// EmbedBatch generates embeddings for multiple texts.
|
||||
// Returns slice of 384-dimensional float32 vectors.
|
||||
func (s *Service) EmbedBatch(texts []string) ([][]float32, error) {
|
||||
func (m *bgeModel) EmbedBatch(texts []string) ([][]float32, error) {
|
||||
if len(texts) == 0 {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
s.mu.Lock()
|
||||
defer s.mu.Unlock()
|
||||
m.mu.Lock()
|
||||
defer m.mu.Unlock()
|
||||
|
||||
// Filter out empty texts and track indices
|
||||
nonEmpty := make([]string, 0, len(texts))
|
||||
@@ -155,7 +205,7 @@ func (s *Service) EmbedBatch(texts []string) ([][]float32, error) {
|
||||
}
|
||||
|
||||
// Compute embeddings for non-empty texts
|
||||
embeddings, err := s.computeBatch(nonEmpty)
|
||||
embeddings, err := m.computeBatch(nonEmpty)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("compute batch embeddings: %w", err)
|
||||
}
|
||||
@@ -173,7 +223,7 @@ func (s *Service) EmbedBatch(texts []string) ([][]float32, error) {
|
||||
}
|
||||
|
||||
// computeBatch runs inference on a batch of texts. Must be called with lock held.
|
||||
func (s *Service) computeBatch(sentences []string) ([][]float32, error) {
|
||||
func (m *bgeModel) computeBatch(sentences []string) ([][]float32, error) {
|
||||
if len(sentences) == 0 {
|
||||
return nil, nil
|
||||
}
|
||||
@@ -184,31 +234,57 @@ func (s *Service) computeBatch(sentences []string) ([][]float32, error) {
|
||||
inputBatch[i] = tokenizer.NewSingleEncodeInput(tokenizer.NewRawInputSequence(sent))
|
||||
}
|
||||
|
||||
encodings, err := s.tk.EncodeBatch(inputBatch, true)
|
||||
encodings, err := m.tk.EncodeBatch(inputBatch, true)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("tokenize: %w", err)
|
||||
}
|
||||
|
||||
batchSize := len(encodings)
|
||||
seqLength := len(encodings[0].Ids)
|
||||
hiddenSize := EmbeddingDim
|
||||
hiddenSize := m.config.HiddenSize
|
||||
|
||||
// Find max sequence length across all encodings (tokenizer may not pad uniformly)
|
||||
// Also enforce MaxSequenceLength to prevent model errors
|
||||
seqLength := 0
|
||||
for _, enc := range encodings {
|
||||
if len(enc.Ids) > seqLength {
|
||||
seqLength = len(enc.Ids)
|
||||
}
|
||||
}
|
||||
// Truncate to max model sequence length
|
||||
if seqLength > MaxSequenceLength {
|
||||
seqLength = MaxSequenceLength
|
||||
}
|
||||
|
||||
inputShape := ort.NewShape(int64(batchSize), int64(seqLength))
|
||||
|
||||
// Create input tensors
|
||||
// Create input tensors (pre-filled with zeros for padding)
|
||||
inputIdsData := make([]int64, batchSize*seqLength)
|
||||
attentionMaskData := make([]int64, batchSize*seqLength)
|
||||
tokenTypeIdsData := make([]int64, batchSize*seqLength)
|
||||
|
||||
for b := 0; b < batchSize; b++ {
|
||||
for i, id := range encodings[b].Ids {
|
||||
inputIdsData[b*seqLength+i] = int64(id)
|
||||
// Copy actual token data (rest remains 0 as padding)
|
||||
// Truncate to seqLength to handle long inputs
|
||||
copyLen := len(encodings[b].Ids)
|
||||
if copyLen > seqLength {
|
||||
copyLen = seqLength
|
||||
}
|
||||
for i, mask := range encodings[b].AttentionMask {
|
||||
attentionMaskData[b*seqLength+i] = int64(mask)
|
||||
for i := 0; i < copyLen; i++ {
|
||||
inputIdsData[b*seqLength+i] = int64(encodings[b].Ids[i])
|
||||
}
|
||||
for i, typeId := range encodings[b].TypeIds {
|
||||
tokenTypeIdsData[b*seqLength+i] = int64(typeId)
|
||||
copyLen = len(encodings[b].AttentionMask)
|
||||
if copyLen > seqLength {
|
||||
copyLen = seqLength
|
||||
}
|
||||
for i := 0; i < copyLen; i++ {
|
||||
attentionMaskData[b*seqLength+i] = int64(encodings[b].AttentionMask[i])
|
||||
}
|
||||
copyLen = len(encodings[b].TypeIds)
|
||||
if copyLen > seqLength {
|
||||
copyLen = seqLength
|
||||
}
|
||||
for i := 0; i < copyLen; i++ {
|
||||
tokenTypeIdsData[b*seqLength+i] = int64(encodings[b].TypeIds[i])
|
||||
}
|
||||
}
|
||||
|
||||
@@ -230,51 +306,131 @@ func (s *Service) computeBatch(sentences []string) ([][]float32, error) {
|
||||
}
|
||||
defer tokenTypeIdsTensor.Destroy()
|
||||
|
||||
sentenceOutputShape := ort.NewShape(int64(batchSize), int64(hiddenSize))
|
||||
sentenceOutputTensor, err := ort.NewEmptyTensor[float32](sentenceOutputShape)
|
||||
// Create output tensor based on pooling strategy
|
||||
var outputShape ort.Shape
|
||||
|
||||
switch m.config.Pooling {
|
||||
case PoolingNone:
|
||||
// Direct sentence embedding output: [batch, hidden]
|
||||
outputShape = ort.NewShape(int64(batchSize), int64(hiddenSize))
|
||||
case PoolingMean, PoolingCLS:
|
||||
// Token-level output: [batch, seq_len, hidden]
|
||||
outputShape = ort.NewShape(int64(batchSize), int64(seqLength), int64(hiddenSize))
|
||||
default:
|
||||
outputShape = ort.NewShape(int64(batchSize), int64(hiddenSize))
|
||||
}
|
||||
|
||||
outputTensor, err := ort.NewEmptyTensor[float32](outputShape)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("create output tensor: %w", err)
|
||||
}
|
||||
defer sentenceOutputTensor.Destroy()
|
||||
defer outputTensor.Destroy()
|
||||
|
||||
// Run inference
|
||||
inputTensors := []ort.Value{inputIdsTensor, attentionMaskTensor, tokenTypeIdsTensor}
|
||||
outputTensors := []ort.Value{sentenceOutputTensor}
|
||||
outputTensors := []ort.Value{outputTensor}
|
||||
|
||||
if err := s.session.Run(inputTensors, outputTensors); err != nil {
|
||||
if err := m.session.Run(inputTensors, outputTensors); err != nil {
|
||||
return nil, fmt.Errorf("run inference: %w", err)
|
||||
}
|
||||
|
||||
// Extract results
|
||||
flatOutput := sentenceOutputTensor.GetData()
|
||||
expectedSize := batchSize * hiddenSize
|
||||
if len(flatOutput) != expectedSize {
|
||||
return nil, fmt.Errorf("unexpected output size: got %d, expected %d", len(flatOutput), expectedSize)
|
||||
}
|
||||
// Extract and pool results based on strategy
|
||||
flatOutput := outputTensor.GetData()
|
||||
|
||||
switch m.config.Pooling {
|
||||
case PoolingNone:
|
||||
// Direct output, no pooling needed
|
||||
expectedSize := batchSize * hiddenSize
|
||||
if len(flatOutput) != expectedSize {
|
||||
return nil, fmt.Errorf("unexpected output size: got %d, expected %d", len(flatOutput), expectedSize)
|
||||
}
|
||||
results := make([][]float32, batchSize)
|
||||
for i := 0; i < batchSize; i++ {
|
||||
start := i * hiddenSize
|
||||
end := start + hiddenSize
|
||||
results[i] = make([]float32, hiddenSize)
|
||||
copy(results[i], flatOutput[start:end])
|
||||
}
|
||||
return results, nil
|
||||
|
||||
case PoolingMean:
|
||||
// Mean pooling over tokens (weighted by attention mask)
|
||||
return meanPooling(flatOutput, attentionMaskData, batchSize, seqLength, hiddenSize), nil
|
||||
|
||||
case PoolingCLS:
|
||||
// CLS token pooling (first token of each sequence)
|
||||
return clsPooling(flatOutput, batchSize, seqLength, hiddenSize), nil
|
||||
|
||||
default:
|
||||
return nil, fmt.Errorf("unknown pooling strategy: %s", m.config.Pooling)
|
||||
}
|
||||
}
|
||||
|
||||
// meanPooling applies mean pooling over token embeddings, weighted by attention mask.
|
||||
// Input shape: [batch, seq_len, hidden], attention mask: [batch, seq_len]
|
||||
// Output shape: [batch, hidden]
|
||||
func meanPooling(embeddings []float32, attentionMask []int64, batchSize, seqLen, hiddenSize int) [][]float32 {
|
||||
results := make([][]float32, batchSize)
|
||||
for i := 0; i < batchSize; i++ {
|
||||
start := i * hiddenSize
|
||||
end := start + hiddenSize
|
||||
results[i] = make([]float32, hiddenSize)
|
||||
copy(results[i], flatOutput[start:end])
|
||||
|
||||
for b := 0; b < batchSize; b++ {
|
||||
result := make([]float32, hiddenSize)
|
||||
var maskSum float32
|
||||
|
||||
// Sum embeddings weighted by attention mask
|
||||
for s := 0; s < seqLen; s++ {
|
||||
maskVal := float32(attentionMask[b*seqLen+s])
|
||||
maskSum += maskVal
|
||||
|
||||
if maskVal > 0 {
|
||||
embOffset := (b*seqLen + s) * hiddenSize
|
||||
for h := 0; h < hiddenSize; h++ {
|
||||
result[h] += embeddings[embOffset+h] * maskVal
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Normalize by mask sum (avoid division by zero)
|
||||
if maskSum > 0 {
|
||||
for h := 0; h < hiddenSize; h++ {
|
||||
result[h] /= maskSum
|
||||
}
|
||||
}
|
||||
|
||||
results[b] = result
|
||||
}
|
||||
|
||||
return results, nil
|
||||
return results
|
||||
}
|
||||
|
||||
// clsPooling extracts the [CLS] token embedding (first token).
|
||||
// Input shape: [batch, seq_len, hidden]
|
||||
// Output shape: [batch, hidden]
|
||||
func clsPooling(embeddings []float32, batchSize, seqLen, hiddenSize int) [][]float32 {
|
||||
results := make([][]float32, batchSize)
|
||||
|
||||
for b := 0; b < batchSize; b++ {
|
||||
result := make([]float32, hiddenSize)
|
||||
// CLS token is at position 0
|
||||
embOffset := b * seqLen * hiddenSize
|
||||
copy(result, embeddings[embOffset:embOffset+hiddenSize])
|
||||
results[b] = result
|
||||
}
|
||||
|
||||
return results
|
||||
}
|
||||
|
||||
// Close releases model resources.
|
||||
func (s *Service) Close() error {
|
||||
s.mu.Lock()
|
||||
defer s.mu.Unlock()
|
||||
func (m *bgeModel) Close() error {
|
||||
m.mu.Lock()
|
||||
defer m.mu.Unlock()
|
||||
|
||||
var errs []error
|
||||
|
||||
if s.session != nil {
|
||||
if err := s.session.Destroy(); err != nil {
|
||||
if m.session != nil {
|
||||
if err := m.session.Destroy(); err != nil {
|
||||
errs = append(errs, fmt.Errorf("destroy session: %w", err))
|
||||
}
|
||||
s.session = nil
|
||||
m.session = nil
|
||||
}
|
||||
|
||||
if err := ort.DestroyEnvironment(); err != nil {
|
||||
@@ -282,10 +438,75 @@ func (s *Service) Close() error {
|
||||
}
|
||||
|
||||
// Optionally clean up extracted library (leave for caching)
|
||||
// os.RemoveAll(s.libDir)
|
||||
// os.RemoveAll(m.libDir)
|
||||
|
||||
if len(errs) > 0 {
|
||||
return errs[0]
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// Register the BGE model with the default registry at init time
|
||||
func init() {
|
||||
RegisterModel(ModelMetadata{
|
||||
Name: BGEModelName,
|
||||
Version: BGEModelVersion,
|
||||
Dimensions: EmbeddingDim,
|
||||
Description: "High-quality semantic search model",
|
||||
Default: true,
|
||||
}, newBGEModel)
|
||||
}
|
||||
|
||||
// Service provides thread-safe text embedding generation with model abstraction.
|
||||
type Service struct {
|
||||
model EmbeddingModel
|
||||
}
|
||||
|
||||
// NewService creates a new embedding service using the default model.
|
||||
func NewService() (*Service, error) {
|
||||
return NewServiceWithModel(DefaultModelVersion)
|
||||
}
|
||||
|
||||
// NewServiceWithModel creates a new embedding service using the specified model.
|
||||
func NewServiceWithModel(version string) (*Service, error) {
|
||||
if version == "" {
|
||||
version = DefaultModelVersion
|
||||
}
|
||||
|
||||
model, err := GetModel(version)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("get model %s: %w", version, err)
|
||||
}
|
||||
|
||||
return &Service{model: model}, nil
|
||||
}
|
||||
|
||||
// Name returns the human-readable model name.
|
||||
func (s *Service) Name() string {
|
||||
return s.model.Name()
|
||||
}
|
||||
|
||||
// Version returns the short version string for storage.
|
||||
func (s *Service) Version() string {
|
||||
return s.model.Version()
|
||||
}
|
||||
|
||||
// Dimensions returns the embedding vector size.
|
||||
func (s *Service) Dimensions() int {
|
||||
return s.model.Dimensions()
|
||||
}
|
||||
|
||||
// Embed generates an embedding for a single text.
|
||||
func (s *Service) Embed(text string) ([]float32, error) {
|
||||
return s.model.Embed(text)
|
||||
}
|
||||
|
||||
// EmbedBatch generates embeddings for multiple texts.
|
||||
func (s *Service) EmbedBatch(texts []string) ([][]float32, error) {
|
||||
return s.model.EmbedBatch(texts)
|
||||
}
|
||||
|
||||
// Close releases model resources.
|
||||
func (s *Service) Close() error {
|
||||
return s.model.Close()
|
||||
}
|
||||
|
||||
@@ -22,8 +22,10 @@ func TestNewService(t *testing.T) {
|
||||
|
||||
defer svc.Close()
|
||||
|
||||
assert.NotNil(t, svc.tk)
|
||||
assert.NotNil(t, svc.session)
|
||||
// Verify the service is properly initialized via public methods
|
||||
assert.NotEmpty(t, svc.Name())
|
||||
assert.NotEmpty(t, svc.Version())
|
||||
assert.Equal(t, EmbeddingDim, svc.Dimensions())
|
||||
}
|
||||
|
||||
// TestEmbed_SingleText tests embedding a single text.
|
||||
@@ -269,8 +271,8 @@ func TestClose(t *testing.T) {
|
||||
err = svc.Close()
|
||||
require.NoError(t, err)
|
||||
|
||||
// Session should be nil after close
|
||||
assert.Nil(t, svc.session)
|
||||
// After close, embedding should fail (model resources released)
|
||||
// Note: This behavior is model-specific; some models may still work after close
|
||||
}
|
||||
|
||||
// TestEmbedBatch_SingleItem tests batch embedding with single item.
|
||||
|
||||
@@ -0,0 +1,438 @@
|
||||
// Package pattern provides pattern detection and recognition functionality.
|
||||
package pattern
|
||||
|
||||
import (
|
||||
"context"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/lukaszraczylo/claude-mnemonic/internal/db/sqlite"
|
||||
"github.com/lukaszraczylo/claude-mnemonic/pkg/models"
|
||||
"github.com/rs/zerolog/log"
|
||||
)
|
||||
|
||||
// DetectorConfig contains configuration for the pattern detector.
|
||||
type DetectorConfig struct {
|
||||
// MinMatchScore is the minimum similarity score to consider a match (0.0-1.0).
|
||||
MinMatchScore float64
|
||||
// MinFrequencyForPattern is the minimum occurrences before creating a pattern.
|
||||
MinFrequencyForPattern int
|
||||
// AnalysisInterval is how often to run background pattern analysis.
|
||||
AnalysisInterval time.Duration
|
||||
// MaxPatternsToTrack is the maximum number of active patterns.
|
||||
MaxPatternsToTrack int
|
||||
}
|
||||
|
||||
// DefaultConfig returns the default detector configuration.
|
||||
func DefaultConfig() DetectorConfig {
|
||||
return DetectorConfig{
|
||||
MinMatchScore: 0.3, // 30% similarity threshold
|
||||
MinFrequencyForPattern: 2, // At least 2 occurrences to form a pattern
|
||||
AnalysisInterval: 5 * time.Minute,
|
||||
MaxPatternsToTrack: 1000,
|
||||
}
|
||||
}
|
||||
|
||||
// PatternSyncFunc is a callback for syncing patterns to vector store.
|
||||
type PatternSyncFunc func(pattern *models.Pattern)
|
||||
|
||||
// Detector detects and tracks recurring patterns across observations.
|
||||
type Detector struct {
|
||||
config DetectorConfig
|
||||
patternStore *sqlite.PatternStore
|
||||
observationStore *sqlite.ObservationStore
|
||||
|
||||
// Vector sync callback
|
||||
syncFunc PatternSyncFunc
|
||||
|
||||
// Candidate tracking (patterns not yet confirmed)
|
||||
candidates map[string]*candidatePattern
|
||||
candidatesMu sync.RWMutex
|
||||
|
||||
// Background analysis
|
||||
ctx context.Context
|
||||
cancel context.CancelFunc
|
||||
wg sync.WaitGroup
|
||||
}
|
||||
|
||||
// SetSyncFunc sets the callback for syncing patterns to vector store.
|
||||
func (d *Detector) SetSyncFunc(fn PatternSyncFunc) {
|
||||
d.syncFunc = fn
|
||||
}
|
||||
|
||||
// candidatePattern tracks a potential pattern before it reaches frequency threshold.
|
||||
type candidatePattern struct {
|
||||
signature []string
|
||||
observationIDs []int64
|
||||
projects []string
|
||||
patternType models.PatternType
|
||||
title string
|
||||
lastSeenEpoch int64
|
||||
}
|
||||
|
||||
// NewDetector creates a new pattern detector.
|
||||
func NewDetector(patternStore *sqlite.PatternStore, observationStore *sqlite.ObservationStore, config DetectorConfig) *Detector {
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
return &Detector{
|
||||
config: config,
|
||||
patternStore: patternStore,
|
||||
observationStore: observationStore,
|
||||
candidates: make(map[string]*candidatePattern),
|
||||
ctx: ctx,
|
||||
cancel: cancel,
|
||||
}
|
||||
}
|
||||
|
||||
// Start begins background pattern analysis.
|
||||
func (d *Detector) Start() {
|
||||
d.wg.Add(1)
|
||||
go d.backgroundAnalysis()
|
||||
log.Info().Dur("interval", d.config.AnalysisInterval).Msg("Pattern detector started")
|
||||
}
|
||||
|
||||
// Stop stops background pattern analysis.
|
||||
func (d *Detector) Stop() {
|
||||
d.cancel()
|
||||
d.wg.Wait()
|
||||
log.Info().Msg("Pattern detector stopped")
|
||||
}
|
||||
|
||||
// backgroundAnalysis runs periodic pattern analysis.
|
||||
func (d *Detector) backgroundAnalysis() {
|
||||
defer d.wg.Done()
|
||||
|
||||
ticker := time.NewTicker(d.config.AnalysisInterval)
|
||||
defer ticker.Stop()
|
||||
|
||||
for {
|
||||
select {
|
||||
case <-d.ctx.Done():
|
||||
return
|
||||
case <-ticker.C:
|
||||
if err := d.AnalyzeRecentObservations(d.ctx); err != nil {
|
||||
log.Warn().Err(err).Msg("Background pattern analysis failed")
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// AnalyzeObservation processes a new observation for pattern detection.
|
||||
// This is called synchronously when a new observation is stored.
|
||||
func (d *Detector) AnalyzeObservation(ctx context.Context, obs *models.Observation) (*DetectionResult, error) {
|
||||
result := &DetectionResult{}
|
||||
|
||||
// Extract signature from observation
|
||||
signature := models.ExtractSignature(
|
||||
obs.Concepts,
|
||||
obs.Title.String,
|
||||
obs.Narrative.String,
|
||||
)
|
||||
if len(signature) == 0 {
|
||||
return result, nil // Nothing to detect
|
||||
}
|
||||
|
||||
// Check against existing patterns
|
||||
matches, err := d.patternStore.FindMatchingPatterns(ctx, signature, d.config.MinMatchScore)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if len(matches) > 0 {
|
||||
// Found existing pattern match
|
||||
bestMatch := matches[0]
|
||||
for _, m := range matches[1:] {
|
||||
if models.CalculateMatchScore(signature, m.Signature) > models.CalculateMatchScore(signature, bestMatch.Signature) {
|
||||
bestMatch = m
|
||||
}
|
||||
}
|
||||
|
||||
// Update the pattern with new occurrence
|
||||
if err := d.patternStore.IncrementPatternFrequency(ctx, bestMatch.ID, obs.Project, obs.ID); err != nil {
|
||||
log.Warn().Err(err).Int64("pattern_id", bestMatch.ID).Msg("Failed to update pattern frequency")
|
||||
}
|
||||
|
||||
result.MatchedPattern = bestMatch
|
||||
result.MatchScore = models.CalculateMatchScore(signature, bestMatch.Signature)
|
||||
result.IsNewPattern = false
|
||||
|
||||
log.Debug().
|
||||
Int64("pattern_id", bestMatch.ID).
|
||||
Str("pattern_name", bestMatch.Name).
|
||||
Float64("score", result.MatchScore).
|
||||
Msg("Observation matched existing pattern")
|
||||
|
||||
return result, nil
|
||||
}
|
||||
|
||||
// No existing pattern match - check candidates
|
||||
candidateKey := generateCandidateKey(signature)
|
||||
d.candidatesMu.Lock()
|
||||
defer d.candidatesMu.Unlock()
|
||||
|
||||
if candidate, exists := d.candidates[candidateKey]; exists {
|
||||
// Update existing candidate
|
||||
candidate.observationIDs = append(candidate.observationIDs, obs.ID)
|
||||
|
||||
// Add project if not already present
|
||||
found := false
|
||||
for _, p := range candidate.projects {
|
||||
if p == obs.Project {
|
||||
found = true
|
||||
break
|
||||
}
|
||||
}
|
||||
if !found {
|
||||
candidate.projects = append(candidate.projects, obs.Project)
|
||||
}
|
||||
candidate.lastSeenEpoch = time.Now().UnixMilli()
|
||||
|
||||
// Check if candidate should be promoted to pattern
|
||||
if len(candidate.observationIDs) >= d.config.MinFrequencyForPattern {
|
||||
pattern, err := d.promoteCandidate(ctx, candidateKey, candidate)
|
||||
if err != nil {
|
||||
log.Warn().Err(err).Msg("Failed to promote candidate to pattern")
|
||||
} else {
|
||||
result.MatchedPattern = pattern
|
||||
result.IsNewPattern = true
|
||||
log.Info().
|
||||
Str("pattern_name", pattern.Name).
|
||||
Int("frequency", pattern.Frequency).
|
||||
Msg("New pattern detected and stored")
|
||||
}
|
||||
}
|
||||
} else {
|
||||
// Create new candidate
|
||||
patternType := models.DetectPatternType(obs.Concepts, obs.Title.String, obs.Narrative.String)
|
||||
d.candidates[candidateKey] = &candidatePattern{
|
||||
signature: signature,
|
||||
observationIDs: []int64{obs.ID},
|
||||
projects: []string{obs.Project},
|
||||
patternType: patternType,
|
||||
title: obs.Title.String,
|
||||
lastSeenEpoch: time.Now().UnixMilli(),
|
||||
}
|
||||
log.Debug().
|
||||
Str("candidate_key", candidateKey).
|
||||
Strs("signature", signature).
|
||||
Msg("New pattern candidate created")
|
||||
}
|
||||
|
||||
return result, nil
|
||||
}
|
||||
|
||||
// promoteCandidate converts a candidate to a stored pattern.
|
||||
func (d *Detector) promoteCandidate(ctx context.Context, key string, candidate *candidatePattern) (*models.Pattern, error) {
|
||||
// Generate pattern name from signature
|
||||
name := generatePatternName(candidate.patternType, candidate.signature, candidate.title)
|
||||
|
||||
// Create base pattern using NewPattern with first observation
|
||||
firstProject := ""
|
||||
if len(candidate.projects) > 0 {
|
||||
firstProject = candidate.projects[0]
|
||||
}
|
||||
var firstObsID int64
|
||||
if len(candidate.observationIDs) > 0 {
|
||||
firstObsID = candidate.observationIDs[0]
|
||||
}
|
||||
pattern := models.NewPattern(
|
||||
name,
|
||||
candidate.patternType,
|
||||
"Automatically detected pattern from recurring observations",
|
||||
candidate.signature,
|
||||
firstProject,
|
||||
firstObsID,
|
||||
)
|
||||
|
||||
// Add remaining projects and observations
|
||||
for i := 1; i < len(candidate.projects); i++ {
|
||||
pattern.Projects = append(pattern.Projects, candidate.projects[i])
|
||||
}
|
||||
for i := 1; i < len(candidate.observationIDs); i++ {
|
||||
pattern.ObservationIDs = append(pattern.ObservationIDs, candidate.observationIDs[i])
|
||||
}
|
||||
pattern.Frequency = len(candidate.observationIDs)
|
||||
|
||||
id, err := d.patternStore.StorePattern(ctx, pattern)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
pattern.ID = id
|
||||
|
||||
// Sync to vector store if callback is set
|
||||
if d.syncFunc != nil {
|
||||
d.syncFunc(pattern)
|
||||
}
|
||||
|
||||
// Remove from candidates
|
||||
delete(d.candidates, key)
|
||||
|
||||
return pattern, nil
|
||||
}
|
||||
|
||||
// AnalyzeRecentObservations analyzes recent observations for pattern detection.
|
||||
// This is used for background batch analysis.
|
||||
func (d *Detector) AnalyzeRecentObservations(ctx context.Context) error {
|
||||
// Get observations from the last 24 hours that haven't been analyzed
|
||||
observations, err := d.observationStore.GetRecentObservations(ctx, "", 100)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
analyzed := 0
|
||||
patternsFound := 0
|
||||
for _, obs := range observations {
|
||||
result, err := d.AnalyzeObservation(ctx, obs)
|
||||
if err != nil {
|
||||
log.Warn().Err(err).Int64("obs_id", obs.ID).Msg("Failed to analyze observation")
|
||||
continue
|
||||
}
|
||||
analyzed++
|
||||
if result.MatchedPattern != nil {
|
||||
patternsFound++
|
||||
}
|
||||
}
|
||||
|
||||
if analyzed > 0 {
|
||||
log.Info().
|
||||
Int("analyzed", analyzed).
|
||||
Int("patterns_found", patternsFound).
|
||||
Msg("Background pattern analysis completed")
|
||||
}
|
||||
|
||||
// Clean up old candidates (older than 7 days)
|
||||
d.cleanupOldCandidates()
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// cleanupOldCandidates removes candidates that haven't been seen recently.
|
||||
func (d *Detector) cleanupOldCandidates() {
|
||||
d.candidatesMu.Lock()
|
||||
defer d.candidatesMu.Unlock()
|
||||
|
||||
threshold := time.Now().Add(-7 * 24 * time.Hour).UnixMilli()
|
||||
for key, candidate := range d.candidates {
|
||||
if candidate.lastSeenEpoch < threshold {
|
||||
delete(d.candidates, key)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// GetPatternInsight returns a formatted insight string for a pattern.
|
||||
func (d *Detector) GetPatternInsight(ctx context.Context, patternID int64) (string, error) {
|
||||
pattern, err := d.patternStore.GetPatternByID(ctx, patternID)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
return formatPatternInsight(pattern), nil
|
||||
}
|
||||
|
||||
// DetectionResult contains the result of pattern detection.
|
||||
type DetectionResult struct {
|
||||
MatchedPattern *models.Pattern
|
||||
MatchScore float64
|
||||
IsNewPattern bool
|
||||
}
|
||||
|
||||
// generateCandidateKey creates a unique key for a signature.
|
||||
func generateCandidateKey(signature []string) string {
|
||||
if len(signature) == 0 {
|
||||
return ""
|
||||
}
|
||||
key := ""
|
||||
for _, s := range signature {
|
||||
key += s + "|"
|
||||
}
|
||||
return key
|
||||
}
|
||||
|
||||
// generatePatternName creates a human-readable name for a pattern.
|
||||
func generatePatternName(patternType models.PatternType, signature []string, title string) string {
|
||||
// Use title if available and meaningful
|
||||
if title != "" && len(title) < 60 {
|
||||
return title
|
||||
}
|
||||
|
||||
// Otherwise generate from type and signature
|
||||
prefix := ""
|
||||
switch patternType {
|
||||
case models.PatternTypeBug:
|
||||
prefix = "Bug Pattern: "
|
||||
case models.PatternTypeRefactor:
|
||||
prefix = "Refactor Pattern: "
|
||||
case models.PatternTypeArchitecture:
|
||||
prefix = "Architecture Pattern: "
|
||||
case models.PatternTypeAntiPattern:
|
||||
prefix = "Anti-Pattern: "
|
||||
case models.PatternTypeBestPractice:
|
||||
prefix = "Best Practice: "
|
||||
}
|
||||
|
||||
// Use first few signature elements
|
||||
if len(signature) > 0 {
|
||||
name := prefix
|
||||
for i, s := range signature {
|
||||
if i >= 3 {
|
||||
break
|
||||
}
|
||||
if i > 0 {
|
||||
name += ", "
|
||||
}
|
||||
name += s
|
||||
}
|
||||
return name
|
||||
}
|
||||
|
||||
return prefix + "Unnamed"
|
||||
}
|
||||
|
||||
// formatPatternInsight creates a human-readable insight from a pattern.
|
||||
func formatPatternInsight(pattern *models.Pattern) string {
|
||||
insight := "I've encountered this pattern " +
|
||||
itoa(pattern.Frequency) + " times"
|
||||
|
||||
if len(pattern.Projects) > 1 {
|
||||
insight += " across " + itoa(len(pattern.Projects)) + " projects"
|
||||
}
|
||||
|
||||
insight += ". "
|
||||
|
||||
if pattern.Recommendation.Valid && pattern.Recommendation.String != "" {
|
||||
insight += "What works: " + pattern.Recommendation.String
|
||||
} else {
|
||||
switch pattern.Type {
|
||||
case models.PatternTypeBug:
|
||||
insight += "This appears to be a recurring bug pattern."
|
||||
case models.PatternTypeAntiPattern:
|
||||
insight += "This is an identified anti-pattern to avoid."
|
||||
case models.PatternTypeBestPractice:
|
||||
insight += "This is a validated best practice."
|
||||
default:
|
||||
insight += "This is a recognized pattern in the codebase."
|
||||
}
|
||||
}
|
||||
|
||||
return insight
|
||||
}
|
||||
|
||||
// itoa converts int to string without importing strconv.
|
||||
func itoa(n int) string {
|
||||
if n == 0 {
|
||||
return "0"
|
||||
}
|
||||
negative := false
|
||||
if n < 0 {
|
||||
negative = true
|
||||
n = -n
|
||||
}
|
||||
var digits []byte
|
||||
for n > 0 {
|
||||
digits = append([]byte{byte('0' + n%10)}, digits...)
|
||||
n /= 10
|
||||
}
|
||||
if negative {
|
||||
digits = append([]byte{'-'}, digits...)
|
||||
}
|
||||
return string(digits)
|
||||
}
|
||||
@@ -0,0 +1,450 @@
|
||||
package pattern
|
||||
|
||||
import (
|
||||
"context"
|
||||
"database/sql"
|
||||
"os"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/lukaszraczylo/claude-mnemonic/internal/db/sqlite"
|
||||
"github.com/lukaszraczylo/claude-mnemonic/pkg/models"
|
||||
)
|
||||
|
||||
func TestNewDetector(t *testing.T) {
|
||||
store := setupTestStore(t)
|
||||
defer store.Close()
|
||||
|
||||
patternStore := sqlite.NewPatternStore(store)
|
||||
observationStore := sqlite.NewObservationStore(store)
|
||||
config := DefaultConfig()
|
||||
|
||||
detector := NewDetector(patternStore, observationStore, config)
|
||||
if detector == nil {
|
||||
t.Fatal("Expected non-nil detector")
|
||||
}
|
||||
|
||||
if detector.config.MinMatchScore != config.MinMatchScore {
|
||||
t.Errorf("Expected MinMatchScore %f, got %f",
|
||||
config.MinMatchScore, detector.config.MinMatchScore)
|
||||
}
|
||||
}
|
||||
|
||||
func TestDetector_StartStop(t *testing.T) {
|
||||
store := setupTestStore(t)
|
||||
defer store.Close()
|
||||
|
||||
patternStore := sqlite.NewPatternStore(store)
|
||||
observationStore := sqlite.NewObservationStore(store)
|
||||
config := DefaultConfig()
|
||||
config.AnalysisInterval = 100 * time.Millisecond // Short interval for testing
|
||||
|
||||
detector := NewDetector(patternStore, observationStore, config)
|
||||
|
||||
// Start
|
||||
detector.Start()
|
||||
|
||||
// Wait a bit to ensure background goroutine is running
|
||||
time.Sleep(50 * time.Millisecond)
|
||||
|
||||
// Stop
|
||||
detector.Stop()
|
||||
|
||||
// Verify we can stop without hanging
|
||||
// (if this test hangs, the Stop logic is broken)
|
||||
}
|
||||
|
||||
func TestDetector_AnalyzeObservation_NewCandidate(t *testing.T) {
|
||||
store := setupTestStore(t)
|
||||
defer store.Close()
|
||||
|
||||
patternStore := sqlite.NewPatternStore(store)
|
||||
observationStore := sqlite.NewObservationStore(store)
|
||||
config := DefaultConfig()
|
||||
config.MinFrequencyForPattern = 2
|
||||
|
||||
detector := NewDetector(patternStore, observationStore, config)
|
||||
ctx := context.Background()
|
||||
|
||||
// Create an observation
|
||||
obs := createTestObservation(1, "nil-handling", []string{"nil", "error-handling"})
|
||||
|
||||
// Analyze first observation
|
||||
result, err := detector.AnalyzeObservation(ctx, obs)
|
||||
if err != nil {
|
||||
t.Fatalf("AnalyzeObservation() error = %v", err)
|
||||
}
|
||||
|
||||
// First observation should create a candidate, not a pattern
|
||||
if result.MatchedPattern != nil {
|
||||
t.Errorf("Expected no pattern match for first observation")
|
||||
}
|
||||
if result.IsNewPattern {
|
||||
t.Errorf("Expected IsNewPattern to be false for first observation")
|
||||
}
|
||||
}
|
||||
|
||||
func TestDetector_AnalyzeObservation_PromoteToPattern(t *testing.T) {
|
||||
store := setupTestStore(t)
|
||||
defer store.Close()
|
||||
|
||||
patternStore := sqlite.NewPatternStore(store)
|
||||
observationStore := sqlite.NewObservationStore(store)
|
||||
config := DefaultConfig()
|
||||
config.MinFrequencyForPattern = 2
|
||||
|
||||
detector := NewDetector(patternStore, observationStore, config)
|
||||
ctx := context.Background()
|
||||
|
||||
// Create two similar observations
|
||||
obs1 := createTestObservation(1, "Nil pointer handling", []string{"nil", "error-handling"})
|
||||
obs2 := createTestObservation(2, "Nil pointer handling", []string{"nil", "error-handling"})
|
||||
|
||||
// Analyze first observation
|
||||
_, err := detector.AnalyzeObservation(ctx, obs1)
|
||||
if err != nil {
|
||||
t.Fatalf("AnalyzeObservation(obs1) error = %v", err)
|
||||
}
|
||||
|
||||
// Analyze second observation - should promote to pattern
|
||||
result, err := detector.AnalyzeObservation(ctx, obs2)
|
||||
if err != nil {
|
||||
t.Fatalf("AnalyzeObservation(obs2) error = %v", err)
|
||||
}
|
||||
|
||||
if result.MatchedPattern == nil {
|
||||
t.Fatal("Expected pattern to be created after second occurrence")
|
||||
}
|
||||
if !result.IsNewPattern {
|
||||
t.Errorf("Expected IsNewPattern to be true for newly promoted pattern")
|
||||
}
|
||||
if result.MatchedPattern.Frequency != 2 {
|
||||
t.Errorf("Expected frequency 2, got %d", result.MatchedPattern.Frequency)
|
||||
}
|
||||
}
|
||||
|
||||
func TestDetector_AnalyzeObservation_MatchExisting(t *testing.T) {
|
||||
store := setupTestStore(t)
|
||||
defer store.Close()
|
||||
|
||||
patternStore := sqlite.NewPatternStore(store)
|
||||
observationStore := sqlite.NewObservationStore(store)
|
||||
config := DefaultConfig()
|
||||
|
||||
detector := NewDetector(patternStore, observationStore, config)
|
||||
ctx := context.Background()
|
||||
|
||||
// Create an existing pattern
|
||||
pattern := &models.Pattern{
|
||||
Name: "Existing Pattern",
|
||||
Type: models.PatternTypeBug,
|
||||
Signature: []string{"nil", "error-handling", "pointer"},
|
||||
Frequency: 5,
|
||||
Projects: []string{"proj1"},
|
||||
ObservationIDs: []int64{1, 2, 3, 4, 5},
|
||||
Status: models.PatternStatusActive,
|
||||
Confidence: 0.7,
|
||||
LastSeenAt: time.Now().Format(time.RFC3339),
|
||||
LastSeenEpoch: time.Now().UnixMilli(),
|
||||
CreatedAt: time.Now().Format(time.RFC3339),
|
||||
CreatedAtEpoch: time.Now().UnixMilli(),
|
||||
}
|
||||
patternStore.StorePattern(ctx, pattern)
|
||||
|
||||
// Create observation with similar signature
|
||||
obs := createTestObservation(10, "Nil check", []string{"nil", "error-handling"})
|
||||
|
||||
// Analyze - should match existing pattern
|
||||
result, err := detector.AnalyzeObservation(ctx, obs)
|
||||
if err != nil {
|
||||
t.Fatalf("AnalyzeObservation() error = %v", err)
|
||||
}
|
||||
|
||||
if result.MatchedPattern == nil {
|
||||
t.Fatal("Expected to match existing pattern")
|
||||
}
|
||||
if result.IsNewPattern {
|
||||
t.Errorf("Expected IsNewPattern to be false for existing pattern")
|
||||
}
|
||||
if result.MatchScore < 0.3 {
|
||||
t.Errorf("Expected match score >= 0.3, got %f", result.MatchScore)
|
||||
}
|
||||
}
|
||||
|
||||
func TestDetector_AnalyzeObservation_NoMatch(t *testing.T) {
|
||||
store := setupTestStore(t)
|
||||
defer store.Close()
|
||||
|
||||
patternStore := sqlite.NewPatternStore(store)
|
||||
observationStore := sqlite.NewObservationStore(store)
|
||||
config := DefaultConfig()
|
||||
config.MinMatchScore = 0.5 // Higher threshold
|
||||
|
||||
detector := NewDetector(patternStore, observationStore, config)
|
||||
ctx := context.Background()
|
||||
|
||||
// Create an existing pattern with specific signature
|
||||
pattern := &models.Pattern{
|
||||
Name: "Specific Pattern",
|
||||
Type: models.PatternTypeBug,
|
||||
Signature: []string{"database", "connection", "pool"},
|
||||
Frequency: 3,
|
||||
Projects: []string{"proj1"},
|
||||
ObservationIDs: []int64{1, 2, 3},
|
||||
Status: models.PatternStatusActive,
|
||||
Confidence: 0.6,
|
||||
LastSeenAt: time.Now().Format(time.RFC3339),
|
||||
LastSeenEpoch: time.Now().UnixMilli(),
|
||||
CreatedAt: time.Now().Format(time.RFC3339),
|
||||
CreatedAtEpoch: time.Now().UnixMilli(),
|
||||
}
|
||||
patternStore.StorePattern(ctx, pattern)
|
||||
|
||||
// Create observation with completely different signature
|
||||
obs := createTestObservation(10, "UI Component", []string{"frontend", "react", "component"})
|
||||
|
||||
// Analyze - should not match
|
||||
result, err := detector.AnalyzeObservation(ctx, obs)
|
||||
if err != nil {
|
||||
t.Fatalf("AnalyzeObservation() error = %v", err)
|
||||
}
|
||||
|
||||
if result.MatchedPattern != nil {
|
||||
t.Errorf("Expected no match for unrelated observation")
|
||||
}
|
||||
}
|
||||
|
||||
func TestDetector_CandidateCleanup(t *testing.T) {
|
||||
store := setupTestStore(t)
|
||||
defer store.Close()
|
||||
|
||||
patternStore := sqlite.NewPatternStore(store)
|
||||
observationStore := sqlite.NewObservationStore(store)
|
||||
config := DefaultConfig()
|
||||
config.MinFrequencyForPattern = 3 // Higher threshold
|
||||
|
||||
detector := NewDetector(patternStore, observationStore, config)
|
||||
|
||||
// Add an old candidate manually
|
||||
oldKey := "old|candidate|"
|
||||
detector.candidates[oldKey] = &candidatePattern{
|
||||
signature: []string{"old", "candidate"},
|
||||
observationIDs: []int64{1},
|
||||
projects: []string{"proj1"},
|
||||
patternType: models.PatternTypeBug,
|
||||
title: "Old Candidate",
|
||||
lastSeenEpoch: time.Now().Add(-8 * 24 * time.Hour).UnixMilli(), // 8 days ago
|
||||
}
|
||||
|
||||
// Add a recent candidate
|
||||
recentKey := "recent|candidate|"
|
||||
detector.candidates[recentKey] = &candidatePattern{
|
||||
signature: []string{"recent", "candidate"},
|
||||
observationIDs: []int64{2},
|
||||
projects: []string{"proj1"},
|
||||
patternType: models.PatternTypeBug,
|
||||
title: "Recent Candidate",
|
||||
lastSeenEpoch: time.Now().UnixMilli(),
|
||||
}
|
||||
|
||||
// Run cleanup
|
||||
detector.cleanupOldCandidates()
|
||||
|
||||
// Old candidate should be removed
|
||||
if _, exists := detector.candidates[oldKey]; exists {
|
||||
t.Errorf("Expected old candidate to be cleaned up")
|
||||
}
|
||||
|
||||
// Recent candidate should remain
|
||||
if _, exists := detector.candidates[recentKey]; !exists {
|
||||
t.Errorf("Expected recent candidate to remain")
|
||||
}
|
||||
}
|
||||
|
||||
func TestDetector_GetPatternInsight(t *testing.T) {
|
||||
store := setupTestStore(t)
|
||||
defer store.Close()
|
||||
|
||||
patternStore := sqlite.NewPatternStore(store)
|
||||
observationStore := sqlite.NewObservationStore(store)
|
||||
config := DefaultConfig()
|
||||
|
||||
detector := NewDetector(patternStore, observationStore, config)
|
||||
ctx := context.Background()
|
||||
|
||||
// Create pattern with recommendation
|
||||
pattern := &models.Pattern{
|
||||
Name: "Insight Test Pattern",
|
||||
Type: models.PatternTypeBestPractice,
|
||||
Signature: []string{"test"},
|
||||
Recommendation: sql.NullString{String: "Always write tests first", Valid: true},
|
||||
Frequency: 12,
|
||||
Projects: []string{"proj1", "proj2", "proj3"},
|
||||
ObservationIDs: []int64{1},
|
||||
Status: models.PatternStatusActive,
|
||||
Confidence: 0.8,
|
||||
LastSeenAt: time.Now().Format(time.RFC3339),
|
||||
LastSeenEpoch: time.Now().UnixMilli(),
|
||||
CreatedAt: time.Now().Format(time.RFC3339),
|
||||
CreatedAtEpoch: time.Now().UnixMilli(),
|
||||
}
|
||||
id, _ := patternStore.StorePattern(ctx, pattern)
|
||||
|
||||
// Get insight
|
||||
insight, err := detector.GetPatternInsight(ctx, id)
|
||||
if err != nil {
|
||||
t.Fatalf("GetPatternInsight() error = %v", err)
|
||||
}
|
||||
|
||||
// Verify insight contains expected elements
|
||||
if insight == "" {
|
||||
t.Error("Expected non-empty insight")
|
||||
}
|
||||
if !containsString(insight, "12 times") {
|
||||
t.Errorf("Expected insight to contain frequency, got: %s", insight)
|
||||
}
|
||||
if !containsString(insight, "3 projects") {
|
||||
t.Errorf("Expected insight to contain project count, got: %s", insight)
|
||||
}
|
||||
if !containsString(insight, "Always write tests first") {
|
||||
t.Errorf("Expected insight to contain recommendation, got: %s", insight)
|
||||
}
|
||||
}
|
||||
|
||||
func TestDefaultConfig(t *testing.T) {
|
||||
config := DefaultConfig()
|
||||
|
||||
if config.MinMatchScore <= 0 || config.MinMatchScore > 1 {
|
||||
t.Errorf("Invalid MinMatchScore: %f", config.MinMatchScore)
|
||||
}
|
||||
if config.MinFrequencyForPattern < 1 {
|
||||
t.Errorf("Invalid MinFrequencyForPattern: %d", config.MinFrequencyForPattern)
|
||||
}
|
||||
if config.AnalysisInterval <= 0 {
|
||||
t.Errorf("Invalid AnalysisInterval: %v", config.AnalysisInterval)
|
||||
}
|
||||
if config.MaxPatternsToTrack <= 0 {
|
||||
t.Errorf("Invalid MaxPatternsToTrack: %d", config.MaxPatternsToTrack)
|
||||
}
|
||||
}
|
||||
|
||||
func TestGeneratePatternName(t *testing.T) {
|
||||
tests := []struct {
|
||||
patternType models.PatternType
|
||||
signature []string
|
||||
title string
|
||||
wantPrefix string
|
||||
}{
|
||||
{models.PatternTypeBug, []string{"nil", "error"}, "", "Bug Pattern:"},
|
||||
{models.PatternTypeRefactor, []string{"extract"}, "", "Refactor Pattern:"},
|
||||
{models.PatternTypeArchitecture, []string{"service"}, "", "Architecture Pattern:"},
|
||||
{models.PatternTypeAntiPattern, []string{"god-class"}, "", "Anti-Pattern:"},
|
||||
{models.PatternTypeBestPractice, []string{"testing"}, "", "Best Practice:"},
|
||||
{models.PatternTypeBug, []string{}, "Short Title", "Short Title"}, // Use title directly
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
name := generatePatternName(tt.patternType, tt.signature, tt.title)
|
||||
if !hasPrefix(name, tt.wantPrefix) {
|
||||
t.Errorf("generatePatternName(%v, %v, %q) = %q, want prefix %q",
|
||||
tt.patternType, tt.signature, tt.title, name, tt.wantPrefix)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestFormatPatternInsight(t *testing.T) {
|
||||
// Pattern without recommendation
|
||||
pattern1 := &models.Pattern{
|
||||
Type: models.PatternTypeBug,
|
||||
Frequency: 5,
|
||||
Projects: []string{"proj1"},
|
||||
}
|
||||
insight1 := formatPatternInsight(pattern1)
|
||||
if !containsString(insight1, "5 times") {
|
||||
t.Errorf("Expected insight to contain frequency")
|
||||
}
|
||||
if !containsString(insight1, "recurring bug pattern") {
|
||||
t.Errorf("Expected bug pattern description")
|
||||
}
|
||||
|
||||
// Pattern with recommendation
|
||||
pattern2 := &models.Pattern{
|
||||
Type: models.PatternTypeBestPractice,
|
||||
Frequency: 10,
|
||||
Projects: []string{"proj1", "proj2"},
|
||||
Recommendation: sql.NullString{String: "Do this", Valid: true},
|
||||
}
|
||||
insight2 := formatPatternInsight(pattern2)
|
||||
if !containsString(insight2, "10 times") {
|
||||
t.Errorf("Expected insight to contain frequency")
|
||||
}
|
||||
if !containsString(insight2, "2 projects") {
|
||||
t.Errorf("Expected insight to contain project count")
|
||||
}
|
||||
if !containsString(insight2, "Do this") {
|
||||
t.Errorf("Expected insight to contain recommendation")
|
||||
}
|
||||
}
|
||||
|
||||
// Helper functions
|
||||
|
||||
func setupTestStore(t *testing.T) *sqlite.Store {
|
||||
t.Helper()
|
||||
|
||||
// Create temp database file
|
||||
tmpFile, err := os.CreateTemp("", "pattern_test_*.db")
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to create temp file: %v", err)
|
||||
}
|
||||
tmpFile.Close()
|
||||
|
||||
t.Cleanup(func() {
|
||||
os.Remove(tmpFile.Name())
|
||||
})
|
||||
|
||||
store, err := sqlite.NewStore(sqlite.StoreConfig{
|
||||
Path: tmpFile.Name(),
|
||||
MaxConns: 1,
|
||||
WALMode: true,
|
||||
})
|
||||
if err != nil {
|
||||
// Check if this is an FTS5 related error
|
||||
if containsString(err.Error(), "fts5") || containsString(err.Error(), "no such module") {
|
||||
t.Skip("FTS5 not available in this SQLite build")
|
||||
}
|
||||
t.Fatalf("Failed to create store: %v", err)
|
||||
}
|
||||
|
||||
return store
|
||||
}
|
||||
|
||||
func createTestObservation(id int64, title string, concepts []string) *models.Observation {
|
||||
return &models.Observation{
|
||||
ID: id,
|
||||
SDKSessionID: "test-session",
|
||||
Project: "test-project",
|
||||
Scope: models.ScopeProject,
|
||||
Type: models.ObsTypeBugfix,
|
||||
Title: sql.NullString{String: title, Valid: true},
|
||||
Narrative: sql.NullString{String: "Test narrative", Valid: true},
|
||||
Concepts: concepts,
|
||||
CreatedAt: time.Now().Format(time.RFC3339),
|
||||
CreatedAtEpoch: time.Now().UnixMilli(),
|
||||
}
|
||||
}
|
||||
|
||||
func containsString(s, substr string) bool {
|
||||
for i := 0; i <= len(s)-len(substr); i++ {
|
||||
if s[i:i+len(substr)] == substr {
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
func hasPrefix(s, prefix string) bool {
|
||||
if len(s) < len(prefix) {
|
||||
return false
|
||||
}
|
||||
return s[:len(prefix)] == prefix
|
||||
}
|
||||
@@ -0,0 +1,14 @@
|
||||
// Package reranking provides cross-encoder reranking for search results.
|
||||
package reranking
|
||||
|
||||
import (
|
||||
_ "embed"
|
||||
)
|
||||
|
||||
// Cross-encoder model and tokenizer files - embedded for all platforms
|
||||
//
|
||||
//go:embed assets/model.onnx
|
||||
var crossEncoderModelData []byte
|
||||
|
||||
//go:embed assets/tokenizer.json
|
||||
var crossEncoderTokenizerData []byte
|
||||
Binary file not shown.
File diff suppressed because it is too large
Load Diff
@@ -0,0 +1,57 @@
|
||||
{
|
||||
"added_tokens_decoder": {
|
||||
"0": {
|
||||
"content": "[PAD]",
|
||||
"lstrip": false,
|
||||
"normalized": false,
|
||||
"rstrip": false,
|
||||
"single_word": false,
|
||||
"special": true
|
||||
},
|
||||
"100": {
|
||||
"content": "[UNK]",
|
||||
"lstrip": false,
|
||||
"normalized": false,
|
||||
"rstrip": false,
|
||||
"single_word": false,
|
||||
"special": true
|
||||
},
|
||||
"101": {
|
||||
"content": "[CLS]",
|
||||
"lstrip": false,
|
||||
"normalized": false,
|
||||
"rstrip": false,
|
||||
"single_word": false,
|
||||
"special": true
|
||||
},
|
||||
"102": {
|
||||
"content": "[SEP]",
|
||||
"lstrip": false,
|
||||
"normalized": false,
|
||||
"rstrip": false,
|
||||
"single_word": false,
|
||||
"special": true
|
||||
},
|
||||
"103": {
|
||||
"content": "[MASK]",
|
||||
"lstrip": false,
|
||||
"normalized": false,
|
||||
"rstrip": false,
|
||||
"single_word": false,
|
||||
"special": true
|
||||
}
|
||||
},
|
||||
"clean_up_tokenization_spaces": true,
|
||||
"cls_token": "[CLS]",
|
||||
"do_basic_tokenize": true,
|
||||
"do_lower_case": true,
|
||||
"mask_token": "[MASK]",
|
||||
"model_max_length": 512,
|
||||
"never_split": null,
|
||||
"pad_token": "[PAD]",
|
||||
"sep_token": "[SEP]",
|
||||
"strip_accents": null,
|
||||
"tokenize_chinese_chars": true,
|
||||
"tokenizer_class": "BertTokenizer",
|
||||
"unk_token": "[UNK]"
|
||||
}
|
||||
@@ -0,0 +1,380 @@
|
||||
// Package reranking provides cross-encoder reranking for search results.
|
||||
// Uses MS-MARCO MiniLM L6 v2 cross-encoder model for relevance scoring.
|
||||
package reranking
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"fmt"
|
||||
"math"
|
||||
"sort"
|
||||
"sync"
|
||||
|
||||
"github.com/rs/zerolog/log"
|
||||
"github.com/sugarme/tokenizer"
|
||||
"github.com/sugarme/tokenizer/pretrained"
|
||||
ort "github.com/yalue/onnxruntime_go"
|
||||
)
|
||||
|
||||
const (
|
||||
// ModelName is the human-readable name for the cross-encoder model
|
||||
ModelName = "ms-marco-MiniLM-L6-v2"
|
||||
// ModelVersion is the short version string for identification
|
||||
ModelVersion = "msmarco-v2"
|
||||
// MaxSequenceLength is the maximum combined query+document token length
|
||||
MaxSequenceLength = 512
|
||||
// DefaultCandidateLimit is the default number of candidates to rerank
|
||||
DefaultCandidateLimit = 100
|
||||
// DefaultResultLimit is the default number of results to return after reranking
|
||||
DefaultResultLimit = 10
|
||||
)
|
||||
|
||||
// Candidate represents a search result candidate for reranking.
|
||||
type Candidate struct {
|
||||
ID string // Document ID
|
||||
Content string // Document text content for scoring
|
||||
Score float64 // Original bi-encoder similarity score
|
||||
Metadata map[string]any // Preserved metadata
|
||||
RerankInfo map[string]float64 // Reranking debug info (optional)
|
||||
}
|
||||
|
||||
// RerankResult represents a reranked search result.
|
||||
type RerankResult struct {
|
||||
ID string // Document ID
|
||||
Content string // Document text content
|
||||
OriginalScore float64 // Original bi-encoder score
|
||||
RerankScore float64 // Cross-encoder relevance score
|
||||
CombinedScore float64 // Weighted combination of scores
|
||||
Metadata map[string]any // Preserved metadata
|
||||
OriginalRank int // Position before reranking (1-indexed)
|
||||
RerankRank int // Position after reranking (1-indexed)
|
||||
RankImprovement int // How much the rank improved (positive = moved up)
|
||||
}
|
||||
|
||||
// Service provides cross-encoder reranking functionality.
|
||||
type Service struct {
|
||||
tk *tokenizer.Tokenizer
|
||||
session *ort.DynamicAdvancedSession
|
||||
mu sync.Mutex
|
||||
|
||||
// Weight for combining scores: combined = alpha*rerank + (1-alpha)*original
|
||||
// Default 0.7 favors cross-encoder score
|
||||
Alpha float64
|
||||
}
|
||||
|
||||
// Config holds configuration for the reranking service.
|
||||
type Config struct {
|
||||
// Alpha is the weight for combining scores (0.0-1.0)
|
||||
// Higher values favor cross-encoder scores, lower values favor bi-encoder scores
|
||||
Alpha float64
|
||||
}
|
||||
|
||||
// DefaultConfig returns sensible defaults for reranking.
|
||||
func DefaultConfig() Config {
|
||||
return Config{
|
||||
Alpha: 0.7, // Favor cross-encoder by default
|
||||
}
|
||||
}
|
||||
|
||||
// NewService creates a new cross-encoder reranking service.
|
||||
// Note: ONNX runtime must be initialized before calling this (via embedding.NewService).
|
||||
func NewService(cfg Config) (*Service, error) {
|
||||
// Load tokenizer from embedded data
|
||||
tk, err := pretrained.FromReader(bytes.NewReader(crossEncoderTokenizerData))
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("load cross-encoder tokenizer: %w", err)
|
||||
}
|
||||
|
||||
// Configure tokenizer for sequence classification (pairs)
|
||||
tk.WithTruncation(&tokenizer.TruncationParams{
|
||||
MaxLength: MaxSequenceLength,
|
||||
Strategy: tokenizer.LongestFirst,
|
||||
Stride: 0,
|
||||
})
|
||||
|
||||
// Cross-encoder outputs a single logit for relevance scoring
|
||||
inputNames := []string{"input_ids", "attention_mask", "token_type_ids"}
|
||||
outputNames := []string{"logits"}
|
||||
|
||||
session, err := ort.NewDynamicAdvancedSessionWithONNXData(
|
||||
crossEncoderModelData,
|
||||
inputNames,
|
||||
outputNames,
|
||||
nil,
|
||||
)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("create cross-encoder ONNX session: %w", err)
|
||||
}
|
||||
|
||||
alpha := cfg.Alpha
|
||||
if alpha <= 0 || alpha > 1 {
|
||||
alpha = 0.7
|
||||
}
|
||||
|
||||
return &Service{
|
||||
tk: tk,
|
||||
session: session,
|
||||
Alpha: alpha,
|
||||
}, nil
|
||||
}
|
||||
|
||||
// Rerank reranks candidates using the cross-encoder model.
|
||||
// Takes a query and list of candidates, returns reranked results.
|
||||
func (s *Service) Rerank(query string, candidates []Candidate, limit int) ([]RerankResult, error) {
|
||||
if len(candidates) == 0 {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
if limit <= 0 {
|
||||
limit = DefaultResultLimit
|
||||
}
|
||||
|
||||
s.mu.Lock()
|
||||
defer s.mu.Unlock()
|
||||
|
||||
// Score all query-document pairs
|
||||
scores, err := s.scoreAll(query, candidates)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("score candidates: %w", err)
|
||||
}
|
||||
|
||||
// Build results with combined scores
|
||||
results := make([]RerankResult, len(candidates))
|
||||
for i, c := range candidates {
|
||||
// Normalize cross-encoder score to 0-1 range using sigmoid
|
||||
normalizedRerank := sigmoid(scores[i])
|
||||
|
||||
results[i] = RerankResult{
|
||||
ID: c.ID,
|
||||
Content: c.Content,
|
||||
OriginalScore: c.Score,
|
||||
RerankScore: normalizedRerank,
|
||||
CombinedScore: s.Alpha*normalizedRerank + (1-s.Alpha)*c.Score,
|
||||
Metadata: c.Metadata,
|
||||
OriginalRank: i + 1,
|
||||
}
|
||||
}
|
||||
|
||||
// Sort by combined score (descending)
|
||||
sort.Slice(results, func(i, j int) bool {
|
||||
return results[i].CombinedScore > results[j].CombinedScore
|
||||
})
|
||||
|
||||
// Assign rerank positions and calculate improvement
|
||||
for i := range results {
|
||||
results[i].RerankRank = i + 1
|
||||
results[i].RankImprovement = results[i].OriginalRank - results[i].RerankRank
|
||||
}
|
||||
|
||||
// Apply limit
|
||||
if len(results) > limit {
|
||||
results = results[:limit]
|
||||
}
|
||||
|
||||
log.Debug().
|
||||
Int("candidates", len(candidates)).
|
||||
Int("returned", len(results)).
|
||||
Float64("alpha", s.Alpha).
|
||||
Msg("Cross-encoder reranking completed")
|
||||
|
||||
return results, nil
|
||||
}
|
||||
|
||||
// RerankByScore reranks candidates and returns sorted by pure cross-encoder score.
|
||||
// Useful when you want to completely replace bi-encoder ranking.
|
||||
func (s *Service) RerankByScore(query string, candidates []Candidate, limit int) ([]RerankResult, error) {
|
||||
if len(candidates) == 0 {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
if limit <= 0 {
|
||||
limit = DefaultResultLimit
|
||||
}
|
||||
|
||||
s.mu.Lock()
|
||||
defer s.mu.Unlock()
|
||||
|
||||
scores, err := s.scoreAll(query, candidates)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("score candidates: %w", err)
|
||||
}
|
||||
|
||||
results := make([]RerankResult, len(candidates))
|
||||
for i, c := range candidates {
|
||||
normalizedRerank := sigmoid(scores[i])
|
||||
results[i] = RerankResult{
|
||||
ID: c.ID,
|
||||
Content: c.Content,
|
||||
OriginalScore: c.Score,
|
||||
RerankScore: normalizedRerank,
|
||||
CombinedScore: normalizedRerank, // Use pure rerank score
|
||||
Metadata: c.Metadata,
|
||||
OriginalRank: i + 1,
|
||||
}
|
||||
}
|
||||
|
||||
// Sort by rerank score only
|
||||
sort.Slice(results, func(i, j int) bool {
|
||||
return results[i].RerankScore > results[j].RerankScore
|
||||
})
|
||||
|
||||
for i := range results {
|
||||
results[i].RerankRank = i + 1
|
||||
results[i].RankImprovement = results[i].OriginalRank - results[i].RerankRank
|
||||
}
|
||||
|
||||
if len(results) > limit {
|
||||
results = results[:limit]
|
||||
}
|
||||
|
||||
return results, nil
|
||||
}
|
||||
|
||||
// scoreAll scores all query-document pairs using the cross-encoder.
|
||||
// Returns raw logits (before sigmoid normalization).
|
||||
func (s *Service) scoreAll(query string, candidates []Candidate) ([]float64, error) {
|
||||
batchSize := len(candidates)
|
||||
|
||||
// Tokenize all query-document pairs
|
||||
pairs := make([]tokenizer.EncodeInput, batchSize)
|
||||
for i, c := range candidates {
|
||||
// Cross-encoder takes query and document as a pair
|
||||
pairs[i] = tokenizer.NewDualEncodeInput(
|
||||
tokenizer.NewRawInputSequence(query),
|
||||
tokenizer.NewRawInputSequence(c.Content),
|
||||
)
|
||||
}
|
||||
|
||||
encodings, err := s.tk.EncodeBatch(pairs, true)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("tokenize pairs: %w", err)
|
||||
}
|
||||
|
||||
// Find max sequence length
|
||||
seqLength := 0
|
||||
for _, enc := range encodings {
|
||||
if len(enc.Ids) > seqLength {
|
||||
seqLength = len(enc.Ids)
|
||||
}
|
||||
}
|
||||
if seqLength > MaxSequenceLength {
|
||||
seqLength = MaxSequenceLength
|
||||
}
|
||||
|
||||
inputShape := ort.NewShape(int64(batchSize), int64(seqLength))
|
||||
|
||||
// Create input tensors
|
||||
inputIdsData := make([]int64, batchSize*seqLength)
|
||||
attentionMaskData := make([]int64, batchSize*seqLength)
|
||||
tokenTypeIdsData := make([]int64, batchSize*seqLength)
|
||||
|
||||
for b := 0; b < batchSize; b++ {
|
||||
copyLen := len(encodings[b].Ids)
|
||||
if copyLen > seqLength {
|
||||
copyLen = seqLength
|
||||
}
|
||||
for i := 0; i < copyLen; i++ {
|
||||
inputIdsData[b*seqLength+i] = int64(encodings[b].Ids[i])
|
||||
}
|
||||
|
||||
copyLen = len(encodings[b].AttentionMask)
|
||||
if copyLen > seqLength {
|
||||
copyLen = seqLength
|
||||
}
|
||||
for i := 0; i < copyLen; i++ {
|
||||
attentionMaskData[b*seqLength+i] = int64(encodings[b].AttentionMask[i])
|
||||
}
|
||||
|
||||
copyLen = len(encodings[b].TypeIds)
|
||||
if copyLen > seqLength {
|
||||
copyLen = seqLength
|
||||
}
|
||||
for i := 0; i < copyLen; i++ {
|
||||
tokenTypeIdsData[b*seqLength+i] = int64(encodings[b].TypeIds[i])
|
||||
}
|
||||
}
|
||||
|
||||
inputIdsTensor, err := ort.NewTensor(inputShape, inputIdsData)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("create input_ids tensor: %w", err)
|
||||
}
|
||||
defer inputIdsTensor.Destroy()
|
||||
|
||||
attentionMaskTensor, err := ort.NewTensor(inputShape, attentionMaskData)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("create attention_mask tensor: %w", err)
|
||||
}
|
||||
defer attentionMaskTensor.Destroy()
|
||||
|
||||
tokenTypeIdsTensor, err := ort.NewTensor(inputShape, tokenTypeIdsData)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("create token_type_ids tensor: %w", err)
|
||||
}
|
||||
defer tokenTypeIdsTensor.Destroy()
|
||||
|
||||
// Cross-encoder outputs [batch, 1] logits
|
||||
outputShape := ort.NewShape(int64(batchSize), 1)
|
||||
outputTensor, err := ort.NewEmptyTensor[float32](outputShape)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("create output tensor: %w", err)
|
||||
}
|
||||
defer outputTensor.Destroy()
|
||||
|
||||
// Run inference
|
||||
inputTensors := []ort.Value{inputIdsTensor, attentionMaskTensor, tokenTypeIdsTensor}
|
||||
outputTensors := []ort.Value{outputTensor}
|
||||
|
||||
if err := s.session.Run(inputTensors, outputTensors); err != nil {
|
||||
return nil, fmt.Errorf("run cross-encoder inference: %w", err)
|
||||
}
|
||||
|
||||
// Extract scores
|
||||
flatOutput := outputTensor.GetData()
|
||||
scores := make([]float64, batchSize)
|
||||
for i := 0; i < batchSize; i++ {
|
||||
scores[i] = float64(flatOutput[i])
|
||||
}
|
||||
|
||||
return scores, nil
|
||||
}
|
||||
|
||||
// Score scores a single query-document pair.
|
||||
// Returns the raw cross-encoder logit and normalized score.
|
||||
func (s *Service) Score(query, document string) (rawScore, normalizedScore float64, err error) {
|
||||
s.mu.Lock()
|
||||
defer s.mu.Unlock()
|
||||
|
||||
scores, err := s.scoreAll(query, []Candidate{{Content: document}})
|
||||
if err != nil {
|
||||
return 0, 0, err
|
||||
}
|
||||
|
||||
rawScore = scores[0]
|
||||
normalizedScore = sigmoid(rawScore)
|
||||
return rawScore, normalizedScore, nil
|
||||
}
|
||||
|
||||
// Close releases model resources.
|
||||
func (s *Service) Close() error {
|
||||
s.mu.Lock()
|
||||
defer s.mu.Unlock()
|
||||
|
||||
if s.session != nil {
|
||||
if err := s.session.Destroy(); err != nil {
|
||||
return fmt.Errorf("destroy cross-encoder session: %w", err)
|
||||
}
|
||||
s.session = nil
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// sigmoid applies the sigmoid function to normalize scores to 0-1 range.
|
||||
func sigmoid(x float64) float64 {
|
||||
if x > 20 {
|
||||
return 1.0
|
||||
}
|
||||
if x < -20 {
|
||||
return 0.0
|
||||
}
|
||||
return 1.0 / (1.0 + math.Exp(-x))
|
||||
}
|
||||
@@ -0,0 +1,448 @@
|
||||
package reranking
|
||||
|
||||
import (
|
||||
"sync"
|
||||
"testing"
|
||||
|
||||
"github.com/lukaszraczylo/claude-mnemonic/internal/embedding"
|
||||
)
|
||||
|
||||
// initONNX initializes the ONNX runtime via the embedding service.
|
||||
// Must be called before creating reranking service.
|
||||
func initONNX(t *testing.T) func() {
|
||||
t.Helper()
|
||||
embSvc, err := embedding.NewService()
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to initialize ONNX via embedding service: %v", err)
|
||||
}
|
||||
return func() {
|
||||
embSvc.Close()
|
||||
}
|
||||
}
|
||||
|
||||
// TestSigmoid tests the sigmoid normalization function.
|
||||
func TestSigmoid(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
input float64
|
||||
wantMin float64
|
||||
wantMax float64
|
||||
}{
|
||||
{"positive large", 10, 0.9999, 1.0},
|
||||
{"positive small", 1, 0.7, 0.8},
|
||||
{"zero", 0, 0.4999, 0.5001},
|
||||
{"negative small", -1, 0.2, 0.3},
|
||||
{"negative large", -10, 0, 0.0001},
|
||||
{"very positive", 25, 0.999999, 1.0},
|
||||
{"very negative", -25, 0, 0.000001},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
got := sigmoid(tt.input)
|
||||
if got < tt.wantMin || got > tt.wantMax {
|
||||
t.Errorf("sigmoid(%v) = %v, want in range [%v, %v]",
|
||||
tt.input, got, tt.wantMin, tt.wantMax)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// TestDefaultConfig tests the default configuration values.
|
||||
func TestDefaultConfig(t *testing.T) {
|
||||
cfg := DefaultConfig()
|
||||
|
||||
if cfg.Alpha < 0 || cfg.Alpha > 1 {
|
||||
t.Errorf("DefaultConfig().Alpha = %v, want in range [0, 1]", cfg.Alpha)
|
||||
}
|
||||
if cfg.Alpha != 0.7 {
|
||||
t.Errorf("DefaultConfig().Alpha = %v, want 0.7", cfg.Alpha)
|
||||
}
|
||||
}
|
||||
|
||||
// TestNewService tests service creation.
|
||||
func TestNewService(t *testing.T) {
|
||||
cleanup := initONNX(t)
|
||||
defer cleanup()
|
||||
|
||||
cfg := DefaultConfig()
|
||||
svc, err := NewService(cfg)
|
||||
if err != nil {
|
||||
t.Fatalf("NewService() error = %v", err)
|
||||
}
|
||||
defer svc.Close()
|
||||
|
||||
if svc == nil {
|
||||
t.Fatal("NewService() returned nil")
|
||||
}
|
||||
|
||||
if svc.Alpha != cfg.Alpha {
|
||||
t.Errorf("Service.Alpha = %v, want %v", svc.Alpha, cfg.Alpha)
|
||||
}
|
||||
}
|
||||
|
||||
// TestScore tests single pair scoring.
|
||||
func TestScore(t *testing.T) {
|
||||
cleanup := initONNX(t)
|
||||
defer cleanup()
|
||||
|
||||
cfg := DefaultConfig()
|
||||
svc, err := NewService(cfg)
|
||||
if err != nil {
|
||||
t.Fatalf("NewService() error = %v", err)
|
||||
}
|
||||
defer svc.Close()
|
||||
|
||||
query := "What is the capital of France?"
|
||||
relevant := "Paris is the capital and largest city of France."
|
||||
irrelevant := "Dogs are popular pets known for their loyalty."
|
||||
|
||||
// Score relevant document
|
||||
_, relevantNorm, err := svc.Score(query, relevant)
|
||||
if err != nil {
|
||||
t.Fatalf("Score(relevant) error = %v", err)
|
||||
}
|
||||
|
||||
// Score irrelevant document
|
||||
_, irrelevantNorm, err := svc.Score(query, irrelevant)
|
||||
if err != nil {
|
||||
t.Fatalf("Score(irrelevant) error = %v", err)
|
||||
}
|
||||
|
||||
// Relevant document should score higher
|
||||
if relevantNorm <= irrelevantNorm {
|
||||
t.Errorf("Expected relevant (%v) > irrelevant (%v)",
|
||||
relevantNorm, irrelevantNorm)
|
||||
}
|
||||
}
|
||||
|
||||
// TestRerank tests the reranking functionality.
|
||||
func TestRerank(t *testing.T) {
|
||||
cleanup := initONNX(t)
|
||||
defer cleanup()
|
||||
|
||||
cfg := DefaultConfig()
|
||||
svc, err := NewService(cfg)
|
||||
if err != nil {
|
||||
t.Fatalf("NewService() error = %v", err)
|
||||
}
|
||||
defer svc.Close()
|
||||
|
||||
query := "How to handle errors in Go?"
|
||||
candidates := []Candidate{
|
||||
{
|
||||
ID: "1",
|
||||
Content: "Python exception handling with try/except blocks.",
|
||||
Score: 0.8, // High bi-encoder score but irrelevant
|
||||
},
|
||||
{
|
||||
ID: "2",
|
||||
Content: "Go error handling uses explicit return values. Functions return error as the last value.",
|
||||
Score: 0.6, // Lower bi-encoder score but relevant
|
||||
},
|
||||
{
|
||||
ID: "3",
|
||||
Content: "JavaScript uses Promise.catch for async error handling.",
|
||||
Score: 0.7,
|
||||
},
|
||||
}
|
||||
|
||||
results, err := svc.Rerank(query, candidates, 3)
|
||||
if err != nil {
|
||||
t.Fatalf("Rerank() error = %v", err)
|
||||
}
|
||||
|
||||
if len(results) != 3 {
|
||||
t.Fatalf("Rerank() returned %d results, want 3", len(results))
|
||||
}
|
||||
|
||||
// The Go error handling document should rank higher after reranking
|
||||
var goRank int
|
||||
for i, r := range results {
|
||||
if r.ID == "2" {
|
||||
goRank = i + 1
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
if goRank == 0 {
|
||||
t.Error("Go document not found in results")
|
||||
}
|
||||
|
||||
// Verify all results have required fields populated
|
||||
for i, r := range results {
|
||||
if r.ID == "" {
|
||||
t.Errorf("Result %d has empty ID", i)
|
||||
}
|
||||
if r.Content == "" {
|
||||
t.Errorf("Result %d has empty Content", i)
|
||||
}
|
||||
if r.RerankRank != i+1 {
|
||||
t.Errorf("Result %d has RerankRank %d, want %d", i, r.RerankRank, i+1)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// TestRerankEmpty tests reranking with empty input.
|
||||
func TestRerankEmpty(t *testing.T) {
|
||||
cleanup := initONNX(t)
|
||||
defer cleanup()
|
||||
|
||||
cfg := DefaultConfig()
|
||||
svc, err := NewService(cfg)
|
||||
if err != nil {
|
||||
t.Fatalf("NewService() error = %v", err)
|
||||
}
|
||||
defer svc.Close()
|
||||
|
||||
results, err := svc.Rerank("test query", nil, 10)
|
||||
if err != nil {
|
||||
t.Fatalf("Rerank(nil) error = %v", err)
|
||||
}
|
||||
|
||||
if results != nil {
|
||||
t.Errorf("Rerank(nil) = %v, want nil", results)
|
||||
}
|
||||
|
||||
results, err = svc.Rerank("test query", []Candidate{}, 10)
|
||||
if err != nil {
|
||||
t.Fatalf("Rerank([]) error = %v", err)
|
||||
}
|
||||
|
||||
if results != nil {
|
||||
t.Errorf("Rerank([]) = %v, want nil", results)
|
||||
}
|
||||
}
|
||||
|
||||
// TestRerankLimit tests that limit is respected.
|
||||
func TestRerankLimit(t *testing.T) {
|
||||
cleanup := initONNX(t)
|
||||
defer cleanup()
|
||||
|
||||
cfg := DefaultConfig()
|
||||
svc, err := NewService(cfg)
|
||||
if err != nil {
|
||||
t.Fatalf("NewService() error = %v", err)
|
||||
}
|
||||
defer svc.Close()
|
||||
|
||||
candidates := make([]Candidate, 20)
|
||||
for i := range candidates {
|
||||
candidates[i] = Candidate{
|
||||
ID: string(rune('A' + i)),
|
||||
Content: "Test document content for ranking.",
|
||||
Score: 0.5,
|
||||
}
|
||||
}
|
||||
|
||||
results, err := svc.Rerank("test query", candidates, 5)
|
||||
if err != nil {
|
||||
t.Fatalf("Rerank() error = %v", err)
|
||||
}
|
||||
|
||||
if len(results) != 5 {
|
||||
t.Errorf("Rerank() returned %d results, want 5", len(results))
|
||||
}
|
||||
}
|
||||
|
||||
// TestRerankByScore tests pure cross-encoder ranking.
|
||||
func TestRerankByScore(t *testing.T) {
|
||||
cleanup := initONNX(t)
|
||||
defer cleanup()
|
||||
|
||||
cfg := DefaultConfig()
|
||||
svc, err := NewService(cfg)
|
||||
if err != nil {
|
||||
t.Fatalf("NewService() error = %v", err)
|
||||
}
|
||||
defer svc.Close()
|
||||
|
||||
query := "machine learning algorithms"
|
||||
candidates := []Candidate{
|
||||
{
|
||||
ID: "1",
|
||||
Content: "Cooking recipes for Italian pasta dishes.",
|
||||
Score: 0.9, // High original score
|
||||
},
|
||||
{
|
||||
ID: "2",
|
||||
Content: "Neural networks are a type of machine learning algorithm.",
|
||||
Score: 0.3, // Low original score
|
||||
},
|
||||
}
|
||||
|
||||
results, err := svc.RerankByScore(query, candidates, 2)
|
||||
if err != nil {
|
||||
t.Fatalf("RerankByScore() error = %v", err)
|
||||
}
|
||||
|
||||
// Document 2 should rank first since it's about ML
|
||||
if results[0].ID != "2" {
|
||||
t.Errorf("Expected ML document to rank first, got %v", results[0].ID)
|
||||
}
|
||||
|
||||
// CombinedScore should equal RerankScore when using RerankByScore
|
||||
for _, r := range results {
|
||||
if r.CombinedScore != r.RerankScore {
|
||||
t.Errorf("RerankByScore: CombinedScore (%v) != RerankScore (%v)",
|
||||
r.CombinedScore, r.RerankScore)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// TestRankImprovement tests that rank improvement is calculated correctly.
|
||||
func TestRankImprovement(t *testing.T) {
|
||||
cleanup := initONNX(t)
|
||||
defer cleanup()
|
||||
|
||||
cfg := DefaultConfig()
|
||||
svc, err := NewService(cfg)
|
||||
if err != nil {
|
||||
t.Fatalf("NewService() error = %v", err)
|
||||
}
|
||||
defer svc.Close()
|
||||
|
||||
// Create candidates where we know the expected reranking
|
||||
candidates := []Candidate{
|
||||
{ID: "A", Content: "Unrelated content about weather forecasting.", Score: 0.9},
|
||||
{ID: "B", Content: "How to fix memory leaks in Go programs.", Score: 0.8},
|
||||
{ID: "C", Content: "More unrelated content about gardening tips.", Score: 0.7},
|
||||
}
|
||||
|
||||
results, err := svc.Rerank("debugging memory issues in Go", candidates, 3)
|
||||
if err != nil {
|
||||
t.Fatalf("Rerank() error = %v", err)
|
||||
}
|
||||
|
||||
for _, r := range results {
|
||||
// RankImprovement = OriginalRank - RerankRank
|
||||
// Positive means moved up, negative means moved down
|
||||
expectedImprovement := r.OriginalRank - r.RerankRank
|
||||
if r.RankImprovement != expectedImprovement {
|
||||
t.Errorf("ID %s: RankImprovement = %d, want %d (orig=%d, new=%d)",
|
||||
r.ID, r.RankImprovement, expectedImprovement,
|
||||
r.OriginalRank, r.RerankRank)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// TestConcurrentRerank tests concurrent reranking calls.
|
||||
func TestConcurrentRerank(t *testing.T) {
|
||||
cleanup := initONNX(t)
|
||||
defer cleanup()
|
||||
|
||||
cfg := DefaultConfig()
|
||||
svc, err := NewService(cfg)
|
||||
if err != nil {
|
||||
t.Fatalf("NewService() error = %v", err)
|
||||
}
|
||||
defer svc.Close()
|
||||
|
||||
candidates := []Candidate{
|
||||
{ID: "1", Content: "Test document one.", Score: 0.5},
|
||||
{ID: "2", Content: "Test document two.", Score: 0.5},
|
||||
}
|
||||
|
||||
var wg sync.WaitGroup
|
||||
errors := make(chan error, 10)
|
||||
|
||||
for i := 0; i < 10; i++ {
|
||||
wg.Add(1)
|
||||
go func(idx int) {
|
||||
defer wg.Done()
|
||||
_, err := svc.Rerank("concurrent test query", candidates, 2)
|
||||
if err != nil {
|
||||
errors <- err
|
||||
}
|
||||
}(i)
|
||||
}
|
||||
|
||||
wg.Wait()
|
||||
close(errors)
|
||||
|
||||
for err := range errors {
|
||||
t.Errorf("Concurrent Rerank error: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
// TestClose tests service cleanup.
|
||||
func TestClose(t *testing.T) {
|
||||
cleanup := initONNX(t)
|
||||
defer cleanup()
|
||||
|
||||
cfg := DefaultConfig()
|
||||
svc, err := NewService(cfg)
|
||||
if err != nil {
|
||||
t.Fatalf("NewService() error = %v", err)
|
||||
}
|
||||
|
||||
err = svc.Close()
|
||||
if err != nil {
|
||||
t.Errorf("Close() error = %v", err)
|
||||
}
|
||||
|
||||
// Double close should not panic
|
||||
err = svc.Close()
|
||||
if err != nil {
|
||||
t.Errorf("Close() on closed service error = %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
// TestMetadataPreserved tests that metadata is preserved through reranking.
|
||||
func TestMetadataPreserved(t *testing.T) {
|
||||
cleanup := initONNX(t)
|
||||
defer cleanup()
|
||||
|
||||
cfg := DefaultConfig()
|
||||
svc, err := NewService(cfg)
|
||||
if err != nil {
|
||||
t.Fatalf("NewService() error = %v", err)
|
||||
}
|
||||
defer svc.Close()
|
||||
|
||||
candidates := []Candidate{
|
||||
{
|
||||
ID: "1",
|
||||
Content: "Test content.",
|
||||
Score: 0.5,
|
||||
Metadata: map[string]any{"custom": "value1", "num": 42},
|
||||
},
|
||||
{
|
||||
ID: "2",
|
||||
Content: "Another test.",
|
||||
Score: 0.5,
|
||||
Metadata: map[string]any{"custom": "value2"},
|
||||
},
|
||||
}
|
||||
|
||||
results, err := svc.Rerank("query", candidates, 2)
|
||||
if err != nil {
|
||||
t.Fatalf("Rerank() error = %v", err)
|
||||
}
|
||||
|
||||
for _, r := range results {
|
||||
if r.Metadata == nil {
|
||||
t.Errorf("Result %s has nil metadata", r.ID)
|
||||
continue
|
||||
}
|
||||
|
||||
// Find original candidate
|
||||
var original *Candidate
|
||||
for i := range candidates {
|
||||
if candidates[i].ID == r.ID {
|
||||
original = &candidates[i]
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
if original == nil {
|
||||
t.Errorf("Could not find original for result %s", r.ID)
|
||||
continue
|
||||
}
|
||||
|
||||
// Check metadata preserved
|
||||
if original.Metadata["custom"] != r.Metadata["custom"] {
|
||||
t.Errorf("Metadata not preserved for %s", r.ID)
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,168 @@
|
||||
// Package scoring provides importance score calculation for observations.
|
||||
package scoring
|
||||
|
||||
import (
|
||||
"math"
|
||||
"time"
|
||||
|
||||
"github.com/lukaszraczylo/claude-mnemonic/pkg/models"
|
||||
)
|
||||
|
||||
// Calculator computes importance scores for observations.
|
||||
type Calculator struct {
|
||||
config *models.ScoringConfig
|
||||
}
|
||||
|
||||
// NewCalculator creates a new scoring calculator.
|
||||
// If config is nil, uses the default configuration.
|
||||
func NewCalculator(config *models.ScoringConfig) *Calculator {
|
||||
if config == nil {
|
||||
config = models.DefaultScoringConfig()
|
||||
}
|
||||
return &Calculator{config: config}
|
||||
}
|
||||
|
||||
// Calculate computes the importance score for an observation at the given time.
|
||||
//
|
||||
// The scoring formula:
|
||||
//
|
||||
// FinalScore = (BaseScore × TypeWeight × RecencyDecay) + FeedbackContrib + ConceptContrib + RetrievalContrib
|
||||
//
|
||||
// Where:
|
||||
// - BaseScore = 1.0
|
||||
// - TypeWeight = observation type multiplier (e.g., bugfix=1.3, change=0.9)
|
||||
// - RecencyDecay = 0.5^(age_days / half_life_days) - halves every 7 days by default
|
||||
// - FeedbackContrib = user_feedback × feedback_weight
|
||||
// - ConceptContrib = sum(concept_weights) × concept_weight_factor
|
||||
// - RetrievalContrib = log2(retrieval_count + 1) × 0.1 × retrieval_weight
|
||||
func (c *Calculator) Calculate(obs *models.Observation, now time.Time) float64 {
|
||||
// 1. Get base type weight
|
||||
typeWeight := models.TypeBaseScore(obs.Type)
|
||||
|
||||
// 2. Calculate recency decay: 0.5^(age_days / half_life_days)
|
||||
ageDays := now.Sub(time.UnixMilli(obs.CreatedAtEpoch)).Hours() / 24.0
|
||||
if ageDays < 0 {
|
||||
ageDays = 0 // Handle future timestamps gracefully
|
||||
}
|
||||
recencyDecay := math.Pow(0.5, ageDays/c.config.RecencyHalfLifeDays)
|
||||
|
||||
// Core score = 1.0 × type_weight × recency_decay
|
||||
coreScore := 1.0 * typeWeight * recencyDecay
|
||||
|
||||
// 3. User feedback contribution: feedback × weight
|
||||
feedbackContrib := float64(obs.UserFeedback) * c.config.FeedbackWeight
|
||||
|
||||
// 4. Concept boost contribution: sum of matching concept weights × factor
|
||||
conceptBoost := 0.0
|
||||
for _, concept := range obs.Concepts {
|
||||
if weight, ok := c.config.ConceptWeights[concept]; ok {
|
||||
conceptBoost += weight
|
||||
}
|
||||
}
|
||||
conceptContrib := conceptBoost * c.config.ConceptWeight
|
||||
|
||||
// 5. Retrieval boost: log2(count + 1) × 0.1 × weight (diminishing returns)
|
||||
retrievalContrib := 0.0
|
||||
if obs.RetrievalCount > 0 {
|
||||
// log2(count + 1) gives diminishing returns: 1→1, 3→2, 7→3, 15→4, etc.
|
||||
retrievalBoost := math.Log2(float64(obs.RetrievalCount)+1) * 0.1
|
||||
retrievalContrib = retrievalBoost * c.config.RetrievalWeight
|
||||
}
|
||||
|
||||
// Final score with minimum threshold
|
||||
finalScore := coreScore + feedbackContrib + conceptContrib + retrievalContrib
|
||||
if finalScore < c.config.MinScore {
|
||||
finalScore = c.config.MinScore
|
||||
}
|
||||
|
||||
return finalScore
|
||||
}
|
||||
|
||||
// CalculateComponents returns the individual components of the importance score.
|
||||
// Useful for debugging and explaining scores to users.
|
||||
func (c *Calculator) CalculateComponents(obs *models.Observation, now time.Time) ScoreComponents {
|
||||
typeWeight := models.TypeBaseScore(obs.Type)
|
||||
|
||||
ageDays := now.Sub(time.UnixMilli(obs.CreatedAtEpoch)).Hours() / 24.0
|
||||
if ageDays < 0 {
|
||||
ageDays = 0
|
||||
}
|
||||
recencyDecay := math.Pow(0.5, ageDays/c.config.RecencyHalfLifeDays)
|
||||
|
||||
coreScore := 1.0 * typeWeight * recencyDecay
|
||||
feedbackContrib := float64(obs.UserFeedback) * c.config.FeedbackWeight
|
||||
|
||||
conceptBoost := 0.0
|
||||
for _, concept := range obs.Concepts {
|
||||
if weight, ok := c.config.ConceptWeights[concept]; ok {
|
||||
conceptBoost += weight
|
||||
}
|
||||
}
|
||||
conceptContrib := conceptBoost * c.config.ConceptWeight
|
||||
|
||||
retrievalContrib := 0.0
|
||||
if obs.RetrievalCount > 0 {
|
||||
retrievalBoost := math.Log2(float64(obs.RetrievalCount)+1) * 0.1
|
||||
retrievalContrib = retrievalBoost * c.config.RetrievalWeight
|
||||
}
|
||||
|
||||
finalScore := coreScore + feedbackContrib + conceptContrib + retrievalContrib
|
||||
if finalScore < c.config.MinScore {
|
||||
finalScore = c.config.MinScore
|
||||
}
|
||||
|
||||
return ScoreComponents{
|
||||
TypeWeight: typeWeight,
|
||||
RecencyDecay: recencyDecay,
|
||||
CoreScore: coreScore,
|
||||
FeedbackContrib: feedbackContrib,
|
||||
ConceptContrib: conceptContrib,
|
||||
RetrievalContrib: retrievalContrib,
|
||||
FinalScore: finalScore,
|
||||
AgeDays: ageDays,
|
||||
}
|
||||
}
|
||||
|
||||
// ScoreComponents contains the breakdown of an importance score calculation.
|
||||
type ScoreComponents struct {
|
||||
TypeWeight float64 `json:"type_weight"`
|
||||
RecencyDecay float64 `json:"recency_decay"`
|
||||
CoreScore float64 `json:"core_score"`
|
||||
FeedbackContrib float64 `json:"feedback_contrib"`
|
||||
ConceptContrib float64 `json:"concept_contrib"`
|
||||
RetrievalContrib float64 `json:"retrieval_contrib"`
|
||||
FinalScore float64 `json:"final_score"`
|
||||
AgeDays float64 `json:"age_days"`
|
||||
}
|
||||
|
||||
// BatchCalculate computes scores for multiple observations.
|
||||
// Returns a map of observation ID to calculated score.
|
||||
func (c *Calculator) BatchCalculate(observations []*models.Observation, now time.Time) map[int64]float64 {
|
||||
scores := make(map[int64]float64, len(observations))
|
||||
for _, obs := range observations {
|
||||
scores[obs.ID] = c.Calculate(obs, now)
|
||||
}
|
||||
return scores
|
||||
}
|
||||
|
||||
// RecalculateThreshold returns the minimum duration before an observation
|
||||
// should have its score recalculated. This prevents excessive recalculation
|
||||
// while ensuring scores stay reasonably fresh.
|
||||
func (c *Calculator) RecalculateThreshold() time.Duration {
|
||||
// Recalculate at most every 6 hours
|
||||
// This balances freshness with performance
|
||||
return 6 * time.Hour
|
||||
}
|
||||
|
||||
// UpdateConfig updates the calculator's scoring configuration.
|
||||
// This allows runtime tuning of scoring parameters.
|
||||
func (c *Calculator) UpdateConfig(config *models.ScoringConfig) {
|
||||
if config != nil {
|
||||
c.config = config
|
||||
}
|
||||
}
|
||||
|
||||
// GetConfig returns the current scoring configuration.
|
||||
func (c *Calculator) GetConfig() *models.ScoringConfig {
|
||||
return c.config
|
||||
}
|
||||
@@ -0,0 +1,638 @@
|
||||
// Package scoring provides importance score calculation for observations.
|
||||
package scoring
|
||||
|
||||
import (
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/lukaszraczylo/claude-mnemonic/pkg/models"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
"github.com/stretchr/testify/suite"
|
||||
)
|
||||
|
||||
// CalculatorSuite is a test suite for the Calculator.
|
||||
type CalculatorSuite struct {
|
||||
suite.Suite
|
||||
calc *Calculator
|
||||
config *models.ScoringConfig
|
||||
now time.Time
|
||||
}
|
||||
|
||||
func (s *CalculatorSuite) SetupTest() {
|
||||
s.config = models.DefaultScoringConfig()
|
||||
s.calc = NewCalculator(s.config)
|
||||
s.now = time.Date(2025, 1, 15, 12, 0, 0, 0, time.UTC)
|
||||
}
|
||||
|
||||
func TestCalculatorSuite(t *testing.T) {
|
||||
suite.Run(t, new(CalculatorSuite))
|
||||
}
|
||||
|
||||
// =============================================================================
|
||||
// GOOD SCENARIOS - Expected normal operations
|
||||
// =============================================================================
|
||||
|
||||
func (s *CalculatorSuite) TestCalculate_GoodScenarios_NewObservation() {
|
||||
// A brand new observation should have score close to type weight
|
||||
obs := &models.Observation{
|
||||
ID: 1,
|
||||
Type: models.ObsTypeBugfix,
|
||||
CreatedAtEpoch: s.now.UnixMilli(),
|
||||
}
|
||||
|
||||
score := s.calc.Calculate(obs, s.now)
|
||||
|
||||
// Expected: 1.0 × 1.3 (bugfix weight) × 1.0 (no decay) = 1.3
|
||||
s.InDelta(1.3, score, 0.01, "new bugfix should score ~1.3")
|
||||
}
|
||||
|
||||
func (s *CalculatorSuite) TestCalculate_GoodScenarios_OneWeekOld() {
|
||||
// One week old observation should have half the recency score
|
||||
obs := &models.Observation{
|
||||
ID: 1,
|
||||
Type: models.ObsTypeDiscovery,
|
||||
CreatedAtEpoch: s.now.Add(-7 * 24 * time.Hour).UnixMilli(),
|
||||
}
|
||||
|
||||
score := s.calc.Calculate(obs, s.now)
|
||||
|
||||
// Expected: 1.0 × 1.1 (discovery) × 0.5 (7 days half-life) = 0.55
|
||||
s.InDelta(0.55, score, 0.05, "7-day old discovery should score ~0.55")
|
||||
}
|
||||
|
||||
func (s *CalculatorSuite) TestCalculate_GoodScenarios_TwoWeeksOld() {
|
||||
// Two weeks old should have 1/4 recency score
|
||||
obs := &models.Observation{
|
||||
ID: 1,
|
||||
Type: models.ObsTypeFeature,
|
||||
CreatedAtEpoch: s.now.Add(-14 * 24 * time.Hour).UnixMilli(),
|
||||
}
|
||||
|
||||
score := s.calc.Calculate(obs, s.now)
|
||||
|
||||
// Expected: 1.0 × 1.2 (feature) × 0.25 (14 days = 2 half-lives) = 0.30
|
||||
s.InDelta(0.30, score, 0.05, "14-day old feature should score ~0.30")
|
||||
}
|
||||
|
||||
func (s *CalculatorSuite) TestCalculate_GoodScenarios_PositiveFeedback() {
|
||||
// Positive feedback should boost score
|
||||
obs := &models.Observation{
|
||||
ID: 1,
|
||||
Type: models.ObsTypeChange,
|
||||
CreatedAtEpoch: s.now.UnixMilli(),
|
||||
UserFeedback: 1, // thumbs up
|
||||
}
|
||||
|
||||
score := s.calc.Calculate(obs, s.now)
|
||||
|
||||
// Expected: (1.0 × 0.9) + 0.30 (feedback) = 1.20
|
||||
s.InDelta(1.20, score, 0.01, "thumbs up should boost score by 0.30")
|
||||
}
|
||||
|
||||
func (s *CalculatorSuite) TestCalculate_GoodScenarios_NegativeFeedback() {
|
||||
// Negative feedback should reduce score
|
||||
obs := &models.Observation{
|
||||
ID: 1,
|
||||
Type: models.ObsTypeChange,
|
||||
CreatedAtEpoch: s.now.UnixMilli(),
|
||||
UserFeedback: -1, // thumbs down
|
||||
}
|
||||
|
||||
score := s.calc.Calculate(obs, s.now)
|
||||
|
||||
// Expected: (1.0 × 0.9) - 0.30 (feedback) = 0.60
|
||||
s.InDelta(0.60, score, 0.01, "thumbs down should reduce score by 0.30")
|
||||
}
|
||||
|
||||
func (s *CalculatorSuite) TestCalculate_GoodScenarios_WithConcepts() {
|
||||
// Observation with valuable concepts should get boost
|
||||
obs := &models.Observation{
|
||||
ID: 1,
|
||||
Type: models.ObsTypeBugfix,
|
||||
CreatedAtEpoch: s.now.UnixMilli(),
|
||||
Concepts: []string{"security", "gotcha"},
|
||||
}
|
||||
|
||||
score := s.calc.Calculate(obs, s.now)
|
||||
|
||||
// Concept boost: (0.30 + 0.25) × 0.20 = 0.11
|
||||
// Expected: 1.3 + 0.11 = 1.41
|
||||
s.InDelta(1.41, score, 0.05, "security+gotcha concepts should boost score")
|
||||
}
|
||||
|
||||
func (s *CalculatorSuite) TestCalculate_GoodScenarios_WithRetrievals() {
|
||||
// Popular observations should get retrieval boost
|
||||
obs := &models.Observation{
|
||||
ID: 1,
|
||||
Type: models.ObsTypeDiscovery,
|
||||
CreatedAtEpoch: s.now.UnixMilli(),
|
||||
RetrievalCount: 7, // log2(8) = 3
|
||||
}
|
||||
|
||||
score := s.calc.Calculate(obs, s.now)
|
||||
|
||||
// Retrieval boost: log2(7+1) × 0.1 × 0.15 = 3 × 0.1 × 0.15 = 0.045
|
||||
// Expected: 1.1 + 0.045 ≈ 1.145
|
||||
s.InDelta(1.145, score, 0.05, "7 retrievals should add small boost")
|
||||
}
|
||||
|
||||
func (s *CalculatorSuite) TestCalculate_GoodScenarios_CombinedFactors() {
|
||||
// Test with all factors combined
|
||||
obs := &models.Observation{
|
||||
ID: 1,
|
||||
Type: models.ObsTypeBugfix,
|
||||
CreatedAtEpoch: s.now.Add(-7 * 24 * time.Hour).UnixMilli(), // 7 days old
|
||||
UserFeedback: 1,
|
||||
Concepts: []string{"security"},
|
||||
RetrievalCount: 3,
|
||||
}
|
||||
|
||||
score := s.calc.Calculate(obs, s.now)
|
||||
|
||||
// Core: 1.0 × 1.3 × 0.5 = 0.65
|
||||
// Feedback: 0.30
|
||||
// Concept: 0.30 × 0.20 = 0.06
|
||||
// Retrieval: log2(4) × 0.1 × 0.15 = 2 × 0.1 × 0.15 = 0.03
|
||||
// Total ≈ 1.04
|
||||
s.InDelta(1.04, score, 0.1, "combined factors should result in ~1.04")
|
||||
}
|
||||
|
||||
// =============================================================================
|
||||
// WORSE SCENARIOS - Degraded but acceptable operations
|
||||
// =============================================================================
|
||||
|
||||
func (s *CalculatorSuite) TestCalculate_WorseScenarios_VeryOldObservation() {
|
||||
// Very old observation should have low but non-zero score
|
||||
obs := &models.Observation{
|
||||
ID: 1,
|
||||
Type: models.ObsTypeChange,
|
||||
CreatedAtEpoch: s.now.Add(-90 * 24 * time.Hour).UnixMilli(), // 90 days old
|
||||
}
|
||||
|
||||
score := s.calc.Calculate(obs, s.now)
|
||||
|
||||
// 90 days = ~12.86 half-lives → decay ≈ 0.00014
|
||||
// Core: 1.0 × 0.9 × 0.00014 = 0.000126
|
||||
// But minimum score is 0.01
|
||||
s.GreaterOrEqual(score, 0.01, "very old observation should still meet minimum")
|
||||
s.Less(score, 0.1, "very old observation should be low scoring")
|
||||
}
|
||||
|
||||
func (s *CalculatorSuite) TestCalculate_WorseScenarios_NegativeFeedbackOld() {
|
||||
// Old observation with negative feedback should still have minimum score
|
||||
obs := &models.Observation{
|
||||
ID: 1,
|
||||
Type: models.ObsTypeChange,
|
||||
CreatedAtEpoch: s.now.Add(-60 * 24 * time.Hour).UnixMilli(),
|
||||
UserFeedback: -1,
|
||||
}
|
||||
|
||||
score := s.calc.Calculate(obs, s.now)
|
||||
|
||||
s.GreaterOrEqual(score, s.config.MinScore, "should never go below minimum score")
|
||||
}
|
||||
|
||||
func (s *CalculatorSuite) TestCalculate_WorseScenarios_UnknownConcepts() {
|
||||
// Unknown concepts should not affect score negatively
|
||||
obs := &models.Observation{
|
||||
ID: 1,
|
||||
Type: models.ObsTypeDiscovery,
|
||||
CreatedAtEpoch: s.now.UnixMilli(),
|
||||
Concepts: []string{"unknown-concept", "another-unknown"},
|
||||
}
|
||||
|
||||
score := s.calc.Calculate(obs, s.now)
|
||||
|
||||
// Should just be the base score without concept boost
|
||||
s.InDelta(1.1, score, 0.01, "unknown concepts should not affect score")
|
||||
}
|
||||
|
||||
func (s *CalculatorSuite) TestCalculate_WorseScenarios_MixedConcepts() {
|
||||
// Mix of known and unknown concepts
|
||||
obs := &models.Observation{
|
||||
ID: 1,
|
||||
Type: models.ObsTypeDiscovery,
|
||||
CreatedAtEpoch: s.now.UnixMilli(),
|
||||
Concepts: []string{"security", "unknown-concept"},
|
||||
}
|
||||
|
||||
score := s.calc.Calculate(obs, s.now)
|
||||
|
||||
// Only security should contribute
|
||||
// Expected: 1.1 + (0.30 × 0.20) = 1.16
|
||||
s.InDelta(1.16, score, 0.05, "only known concepts should boost score")
|
||||
}
|
||||
|
||||
// =============================================================================
|
||||
// BAD SCENARIOS - Edge cases and error conditions
|
||||
// =============================================================================
|
||||
|
||||
func (s *CalculatorSuite) TestCalculate_BadScenarios_FutureTimestamp() {
|
||||
// Observation created in the future (clock skew)
|
||||
obs := &models.Observation{
|
||||
ID: 1,
|
||||
Type: models.ObsTypeBugfix,
|
||||
CreatedAtEpoch: s.now.Add(24 * time.Hour).UnixMilli(), // 1 day in future
|
||||
}
|
||||
|
||||
score := s.calc.Calculate(obs, s.now)
|
||||
|
||||
// Should handle gracefully - age should be 0
|
||||
s.InDelta(1.3, score, 0.01, "future timestamp should be treated as now")
|
||||
}
|
||||
|
||||
func (s *CalculatorSuite) TestCalculate_BadScenarios_ZeroEpoch() {
|
||||
// Missing creation timestamp
|
||||
obs := &models.Observation{
|
||||
ID: 1,
|
||||
Type: models.ObsTypeDiscovery,
|
||||
CreatedAtEpoch: 0, // Missing timestamp
|
||||
}
|
||||
|
||||
score := s.calc.Calculate(obs, s.now)
|
||||
|
||||
// This will be treated as very old (1970)
|
||||
s.GreaterOrEqual(score, s.config.MinScore, "should still meet minimum")
|
||||
}
|
||||
|
||||
func (s *CalculatorSuite) TestCalculate_BadScenarios_EmptyObservation() {
|
||||
// Minimal observation with defaults
|
||||
obs := &models.Observation{
|
||||
ID: 1,
|
||||
Type: "", // Empty type
|
||||
CreatedAtEpoch: s.now.UnixMilli(),
|
||||
}
|
||||
|
||||
score := s.calc.Calculate(obs, s.now)
|
||||
|
||||
// Unknown type should default to 1.0 weight
|
||||
s.InDelta(1.0, score, 0.01, "empty type should use default weight 1.0")
|
||||
}
|
||||
|
||||
func (s *CalculatorSuite) TestCalculate_BadScenarios_ExtremeRetrievalCount() {
|
||||
// Very high retrieval count
|
||||
obs := &models.Observation{
|
||||
ID: 1,
|
||||
Type: models.ObsTypeDiscovery,
|
||||
CreatedAtEpoch: s.now.UnixMilli(),
|
||||
RetrievalCount: 1000000, // Extreme value
|
||||
}
|
||||
|
||||
score := s.calc.Calculate(obs, s.now)
|
||||
|
||||
// log2(1000001) ≈ 19.93, so boost = 19.93 × 0.1 × 0.15 ≈ 0.30
|
||||
// Score should be reasonable, not exploding
|
||||
s.Less(score, 2.0, "extreme retrieval count should not explode score")
|
||||
}
|
||||
|
||||
func (s *CalculatorSuite) TestCalculate_BadScenarios_NegativeRetrievalCount() {
|
||||
// Negative retrieval count (should not happen but test defensively)
|
||||
obs := &models.Observation{
|
||||
ID: 1,
|
||||
Type: models.ObsTypeDiscovery,
|
||||
CreatedAtEpoch: s.now.UnixMilli(),
|
||||
RetrievalCount: -5,
|
||||
}
|
||||
|
||||
score := s.calc.Calculate(obs, s.now)
|
||||
|
||||
// Should not panic and should give base score
|
||||
s.InDelta(1.1, score, 0.01, "negative retrieval should be ignored")
|
||||
}
|
||||
|
||||
// =============================================================================
|
||||
// EDGE CASES - Boundary conditions
|
||||
// =============================================================================
|
||||
|
||||
func (s *CalculatorSuite) TestCalculate_EdgeCases_ExactlyOneHalfLife() {
|
||||
obs := &models.Observation{
|
||||
ID: 1,
|
||||
Type: models.ObsTypeChange, // 0.9 weight
|
||||
CreatedAtEpoch: s.now.Add(-7 * 24 * time.Hour).UnixMilli(),
|
||||
}
|
||||
|
||||
score := s.calc.Calculate(obs, s.now)
|
||||
s.InDelta(0.45, score, 0.01, "exactly 7 days should give 0.5 decay")
|
||||
}
|
||||
|
||||
func (s *CalculatorSuite) TestCalculate_EdgeCases_ExactlyTwoHalfLives() {
|
||||
obs := &models.Observation{
|
||||
ID: 1,
|
||||
Type: models.ObsTypeChange,
|
||||
CreatedAtEpoch: s.now.Add(-14 * 24 * time.Hour).UnixMilli(),
|
||||
}
|
||||
|
||||
score := s.calc.Calculate(obs, s.now)
|
||||
s.InDelta(0.225, score, 0.01, "exactly 14 days should give 0.25 decay")
|
||||
}
|
||||
|
||||
func (s *CalculatorSuite) TestCalculate_EdgeCases_AllTypeWeights() {
|
||||
types := []struct {
|
||||
t models.ObservationType
|
||||
weight float64
|
||||
}{
|
||||
{models.ObsTypeBugfix, 1.3},
|
||||
{models.ObsTypeFeature, 1.2},
|
||||
{models.ObsTypeDiscovery, 1.1},
|
||||
{models.ObsTypeDecision, 1.1},
|
||||
{models.ObsTypeRefactor, 1.0},
|
||||
{models.ObsTypeChange, 0.9},
|
||||
}
|
||||
|
||||
for _, tt := range types {
|
||||
s.Run(string(tt.t), func() {
|
||||
obs := &models.Observation{
|
||||
ID: 1,
|
||||
Type: tt.t,
|
||||
CreatedAtEpoch: s.now.UnixMilli(),
|
||||
}
|
||||
score := s.calc.Calculate(obs, s.now)
|
||||
s.InDelta(tt.weight, score, 0.01)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func (s *CalculatorSuite) TestCalculate_EdgeCases_MinimumScoreEnforced() {
|
||||
// Create worst case scenario
|
||||
obs := &models.Observation{
|
||||
ID: 1,
|
||||
Type: models.ObsTypeChange, // Lowest weight 0.9
|
||||
CreatedAtEpoch: time.Date(2000, 1, 1, 0, 0, 0, 0, time.UTC).UnixMilli(), // Very old
|
||||
UserFeedback: -1, // Negative feedback
|
||||
}
|
||||
|
||||
score := s.calc.Calculate(obs, s.now)
|
||||
|
||||
s.Equal(s.config.MinScore, score, "should be exactly minimum score")
|
||||
}
|
||||
|
||||
func (s *CalculatorSuite) TestCalculate_EdgeCases_AllConceptsMaxWeight() {
|
||||
// Observation with all high-value concepts
|
||||
obs := &models.Observation{
|
||||
ID: 1,
|
||||
Type: models.ObsTypeBugfix,
|
||||
CreatedAtEpoch: s.now.UnixMilli(),
|
||||
Concepts: []string{"security", "gotcha", "best-practice", "anti-pattern"},
|
||||
}
|
||||
|
||||
score := s.calc.Calculate(obs, s.now)
|
||||
|
||||
// security=0.30, gotcha=0.25, best-practice=0.20, anti-pattern=0.20 = 0.95
|
||||
// Concept contrib: 0.95 × 0.20 = 0.19
|
||||
// Total: 1.3 + 0.19 = 1.49
|
||||
s.InDelta(1.49, score, 0.05, "all high-value concepts should boost significantly")
|
||||
}
|
||||
|
||||
// =============================================================================
|
||||
// CALCULATE COMPONENTS TESTS
|
||||
// =============================================================================
|
||||
|
||||
func (s *CalculatorSuite) TestCalculateComponents_ReturnsAllComponents() {
|
||||
obs := &models.Observation{
|
||||
ID: 1,
|
||||
Type: models.ObsTypeBugfix,
|
||||
CreatedAtEpoch: s.now.Add(-7 * 24 * time.Hour).UnixMilli(),
|
||||
UserFeedback: 1,
|
||||
Concepts: []string{"security"},
|
||||
RetrievalCount: 7,
|
||||
}
|
||||
|
||||
components := s.calc.CalculateComponents(obs, s.now)
|
||||
|
||||
s.InDelta(1.3, components.TypeWeight, 0.01)
|
||||
s.InDelta(0.5, components.RecencyDecay, 0.01)
|
||||
s.InDelta(0.65, components.CoreScore, 0.05)
|
||||
s.InDelta(0.30, components.FeedbackContrib, 0.01)
|
||||
s.InDelta(0.06, components.ConceptContrib, 0.02)
|
||||
s.Greater(components.RetrievalContrib, 0.0)
|
||||
s.InDelta(7.0, components.AgeDays, 0.1)
|
||||
s.Greater(components.FinalScore, 0.0)
|
||||
}
|
||||
|
||||
func (s *CalculatorSuite) TestCalculateComponents_MatchesCalculate() {
|
||||
obs := &models.Observation{
|
||||
ID: 1,
|
||||
Type: models.ObsTypeFeature,
|
||||
CreatedAtEpoch: s.now.Add(-3 * 24 * time.Hour).UnixMilli(),
|
||||
UserFeedback: -1,
|
||||
Concepts: []string{"performance", "architecture"},
|
||||
RetrievalCount: 15,
|
||||
}
|
||||
|
||||
score := s.calc.Calculate(obs, s.now)
|
||||
components := s.calc.CalculateComponents(obs, s.now)
|
||||
|
||||
s.InDelta(score, components.FinalScore, 0.001, "Calculate and CalculateComponents should match")
|
||||
}
|
||||
|
||||
// =============================================================================
|
||||
// BATCH CALCULATE TESTS
|
||||
// =============================================================================
|
||||
|
||||
func (s *CalculatorSuite) TestBatchCalculate_Empty() {
|
||||
scores := s.calc.BatchCalculate(nil, s.now)
|
||||
s.Empty(scores)
|
||||
|
||||
scores = s.calc.BatchCalculate([]*models.Observation{}, s.now)
|
||||
s.Empty(scores)
|
||||
}
|
||||
|
||||
func (s *CalculatorSuite) TestBatchCalculate_Multiple() {
|
||||
obs := []*models.Observation{
|
||||
{ID: 1, Type: models.ObsTypeBugfix, CreatedAtEpoch: s.now.UnixMilli()},
|
||||
{ID: 2, Type: models.ObsTypeFeature, CreatedAtEpoch: s.now.Add(-7 * 24 * time.Hour).UnixMilli()},
|
||||
{ID: 3, Type: models.ObsTypeChange, CreatedAtEpoch: s.now.Add(-14 * 24 * time.Hour).UnixMilli()},
|
||||
}
|
||||
|
||||
scores := s.calc.BatchCalculate(obs, s.now)
|
||||
|
||||
s.Len(scores, 3)
|
||||
s.Contains(scores, int64(1))
|
||||
s.Contains(scores, int64(2))
|
||||
s.Contains(scores, int64(3))
|
||||
|
||||
s.InDelta(1.3, scores[1], 0.01) // New bugfix
|
||||
s.InDelta(0.6, scores[2], 0.1) // 7-day feature
|
||||
s.InDelta(0.225, scores[3], 0.05) // 14-day change
|
||||
}
|
||||
|
||||
// =============================================================================
|
||||
// CONFIGURATION TESTS
|
||||
// =============================================================================
|
||||
|
||||
func (s *CalculatorSuite) TestNewCalculator_NilConfig() {
|
||||
calc := NewCalculator(nil)
|
||||
s.NotNil(calc)
|
||||
s.NotNil(calc.config)
|
||||
s.Equal(7.0, calc.config.RecencyHalfLifeDays)
|
||||
}
|
||||
|
||||
func (s *CalculatorSuite) TestUpdateConfig() {
|
||||
newConfig := &models.ScoringConfig{
|
||||
RecencyHalfLifeDays: 14.0, // Changed from 7
|
||||
FeedbackWeight: 0.50,
|
||||
ConceptWeight: 0.10,
|
||||
RetrievalWeight: 0.05,
|
||||
MinScore: 0.001,
|
||||
ConceptWeights: map[string]float64{"test": 0.5},
|
||||
}
|
||||
|
||||
s.calc.UpdateConfig(newConfig)
|
||||
|
||||
obs := &models.Observation{
|
||||
ID: 1,
|
||||
Type: models.ObsTypeChange,
|
||||
CreatedAtEpoch: s.now.Add(-14 * 24 * time.Hour).UnixMilli(),
|
||||
}
|
||||
|
||||
score := s.calc.Calculate(obs, s.now)
|
||||
|
||||
// With 14-day half-life, 14 days = exactly one half-life
|
||||
// Expected: 1.0 × 0.9 × 0.5 = 0.45
|
||||
s.InDelta(0.45, score, 0.01)
|
||||
}
|
||||
|
||||
func (s *CalculatorSuite) TestUpdateConfig_NilIgnored() {
|
||||
originalConfig := s.calc.GetConfig()
|
||||
s.calc.UpdateConfig(nil)
|
||||
s.Equal(originalConfig, s.calc.GetConfig())
|
||||
}
|
||||
|
||||
func (s *CalculatorSuite) TestGetConfig() {
|
||||
config := s.calc.GetConfig()
|
||||
s.NotNil(config)
|
||||
s.Equal(7.0, config.RecencyHalfLifeDays)
|
||||
}
|
||||
|
||||
func (s *CalculatorSuite) TestRecalculateThreshold() {
|
||||
threshold := s.calc.RecalculateThreshold()
|
||||
s.Equal(6*time.Hour, threshold)
|
||||
}
|
||||
|
||||
// =============================================================================
|
||||
// STANDALONE TESTS (non-suite)
|
||||
// =============================================================================
|
||||
|
||||
func TestNewCalculator_DefaultConfig(t *testing.T) {
|
||||
calc := NewCalculator(nil)
|
||||
require.NotNil(t, calc)
|
||||
assert.Equal(t, 7.0, calc.config.RecencyHalfLifeDays)
|
||||
assert.Equal(t, 0.30, calc.config.FeedbackWeight)
|
||||
assert.Equal(t, 0.01, calc.config.MinScore)
|
||||
}
|
||||
|
||||
func TestCalculator_ConcurrentAccess(t *testing.T) {
|
||||
calc := NewCalculator(nil)
|
||||
now := time.Now()
|
||||
|
||||
// Test that calculator is safe for concurrent reads
|
||||
done := make(chan bool, 10)
|
||||
for i := 0; i < 10; i++ {
|
||||
go func(id int64) {
|
||||
obs := &models.Observation{
|
||||
ID: id,
|
||||
Type: models.ObsTypeBugfix,
|
||||
CreatedAtEpoch: now.UnixMilli(),
|
||||
}
|
||||
score := calc.Calculate(obs, now)
|
||||
assert.Greater(t, score, 0.0)
|
||||
done <- true
|
||||
}(int64(i))
|
||||
}
|
||||
|
||||
for i := 0; i < 10; i++ {
|
||||
<-done
|
||||
}
|
||||
}
|
||||
|
||||
func TestCalculator_DecayPrecision(t *testing.T) {
|
||||
calc := NewCalculator(nil)
|
||||
now := time.Now()
|
||||
|
||||
// Test that decay is mathematically correct
|
||||
testCases := []struct {
|
||||
days int
|
||||
expectedDecay float64
|
||||
}{
|
||||
{0, 1.0},
|
||||
{7, 0.5},
|
||||
{14, 0.25},
|
||||
{21, 0.125},
|
||||
{28, 0.0625},
|
||||
}
|
||||
|
||||
for _, tc := range testCases {
|
||||
t.Run(string(rune('0'+tc.days/7))+"_half_lives", func(t *testing.T) {
|
||||
obs := &models.Observation{
|
||||
ID: 1,
|
||||
Type: models.ObsTypeRefactor, // 1.0 weight
|
||||
CreatedAtEpoch: now.Add(-time.Duration(tc.days) * 24 * time.Hour).UnixMilli(),
|
||||
}
|
||||
components := calc.CalculateComponents(obs, now)
|
||||
assert.InDelta(t, tc.expectedDecay, components.RecencyDecay, 0.001)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestTypeBaseScore_UnknownType(t *testing.T) {
|
||||
score := models.TypeBaseScore("unknown-type")
|
||||
assert.Equal(t, 1.0, score, "unknown type should default to 1.0")
|
||||
}
|
||||
|
||||
func TestTypeBaseScore_AllKnownTypes(t *testing.T) {
|
||||
expected := map[models.ObservationType]float64{
|
||||
models.ObsTypeBugfix: 1.3,
|
||||
models.ObsTypeFeature: 1.2,
|
||||
models.ObsTypeDiscovery: 1.1,
|
||||
models.ObsTypeDecision: 1.1,
|
||||
models.ObsTypeRefactor: 1.0,
|
||||
models.ObsTypeChange: 0.9,
|
||||
}
|
||||
|
||||
for obsType, expectedScore := range expected {
|
||||
t.Run(string(obsType), func(t *testing.T) {
|
||||
score := models.TypeBaseScore(obsType)
|
||||
assert.Equal(t, expectedScore, score)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestCalculator_RetrievalBoostDiminishingReturns(t *testing.T) {
|
||||
calc := NewCalculator(nil)
|
||||
now := time.Now()
|
||||
|
||||
// Test that retrieval boost has diminishing returns
|
||||
// When retrieval count doubles, the boost should NOT double (log2 gives diminishing returns)
|
||||
|
||||
// Collect boosts for different counts
|
||||
boosts := make([]float64, 0)
|
||||
retrievalCounts := []int{1, 3, 7, 15, 31, 63, 127}
|
||||
|
||||
for _, count := range retrievalCounts {
|
||||
obs := &models.Observation{
|
||||
ID: 1,
|
||||
Type: models.ObsTypeRefactor,
|
||||
CreatedAtEpoch: now.UnixMilli(),
|
||||
RetrievalCount: count,
|
||||
}
|
||||
components := calc.CalculateComponents(obs, now)
|
||||
boosts = append(boosts, components.RetrievalContrib)
|
||||
}
|
||||
|
||||
// Verify boost increases but at a decreasing rate
|
||||
for i := 1; i < len(boosts); i++ {
|
||||
// Each boost should be higher than the previous
|
||||
assert.Greater(t, boosts[i], boosts[i-1],
|
||||
"boost should increase with more retrievals")
|
||||
|
||||
// But not proportionally - calculate the ratios
|
||||
if i >= 2 {
|
||||
ratio1 := boosts[i-1] / boosts[i-2]
|
||||
ratio2 := boosts[i] / boosts[i-1]
|
||||
// The growth ratio should be decreasing (diminishing returns)
|
||||
assert.Less(t, ratio2, ratio1+0.01, // Allow small floating point tolerance
|
||||
"growth rate should decrease (diminishing returns)")
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,186 @@
|
||||
// Package scoring provides importance score calculation for observations.
|
||||
package scoring
|
||||
|
||||
import (
|
||||
"context"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/rs/zerolog"
|
||||
|
||||
"github.com/lukaszraczylo/claude-mnemonic/internal/db/sqlite"
|
||||
"github.com/lukaszraczylo/claude-mnemonic/pkg/models"
|
||||
)
|
||||
|
||||
// ObservationStore defines the interface for observation storage operations needed by the recalculator.
|
||||
type ObservationStore interface {
|
||||
GetObservationsNeedingScoreUpdate(ctx context.Context, threshold time.Duration, limit int) ([]*models.Observation, error)
|
||||
UpdateImportanceScores(ctx context.Context, scores map[int64]float64) error
|
||||
GetConceptWeights(ctx context.Context) (map[string]float64, error)
|
||||
}
|
||||
|
||||
// Recalculator periodically recalculates importance scores for observations.
|
||||
type Recalculator struct {
|
||||
store ObservationStore
|
||||
calculator *Calculator
|
||||
log zerolog.Logger
|
||||
interval time.Duration
|
||||
batchSize int
|
||||
stopCh chan struct{}
|
||||
doneCh chan struct{}
|
||||
mu sync.Mutex
|
||||
running bool
|
||||
}
|
||||
|
||||
// NewRecalculator creates a new background recalculator.
|
||||
func NewRecalculator(store ObservationStore, calc *Calculator, log zerolog.Logger) *Recalculator {
|
||||
return &Recalculator{
|
||||
store: store,
|
||||
calculator: calc,
|
||||
log: log.With().Str("component", "recalculator").Logger(),
|
||||
interval: 1 * time.Hour, // Run every hour
|
||||
batchSize: 500, // Process 500 observations at a time
|
||||
stopCh: make(chan struct{}),
|
||||
doneCh: make(chan struct{}),
|
||||
}
|
||||
}
|
||||
|
||||
// Start begins the background recalculation loop.
|
||||
// This should be called in a goroutine.
|
||||
func (r *Recalculator) Start(ctx context.Context) {
|
||||
r.mu.Lock()
|
||||
if r.running {
|
||||
r.mu.Unlock()
|
||||
return
|
||||
}
|
||||
r.running = true
|
||||
r.mu.Unlock()
|
||||
|
||||
defer func() {
|
||||
r.mu.Lock()
|
||||
r.running = false
|
||||
r.mu.Unlock()
|
||||
close(r.doneCh)
|
||||
}()
|
||||
|
||||
// Initial run
|
||||
r.recalculate(ctx)
|
||||
|
||||
r.mu.Lock()
|
||||
interval := r.interval
|
||||
r.mu.Unlock()
|
||||
|
||||
ticker := time.NewTicker(interval)
|
||||
defer ticker.Stop()
|
||||
|
||||
for {
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
r.log.Info().Msg("recalculator shutting down due to context cancellation")
|
||||
return
|
||||
case <-r.stopCh:
|
||||
r.log.Info().Msg("recalculator stopping")
|
||||
return
|
||||
case <-ticker.C:
|
||||
r.recalculate(ctx)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Stop stops the background recalculation loop.
|
||||
func (r *Recalculator) Stop() {
|
||||
r.mu.Lock()
|
||||
if !r.running {
|
||||
r.mu.Unlock()
|
||||
return
|
||||
}
|
||||
r.mu.Unlock()
|
||||
|
||||
close(r.stopCh)
|
||||
<-r.doneCh
|
||||
}
|
||||
|
||||
// recalculate performs a single recalculation batch.
|
||||
func (r *Recalculator) recalculate(ctx context.Context) {
|
||||
now := time.Now()
|
||||
threshold := r.calculator.RecalculateThreshold()
|
||||
|
||||
r.mu.Lock()
|
||||
batchSize := r.batchSize
|
||||
r.mu.Unlock()
|
||||
|
||||
observations, err := r.store.GetObservationsNeedingScoreUpdate(ctx, threshold, batchSize)
|
||||
if err != nil {
|
||||
r.log.Error().Err(err).Msg("failed to get observations for score update")
|
||||
return
|
||||
}
|
||||
|
||||
if len(observations) == 0 {
|
||||
return
|
||||
}
|
||||
|
||||
scores := r.calculator.BatchCalculate(observations, now)
|
||||
|
||||
if err := r.store.UpdateImportanceScores(ctx, scores); err != nil {
|
||||
r.log.Error().Err(err).Msg("failed to update importance scores")
|
||||
return
|
||||
}
|
||||
|
||||
r.log.Info().
|
||||
Int("count", len(scores)).
|
||||
Dur("elapsed", time.Since(now)).
|
||||
Msg("recalculated importance scores")
|
||||
}
|
||||
|
||||
// RecalculateNow triggers an immediate recalculation.
|
||||
// This is useful for testing or when scores need to be updated urgently.
|
||||
func (r *Recalculator) RecalculateNow(ctx context.Context) error {
|
||||
r.recalculate(ctx)
|
||||
return nil
|
||||
}
|
||||
|
||||
// RefreshConceptWeights reloads concept weights from the database.
|
||||
// Call this after updating concept weights to apply changes.
|
||||
func (r *Recalculator) RefreshConceptWeights(ctx context.Context) error {
|
||||
weights, err := r.store.GetConceptWeights(ctx)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
config := r.calculator.GetConfig()
|
||||
config.ConceptWeights = weights
|
||||
r.calculator.UpdateConfig(config)
|
||||
|
||||
r.log.Info().Int("count", len(weights)).Msg("refreshed concept weights")
|
||||
return nil
|
||||
}
|
||||
|
||||
// Stats returns statistics about the recalculator.
|
||||
type Stats struct {
|
||||
Running bool `json:"running"`
|
||||
Interval time.Duration `json:"interval"`
|
||||
BatchSize int `json:"batch_size"`
|
||||
HalfLife float64 `json:"half_life_days"`
|
||||
MinScore float64 `json:"min_score"`
|
||||
ConceptsLen int `json:"concepts_count"`
|
||||
}
|
||||
|
||||
// GetStats returns current recalculator statistics.
|
||||
func (r *Recalculator) GetStats() Stats {
|
||||
r.mu.Lock()
|
||||
defer r.mu.Unlock()
|
||||
|
||||
config := r.calculator.GetConfig()
|
||||
|
||||
return Stats{
|
||||
Running: r.running,
|
||||
Interval: r.interval,
|
||||
BatchSize: r.batchSize,
|
||||
HalfLife: config.RecencyHalfLifeDays,
|
||||
MinScore: config.MinScore,
|
||||
ConceptsLen: len(config.ConceptWeights),
|
||||
}
|
||||
}
|
||||
|
||||
// Ensure ObservationStore satisfies the interface
|
||||
var _ ObservationStore = (*sqlite.ObservationStore)(nil)
|
||||
@@ -0,0 +1,447 @@
|
||||
// Package scoring provides importance score calculation for observations.
|
||||
package scoring
|
||||
|
||||
import (
|
||||
"context"
|
||||
"sync"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/rs/zerolog"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
|
||||
"github.com/lukaszraczylo/claude-mnemonic/pkg/models"
|
||||
)
|
||||
|
||||
// MockObservationStore is a mock implementation of ObservationStore for testing.
|
||||
type MockObservationStore struct {
|
||||
mu sync.Mutex
|
||||
observations []*models.Observation
|
||||
scores map[int64]float64
|
||||
conceptWeights map[string]float64
|
||||
updateErr error
|
||||
getErr error
|
||||
getConceptsErr error
|
||||
updateScoresCalls int
|
||||
}
|
||||
|
||||
func NewMockObservationStore() *MockObservationStore {
|
||||
return &MockObservationStore{
|
||||
observations: []*models.Observation{},
|
||||
scores: make(map[int64]float64),
|
||||
conceptWeights: make(map[string]float64),
|
||||
}
|
||||
}
|
||||
|
||||
func (m *MockObservationStore) GetObservationsNeedingScoreUpdate(ctx context.Context, threshold time.Duration, limit int) ([]*models.Observation, error) {
|
||||
m.mu.Lock()
|
||||
defer m.mu.Unlock()
|
||||
|
||||
if m.getErr != nil {
|
||||
return nil, m.getErr
|
||||
}
|
||||
|
||||
// Return observations that haven't been updated within threshold
|
||||
now := time.Now()
|
||||
var result []*models.Observation
|
||||
for _, obs := range m.observations {
|
||||
if !obs.ScoreUpdatedAt.Valid || now.Sub(time.Unix(obs.ScoreUpdatedAt.Int64, 0)) > threshold {
|
||||
result = append(result, obs)
|
||||
if len(result) >= limit {
|
||||
break
|
||||
}
|
||||
}
|
||||
}
|
||||
return result, nil
|
||||
}
|
||||
|
||||
func (m *MockObservationStore) UpdateImportanceScores(ctx context.Context, scores map[int64]float64) error {
|
||||
m.mu.Lock()
|
||||
defer m.mu.Unlock()
|
||||
|
||||
m.updateScoresCalls++
|
||||
if m.updateErr != nil {
|
||||
return m.updateErr
|
||||
}
|
||||
|
||||
for id, score := range scores {
|
||||
m.scores[id] = score
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (m *MockObservationStore) GetConceptWeights(ctx context.Context) (map[string]float64, error) {
|
||||
m.mu.Lock()
|
||||
defer m.mu.Unlock()
|
||||
|
||||
if m.getConceptsErr != nil {
|
||||
return nil, m.getConceptsErr
|
||||
}
|
||||
return m.conceptWeights, nil
|
||||
}
|
||||
|
||||
func (m *MockObservationStore) AddObservation(obs *models.Observation) {
|
||||
m.mu.Lock()
|
||||
defer m.mu.Unlock()
|
||||
m.observations = append(m.observations, obs)
|
||||
}
|
||||
|
||||
func (m *MockObservationStore) SetConceptWeights(weights map[string]float64) {
|
||||
m.mu.Lock()
|
||||
defer m.mu.Unlock()
|
||||
m.conceptWeights = weights
|
||||
}
|
||||
|
||||
func (m *MockObservationStore) GetScore(id int64) (float64, bool) {
|
||||
m.mu.Lock()
|
||||
defer m.mu.Unlock()
|
||||
score, ok := m.scores[id]
|
||||
return score, ok
|
||||
}
|
||||
|
||||
func (m *MockObservationStore) GetUpdateScoresCalls() int {
|
||||
m.mu.Lock()
|
||||
defer m.mu.Unlock()
|
||||
return m.updateScoresCalls
|
||||
}
|
||||
|
||||
func (m *MockObservationStore) SetUpdateError(err error) {
|
||||
m.mu.Lock()
|
||||
defer m.mu.Unlock()
|
||||
m.updateErr = err
|
||||
}
|
||||
|
||||
func (m *MockObservationStore) SetGetError(err error) {
|
||||
m.mu.Lock()
|
||||
defer m.mu.Unlock()
|
||||
m.getErr = err
|
||||
}
|
||||
|
||||
// =============================================================================
|
||||
// RECALCULATOR TESTS
|
||||
// =============================================================================
|
||||
|
||||
func TestNewRecalculator(t *testing.T) {
|
||||
store := NewMockObservationStore()
|
||||
calc := NewCalculator(nil)
|
||||
log := zerolog.Nop()
|
||||
|
||||
recalc := NewRecalculator(store, calc, log)
|
||||
|
||||
require.NotNil(t, recalc)
|
||||
assert.NotNil(t, recalc.store)
|
||||
assert.NotNil(t, recalc.calculator)
|
||||
assert.Equal(t, 1*time.Hour, recalc.interval)
|
||||
assert.Equal(t, 500, recalc.batchSize)
|
||||
assert.False(t, recalc.running)
|
||||
}
|
||||
|
||||
func TestRecalculator_RecalculateNow(t *testing.T) {
|
||||
store := NewMockObservationStore()
|
||||
calc := NewCalculator(nil)
|
||||
log := zerolog.Nop()
|
||||
recalc := NewRecalculator(store, calc, log)
|
||||
|
||||
// Add observations
|
||||
now := time.Now()
|
||||
store.AddObservation(&models.Observation{
|
||||
ID: 1,
|
||||
Type: models.ObsTypeBugfix,
|
||||
CreatedAtEpoch: now.UnixMilli(),
|
||||
})
|
||||
store.AddObservation(&models.Observation{
|
||||
ID: 2,
|
||||
Type: models.ObsTypeFeature,
|
||||
CreatedAtEpoch: now.Add(-7 * 24 * time.Hour).UnixMilli(),
|
||||
})
|
||||
|
||||
ctx := context.Background()
|
||||
err := recalc.RecalculateNow(ctx)
|
||||
|
||||
require.NoError(t, err)
|
||||
|
||||
// Verify scores were calculated
|
||||
score1, ok := store.GetScore(1)
|
||||
assert.True(t, ok)
|
||||
assert.Greater(t, score1, 0.0)
|
||||
|
||||
score2, ok := store.GetScore(2)
|
||||
assert.True(t, ok)
|
||||
assert.Greater(t, score2, 0.0)
|
||||
assert.Less(t, score2, score1, "older observation should have lower score")
|
||||
}
|
||||
|
||||
func TestRecalculator_RefreshConceptWeights(t *testing.T) {
|
||||
store := NewMockObservationStore()
|
||||
calc := NewCalculator(nil)
|
||||
log := zerolog.Nop()
|
||||
recalc := NewRecalculator(store, calc, log)
|
||||
|
||||
// Set up custom weights in store
|
||||
customWeights := map[string]float64{
|
||||
"security": 0.50,
|
||||
"performance": 0.25,
|
||||
}
|
||||
store.SetConceptWeights(customWeights)
|
||||
|
||||
ctx := context.Background()
|
||||
err := recalc.RefreshConceptWeights(ctx)
|
||||
|
||||
require.NoError(t, err)
|
||||
|
||||
// Verify weights were updated in calculator config
|
||||
config := calc.GetConfig()
|
||||
assert.Equal(t, 0.50, config.ConceptWeights["security"])
|
||||
assert.Equal(t, 0.25, config.ConceptWeights["performance"])
|
||||
}
|
||||
|
||||
func TestRecalculator_RefreshConceptWeights_Error(t *testing.T) {
|
||||
store := NewMockObservationStore()
|
||||
calc := NewCalculator(nil)
|
||||
log := zerolog.Nop()
|
||||
recalc := NewRecalculator(store, calc, log)
|
||||
|
||||
store.getConceptsErr = assert.AnError
|
||||
|
||||
ctx := context.Background()
|
||||
err := recalc.RefreshConceptWeights(ctx)
|
||||
|
||||
require.Error(t, err)
|
||||
}
|
||||
|
||||
func TestRecalculator_GetStats(t *testing.T) {
|
||||
store := NewMockObservationStore()
|
||||
calc := NewCalculator(nil)
|
||||
log := zerolog.Nop()
|
||||
recalc := NewRecalculator(store, calc, log)
|
||||
|
||||
// Set fields directly for testing
|
||||
recalc.interval = 2 * time.Hour
|
||||
recalc.batchSize = 250
|
||||
|
||||
stats := recalc.GetStats()
|
||||
|
||||
assert.False(t, stats.Running)
|
||||
assert.Equal(t, 2*time.Hour, stats.Interval)
|
||||
assert.Equal(t, 250, stats.BatchSize)
|
||||
assert.Equal(t, 7.0, stats.HalfLife)
|
||||
assert.Equal(t, 0.01, stats.MinScore)
|
||||
}
|
||||
|
||||
func TestRecalculator_StartStop(t *testing.T) {
|
||||
store := NewMockObservationStore()
|
||||
calc := NewCalculator(nil)
|
||||
log := zerolog.Nop()
|
||||
recalc := NewRecalculator(store, calc, log)
|
||||
|
||||
// Use a short interval for testing
|
||||
recalc.interval = 50 * time.Millisecond
|
||||
|
||||
// Add an observation
|
||||
store.AddObservation(&models.Observation{
|
||||
ID: 1,
|
||||
Type: models.ObsTypeBugfix,
|
||||
CreatedAtEpoch: time.Now().UnixMilli(),
|
||||
})
|
||||
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
defer cancel()
|
||||
|
||||
// Start in goroutine
|
||||
go recalc.Start(ctx)
|
||||
|
||||
// Wait for initial run and at least one tick
|
||||
time.Sleep(100 * time.Millisecond)
|
||||
|
||||
// Verify it ran
|
||||
calls := store.GetUpdateScoresCalls()
|
||||
assert.GreaterOrEqual(t, calls, 1, "should have run at least once")
|
||||
|
||||
// Stop via context cancellation
|
||||
cancel()
|
||||
time.Sleep(100 * time.Millisecond)
|
||||
|
||||
// Verify stopped
|
||||
stats := recalc.GetStats()
|
||||
assert.False(t, stats.Running)
|
||||
}
|
||||
|
||||
func TestRecalculator_StartTwice(t *testing.T) {
|
||||
store := NewMockObservationStore()
|
||||
calc := NewCalculator(nil)
|
||||
log := zerolog.Nop()
|
||||
recalc := NewRecalculator(store, calc, log)
|
||||
|
||||
recalc.interval = 1 * time.Hour // Long interval so it doesn't tick during test
|
||||
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
defer cancel()
|
||||
|
||||
// Start first time
|
||||
go recalc.Start(ctx)
|
||||
time.Sleep(50 * time.Millisecond)
|
||||
|
||||
// Try to start second time (should return immediately)
|
||||
go recalc.Start(ctx)
|
||||
time.Sleep(50 * time.Millisecond)
|
||||
|
||||
// Should still be running (only once)
|
||||
stats := recalc.GetStats()
|
||||
assert.True(t, stats.Running)
|
||||
|
||||
cancel()
|
||||
}
|
||||
|
||||
func TestRecalculator_StopWhenNotRunning(t *testing.T) {
|
||||
store := NewMockObservationStore()
|
||||
calc := NewCalculator(nil)
|
||||
log := zerolog.Nop()
|
||||
recalc := NewRecalculator(store, calc, log)
|
||||
|
||||
// Stop without starting - should not panic
|
||||
recalc.Stop()
|
||||
|
||||
stats := recalc.GetStats()
|
||||
assert.False(t, stats.Running)
|
||||
}
|
||||
|
||||
func TestRecalculator_EmptyStore(t *testing.T) {
|
||||
store := NewMockObservationStore()
|
||||
calc := NewCalculator(nil)
|
||||
log := zerolog.Nop()
|
||||
recalc := NewRecalculator(store, calc, log)
|
||||
|
||||
ctx := context.Background()
|
||||
err := recalc.RecalculateNow(ctx)
|
||||
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, 0, store.GetUpdateScoresCalls(), "should not call update with no observations")
|
||||
}
|
||||
|
||||
func TestRecalculator_GetObservationsError(t *testing.T) {
|
||||
store := NewMockObservationStore()
|
||||
calc := NewCalculator(nil)
|
||||
log := zerolog.Nop()
|
||||
recalc := NewRecalculator(store, calc, log)
|
||||
|
||||
store.SetGetError(assert.AnError)
|
||||
|
||||
ctx := context.Background()
|
||||
err := recalc.RecalculateNow(ctx)
|
||||
|
||||
// Should not return error (logs it instead)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, 0, store.GetUpdateScoresCalls())
|
||||
}
|
||||
|
||||
func TestRecalculator_UpdateScoresError(t *testing.T) {
|
||||
store := NewMockObservationStore()
|
||||
calc := NewCalculator(nil)
|
||||
log := zerolog.Nop()
|
||||
recalc := NewRecalculator(store, calc, log)
|
||||
|
||||
store.AddObservation(&models.Observation{
|
||||
ID: 1,
|
||||
Type: models.ObsTypeBugfix,
|
||||
CreatedAtEpoch: time.Now().UnixMilli(),
|
||||
})
|
||||
|
||||
store.SetUpdateError(assert.AnError)
|
||||
|
||||
ctx := context.Background()
|
||||
err := recalc.RecalculateNow(ctx)
|
||||
|
||||
// Should not return error (logs it instead)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, 1, store.GetUpdateScoresCalls())
|
||||
}
|
||||
|
||||
func TestRecalculator_BatchProcessing(t *testing.T) {
|
||||
store := NewMockObservationStore()
|
||||
calc := NewCalculator(nil)
|
||||
log := zerolog.Nop()
|
||||
recalc := NewRecalculator(store, calc, log)
|
||||
|
||||
// Set small batch size
|
||||
recalc.batchSize = 3
|
||||
|
||||
// Add 5 observations
|
||||
now := time.Now()
|
||||
for i := 1; i <= 5; i++ {
|
||||
store.AddObservation(&models.Observation{
|
||||
ID: int64(i),
|
||||
Type: models.ObsTypeBugfix,
|
||||
CreatedAtEpoch: now.UnixMilli(),
|
||||
})
|
||||
}
|
||||
|
||||
ctx := context.Background()
|
||||
err := recalc.RecalculateNow(ctx)
|
||||
|
||||
require.NoError(t, err)
|
||||
|
||||
// Should only process batch size (3)
|
||||
scores := 0
|
||||
for i := 1; i <= 5; i++ {
|
||||
if _, ok := store.GetScore(int64(i)); ok {
|
||||
scores++
|
||||
}
|
||||
}
|
||||
assert.Equal(t, 3, scores, "should only process batch size observations")
|
||||
}
|
||||
|
||||
func TestRecalculator_ConcurrentAccess(t *testing.T) {
|
||||
store := NewMockObservationStore()
|
||||
calc := NewCalculator(nil)
|
||||
log := zerolog.Nop()
|
||||
recalc := NewRecalculator(store, calc, log)
|
||||
|
||||
// Add observations
|
||||
now := time.Now()
|
||||
for i := 1; i <= 10; i++ {
|
||||
store.AddObservation(&models.Observation{
|
||||
ID: int64(i),
|
||||
Type: models.ObsTypeBugfix,
|
||||
CreatedAtEpoch: now.UnixMilli(),
|
||||
})
|
||||
}
|
||||
|
||||
ctx := context.Background()
|
||||
var wg sync.WaitGroup
|
||||
|
||||
// Run multiple recalculations concurrently
|
||||
for i := 0; i < 5; i++ {
|
||||
wg.Add(1)
|
||||
go func() {
|
||||
defer wg.Done()
|
||||
_ = recalc.RecalculateNow(ctx)
|
||||
}()
|
||||
}
|
||||
|
||||
wg.Wait()
|
||||
|
||||
// Should complete without race conditions
|
||||
// (use -race flag to verify)
|
||||
assert.GreaterOrEqual(t, store.GetUpdateScoresCalls(), 1)
|
||||
}
|
||||
|
||||
func TestRecalculator_StatsThreadSafe(t *testing.T) {
|
||||
store := NewMockObservationStore()
|
||||
calc := NewCalculator(nil)
|
||||
log := zerolog.Nop()
|
||||
recalc := NewRecalculator(store, calc, log)
|
||||
|
||||
var wg sync.WaitGroup
|
||||
|
||||
// Concurrent reads (use -race flag to verify)
|
||||
for i := 0; i < 10; i++ {
|
||||
wg.Add(1)
|
||||
go func() {
|
||||
defer wg.Done()
|
||||
_ = recalc.GetStats()
|
||||
}()
|
||||
}
|
||||
|
||||
wg.Wait()
|
||||
}
|
||||
@@ -0,0 +1,448 @@
|
||||
// Package expansion provides context-aware query expansion for improved search recall.
|
||||
package expansion
|
||||
|
||||
import (
|
||||
"context"
|
||||
"regexp"
|
||||
"strings"
|
||||
"sync"
|
||||
|
||||
"github.com/lukaszraczylo/claude-mnemonic/internal/embedding"
|
||||
"github.com/rs/zerolog/log"
|
||||
)
|
||||
|
||||
// QueryIntent represents the detected intent of a query.
|
||||
type QueryIntent string
|
||||
|
||||
const (
|
||||
// IntentQuestion indicates a question-type query (how, why, what, etc.)
|
||||
IntentQuestion QueryIntent = "question"
|
||||
// IntentError indicates an error/debugging query
|
||||
IntentError QueryIntent = "error"
|
||||
// IntentImplementation indicates an implementation/coding query
|
||||
IntentImplementation QueryIntent = "implementation"
|
||||
// IntentArchitecture indicates an architecture/design query
|
||||
IntentArchitecture QueryIntent = "architecture"
|
||||
// IntentGeneral indicates a general lookup query
|
||||
IntentGeneral QueryIntent = "general"
|
||||
)
|
||||
|
||||
// ExpandedQuery represents a query variant with metadata.
|
||||
type ExpandedQuery struct {
|
||||
Query string `json:"query"`
|
||||
Weight float64 `json:"weight"` // Weight for result merging (0.0-1.0)
|
||||
Source string `json:"source"` // Where this expansion came from
|
||||
Intent QueryIntent `json:"intent"` // Detected intent
|
||||
}
|
||||
|
||||
// Expander provides context-aware query expansion.
|
||||
type Expander struct {
|
||||
embedSvc *embedding.Service
|
||||
vocabulary []VocabEntry // Known vocabulary from observations
|
||||
vocabVectors [][]float32 // Embeddings for vocabulary entries
|
||||
vocabMu sync.RWMutex // Protects vocabulary
|
||||
intentPatterns map[QueryIntent][]*regexp.Regexp
|
||||
}
|
||||
|
||||
// VocabEntry represents a vocabulary term from observations.
|
||||
type VocabEntry struct {
|
||||
Term string // The term itself
|
||||
Weight float64 // How common/important this term is (0.0-1.0)
|
||||
Source string // Where it came from (title, concept, narrative)
|
||||
}
|
||||
|
||||
// Config holds expander configuration.
|
||||
type Config struct {
|
||||
// MaxExpansions limits the number of expanded queries returned
|
||||
MaxExpansions int
|
||||
// MinSimilarity is the minimum similarity score for vocabulary expansion
|
||||
MinSimilarity float64
|
||||
// EnableVocabularyExpansion enables finding related terms from observations
|
||||
EnableVocabularyExpansion bool
|
||||
}
|
||||
|
||||
// DefaultConfig returns sensible default configuration.
|
||||
func DefaultConfig() Config {
|
||||
return Config{
|
||||
MaxExpansions: 4,
|
||||
MinSimilarity: 0.5,
|
||||
EnableVocabularyExpansion: true,
|
||||
}
|
||||
}
|
||||
|
||||
// NewExpander creates a new query expander.
|
||||
func NewExpander(embedSvc *embedding.Service) *Expander {
|
||||
e := &Expander{
|
||||
embedSvc: embedSvc,
|
||||
intentPatterns: buildIntentPatterns(),
|
||||
}
|
||||
return e
|
||||
}
|
||||
|
||||
// buildIntentPatterns creates regex patterns for intent detection.
|
||||
func buildIntentPatterns() map[QueryIntent][]*regexp.Regexp {
|
||||
patterns := make(map[QueryIntent][]*regexp.Regexp)
|
||||
|
||||
// Question patterns
|
||||
patterns[IntentQuestion] = []*regexp.Regexp{
|
||||
regexp.MustCompile(`(?i)^(how|why|what|when|where|which|who)\b`),
|
||||
regexp.MustCompile(`(?i)\?$`),
|
||||
regexp.MustCompile(`(?i)\b(explain|describe|understand)\b`),
|
||||
}
|
||||
|
||||
// Error/debugging patterns
|
||||
patterns[IntentError] = []*regexp.Regexp{
|
||||
regexp.MustCompile(`(?i)\b(error|bug|issue|problem|fail|crash|exception|panic)\b`),
|
||||
regexp.MustCompile(`(?i)\b(fix|debug|troubleshoot|resolve)\b`),
|
||||
regexp.MustCompile(`(?i)\b(doesn't work|not working|broken)\b`),
|
||||
}
|
||||
|
||||
// Implementation patterns
|
||||
patterns[IntentImplementation] = []*regexp.Regexp{
|
||||
regexp.MustCompile(`(?i)\b(implement|add|create|build|write|code)\b`),
|
||||
regexp.MustCompile(`(?i)\b(function|method|handler|endpoint|api)\b`),
|
||||
regexp.MustCompile(`(?i)\b(feature|functionality)\b`),
|
||||
}
|
||||
|
||||
// Architecture patterns
|
||||
patterns[IntentArchitecture] = []*regexp.Regexp{
|
||||
regexp.MustCompile(`(?i)\b(architecture|design|pattern|structure)\b`),
|
||||
regexp.MustCompile(`(?i)\b(component|module|layer|service)\b`),
|
||||
regexp.MustCompile(`(?i)\b(flow|pipeline|workflow)\b`),
|
||||
}
|
||||
|
||||
return patterns
|
||||
}
|
||||
|
||||
// DetectIntent analyzes a query to determine its intent.
|
||||
func (e *Expander) DetectIntent(query string) QueryIntent {
|
||||
query = strings.TrimSpace(query)
|
||||
if query == "" {
|
||||
return IntentGeneral
|
||||
}
|
||||
|
||||
// Check patterns in priority order
|
||||
intentOrder := []QueryIntent{IntentError, IntentQuestion, IntentImplementation, IntentArchitecture}
|
||||
|
||||
for _, intent := range intentOrder {
|
||||
patterns := e.intentPatterns[intent]
|
||||
for _, pattern := range patterns {
|
||||
if pattern.MatchString(query) {
|
||||
return intent
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return IntentGeneral
|
||||
}
|
||||
|
||||
// Expand generates expanded query variants based on the original query.
|
||||
func (e *Expander) Expand(ctx context.Context, query string, cfg Config) []ExpandedQuery {
|
||||
query = strings.TrimSpace(query)
|
||||
if query == "" {
|
||||
return nil
|
||||
}
|
||||
|
||||
intent := e.DetectIntent(query)
|
||||
expansions := make([]ExpandedQuery, 0, cfg.MaxExpansions)
|
||||
|
||||
// Always include the original query with highest weight
|
||||
expansions = append(expansions, ExpandedQuery{
|
||||
Query: query,
|
||||
Weight: 1.0,
|
||||
Source: "original",
|
||||
Intent: intent,
|
||||
})
|
||||
|
||||
// Generate intent-based expansions
|
||||
intentExpansions := e.expandByIntent(query, intent)
|
||||
expansions = append(expansions, intentExpansions...)
|
||||
|
||||
// Generate vocabulary-based expansions if enabled and we have vocabulary
|
||||
if cfg.EnableVocabularyExpansion && e.embedSvc != nil && len(e.vocabulary) > 0 {
|
||||
vocabExpansions := e.expandByVocabulary(ctx, query, cfg.MinSimilarity)
|
||||
expansions = append(expansions, vocabExpansions...)
|
||||
}
|
||||
|
||||
// Deduplicate and limit
|
||||
expansions = deduplicateExpansions(expansions)
|
||||
if len(expansions) > cfg.MaxExpansions {
|
||||
expansions = expansions[:cfg.MaxExpansions]
|
||||
}
|
||||
|
||||
log.Debug().
|
||||
Str("query", truncate(query, 50)).
|
||||
Str("intent", string(intent)).
|
||||
Int("expansions", len(expansions)).
|
||||
Msg("Query expanded")
|
||||
|
||||
return expansions
|
||||
}
|
||||
|
||||
// expandByIntent generates expansions based on detected query intent.
|
||||
func (e *Expander) expandByIntent(query string, intent QueryIntent) []ExpandedQuery {
|
||||
var expansions []ExpandedQuery
|
||||
|
||||
// Extract key terms from query for context-aware expansion
|
||||
keyTerms := extractKeyTerms(query)
|
||||
|
||||
switch intent {
|
||||
case IntentQuestion:
|
||||
// For questions, create a declarative variant
|
||||
declarative := makeDeclarative(query)
|
||||
if declarative != query {
|
||||
expansions = append(expansions, ExpandedQuery{
|
||||
Query: declarative,
|
||||
Weight: 0.85,
|
||||
Source: "intent:declarative",
|
||||
Intent: intent,
|
||||
})
|
||||
}
|
||||
|
||||
case IntentError:
|
||||
// For errors, expand with solution-oriented terms
|
||||
if len(keyTerms) > 0 {
|
||||
solutionQuery := strings.Join(keyTerms, " ") + " solution fix"
|
||||
expansions = append(expansions, ExpandedQuery{
|
||||
Query: solutionQuery,
|
||||
Weight: 0.8,
|
||||
Source: "intent:solution",
|
||||
Intent: intent,
|
||||
})
|
||||
}
|
||||
|
||||
case IntentImplementation:
|
||||
// For implementation queries, focus on the what/how
|
||||
if len(keyTerms) > 0 {
|
||||
howQuery := "how " + strings.Join(keyTerms, " ")
|
||||
expansions = append(expansions, ExpandedQuery{
|
||||
Query: howQuery,
|
||||
Weight: 0.75,
|
||||
Source: "intent:how",
|
||||
Intent: intent,
|
||||
})
|
||||
}
|
||||
|
||||
case IntentArchitecture:
|
||||
// For architecture queries, expand with design context
|
||||
if len(keyTerms) > 0 {
|
||||
designQuery := strings.Join(keyTerms, " ") + " design structure"
|
||||
expansions = append(expansions, ExpandedQuery{
|
||||
Query: designQuery,
|
||||
Weight: 0.75,
|
||||
Source: "intent:design",
|
||||
Intent: intent,
|
||||
})
|
||||
}
|
||||
|
||||
case IntentGeneral:
|
||||
// For general queries, try noun phrase extraction
|
||||
// No additional expansion - rely on vocabulary expansion
|
||||
}
|
||||
|
||||
return expansions
|
||||
}
|
||||
|
||||
// expandByVocabulary finds similar terms from the observation vocabulary.
|
||||
func (e *Expander) expandByVocabulary(ctx context.Context, query string, minSimilarity float64) []ExpandedQuery {
|
||||
e.vocabMu.RLock()
|
||||
defer e.vocabMu.RUnlock()
|
||||
|
||||
if len(e.vocabulary) == 0 || e.embedSvc == nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
// Embed the query
|
||||
queryEmb, err := e.embedSvc.Embed(query)
|
||||
if err != nil {
|
||||
log.Warn().Err(err).Msg("Failed to embed query for vocabulary expansion")
|
||||
return nil
|
||||
}
|
||||
|
||||
// Find similar vocabulary terms
|
||||
type scoredTerm struct {
|
||||
entry VocabEntry
|
||||
score float64
|
||||
}
|
||||
|
||||
var similar []scoredTerm
|
||||
for i, entry := range e.vocabulary {
|
||||
if i >= len(e.vocabVectors) {
|
||||
break
|
||||
}
|
||||
|
||||
score := cosineSimilarity(queryEmb, e.vocabVectors[i])
|
||||
if score >= minSimilarity {
|
||||
similar = append(similar, scoredTerm{entry: entry, score: score})
|
||||
}
|
||||
}
|
||||
|
||||
if len(similar) == 0 {
|
||||
return nil
|
||||
}
|
||||
|
||||
// Sort by score (descending) using bubble sort
|
||||
for i := 0; i < len(similar)-1; i++ {
|
||||
for j := i + 1; j < len(similar); j++ {
|
||||
if similar[j].score > similar[i].score {
|
||||
similar[i], similar[j] = similar[j], similar[i]
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Create expansion by combining top similar terms with query
|
||||
var expansions []ExpandedQuery
|
||||
if len(similar) > 0 {
|
||||
// Take top 2 similar terms and combine with original key terms
|
||||
keyTerms := extractKeyTerms(query)
|
||||
for i := 0; i < min(2, len(similar)); i++ {
|
||||
term := similar[i].entry.Term
|
||||
// Don't add if term is already in query
|
||||
if strings.Contains(strings.ToLower(query), strings.ToLower(term)) {
|
||||
continue
|
||||
}
|
||||
|
||||
combinedQuery := strings.Join(keyTerms, " ") + " " + term
|
||||
expansions = append(expansions, ExpandedQuery{
|
||||
Query: combinedQuery,
|
||||
Weight: 0.7 * similar[i].score * similar[i].entry.Weight,
|
||||
Source: "vocabulary:" + term,
|
||||
Intent: IntentGeneral,
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
return expansions
|
||||
}
|
||||
|
||||
// Helper functions
|
||||
|
||||
// extractKeyTerms extracts meaningful terms from a query.
|
||||
func extractKeyTerms(query string) []string {
|
||||
// Common stop words to filter out
|
||||
stopWords := map[string]bool{
|
||||
"a": true, "an": true, "the": true, "is": true, "are": true,
|
||||
"was": true, "were": true, "be": true, "been": true, "being": true,
|
||||
"have": true, "has": true, "had": true, "do": true, "does": true,
|
||||
"did": true, "will": true, "would": true, "could": true, "should": true,
|
||||
"may": true, "might": true, "must": true, "can": true,
|
||||
"i": true, "me": true, "my": true, "we": true, "our": true,
|
||||
"you": true, "your": true, "it": true, "its": true,
|
||||
"this": true, "that": true, "these": true, "those": true,
|
||||
"what": true, "which": true, "who": true, "whom": true,
|
||||
"how": true, "why": true, "when": true, "where": true,
|
||||
"to": true, "for": true, "with": true, "about": true, "from": true,
|
||||
"in": true, "on": true, "at": true, "by": true, "of": true,
|
||||
"and": true, "or": true, "but": true, "if": true, "then": true,
|
||||
}
|
||||
|
||||
// Split and filter
|
||||
words := strings.Fields(strings.ToLower(query))
|
||||
var terms []string
|
||||
|
||||
for _, word := range words {
|
||||
// Remove punctuation
|
||||
word = strings.Trim(word, ".,?!;:'\"()[]{}")
|
||||
if len(word) < 2 {
|
||||
continue
|
||||
}
|
||||
if stopWords[word] {
|
||||
continue
|
||||
}
|
||||
terms = append(terms, word)
|
||||
}
|
||||
|
||||
return terms
|
||||
}
|
||||
|
||||
// makeDeclarative converts a question to a declarative statement.
|
||||
func makeDeclarative(query string) string {
|
||||
query = strings.TrimSpace(query)
|
||||
|
||||
// Remove question mark
|
||||
query = strings.TrimSuffix(query, "?")
|
||||
|
||||
// Handle common question patterns
|
||||
patterns := []struct {
|
||||
prefix string
|
||||
replacement string
|
||||
}{
|
||||
{"how do i ", ""},
|
||||
{"how to ", ""},
|
||||
{"how does ", ""},
|
||||
{"how is ", ""},
|
||||
{"what is ", ""},
|
||||
{"what are ", ""},
|
||||
{"why does ", ""},
|
||||
{"why is ", ""},
|
||||
{"where is ", ""},
|
||||
{"where are ", ""},
|
||||
{"when does ", ""},
|
||||
{"when is ", ""},
|
||||
}
|
||||
|
||||
lower := strings.ToLower(query)
|
||||
for _, p := range patterns {
|
||||
if strings.HasPrefix(lower, p.prefix) {
|
||||
return strings.TrimSpace(query[len(p.prefix):])
|
||||
}
|
||||
}
|
||||
|
||||
return query
|
||||
}
|
||||
|
||||
// deduplicateExpansions removes duplicate queries while preserving order.
|
||||
func deduplicateExpansions(expansions []ExpandedQuery) []ExpandedQuery {
|
||||
seen := make(map[string]bool)
|
||||
result := make([]ExpandedQuery, 0, len(expansions))
|
||||
|
||||
for _, exp := range expansions {
|
||||
normalized := strings.ToLower(strings.TrimSpace(exp.Query))
|
||||
if !seen[normalized] {
|
||||
seen[normalized] = true
|
||||
result = append(result, exp)
|
||||
}
|
||||
}
|
||||
|
||||
return result
|
||||
}
|
||||
|
||||
// cosineSimilarity computes cosine similarity between two vectors.
|
||||
func cosineSimilarity(a, b []float32) float64 {
|
||||
if len(a) != len(b) || len(a) == 0 {
|
||||
return 0
|
||||
}
|
||||
|
||||
var dot, normA, normB float64
|
||||
for i := range a {
|
||||
dot += float64(a[i]) * float64(b[i])
|
||||
normA += float64(a[i]) * float64(a[i])
|
||||
normB += float64(b[i]) * float64(b[i])
|
||||
}
|
||||
|
||||
if normA == 0 || normB == 0 {
|
||||
return 0
|
||||
}
|
||||
|
||||
return dot / (sqrt(normA) * sqrt(normB))
|
||||
}
|
||||
|
||||
// sqrt is a simple square root implementation.
|
||||
func sqrt(x float64) float64 {
|
||||
if x <= 0 {
|
||||
return 0
|
||||
}
|
||||
z := x
|
||||
for i := 0; i < 10; i++ {
|
||||
z = (z + x/z) / 2
|
||||
}
|
||||
return z
|
||||
}
|
||||
|
||||
// truncate truncates a string to maxLen characters.
|
||||
func truncate(s string, maxLen int) string {
|
||||
if len(s) <= maxLen {
|
||||
return s
|
||||
}
|
||||
return s[:maxLen] + "..."
|
||||
}
|
||||
@@ -0,0 +1,518 @@
|
||||
// Package expansion provides context-aware query expansion for improved search recall.
|
||||
package expansion
|
||||
|
||||
import (
|
||||
"context"
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/suite"
|
||||
)
|
||||
|
||||
// ExpanderSuite tests the Expander functionality.
|
||||
type ExpanderSuite struct {
|
||||
suite.Suite
|
||||
expander *Expander
|
||||
}
|
||||
|
||||
func TestExpanderSuite(t *testing.T) {
|
||||
suite.Run(t, new(ExpanderSuite))
|
||||
}
|
||||
|
||||
func (s *ExpanderSuite) SetupTest() {
|
||||
// Create expander without embedding service for basic tests
|
||||
s.expander = NewExpander(nil)
|
||||
}
|
||||
|
||||
// TestNewExpander tests expander creation.
|
||||
func (s *ExpanderSuite) TestNewExpander() {
|
||||
e := NewExpander(nil)
|
||||
s.NotNil(e)
|
||||
s.NotNil(e.intentPatterns)
|
||||
s.Nil(e.embedSvc)
|
||||
}
|
||||
|
||||
// TestDetectIntent tests intent detection.
|
||||
func (s *ExpanderSuite) TestDetectIntent() {
|
||||
tests := []struct {
|
||||
name string
|
||||
query string
|
||||
expected QueryIntent
|
||||
}{
|
||||
// Question intents
|
||||
{"how_question", "how do I implement auth?", IntentQuestion},
|
||||
{"why_question", "why does this fail?", IntentError}, // "fail" triggers error first
|
||||
{"what_question", "what is the purpose of this function?", IntentQuestion},
|
||||
{"question_mark", "the handler for auth?", IntentQuestion},
|
||||
{"explain", "explain the architecture", IntentQuestion},
|
||||
|
||||
// Error intents
|
||||
{"error_word", "authentication error in login", IntentError},
|
||||
{"bug_word", "bug in user registration", IntentError},
|
||||
{"fix_word", "fix the memory leak", IntentError},
|
||||
{"not_working", "login not working", IntentError},
|
||||
{"crash", "application crash on startup", IntentError},
|
||||
|
||||
// Implementation intents
|
||||
{"implement", "implement user authentication", IntentImplementation},
|
||||
{"add_feature", "add new endpoint for users", IntentImplementation},
|
||||
{"create", "create a handler for uploads", IntentImplementation},
|
||||
{"function", "function to validate input", IntentImplementation},
|
||||
|
||||
// Architecture intents
|
||||
{"architecture", "architecture of the system", IntentArchitecture},
|
||||
{"design", "design pattern for observers", IntentArchitecture},
|
||||
{"component", "component structure", IntentArchitecture},
|
||||
{"flow", "data flow in the pipeline", IntentArchitecture},
|
||||
|
||||
// General intents
|
||||
{"general", "user authentication", IntentGeneral},
|
||||
{"empty", "", IntentGeneral},
|
||||
{"simple", "database", IntentGeneral},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
s.Run(tt.name, func() {
|
||||
result := s.expander.DetectIntent(tt.query)
|
||||
s.Equal(tt.expected, result, "Query: %s", tt.query)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// TestExpand tests basic query expansion.
|
||||
func (s *ExpanderSuite) TestExpand() {
|
||||
ctx := context.Background()
|
||||
cfg := DefaultConfig()
|
||||
cfg.EnableVocabularyExpansion = false // Disable for unit test
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
query string
|
||||
minExpansions int
|
||||
hasOriginal bool
|
||||
expectedIntent QueryIntent
|
||||
}{
|
||||
{"question", "how do I implement auth", 1, true, IntentQuestion},
|
||||
{"error", "fix the bug in login", 1, true, IntentError},
|
||||
{"implementation", "implement user handler", 1, true, IntentImplementation},
|
||||
{"architecture", "architecture design", 1, true, IntentArchitecture},
|
||||
{"general", "database connection", 1, true, IntentGeneral},
|
||||
{"empty", "", 0, false, IntentGeneral},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
s.Run(tt.name, func() {
|
||||
expansions := s.expander.Expand(ctx, tt.query, cfg)
|
||||
|
||||
if tt.minExpansions == 0 {
|
||||
s.Empty(expansions)
|
||||
return
|
||||
}
|
||||
|
||||
s.GreaterOrEqual(len(expansions), tt.minExpansions)
|
||||
|
||||
if tt.hasOriginal {
|
||||
// First expansion should be the original
|
||||
s.Equal(tt.query, expansions[0].Query)
|
||||
s.Equal(1.0, expansions[0].Weight)
|
||||
s.Equal("original", expansions[0].Source)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// TestExpandWithConfig tests expansion with custom config.
|
||||
func (s *ExpanderSuite) TestExpandWithConfig() {
|
||||
ctx := context.Background()
|
||||
|
||||
cfg := Config{
|
||||
MaxExpansions: 2,
|
||||
MinSimilarity: 0.7,
|
||||
EnableVocabularyExpansion: false,
|
||||
}
|
||||
|
||||
expansions := s.expander.Expand(ctx, "how to implement authentication", cfg)
|
||||
s.LessOrEqual(len(expansions), cfg.MaxExpansions)
|
||||
}
|
||||
|
||||
// TestExpandDeduplication tests that duplicates are removed.
|
||||
func (s *ExpanderSuite) TestExpandDeduplication() {
|
||||
ctx := context.Background()
|
||||
cfg := DefaultConfig()
|
||||
cfg.EnableVocabularyExpansion = false
|
||||
|
||||
// Query that might generate duplicate expansions
|
||||
query := "how to fix authentication"
|
||||
expansions := s.expander.Expand(ctx, query, cfg)
|
||||
|
||||
// Check for duplicates
|
||||
seen := make(map[string]bool)
|
||||
for _, exp := range expansions {
|
||||
normalized := exp.Query
|
||||
s.False(seen[normalized], "Duplicate expansion found: %s", exp.Query)
|
||||
seen[normalized] = true
|
||||
}
|
||||
}
|
||||
|
||||
// TestExtractKeyTerms tests key term extraction.
|
||||
func TestExtractKeyTerms(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
query string
|
||||
expected []string
|
||||
}{
|
||||
{
|
||||
name: "simple",
|
||||
query: "user authentication handler",
|
||||
expected: []string{"user", "authentication", "handler"},
|
||||
},
|
||||
{
|
||||
name: "with_stop_words",
|
||||
query: "how to implement the user login",
|
||||
expected: []string{"implement", "user", "login"},
|
||||
},
|
||||
{
|
||||
name: "with_punctuation",
|
||||
query: "fix the bug, please!",
|
||||
expected: []string{"fix", "bug", "please"},
|
||||
},
|
||||
{
|
||||
name: "empty",
|
||||
query: "",
|
||||
expected: nil,
|
||||
},
|
||||
{
|
||||
name: "only_stop_words",
|
||||
query: "the a an is are",
|
||||
expected: nil,
|
||||
},
|
||||
{
|
||||
name: "short_words_filtered",
|
||||
query: "a b c auth",
|
||||
expected: []string{"auth"},
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
result := extractKeyTerms(tt.query)
|
||||
if tt.expected == nil {
|
||||
assert.Empty(t, result)
|
||||
} else {
|
||||
assert.Equal(t, tt.expected, result)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// TestMakeDeclarative tests question to declarative conversion.
|
||||
func TestMakeDeclarative(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
query string
|
||||
expected string
|
||||
}{
|
||||
{
|
||||
name: "how_do_i",
|
||||
query: "how do I implement auth?",
|
||||
expected: "implement auth",
|
||||
},
|
||||
{
|
||||
name: "how_to",
|
||||
query: "how to fix the bug",
|
||||
expected: "fix the bug",
|
||||
},
|
||||
{
|
||||
name: "what_is",
|
||||
query: "what is the purpose of this?",
|
||||
expected: "the purpose of this",
|
||||
},
|
||||
{
|
||||
name: "why_does",
|
||||
query: "why does this fail?",
|
||||
expected: "this fail",
|
||||
},
|
||||
{
|
||||
name: "already_declarative",
|
||||
query: "user authentication",
|
||||
expected: "user authentication",
|
||||
},
|
||||
{
|
||||
name: "question_mark_only",
|
||||
query: "authentication?",
|
||||
expected: "authentication",
|
||||
},
|
||||
{
|
||||
name: "case_insensitive",
|
||||
query: "How To Fix Auth?",
|
||||
expected: "Fix Auth",
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
result := makeDeclarative(tt.query)
|
||||
assert.Equal(t, tt.expected, result)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// TestDeduplicateExpansions tests deduplication.
|
||||
func TestDeduplicateExpansions(t *testing.T) {
|
||||
expansions := []ExpandedQuery{
|
||||
{Query: "auth handler", Weight: 1.0},
|
||||
{Query: "AUTH HANDLER", Weight: 0.8}, // Duplicate (case insensitive)
|
||||
{Query: "auth handler ", Weight: 0.7}, // Duplicate (whitespace)
|
||||
{Query: "user auth", Weight: 0.6},
|
||||
}
|
||||
|
||||
result := deduplicateExpansions(expansions)
|
||||
assert.Len(t, result, 2) // "auth handler" and "user auth"
|
||||
assert.Equal(t, "auth handler", result[0].Query)
|
||||
assert.Equal(t, 1.0, result[0].Weight) // First one preserved
|
||||
assert.Equal(t, "user auth", result[1].Query)
|
||||
}
|
||||
|
||||
// TestCosineSimilarity tests cosine similarity calculation.
|
||||
func TestCosineSimilarity(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
a []float32
|
||||
b []float32
|
||||
expected float64
|
||||
delta float64
|
||||
}{
|
||||
{
|
||||
name: "identical_vectors",
|
||||
a: []float32{1, 0, 0},
|
||||
b: []float32{1, 0, 0},
|
||||
expected: 1.0,
|
||||
delta: 0.001,
|
||||
},
|
||||
{
|
||||
name: "orthogonal_vectors",
|
||||
a: []float32{1, 0, 0},
|
||||
b: []float32{0, 1, 0},
|
||||
expected: 0.0,
|
||||
delta: 0.001,
|
||||
},
|
||||
{
|
||||
name: "opposite_vectors",
|
||||
a: []float32{1, 0, 0},
|
||||
b: []float32{-1, 0, 0},
|
||||
expected: -1.0,
|
||||
delta: 0.001,
|
||||
},
|
||||
{
|
||||
name: "similar_vectors",
|
||||
a: []float32{1, 1, 0},
|
||||
b: []float32{1, 0, 0},
|
||||
expected: 0.707,
|
||||
delta: 0.01,
|
||||
},
|
||||
{
|
||||
name: "empty_vectors",
|
||||
a: []float32{},
|
||||
b: []float32{},
|
||||
expected: 0.0,
|
||||
delta: 0.001,
|
||||
},
|
||||
{
|
||||
name: "different_lengths",
|
||||
a: []float32{1, 0},
|
||||
b: []float32{1, 0, 0},
|
||||
expected: 0.0,
|
||||
delta: 0.001,
|
||||
},
|
||||
{
|
||||
name: "zero_vector",
|
||||
a: []float32{0, 0, 0},
|
||||
b: []float32{1, 1, 1},
|
||||
expected: 0.0,
|
||||
delta: 0.001,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
result := cosineSimilarity(tt.a, tt.b)
|
||||
assert.InDelta(t, tt.expected, result, tt.delta)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// TestDefaultConfig tests default configuration.
|
||||
func TestDefaultConfig(t *testing.T) {
|
||||
cfg := DefaultConfig()
|
||||
|
||||
assert.Equal(t, 4, cfg.MaxExpansions)
|
||||
assert.Equal(t, 0.5, cfg.MinSimilarity)
|
||||
assert.True(t, cfg.EnableVocabularyExpansion)
|
||||
}
|
||||
|
||||
// TestExpandedQueryStruct tests ExpandedQuery struct.
|
||||
func TestExpandedQueryStruct(t *testing.T) {
|
||||
eq := ExpandedQuery{
|
||||
Query: "test query",
|
||||
Weight: 0.85,
|
||||
Source: "vocabulary:auth",
|
||||
Intent: IntentQuestion,
|
||||
}
|
||||
|
||||
assert.Equal(t, "test query", eq.Query)
|
||||
assert.Equal(t, 0.85, eq.Weight)
|
||||
assert.Equal(t, "vocabulary:auth", eq.Source)
|
||||
assert.Equal(t, IntentQuestion, eq.Intent)
|
||||
}
|
||||
|
||||
// TestVocabEntry tests VocabEntry struct.
|
||||
func TestVocabEntry(t *testing.T) {
|
||||
ve := VocabEntry{
|
||||
Term: "authentication",
|
||||
Weight: 0.9,
|
||||
Source: "concept",
|
||||
}
|
||||
|
||||
assert.Equal(t, "authentication", ve.Term)
|
||||
assert.Equal(t, 0.9, ve.Weight)
|
||||
assert.Equal(t, "concept", ve.Source)
|
||||
}
|
||||
|
||||
// TestIntentConstants tests intent constant values.
|
||||
func TestIntentConstants(t *testing.T) {
|
||||
assert.Equal(t, QueryIntent("question"), IntentQuestion)
|
||||
assert.Equal(t, QueryIntent("error"), IntentError)
|
||||
assert.Equal(t, QueryIntent("implementation"), IntentImplementation)
|
||||
assert.Equal(t, QueryIntent("architecture"), IntentArchitecture)
|
||||
assert.Equal(t, QueryIntent("general"), IntentGeneral)
|
||||
}
|
||||
|
||||
// TestTruncate tests the truncate helper.
|
||||
func TestTruncate(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
input string
|
||||
maxLen int
|
||||
expected string
|
||||
}{
|
||||
{"short", "hello", 10, "hello"},
|
||||
{"exact", "hello", 5, "hello"},
|
||||
{"long", "hello world", 5, "hello..."},
|
||||
{"empty", "", 10, ""},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
result := truncate(tt.input, tt.maxLen)
|
||||
assert.Equal(t, tt.expected, result)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// TestSqrt tests the sqrt helper.
|
||||
func TestSqrt(t *testing.T) {
|
||||
tests := []struct {
|
||||
input float64
|
||||
expected float64
|
||||
delta float64
|
||||
}{
|
||||
{4.0, 2.0, 0.001},
|
||||
{9.0, 3.0, 0.001},
|
||||
{16.0, 4.0, 0.001},
|
||||
{2.0, 1.414, 0.01},
|
||||
{0.0, 0.0, 0.001},
|
||||
{-1.0, 0.0, 0.001},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run("", func(t *testing.T) {
|
||||
result := sqrt(tt.input)
|
||||
assert.InDelta(t, tt.expected, result, tt.delta)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// TestExpandByIntentError tests error intent expansion.
|
||||
func (s *ExpanderSuite) TestExpandByIntentError() {
|
||||
expansions := s.expander.expandByIntent("fix authentication bug", IntentError)
|
||||
s.NotEmpty(expansions)
|
||||
|
||||
// Should have solution-oriented expansion
|
||||
hasSolution := false
|
||||
for _, exp := range expansions {
|
||||
if exp.Source == "intent:solution" {
|
||||
hasSolution = true
|
||||
break
|
||||
}
|
||||
}
|
||||
s.True(hasSolution)
|
||||
}
|
||||
|
||||
// TestExpandByIntentQuestion tests question intent expansion.
|
||||
func (s *ExpanderSuite) TestExpandByIntentQuestion() {
|
||||
expansions := s.expander.expandByIntent("how do I implement auth", IntentQuestion)
|
||||
s.NotEmpty(expansions)
|
||||
|
||||
// Should have declarative expansion
|
||||
hasDeclarative := false
|
||||
for _, exp := range expansions {
|
||||
if exp.Source == "intent:declarative" {
|
||||
hasDeclarative = true
|
||||
break
|
||||
}
|
||||
}
|
||||
s.True(hasDeclarative)
|
||||
}
|
||||
|
||||
// TestExpandByIntentImplementation tests implementation intent expansion.
|
||||
func (s *ExpanderSuite) TestExpandByIntentImplementation() {
|
||||
expansions := s.expander.expandByIntent("implement user handler", IntentImplementation)
|
||||
s.NotEmpty(expansions)
|
||||
|
||||
// Should have how expansion
|
||||
hasHow := false
|
||||
for _, exp := range expansions {
|
||||
if exp.Source == "intent:how" {
|
||||
hasHow = true
|
||||
break
|
||||
}
|
||||
}
|
||||
s.True(hasHow)
|
||||
}
|
||||
|
||||
// TestExpandByIntentArchitecture tests architecture intent expansion.
|
||||
func (s *ExpanderSuite) TestExpandByIntentArchitecture() {
|
||||
expansions := s.expander.expandByIntent("system architecture design", IntentArchitecture)
|
||||
s.NotEmpty(expansions)
|
||||
|
||||
// Should have design expansion
|
||||
hasDesign := false
|
||||
for _, exp := range expansions {
|
||||
if exp.Source == "intent:design" {
|
||||
hasDesign = true
|
||||
break
|
||||
}
|
||||
}
|
||||
s.True(hasDesign)
|
||||
}
|
||||
|
||||
// TestExpandByIntentGeneral tests general intent returns no expansions.
|
||||
func (s *ExpanderSuite) TestExpandByIntentGeneral() {
|
||||
expansions := s.expander.expandByIntent("database", IntentGeneral)
|
||||
s.Empty(expansions) // General intent doesn't add intent-based expansions
|
||||
}
|
||||
|
||||
// TestEmptyVocabulary tests expansion with empty vocabulary.
|
||||
func (s *ExpanderSuite) TestEmptyVocabulary() {
|
||||
ctx := context.Background()
|
||||
expansions := s.expander.expandByVocabulary(ctx, "test query", 0.5)
|
||||
s.Empty(expansions)
|
||||
}
|
||||
|
||||
// TestIntentPatternsExist tests that all intents have patterns.
|
||||
func (s *ExpanderSuite) TestIntentPatternsExist() {
|
||||
s.NotEmpty(s.expander.intentPatterns[IntentQuestion])
|
||||
s.NotEmpty(s.expander.intentPatterns[IntentError])
|
||||
s.NotEmpty(s.expander.intentPatterns[IntentImplementation])
|
||||
s.NotEmpty(s.expander.intentPatterns[IntentArchitecture])
|
||||
}
|
||||
+29
-15
@@ -35,20 +35,21 @@ func NewManager(
|
||||
|
||||
// SearchParams contains parameters for unified search.
|
||||
type SearchParams struct {
|
||||
Query string
|
||||
Type string // "observations", "sessions", "prompts", or empty for all
|
||||
Project string
|
||||
ObsType string // Observation type filter
|
||||
Concepts string
|
||||
Files string
|
||||
DateStart int64
|
||||
DateEnd int64
|
||||
OrderBy string // "relevance", "date_desc", "date_asc"
|
||||
Limit int
|
||||
Offset int
|
||||
Format string // "index" or "full"
|
||||
Scope string // "project", "global", or empty for project+global
|
||||
IncludeGlobal bool // If true, include global observations along with project-scoped
|
||||
Query string
|
||||
Type string // "observations", "sessions", "prompts", or empty for all
|
||||
Project string
|
||||
ObsType string // Observation type filter
|
||||
Concepts string
|
||||
Files string
|
||||
DateStart int64
|
||||
DateEnd int64
|
||||
OrderBy string // "relevance", "date_desc", "date_asc"
|
||||
Limit int
|
||||
Offset int
|
||||
Format string // "index" or "full"
|
||||
Scope string // "project", "global", or empty for project+global
|
||||
IncludeGlobal bool // If true, include global observations along with project-scoped
|
||||
ExcludeSuperseded bool // If true, exclude observations that have been superseded
|
||||
}
|
||||
|
||||
// SearchResult represents a unified search result.
|
||||
@@ -126,6 +127,10 @@ func (m *Manager) vectorSearch(ctx context.Context, params SearchParams) (*Unifi
|
||||
obs, err := m.observationStore.GetObservationsByIDs(ctx, obsIDs, params.OrderBy, 0)
|
||||
if err == nil {
|
||||
for _, o := range obs {
|
||||
// Skip superseded observations when requested
|
||||
if params.ExcludeSuperseded && o.IsSuperseded {
|
||||
continue
|
||||
}
|
||||
results = append(results, m.observationToResult(o, params.Format))
|
||||
}
|
||||
}
|
||||
@@ -167,7 +172,16 @@ func (m *Manager) filterSearch(ctx context.Context, params SearchParams) (*Unifi
|
||||
|
||||
// Search observations
|
||||
if params.Type == "" || params.Type == "observations" {
|
||||
obs, err := m.observationStore.GetRecentObservations(ctx, params.Project, params.Limit)
|
||||
var obs []*models.Observation
|
||||
var err error
|
||||
|
||||
// Use active observations (excluding superseded) when requested
|
||||
if params.ExcludeSuperseded {
|
||||
obs, err = m.observationStore.GetActiveObservations(ctx, params.Project, params.Limit)
|
||||
} else {
|
||||
obs, err = m.observationStore.GetRecentObservations(ctx, params.Project, params.Limit)
|
||||
}
|
||||
|
||||
if err == nil {
|
||||
for _, o := range obs {
|
||||
results = append(results, m.observationToResult(o, params.Format))
|
||||
|
||||
@@ -60,12 +60,15 @@ func (c *Client) AddDocuments(ctx context.Context, docs []Document) error {
|
||||
return fmt.Errorf("generate embeddings: %w", err)
|
||||
}
|
||||
|
||||
// Insert into vectors table
|
||||
// Insert into vectors table with model version tracking
|
||||
const insertQuery = `
|
||||
INSERT OR REPLACE INTO vectors (doc_id, embedding, sqlite_id, doc_type, field_type, project, scope)
|
||||
VALUES (?, ?, ?, ?, ?, ?, ?)
|
||||
INSERT OR REPLACE INTO vectors (doc_id, embedding, sqlite_id, doc_type, field_type, project, scope, model_version)
|
||||
VALUES (?, ?, ?, ?, ?, ?, ?, ?)
|
||||
`
|
||||
|
||||
// Get current model version for tracking
|
||||
modelVersion := c.embedSvc.Version()
|
||||
|
||||
tx, err := c.db.BeginTx(ctx, nil)
|
||||
if err != nil {
|
||||
return fmt.Errorf("begin transaction: %w", err)
|
||||
@@ -104,6 +107,7 @@ func (c *Client) AddDocuments(ctx context.Context, docs []Document) error {
|
||||
fieldType,
|
||||
project,
|
||||
scope,
|
||||
modelVersion,
|
||||
)
|
||||
if err != nil {
|
||||
return fmt.Errorf("insert document %s: %w", doc.ID, err)
|
||||
@@ -114,7 +118,7 @@ func (c *Client) AddDocuments(ctx context.Context, docs []Document) error {
|
||||
return fmt.Errorf("commit transaction: %w", err)
|
||||
}
|
||||
|
||||
log.Debug().Int("count", len(docs)).Msg("Added documents to sqlite-vec")
|
||||
log.Debug().Int("count", len(docs)).Str("model", modelVersion).Msg("Added documents to sqlite-vec")
|
||||
return nil
|
||||
}
|
||||
|
||||
@@ -212,6 +216,7 @@ func (c *Client) Query(ctx context.Context, query string, limit int, where map[s
|
||||
return nil, fmt.Errorf("scan row: %w", err)
|
||||
}
|
||||
|
||||
r.Similarity = DistanceToSimilarity(r.Distance)
|
||||
r.Metadata = map[string]any{
|
||||
"sqlite_id": float64(sqliteID), // Keep as float64 for compatibility
|
||||
"doc_type": docType.String,
|
||||
@@ -252,3 +257,148 @@ func truncateString(s string, maxLen int) string {
|
||||
}
|
||||
return s[:maxLen] + "..."
|
||||
}
|
||||
|
||||
// Count returns the total number of vectors in the store.
|
||||
func (c *Client) Count(ctx context.Context) (int64, error) {
|
||||
c.mu.Lock()
|
||||
defer c.mu.Unlock()
|
||||
|
||||
var count int64
|
||||
err := c.db.QueryRowContext(ctx, "SELECT COUNT(*) FROM vectors").Scan(&count)
|
||||
if err != nil {
|
||||
return 0, fmt.Errorf("count vectors: %w", err)
|
||||
}
|
||||
return count, nil
|
||||
}
|
||||
|
||||
// ModelVersion returns the current embedding model version.
|
||||
func (c *Client) ModelVersion() string {
|
||||
return c.embedSvc.Version()
|
||||
}
|
||||
|
||||
// NeedsRebuild checks if vectors need to be rebuilt due to model version change.
|
||||
// Returns true if:
|
||||
// - The vectors table is empty
|
||||
// - Any vectors have a different model_version than the current model
|
||||
func (c *Client) NeedsRebuild(ctx context.Context) (bool, string) {
|
||||
c.mu.Lock()
|
||||
defer c.mu.Unlock()
|
||||
|
||||
currentModel := c.embedSvc.Version()
|
||||
|
||||
// Check total count
|
||||
var totalCount int64
|
||||
err := c.db.QueryRowContext(ctx, "SELECT COUNT(*) FROM vectors").Scan(&totalCount)
|
||||
if err != nil {
|
||||
log.Warn().Err(err).Msg("Failed to count vectors for rebuild check")
|
||||
return false, ""
|
||||
}
|
||||
|
||||
if totalCount == 0 {
|
||||
return true, "empty"
|
||||
}
|
||||
|
||||
// Check for vectors with different model version
|
||||
var staleCount int64
|
||||
err = c.db.QueryRowContext(ctx,
|
||||
"SELECT COUNT(*) FROM vectors WHERE model_version != ? OR model_version IS NULL",
|
||||
currentModel,
|
||||
).Scan(&staleCount)
|
||||
if err != nil {
|
||||
log.Warn().Err(err).Msg("Failed to count stale vectors")
|
||||
return false, ""
|
||||
}
|
||||
|
||||
if staleCount > 0 {
|
||||
return true, fmt.Sprintf("model_mismatch:%d", staleCount)
|
||||
}
|
||||
|
||||
return false, ""
|
||||
}
|
||||
|
||||
// StaleVectorInfo contains information about a vector that needs rebuilding.
|
||||
type StaleVectorInfo struct {
|
||||
DocID string
|
||||
SQLiteID int64
|
||||
DocType string
|
||||
FieldType string
|
||||
Project string
|
||||
Scope string
|
||||
}
|
||||
|
||||
// GetStaleVectors returns doc_ids of vectors with mismatched or null model versions.
|
||||
// This enables granular rebuild - only re-embedding documents that need updating.
|
||||
func (c *Client) GetStaleVectors(ctx context.Context) ([]StaleVectorInfo, error) {
|
||||
c.mu.Lock()
|
||||
defer c.mu.Unlock()
|
||||
|
||||
currentModel := c.embedSvc.Version()
|
||||
|
||||
query := `
|
||||
SELECT doc_id, sqlite_id, doc_type, field_type, project, scope
|
||||
FROM vectors
|
||||
WHERE model_version != ? OR model_version IS NULL
|
||||
`
|
||||
|
||||
rows, err := c.db.QueryContext(ctx, query, currentModel)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("query stale vectors: %w", err)
|
||||
}
|
||||
defer rows.Close()
|
||||
|
||||
var results []StaleVectorInfo
|
||||
for rows.Next() {
|
||||
var info StaleVectorInfo
|
||||
var sqliteID sql.NullInt64
|
||||
var docType, fieldType, project, scope sql.NullString
|
||||
|
||||
if err := rows.Scan(&info.DocID, &sqliteID, &docType, &fieldType, &project, &scope); err != nil {
|
||||
return nil, fmt.Errorf("scan row: %w", err)
|
||||
}
|
||||
|
||||
info.SQLiteID = sqliteID.Int64
|
||||
info.DocType = docType.String
|
||||
info.FieldType = fieldType.String
|
||||
info.Project = project.String
|
||||
info.Scope = scope.String
|
||||
|
||||
results = append(results, info)
|
||||
}
|
||||
|
||||
if err = rows.Err(); err != nil {
|
||||
return nil, fmt.Errorf("iterate rows: %w", err)
|
||||
}
|
||||
|
||||
return results, nil
|
||||
}
|
||||
|
||||
// DeleteVectorsByDocIDs removes vectors by their doc_ids.
|
||||
// Used for granular rebuild - delete stale vectors before re-adding.
|
||||
func (c *Client) DeleteVectorsByDocIDs(ctx context.Context, docIDs []string) error {
|
||||
if len(docIDs) == 0 {
|
||||
return nil
|
||||
}
|
||||
|
||||
c.mu.Lock()
|
||||
defer c.mu.Unlock()
|
||||
|
||||
// Build placeholder string
|
||||
placeholders := make([]string, len(docIDs))
|
||||
args := make([]interface{}, len(docIDs))
|
||||
for i, id := range docIDs {
|
||||
placeholders[i] = "?"
|
||||
args[i] = id
|
||||
}
|
||||
|
||||
// #nosec G201 -- Placeholders are "?" strings, actual values are parameterized via args
|
||||
query := fmt.Sprintf("DELETE FROM vectors WHERE doc_id IN (%s)",
|
||||
strings.Join(placeholders, ","))
|
||||
|
||||
_, err := c.db.ExecContext(ctx, query, args...)
|
||||
if err != nil {
|
||||
return fmt.Errorf("delete vectors by doc_ids: %w", err)
|
||||
}
|
||||
|
||||
log.Debug().Int("count", len(docIDs)).Msg("Deleted stale vectors by doc_id")
|
||||
return nil
|
||||
}
|
||||
|
||||
@@ -38,7 +38,8 @@ func testDB(t *testing.T) (*sql.DB, func()) {
|
||||
doc_type TEXT,
|
||||
field_type TEXT,
|
||||
project TEXT,
|
||||
scope TEXT
|
||||
scope TEXT,
|
||||
model_version TEXT
|
||||
)
|
||||
`)
|
||||
require.NoError(t, err)
|
||||
|
||||
@@ -19,9 +19,32 @@ type Document struct {
|
||||
|
||||
// QueryResult represents a search result from vector search.
|
||||
type QueryResult struct {
|
||||
ID string
|
||||
Distance float64
|
||||
Metadata map[string]any
|
||||
ID string
|
||||
Distance float64
|
||||
Similarity float64 // 1.0 = identical, 0.0 = opposite (derived from distance)
|
||||
Metadata map[string]any
|
||||
}
|
||||
|
||||
// DistanceToSimilarity converts sqlite-vec cosine distance to similarity score.
|
||||
// Cosine distance: 0 = identical, 2 = opposite
|
||||
// Similarity: 1.0 = identical, 0.0 = opposite
|
||||
func DistanceToSimilarity(distance float64) float64 {
|
||||
return 1.0 - (distance / 2.0)
|
||||
}
|
||||
|
||||
// FilterByThreshold filters results to only include those above the similarity threshold.
|
||||
// If maxResults > 0, also caps the number of results.
|
||||
func FilterByThreshold(results []QueryResult, threshold float64, maxResults int) []QueryResult {
|
||||
var filtered []QueryResult
|
||||
for _, r := range results {
|
||||
if r.Similarity >= threshold {
|
||||
filtered = append(filtered, r)
|
||||
if maxResults > 0 && len(filtered) >= maxResults {
|
||||
break
|
||||
}
|
||||
}
|
||||
}
|
||||
return filtered
|
||||
}
|
||||
|
||||
// ExtractedIDs contains SQLite IDs extracted from query results, grouped by document type.
|
||||
|
||||
@@ -240,3 +240,101 @@ func (s *Sync) DeleteUserPrompts(ctx context.Context, promptIDs []int64) error {
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// SyncPattern syncs a single pattern to the vector store.
|
||||
func (s *Sync) SyncPattern(ctx context.Context, pattern *models.Pattern) error {
|
||||
docs := s.formatPatternDocs(pattern)
|
||||
if len(docs) == 0 {
|
||||
return nil
|
||||
}
|
||||
|
||||
if err := s.client.AddDocuments(ctx, docs); err != nil {
|
||||
return fmt.Errorf("add pattern docs: %w", err)
|
||||
}
|
||||
|
||||
log.Debug().
|
||||
Int64("patternId", pattern.ID).
|
||||
Int("docCount", len(docs)).
|
||||
Msg("Synced pattern to sqlite-vec")
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// formatPatternDocs formats a pattern into vector documents.
|
||||
func (s *Sync) formatPatternDocs(pattern *models.Pattern) []Document {
|
||||
docs := make([]Document, 0, 3)
|
||||
|
||||
baseMetadata := map[string]any{
|
||||
"sqlite_id": pattern.ID,
|
||||
"doc_type": "pattern",
|
||||
"pattern_type": string(pattern.Type),
|
||||
"status": string(pattern.Status),
|
||||
"scope": "global", // Patterns are always global
|
||||
"frequency": pattern.Frequency,
|
||||
"confidence": pattern.Confidence,
|
||||
"created_at_epoch": pattern.CreatedAtEpoch,
|
||||
}
|
||||
|
||||
if len(pattern.Signature) > 0 {
|
||||
baseMetadata["signature"] = joinStrings(pattern.Signature, ",")
|
||||
}
|
||||
if len(pattern.Projects) > 0 {
|
||||
baseMetadata["projects"] = joinStrings(pattern.Projects, ",")
|
||||
}
|
||||
|
||||
// Pattern name as document
|
||||
if pattern.Name != "" {
|
||||
docs = append(docs, Document{
|
||||
ID: fmt.Sprintf("pattern_%d_name", pattern.ID),
|
||||
Content: pattern.Name,
|
||||
Metadata: copyMetadata(baseMetadata, "field_type", "name"),
|
||||
})
|
||||
}
|
||||
|
||||
// Pattern description as document
|
||||
if pattern.Description.Valid && pattern.Description.String != "" {
|
||||
docs = append(docs, Document{
|
||||
ID: fmt.Sprintf("pattern_%d_description", pattern.ID),
|
||||
Content: pattern.Description.String,
|
||||
Metadata: copyMetadata(baseMetadata, "field_type", "description"),
|
||||
})
|
||||
}
|
||||
|
||||
// Pattern recommendation as document
|
||||
if pattern.Recommendation.Valid && pattern.Recommendation.String != "" {
|
||||
docs = append(docs, Document{
|
||||
ID: fmt.Sprintf("pattern_%d_recommendation", pattern.ID),
|
||||
Content: pattern.Recommendation.String,
|
||||
Metadata: copyMetadata(baseMetadata, "field_type", "recommendation"),
|
||||
})
|
||||
}
|
||||
|
||||
return docs
|
||||
}
|
||||
|
||||
// DeletePatterns removes pattern documents from the vector store.
|
||||
func (s *Sync) DeletePatterns(ctx context.Context, patternIDs []int64) error {
|
||||
if len(patternIDs) == 0 {
|
||||
return nil
|
||||
}
|
||||
|
||||
// Generate all possible document IDs for these patterns
|
||||
// Pattern: pattern_{id}_name, pattern_{id}_description, pattern_{id}_recommendation
|
||||
ids := make([]string, 0, len(patternIDs)*3)
|
||||
|
||||
for _, patternID := range patternIDs {
|
||||
ids = append(ids, fmt.Sprintf("pattern_%d_name", patternID))
|
||||
ids = append(ids, fmt.Sprintf("pattern_%d_description", patternID))
|
||||
ids = append(ids, fmt.Sprintf("pattern_%d_recommendation", patternID))
|
||||
}
|
||||
|
||||
if err := s.client.DeleteDocuments(ctx, ids); err != nil {
|
||||
return fmt.Errorf("delete pattern docs: %w", err)
|
||||
}
|
||||
|
||||
log.Debug().
|
||||
Int("patternCount", len(patternIDs)).
|
||||
Msg("Deleted patterns from sqlite-vec")
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
+238
-4
@@ -4,13 +4,18 @@ package worker
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"net/http"
|
||||
"sort"
|
||||
"strconv"
|
||||
"time"
|
||||
|
||||
"github.com/go-chi/chi/v5"
|
||||
"github.com/lukaszraczylo/claude-mnemonic/internal/db/sqlite"
|
||||
"github.com/lukaszraczylo/claude-mnemonic/internal/embedding"
|
||||
"github.com/lukaszraczylo/claude-mnemonic/internal/privacy"
|
||||
"github.com/lukaszraczylo/claude-mnemonic/internal/reranking"
|
||||
"github.com/lukaszraczylo/claude-mnemonic/internal/search/expansion"
|
||||
"github.com/lukaszraczylo/claude-mnemonic/internal/vector/sqlitevec"
|
||||
"github.com/lukaszraczylo/claude-mnemonic/internal/worker/sdk"
|
||||
"github.com/lukaszraczylo/claude-mnemonic/internal/worker/session"
|
||||
@@ -641,6 +646,18 @@ func (s *Service) handleGetTypes(w http.ResponseWriter, r *http.Request) {
|
||||
})
|
||||
}
|
||||
|
||||
// handleGetModels returns available embedding models.
|
||||
func (s *Service) handleGetModels(w http.ResponseWriter, _ *http.Request) {
|
||||
models := embedding.ListModels()
|
||||
defaultModel := embedding.GetDefaultModel()
|
||||
|
||||
writeJSON(w, map[string]interface{}{
|
||||
"models": models,
|
||||
"default": defaultModel,
|
||||
"current": s.embedSvc.Version(),
|
||||
})
|
||||
}
|
||||
|
||||
// handleGetStats returns worker statistics.
|
||||
func (s *Service) handleGetStats(w http.ResponseWriter, r *http.Request) {
|
||||
project := r.URL.Query().Get("project")
|
||||
@@ -658,6 +675,22 @@ func (s *Service) handleGetStats(w http.ResponseWriter, r *http.Request) {
|
||||
"ready": s.ready.Load(),
|
||||
}
|
||||
|
||||
// Add embedding model info
|
||||
if s.embedSvc != nil {
|
||||
response["embeddingModel"] = map[string]interface{}{
|
||||
"name": s.embedSvc.Name(),
|
||||
"version": s.embedSvc.Version(),
|
||||
"dimensions": s.embedSvc.Dimensions(),
|
||||
}
|
||||
}
|
||||
|
||||
// Add vector count
|
||||
if s.vectorClient != nil {
|
||||
if count, err := s.vectorClient.Count(r.Context()); err == nil {
|
||||
response["vectorCount"] = count
|
||||
}
|
||||
}
|
||||
|
||||
// Include project-specific observation count if project is specified
|
||||
if project != "" {
|
||||
count, err := s.observationStore.GetObservationCount(r.Context(), project)
|
||||
@@ -715,15 +748,69 @@ func (s *Service) handleSearchByPrompt(w http.ResponseWriter, r *http.Request) {
|
||||
var observations []*models.Observation
|
||||
var err error
|
||||
var usedVector bool
|
||||
similarityScores := make(map[int64]float64) // Track similarity per observation
|
||||
|
||||
// Get threshold settings from config
|
||||
threshold := s.config.ContextRelevanceThreshold
|
||||
maxResults := s.config.ContextMaxPromptResults
|
||||
|
||||
// Generate expanded queries if query expander is available
|
||||
var expandedQueries []expansion.ExpandedQuery
|
||||
var detectedIntent string
|
||||
if s.queryExpander != nil {
|
||||
cfg := expansion.DefaultConfig()
|
||||
cfg.EnableVocabularyExpansion = false // Vocabulary expansion is optional
|
||||
expandedQueries = s.queryExpander.Expand(r.Context(), query, cfg)
|
||||
if len(expandedQueries) > 0 {
|
||||
detectedIntent = string(expandedQueries[0].Intent)
|
||||
}
|
||||
}
|
||||
if len(expandedQueries) == 0 {
|
||||
// Fallback to just the original query
|
||||
expandedQueries = []expansion.ExpandedQuery{
|
||||
{Query: query, Weight: 1.0, Source: "original"},
|
||||
}
|
||||
}
|
||||
|
||||
// Try vector search first if available
|
||||
if s.vectorClient != nil && s.vectorClient.IsConnected() {
|
||||
where := sqlitevec.BuildWhereFilter(sqlitevec.DocTypeObservation, "")
|
||||
|
||||
vectorResults, vecErr := s.vectorClient.Query(r.Context(), query, limit*2, where)
|
||||
if vecErr == nil && len(vectorResults) > 0 {
|
||||
// Search with each expanded query and merge results
|
||||
allVectorResults := make([]sqlitevec.QueryResult, 0)
|
||||
queryWeights := make(map[string]float64) // Track weights for score merging
|
||||
|
||||
for _, eq := range expandedQueries {
|
||||
vectorResults, vecErr := s.vectorClient.Query(r.Context(), eq.Query, limit*2, where)
|
||||
if vecErr == nil && len(vectorResults) > 0 {
|
||||
// Apply weight to similarity scores before merging
|
||||
for i := range vectorResults {
|
||||
vectorResults[i].Similarity *= eq.Weight
|
||||
}
|
||||
allVectorResults = append(allVectorResults, vectorResults...)
|
||||
queryWeights[eq.Query] = eq.Weight
|
||||
}
|
||||
}
|
||||
|
||||
if len(allVectorResults) > 0 {
|
||||
// Filter by relevance threshold before extracting IDs
|
||||
// Use a slightly lower threshold for expanded queries
|
||||
effectiveThreshold := threshold * 0.9 // Allow slightly lower scores for expanded queries
|
||||
filteredResults := sqlitevec.FilterByThreshold(allVectorResults, effectiveThreshold, 0)
|
||||
|
||||
// Build similarity map for filtered results (keeping highest weighted score per observation)
|
||||
for _, vr := range filteredResults {
|
||||
if sqliteID, ok := vr.Metadata["sqlite_id"].(float64); ok {
|
||||
id := int64(sqliteID)
|
||||
// Keep the highest score for each observation
|
||||
if existing, exists := similarityScores[id]; !exists || vr.Similarity > existing {
|
||||
similarityScores[id] = vr.Similarity
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Extract observation IDs with project/scope filtering using shared helper
|
||||
obsIDs := sqlitevec.ExtractObservationIDs(vectorResults, project)
|
||||
obsIDs := sqlitevec.ExtractObservationIDs(filteredResults, project)
|
||||
|
||||
if len(obsIDs) > 0 {
|
||||
// Fetch full observations from SQLite
|
||||
@@ -773,23 +860,132 @@ func (s *Service) handleSearchByPrompt(w http.ResponseWriter, r *http.Request) {
|
||||
freshObservations = append(freshObservations, obs)
|
||||
}
|
||||
|
||||
// Apply cross-encoder reranking if available
|
||||
var reranked bool
|
||||
if s.reranker != nil && len(freshObservations) > 0 && usedVector {
|
||||
// Build candidates from observations with their bi-encoder scores
|
||||
candidates := make([]reranking.Candidate, len(freshObservations))
|
||||
for i, obs := range freshObservations {
|
||||
content := obs.Title.String
|
||||
if obs.Narrative.Valid && obs.Narrative.String != "" {
|
||||
content = content + " " + obs.Narrative.String
|
||||
}
|
||||
candidates[i] = reranking.Candidate{
|
||||
ID: fmt.Sprintf("%d", obs.ID),
|
||||
Content: content,
|
||||
Score: similarityScores[obs.ID],
|
||||
Metadata: map[string]any{"obs_idx": i},
|
||||
}
|
||||
}
|
||||
|
||||
// Rerank using cross-encoder - use pure mode or combined scores
|
||||
var rerankResults []reranking.RerankResult
|
||||
var rerankErr error
|
||||
if s.config.RerankingPureMode {
|
||||
rerankResults, rerankErr = s.reranker.RerankByScore(query, candidates, s.config.RerankingResults)
|
||||
} else {
|
||||
rerankResults, rerankErr = s.reranker.Rerank(query, candidates, s.config.RerankingResults)
|
||||
}
|
||||
if rerankErr != nil {
|
||||
log.Warn().Err(rerankErr).Msg("Cross-encoder reranking failed, using original order")
|
||||
} else if len(rerankResults) > 0 {
|
||||
// Update similarity scores with reranked scores
|
||||
for _, rr := range rerankResults {
|
||||
if id, err := strconv.ParseInt(rr.ID, 10, 64); err == nil {
|
||||
similarityScores[id] = rr.CombinedScore
|
||||
}
|
||||
}
|
||||
|
||||
// Reorder observations based on rerank results
|
||||
reorderedObs := make([]*models.Observation, 0, len(rerankResults))
|
||||
obsMap := make(map[int64]*models.Observation)
|
||||
for _, obs := range freshObservations {
|
||||
obsMap[obs.ID] = obs
|
||||
}
|
||||
for _, rr := range rerankResults {
|
||||
if id, err := strconv.ParseInt(rr.ID, 10, 64); err == nil {
|
||||
if obs, ok := obsMap[id]; ok {
|
||||
reorderedObs = append(reorderedObs, obs)
|
||||
}
|
||||
}
|
||||
}
|
||||
freshObservations = reorderedObs
|
||||
reranked = true
|
||||
|
||||
log.Debug().
|
||||
Int("candidates", len(candidates)).
|
||||
Int("returned", len(rerankResults)).
|
||||
Msg("Cross-encoder reranking complete")
|
||||
}
|
||||
}
|
||||
|
||||
// Cluster similar observations to remove duplicates
|
||||
clusteredObservations := clusterObservations(freshObservations, 0.4)
|
||||
|
||||
// Sort by similarity score (highest first) if we have scores and didn't rerank
|
||||
if len(similarityScores) > 0 && len(clusteredObservations) > 0 && !reranked {
|
||||
sort.Slice(clusteredObservations, func(i, j int) bool {
|
||||
scoreI := similarityScores[clusteredObservations[i].ID]
|
||||
scoreJ := similarityScores[clusteredObservations[j].ID]
|
||||
return scoreI > scoreJ
|
||||
})
|
||||
}
|
||||
|
||||
// Apply max results cap if configured
|
||||
if maxResults > 0 && len(clusteredObservations) > maxResults {
|
||||
clusteredObservations = clusteredObservations[:maxResults]
|
||||
}
|
||||
|
||||
// Record retrieval stats (no verification done, so verified=0, deleted=0)
|
||||
s.recordRetrievalStats(project, int64(len(clusteredObservations)), 0, 0, true)
|
||||
|
||||
// Increment retrieval counts for scoring (async, non-blocking)
|
||||
if len(clusteredObservations) > 0 {
|
||||
ids := make([]int64, len(clusteredObservations))
|
||||
for i, obs := range clusteredObservations {
|
||||
ids[i] = obs.ID
|
||||
}
|
||||
s.incrementRetrievalCounts(ids)
|
||||
}
|
||||
|
||||
log.Info().
|
||||
Str("project", project).
|
||||
Str("query", query).
|
||||
Str("intent", detectedIntent).
|
||||
Int("expansions", len(expandedQueries)).
|
||||
Int("found", len(clusteredObservations)).
|
||||
Int("stale_excluded", staleCount).
|
||||
Float64("threshold", threshold).
|
||||
Msg("Prompt-based observation search")
|
||||
|
||||
// Build response with similarity scores
|
||||
obsWithScores := make([]map[string]interface{}, len(clusteredObservations))
|
||||
for i, obs := range clusteredObservations {
|
||||
obsMap := obs.ToMap()
|
||||
if score, ok := similarityScores[obs.ID]; ok {
|
||||
obsMap["similarity"] = score
|
||||
}
|
||||
obsWithScores[i] = obsMap
|
||||
}
|
||||
|
||||
// Build expansion info for response
|
||||
expansionInfo := make([]map[string]interface{}, len(expandedQueries))
|
||||
for i, eq := range expandedQueries {
|
||||
expansionInfo[i] = map[string]interface{}{
|
||||
"query": eq.Query,
|
||||
"weight": eq.Weight,
|
||||
"source": eq.Source,
|
||||
}
|
||||
}
|
||||
|
||||
writeJSON(w, map[string]interface{}{
|
||||
"project": project,
|
||||
"query": query,
|
||||
"observations": clusteredObservations,
|
||||
"intent": detectedIntent,
|
||||
"expansions": expansionInfo,
|
||||
"observations": obsWithScores,
|
||||
"threshold": threshold,
|
||||
"max_results": maxResults,
|
||||
})
|
||||
}
|
||||
|
||||
@@ -857,6 +1053,15 @@ func (s *Service) handleContextInject(w http.ResponseWriter, r *http.Request) {
|
||||
// Record retrieval stats (no verification done)
|
||||
s.recordRetrievalStats(project, int64(len(clusteredObservations)), 0, 0, false)
|
||||
|
||||
// Increment retrieval counts for scoring (async, non-blocking)
|
||||
if len(clusteredObservations) > 0 {
|
||||
ids := make([]int64, len(clusteredObservations))
|
||||
for i, obs := range clusteredObservations {
|
||||
ids[i] = obs.ID
|
||||
}
|
||||
s.incrementRetrievalCounts(ids)
|
||||
}
|
||||
|
||||
log.Info().
|
||||
Str("project", project).
|
||||
Int("total", len(observations)).
|
||||
@@ -1015,6 +1220,35 @@ func (s *Service) handleSelfCheck(w http.ResponseWriter, r *http.Request) {
|
||||
}
|
||||
components = append(components, sseStatus)
|
||||
|
||||
// Check Cross-Encoder Reranker
|
||||
rerankerStatus := ComponentHealth{Name: "Cross-Encoder Reranker", Status: "healthy"}
|
||||
if !s.config.RerankingEnabled {
|
||||
rerankerStatus.Status = "degraded"
|
||||
rerankerStatus.Message = "Disabled in config"
|
||||
if overall == "healthy" {
|
||||
overall = "degraded"
|
||||
}
|
||||
} else if s.reranker == nil {
|
||||
rerankerStatus.Status = "degraded"
|
||||
rerankerStatus.Message = "Not initialized"
|
||||
if overall == "healthy" {
|
||||
overall = "degraded"
|
||||
}
|
||||
} else {
|
||||
// Verify reranker is functional using Score
|
||||
_, normalizedScore, err := s.reranker.Score("test query", "test document")
|
||||
if err != nil {
|
||||
rerankerStatus.Status = "unhealthy"
|
||||
rerankerStatus.Message = fmt.Sprintf("Score check failed: %v", err)
|
||||
if overall == "healthy" {
|
||||
overall = "degraded"
|
||||
}
|
||||
} else {
|
||||
rerankerStatus.Message = fmt.Sprintf("Score check passed (%.4f)", normalizedScore)
|
||||
}
|
||||
}
|
||||
components = append(components, rerankerStatus)
|
||||
|
||||
// Calculate uptime
|
||||
uptime := time.Since(s.startTime).Round(time.Second).String()
|
||||
|
||||
|
||||
@@ -0,0 +1,292 @@
|
||||
// Package worker provides the main worker service for claude-mnemonic.
|
||||
package worker
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"net/http"
|
||||
"strconv"
|
||||
|
||||
"github.com/go-chi/chi/v5"
|
||||
"github.com/lukaszraczylo/claude-mnemonic/pkg/models"
|
||||
)
|
||||
|
||||
// DefaultPatternsLimit is the default number of patterns to return.
|
||||
const DefaultPatternsLimit = 100
|
||||
|
||||
// handleGetPatterns returns all active patterns, optionally filtered by type or project.
|
||||
func (s *Service) handleGetPatterns(w http.ResponseWriter, r *http.Request) {
|
||||
s.initMu.RLock()
|
||||
store := s.patternStore
|
||||
s.initMu.RUnlock()
|
||||
|
||||
if store == nil {
|
||||
http.Error(w, "pattern store not initialized", http.StatusServiceUnavailable)
|
||||
return
|
||||
}
|
||||
|
||||
// Parse query parameters
|
||||
limit := DefaultPatternsLimit
|
||||
if l := r.URL.Query().Get("limit"); l != "" {
|
||||
if parsed, err := strconv.Atoi(l); err == nil && parsed > 0 {
|
||||
limit = parsed
|
||||
}
|
||||
}
|
||||
|
||||
patternType := r.URL.Query().Get("type")
|
||||
project := r.URL.Query().Get("project")
|
||||
|
||||
var patterns []*models.Pattern
|
||||
var err error
|
||||
|
||||
if patternType != "" {
|
||||
// Filter by type
|
||||
patterns, err = store.GetPatternsByType(r.Context(), models.PatternType(patternType), limit)
|
||||
} else if project != "" {
|
||||
// Filter by project
|
||||
patterns, err = store.GetPatternsByProject(r.Context(), project, limit)
|
||||
} else {
|
||||
// Get all active patterns
|
||||
patterns, err = store.GetActivePatterns(r.Context(), limit)
|
||||
}
|
||||
|
||||
if err != nil {
|
||||
http.Error(w, err.Error(), http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
|
||||
writeJSON(w, patterns)
|
||||
}
|
||||
|
||||
// handleGetPatternStats returns aggregate statistics about patterns.
|
||||
func (s *Service) handleGetPatternStats(w http.ResponseWriter, r *http.Request) {
|
||||
s.initMu.RLock()
|
||||
store := s.patternStore
|
||||
s.initMu.RUnlock()
|
||||
|
||||
if store == nil {
|
||||
http.Error(w, "pattern store not initialized", http.StatusServiceUnavailable)
|
||||
return
|
||||
}
|
||||
|
||||
stats, err := store.GetPatternStats(r.Context())
|
||||
if err != nil {
|
||||
http.Error(w, err.Error(), http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
|
||||
writeJSON(w, stats)
|
||||
}
|
||||
|
||||
// handleGetPatternByID returns a single pattern by ID.
|
||||
func (s *Service) handleGetPatternByID(w http.ResponseWriter, r *http.Request) {
|
||||
s.initMu.RLock()
|
||||
store := s.patternStore
|
||||
s.initMu.RUnlock()
|
||||
|
||||
if store == nil {
|
||||
http.Error(w, "pattern store not initialized", http.StatusServiceUnavailable)
|
||||
return
|
||||
}
|
||||
|
||||
idStr := chi.URLParam(r, "id")
|
||||
id, err := strconv.ParseInt(idStr, 10, 64)
|
||||
if err != nil {
|
||||
http.Error(w, "invalid pattern ID", http.StatusBadRequest)
|
||||
return
|
||||
}
|
||||
|
||||
pattern, err := store.GetPatternByID(r.Context(), id)
|
||||
if err != nil {
|
||||
http.Error(w, err.Error(), http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
if pattern == nil {
|
||||
http.Error(w, "pattern not found", http.StatusNotFound)
|
||||
return
|
||||
}
|
||||
|
||||
writeJSON(w, pattern)
|
||||
}
|
||||
|
||||
// handleGetPatternInsight returns a formatted insight string for a pattern.
|
||||
func (s *Service) handleGetPatternInsight(w http.ResponseWriter, r *http.Request) {
|
||||
s.initMu.RLock()
|
||||
detector := s.patternDetector
|
||||
s.initMu.RUnlock()
|
||||
|
||||
if detector == nil {
|
||||
http.Error(w, "pattern detector not initialized", http.StatusServiceUnavailable)
|
||||
return
|
||||
}
|
||||
|
||||
idStr := chi.URLParam(r, "id")
|
||||
id, err := strconv.ParseInt(idStr, 10, 64)
|
||||
if err != nil {
|
||||
http.Error(w, "invalid pattern ID", http.StatusBadRequest)
|
||||
return
|
||||
}
|
||||
|
||||
insight, err := detector.GetPatternInsight(r.Context(), id)
|
||||
if err != nil {
|
||||
http.Error(w, err.Error(), http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
|
||||
writeJSON(w, map[string]string{"insight": insight})
|
||||
}
|
||||
|
||||
// handleDeletePattern deletes a pattern by ID.
|
||||
func (s *Service) handleDeletePattern(w http.ResponseWriter, r *http.Request) {
|
||||
s.initMu.RLock()
|
||||
store := s.patternStore
|
||||
s.initMu.RUnlock()
|
||||
|
||||
if store == nil {
|
||||
http.Error(w, "pattern store not initialized", http.StatusServiceUnavailable)
|
||||
return
|
||||
}
|
||||
|
||||
idStr := chi.URLParam(r, "id")
|
||||
id, err := strconv.ParseInt(idStr, 10, 64)
|
||||
if err != nil {
|
||||
http.Error(w, "invalid pattern ID", http.StatusBadRequest)
|
||||
return
|
||||
}
|
||||
|
||||
if err := store.DeletePattern(r.Context(), id); err != nil {
|
||||
http.Error(w, err.Error(), http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
|
||||
writeJSON(w, map[string]string{"status": "deleted"})
|
||||
}
|
||||
|
||||
// handleDeprecatePattern marks a pattern as deprecated.
|
||||
func (s *Service) handleDeprecatePattern(w http.ResponseWriter, r *http.Request) {
|
||||
s.initMu.RLock()
|
||||
store := s.patternStore
|
||||
s.initMu.RUnlock()
|
||||
|
||||
if store == nil {
|
||||
http.Error(w, "pattern store not initialized", http.StatusServiceUnavailable)
|
||||
return
|
||||
}
|
||||
|
||||
idStr := chi.URLParam(r, "id")
|
||||
id, err := strconv.ParseInt(idStr, 10, 64)
|
||||
if err != nil {
|
||||
http.Error(w, "invalid pattern ID", http.StatusBadRequest)
|
||||
return
|
||||
}
|
||||
|
||||
if err := store.MarkPatternDeprecated(r.Context(), id); err != nil {
|
||||
http.Error(w, err.Error(), http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
|
||||
writeJSON(w, map[string]string{"status": "deprecated"})
|
||||
}
|
||||
|
||||
// MergePatternsRequest is the request body for merging patterns.
|
||||
type MergePatternsRequest struct {
|
||||
SourceID int64 `json:"source_id"`
|
||||
TargetID int64 `json:"target_id"`
|
||||
}
|
||||
|
||||
// handleSearchPatterns performs full-text search on patterns.
|
||||
func (s *Service) handleSearchPatterns(w http.ResponseWriter, r *http.Request) {
|
||||
s.initMu.RLock()
|
||||
store := s.patternStore
|
||||
s.initMu.RUnlock()
|
||||
|
||||
if store == nil {
|
||||
http.Error(w, "pattern store not initialized", http.StatusServiceUnavailable)
|
||||
return
|
||||
}
|
||||
|
||||
query := r.URL.Query().Get("q")
|
||||
if query == "" {
|
||||
http.Error(w, "query parameter 'q' is required", http.StatusBadRequest)
|
||||
return
|
||||
}
|
||||
|
||||
limit := DefaultPatternsLimit
|
||||
if l := r.URL.Query().Get("limit"); l != "" {
|
||||
if parsed, err := strconv.Atoi(l); err == nil && parsed > 0 {
|
||||
limit = parsed
|
||||
}
|
||||
}
|
||||
|
||||
patterns, err := store.SearchPatternsFTS(r.Context(), query, limit)
|
||||
if err != nil {
|
||||
http.Error(w, err.Error(), http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
|
||||
writeJSON(w, patterns)
|
||||
}
|
||||
|
||||
// handleGetPatternByName returns a pattern by its name.
|
||||
func (s *Service) handleGetPatternByName(w http.ResponseWriter, r *http.Request) {
|
||||
s.initMu.RLock()
|
||||
store := s.patternStore
|
||||
s.initMu.RUnlock()
|
||||
|
||||
if store == nil {
|
||||
http.Error(w, "pattern store not initialized", http.StatusServiceUnavailable)
|
||||
return
|
||||
}
|
||||
|
||||
name := r.URL.Query().Get("name")
|
||||
if name == "" {
|
||||
http.Error(w, "query parameter 'name' is required", http.StatusBadRequest)
|
||||
return
|
||||
}
|
||||
|
||||
pattern, err := store.GetPatternByName(r.Context(), name)
|
||||
if err != nil {
|
||||
http.Error(w, err.Error(), http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
if pattern == nil {
|
||||
http.Error(w, "pattern not found", http.StatusNotFound)
|
||||
return
|
||||
}
|
||||
|
||||
writeJSON(w, pattern)
|
||||
}
|
||||
|
||||
// handleMergePatterns merges a source pattern into a target pattern.
|
||||
func (s *Service) handleMergePatterns(w http.ResponseWriter, r *http.Request) {
|
||||
s.initMu.RLock()
|
||||
store := s.patternStore
|
||||
s.initMu.RUnlock()
|
||||
|
||||
if store == nil {
|
||||
http.Error(w, "pattern store not initialized", http.StatusServiceUnavailable)
|
||||
return
|
||||
}
|
||||
|
||||
var req MergePatternsRequest
|
||||
if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
|
||||
http.Error(w, err.Error(), http.StatusBadRequest)
|
||||
return
|
||||
}
|
||||
|
||||
if req.SourceID == 0 || req.TargetID == 0 {
|
||||
http.Error(w, "source_id and target_id are required", http.StatusBadRequest)
|
||||
return
|
||||
}
|
||||
|
||||
if req.SourceID == req.TargetID {
|
||||
http.Error(w, "source_id and target_id cannot be the same", http.StatusBadRequest)
|
||||
return
|
||||
}
|
||||
|
||||
if err := store.MergePatterns(r.Context(), req.SourceID, req.TargetID); err != nil {
|
||||
http.Error(w, err.Error(), http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
|
||||
writeJSON(w, map[string]string{"status": "merged"})
|
||||
}
|
||||
@@ -0,0 +1,174 @@
|
||||
// Package worker provides the main worker service for claude-mnemonic.
|
||||
package worker
|
||||
|
||||
import (
|
||||
"net/http"
|
||||
"strconv"
|
||||
|
||||
"github.com/go-chi/chi/v5"
|
||||
"github.com/lukaszraczylo/claude-mnemonic/pkg/models"
|
||||
)
|
||||
|
||||
// DefaultRelationsLimit is the default number of relations to return.
|
||||
const DefaultRelationsLimit = 50
|
||||
|
||||
// handleGetRelations returns relations for an observation.
|
||||
func (s *Service) handleGetRelations(w http.ResponseWriter, r *http.Request) {
|
||||
idStr := chi.URLParam(r, "id")
|
||||
id, err := strconv.ParseInt(idStr, 10, 64)
|
||||
if err != nil {
|
||||
http.Error(w, "invalid observation id", http.StatusBadRequest)
|
||||
return
|
||||
}
|
||||
|
||||
relations, err := s.relationStore.GetRelationsWithDetails(r.Context(), id)
|
||||
if err != nil {
|
||||
http.Error(w, err.Error(), http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
|
||||
if relations == nil {
|
||||
relations = []*models.RelationWithDetails{}
|
||||
}
|
||||
|
||||
writeJSON(w, relations)
|
||||
}
|
||||
|
||||
// handleGetRelationGraph returns the relation graph for an observation.
|
||||
func (s *Service) handleGetRelationGraph(w http.ResponseWriter, r *http.Request) {
|
||||
idStr := chi.URLParam(r, "id")
|
||||
id, err := strconv.ParseInt(idStr, 10, 64)
|
||||
if err != nil {
|
||||
http.Error(w, "invalid observation id", http.StatusBadRequest)
|
||||
return
|
||||
}
|
||||
|
||||
// Get depth parameter (default 2)
|
||||
depth := 2
|
||||
if depthStr := r.URL.Query().Get("depth"); depthStr != "" {
|
||||
if d, err := strconv.Atoi(depthStr); err == nil && d > 0 && d <= 5 {
|
||||
depth = d
|
||||
}
|
||||
}
|
||||
|
||||
graph, err := s.relationStore.GetRelationGraph(r.Context(), id, depth)
|
||||
if err != nil {
|
||||
http.Error(w, err.Error(), http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
|
||||
writeJSON(w, graph)
|
||||
}
|
||||
|
||||
// handleGetRelatedObservations returns observations related to a given one.
|
||||
func (s *Service) handleGetRelatedObservations(w http.ResponseWriter, r *http.Request) {
|
||||
idStr := chi.URLParam(r, "id")
|
||||
id, err := strconv.ParseInt(idStr, 10, 64)
|
||||
if err != nil {
|
||||
http.Error(w, "invalid observation id", http.StatusBadRequest)
|
||||
return
|
||||
}
|
||||
|
||||
// Get minimum confidence parameter (default 0.4)
|
||||
minConfidence := 0.4
|
||||
if confStr := r.URL.Query().Get("min_confidence"); confStr != "" {
|
||||
if c, err := strconv.ParseFloat(confStr, 64); err == nil && c >= 0 && c <= 1 {
|
||||
minConfidence = c
|
||||
}
|
||||
}
|
||||
|
||||
// Get related observation IDs
|
||||
relatedIDs, err := s.relationStore.GetRelatedObservationIDs(r.Context(), id, minConfidence)
|
||||
if err != nil {
|
||||
http.Error(w, err.Error(), http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
|
||||
if len(relatedIDs) == 0 {
|
||||
writeJSON(w, []*models.Observation{})
|
||||
return
|
||||
}
|
||||
|
||||
// Fetch full observations
|
||||
observations, err := s.observationStore.GetObservationsByIDs(r.Context(), relatedIDs, "importance", 50)
|
||||
if err != nil {
|
||||
http.Error(w, err.Error(), http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
|
||||
if observations == nil {
|
||||
observations = []*models.Observation{}
|
||||
}
|
||||
|
||||
writeJSON(w, observations)
|
||||
}
|
||||
|
||||
// handleGetRelationsByType returns all relations of a specific type.
|
||||
func (s *Service) handleGetRelationsByType(w http.ResponseWriter, r *http.Request) {
|
||||
relType := chi.URLParam(r, "type")
|
||||
|
||||
// Validate relation type
|
||||
validType := false
|
||||
for _, t := range models.AllRelationTypes {
|
||||
if string(t) == relType {
|
||||
validType = true
|
||||
break
|
||||
}
|
||||
}
|
||||
if !validType {
|
||||
http.Error(w, "invalid relation type", http.StatusBadRequest)
|
||||
return
|
||||
}
|
||||
|
||||
limit := DefaultRelationsLimit
|
||||
if limitStr := r.URL.Query().Get("limit"); limitStr != "" {
|
||||
if l, err := strconv.Atoi(limitStr); err == nil && l > 0 && l <= 100 {
|
||||
limit = l
|
||||
}
|
||||
}
|
||||
|
||||
relations, err := s.relationStore.GetRelationsByType(r.Context(), models.RelationType(relType), limit)
|
||||
if err != nil {
|
||||
http.Error(w, err.Error(), http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
|
||||
if relations == nil {
|
||||
relations = []*models.ObservationRelation{}
|
||||
}
|
||||
|
||||
writeJSON(w, relations)
|
||||
}
|
||||
|
||||
// handleGetRelationStats returns statistics about relations.
|
||||
func (s *Service) handleGetRelationStats(w http.ResponseWriter, r *http.Request) {
|
||||
// Get total relation count
|
||||
totalCount, err := s.relationStore.GetTotalRelationCount(r.Context())
|
||||
if err != nil {
|
||||
http.Error(w, err.Error(), http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
|
||||
// Get high confidence relations count
|
||||
highConfRelations, err := s.relationStore.GetHighConfidenceRelations(r.Context(), 0.7, 1000)
|
||||
if err != nil {
|
||||
http.Error(w, err.Error(), http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
|
||||
// Count by relation type
|
||||
typeCounts := make(map[string]int)
|
||||
for _, t := range models.AllRelationTypes {
|
||||
relations, err := s.relationStore.GetRelationsByType(r.Context(), t, 1000)
|
||||
if err == nil {
|
||||
typeCounts[string(t)] = len(relations)
|
||||
}
|
||||
}
|
||||
|
||||
writeJSON(w, map[string]interface{}{
|
||||
"total_count": totalCount,
|
||||
"high_confidence": len(highConfRelations),
|
||||
"by_type": typeCounts,
|
||||
"min_confidence_used": 0.4,
|
||||
})
|
||||
}
|
||||
@@ -0,0 +1,354 @@
|
||||
// Package worker provides the main worker service for claude-mnemonic.
|
||||
package worker
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"net/http"
|
||||
"strconv"
|
||||
"time"
|
||||
|
||||
"github.com/go-chi/chi/v5"
|
||||
"github.com/lukaszraczylo/claude-mnemonic/pkg/models"
|
||||
)
|
||||
|
||||
// FeedbackRequest represents a user feedback submission.
|
||||
type FeedbackRequest struct {
|
||||
Feedback int `json:"feedback"` // -1 (thumbs down), 0 (neutral), 1 (thumbs up)
|
||||
}
|
||||
|
||||
// handleObservationFeedback handles user feedback (thumbs up/down) for an observation.
|
||||
// POST /api/observations/{id}/feedback
|
||||
func (s *Service) handleObservationFeedback(w http.ResponseWriter, r *http.Request) {
|
||||
// Parse observation ID
|
||||
idStr := chi.URLParam(r, "id")
|
||||
id, err := strconv.ParseInt(idStr, 10, 64)
|
||||
if err != nil {
|
||||
http.Error(w, "invalid observation id", http.StatusBadRequest)
|
||||
return
|
||||
}
|
||||
|
||||
// Parse request body
|
||||
var req FeedbackRequest
|
||||
if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
|
||||
http.Error(w, "invalid request body", http.StatusBadRequest)
|
||||
return
|
||||
}
|
||||
|
||||
// Validate feedback value
|
||||
if req.Feedback < -1 || req.Feedback > 1 {
|
||||
http.Error(w, "feedback must be -1, 0, or 1", http.StatusBadRequest)
|
||||
return
|
||||
}
|
||||
|
||||
// Get required components
|
||||
s.initMu.RLock()
|
||||
observationStore := s.observationStore
|
||||
scoreCalculator := s.scoreCalculator
|
||||
s.initMu.RUnlock()
|
||||
|
||||
if observationStore == nil {
|
||||
http.Error(w, "service not ready", http.StatusServiceUnavailable)
|
||||
return
|
||||
}
|
||||
|
||||
// Update feedback in database
|
||||
if err := observationStore.UpdateObservationFeedback(r.Context(), id, req.Feedback); err != nil {
|
||||
http.Error(w, "failed to update feedback", http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
|
||||
// Recalculate score immediately if calculator is available
|
||||
var newScore float64
|
||||
if scoreCalculator != nil {
|
||||
obs, err := observationStore.GetObservationByID(r.Context(), id)
|
||||
if err == nil && obs != nil {
|
||||
obs.UserFeedback = req.Feedback // Apply the new feedback
|
||||
newScore = scoreCalculator.Calculate(obs, time.Now())
|
||||
if err := observationStore.UpdateImportanceScore(r.Context(), id, newScore); err != nil {
|
||||
// Log but don't fail - feedback was recorded
|
||||
// Score will be updated on next recalculation cycle
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Broadcast update via SSE
|
||||
s.sseBroadcaster.Broadcast(map[string]interface{}{
|
||||
"type": "observation_feedback",
|
||||
"id": id,
|
||||
"feedback": req.Feedback,
|
||||
"score": newScore,
|
||||
})
|
||||
|
||||
writeJSON(w, map[string]interface{}{
|
||||
"status": "ok",
|
||||
"id": id,
|
||||
"feedback": req.Feedback,
|
||||
"score": newScore,
|
||||
})
|
||||
}
|
||||
|
||||
// handleGetScoringStats returns scoring statistics and configuration.
|
||||
// GET /api/scoring/stats
|
||||
func (s *Service) handleGetScoringStats(w http.ResponseWriter, r *http.Request) {
|
||||
project := r.URL.Query().Get("project")
|
||||
|
||||
s.initMu.RLock()
|
||||
observationStore := s.observationStore
|
||||
recalculator := s.recalculator
|
||||
s.initMu.RUnlock()
|
||||
|
||||
if observationStore == nil {
|
||||
http.Error(w, "service not ready", http.StatusServiceUnavailable)
|
||||
return
|
||||
}
|
||||
|
||||
// Get feedback statistics
|
||||
feedbackStats, err := observationStore.GetObservationFeedbackStats(r.Context(), project)
|
||||
if err != nil {
|
||||
http.Error(w, "failed to get feedback stats", http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
|
||||
response := map[string]interface{}{
|
||||
"feedback": feedbackStats,
|
||||
}
|
||||
|
||||
// Add recalculator stats if available
|
||||
if recalculator != nil {
|
||||
response["recalculator"] = recalculator.GetStats()
|
||||
}
|
||||
|
||||
writeJSON(w, response)
|
||||
}
|
||||
|
||||
// handleGetTopObservations returns the highest-scoring observations.
|
||||
// GET /api/observations/top
|
||||
func (s *Service) handleGetTopObservations(w http.ResponseWriter, r *http.Request) {
|
||||
limit := parseIntParam(r, "limit", 10)
|
||||
project := r.URL.Query().Get("project")
|
||||
|
||||
s.initMu.RLock()
|
||||
observationStore := s.observationStore
|
||||
s.initMu.RUnlock()
|
||||
|
||||
if observationStore == nil {
|
||||
http.Error(w, "service not ready", http.StatusServiceUnavailable)
|
||||
return
|
||||
}
|
||||
|
||||
observations, err := observationStore.GetTopScoringObservations(r.Context(), project, limit)
|
||||
if err != nil {
|
||||
http.Error(w, "failed to get top observations", http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
|
||||
if observations == nil {
|
||||
observations = []*models.Observation{}
|
||||
}
|
||||
|
||||
writeJSON(w, observations)
|
||||
}
|
||||
|
||||
// handleGetMostRetrieved returns the most frequently retrieved observations.
|
||||
// GET /api/observations/most-retrieved
|
||||
func (s *Service) handleGetMostRetrieved(w http.ResponseWriter, r *http.Request) {
|
||||
limit := parseIntParam(r, "limit", 10)
|
||||
project := r.URL.Query().Get("project")
|
||||
|
||||
s.initMu.RLock()
|
||||
observationStore := s.observationStore
|
||||
s.initMu.RUnlock()
|
||||
|
||||
if observationStore == nil {
|
||||
http.Error(w, "service not ready", http.StatusServiceUnavailable)
|
||||
return
|
||||
}
|
||||
|
||||
observations, err := observationStore.GetMostRetrievedObservations(r.Context(), project, limit)
|
||||
if err != nil {
|
||||
http.Error(w, "failed to get most retrieved observations", http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
|
||||
if observations == nil {
|
||||
observations = []*models.Observation{}
|
||||
}
|
||||
|
||||
writeJSON(w, observations)
|
||||
}
|
||||
|
||||
// handleExplainScore returns a breakdown of how an observation's score was calculated.
|
||||
// GET /api/observations/{id}/score
|
||||
func (s *Service) handleExplainScore(w http.ResponseWriter, r *http.Request) {
|
||||
// Parse observation ID
|
||||
idStr := chi.URLParam(r, "id")
|
||||
id, err := strconv.ParseInt(idStr, 10, 64)
|
||||
if err != nil {
|
||||
http.Error(w, "invalid observation id", http.StatusBadRequest)
|
||||
return
|
||||
}
|
||||
|
||||
s.initMu.RLock()
|
||||
observationStore := s.observationStore
|
||||
scoreCalculator := s.scoreCalculator
|
||||
s.initMu.RUnlock()
|
||||
|
||||
if observationStore == nil || scoreCalculator == nil {
|
||||
http.Error(w, "service not ready", http.StatusServiceUnavailable)
|
||||
return
|
||||
}
|
||||
|
||||
// Get observation
|
||||
obs, err := observationStore.GetObservationByID(r.Context(), id)
|
||||
if err != nil {
|
||||
http.Error(w, "failed to get observation", http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
if obs == nil {
|
||||
http.Error(w, "observation not found", http.StatusNotFound)
|
||||
return
|
||||
}
|
||||
|
||||
// Calculate score components
|
||||
components := scoreCalculator.CalculateComponents(obs, time.Now())
|
||||
|
||||
writeJSON(w, map[string]interface{}{
|
||||
"id": id,
|
||||
"components": components,
|
||||
"config": scoreCalculator.GetConfig(),
|
||||
})
|
||||
}
|
||||
|
||||
// handleUpdateConceptWeight updates a concept weight.
|
||||
// PUT /api/scoring/concepts/{concept}
|
||||
func (s *Service) handleUpdateConceptWeight(w http.ResponseWriter, r *http.Request) {
|
||||
concept := chi.URLParam(r, "concept")
|
||||
if concept == "" {
|
||||
http.Error(w, "concept is required", http.StatusBadRequest)
|
||||
return
|
||||
}
|
||||
|
||||
var req struct {
|
||||
Weight float64 `json:"weight"`
|
||||
}
|
||||
if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
|
||||
http.Error(w, "invalid request body", http.StatusBadRequest)
|
||||
return
|
||||
}
|
||||
|
||||
// Validate weight
|
||||
if req.Weight < 0 || req.Weight > 1 {
|
||||
http.Error(w, "weight must be between 0 and 1", http.StatusBadRequest)
|
||||
return
|
||||
}
|
||||
|
||||
s.initMu.RLock()
|
||||
observationStore := s.observationStore
|
||||
recalculator := s.recalculator
|
||||
s.initMu.RUnlock()
|
||||
|
||||
if observationStore == nil {
|
||||
http.Error(w, "service not ready", http.StatusServiceUnavailable)
|
||||
return
|
||||
}
|
||||
|
||||
// Update in database
|
||||
if err := observationStore.UpdateConceptWeight(r.Context(), concept, req.Weight); err != nil {
|
||||
http.Error(w, "failed to update concept weight", http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
|
||||
// Refresh concept weights in recalculator
|
||||
if recalculator != nil {
|
||||
if err := recalculator.RefreshConceptWeights(r.Context()); err != nil {
|
||||
// Log but don't fail - weight was saved
|
||||
}
|
||||
}
|
||||
|
||||
writeJSON(w, map[string]interface{}{
|
||||
"status": "ok",
|
||||
"concept": concept,
|
||||
"weight": req.Weight,
|
||||
})
|
||||
}
|
||||
|
||||
// handleGetConceptWeights returns all concept weights.
|
||||
// GET /api/scoring/concepts
|
||||
func (s *Service) handleGetConceptWeights(w http.ResponseWriter, r *http.Request) {
|
||||
s.initMu.RLock()
|
||||
observationStore := s.observationStore
|
||||
s.initMu.RUnlock()
|
||||
|
||||
if observationStore == nil {
|
||||
http.Error(w, "service not ready", http.StatusServiceUnavailable)
|
||||
return
|
||||
}
|
||||
|
||||
weights, err := observationStore.GetConceptWeights(r.Context())
|
||||
if err != nil {
|
||||
http.Error(w, "failed to get concept weights", http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
|
||||
writeJSON(w, weights)
|
||||
}
|
||||
|
||||
// handleTriggerRecalculation triggers an immediate score recalculation.
|
||||
// POST /api/scoring/recalculate
|
||||
func (s *Service) handleTriggerRecalculation(w http.ResponseWriter, r *http.Request) {
|
||||
s.initMu.RLock()
|
||||
recalculator := s.recalculator
|
||||
s.initMu.RUnlock()
|
||||
|
||||
if recalculator == nil {
|
||||
http.Error(w, "recalculator not available", http.StatusServiceUnavailable)
|
||||
return
|
||||
}
|
||||
|
||||
// Run recalculation in background
|
||||
go func() {
|
||||
if err := recalculator.RecalculateNow(r.Context()); err != nil {
|
||||
// Log error but don't block response
|
||||
}
|
||||
}()
|
||||
|
||||
writeJSON(w, map[string]string{"status": "recalculation triggered"})
|
||||
}
|
||||
|
||||
// parseIntParam parses an integer query parameter with a default value.
|
||||
func parseIntParam(r *http.Request, name string, defaultVal int) int {
|
||||
if val := r.URL.Query().Get(name); val != "" {
|
||||
if parsed, err := strconv.Atoi(val); err == nil && parsed > 0 {
|
||||
return parsed
|
||||
}
|
||||
}
|
||||
return defaultVal
|
||||
}
|
||||
|
||||
// incrementRetrievalCounts increments retrieval counts for observations.
|
||||
// Called after search results are returned to track popularity.
|
||||
func (s *Service) incrementRetrievalCounts(ids []int64) {
|
||||
if len(ids) == 0 {
|
||||
return
|
||||
}
|
||||
|
||||
s.initMu.RLock()
|
||||
store := s.observationStore
|
||||
s.initMu.RUnlock()
|
||||
|
||||
if store == nil {
|
||||
return
|
||||
}
|
||||
|
||||
// Increment in background to not block response
|
||||
go func() {
|
||||
// Create a new context with timeout for the background operation
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 3*time.Second)
|
||||
defer cancel()
|
||||
|
||||
if err := store.IncrementRetrievalCount(ctx, ids); err != nil {
|
||||
// Log but don't fail - this is a background operation
|
||||
}
|
||||
}()
|
||||
}
|
||||
+490
-1
@@ -15,6 +15,10 @@ import (
|
||||
"github.com/lukaszraczylo/claude-mnemonic/internal/config"
|
||||
"github.com/lukaszraczylo/claude-mnemonic/internal/db/sqlite"
|
||||
"github.com/lukaszraczylo/claude-mnemonic/internal/embedding"
|
||||
"github.com/lukaszraczylo/claude-mnemonic/internal/pattern"
|
||||
"github.com/lukaszraczylo/claude-mnemonic/internal/reranking"
|
||||
"github.com/lukaszraczylo/claude-mnemonic/internal/scoring"
|
||||
"github.com/lukaszraczylo/claude-mnemonic/internal/search/expansion"
|
||||
"github.com/lukaszraczylo/claude-mnemonic/internal/update"
|
||||
"github.com/lukaszraczylo/claude-mnemonic/internal/vector/sqlitevec"
|
||||
"github.com/lukaszraczylo/claude-mnemonic/internal/watcher"
|
||||
@@ -64,6 +68,12 @@ type Service struct {
|
||||
observationStore *sqlite.ObservationStore
|
||||
summaryStore *sqlite.SummaryStore
|
||||
promptStore *sqlite.PromptStore
|
||||
conflictStore *sqlite.ConflictStore
|
||||
patternStore *sqlite.PatternStore
|
||||
relationStore *sqlite.RelationStore
|
||||
|
||||
// Pattern detection
|
||||
patternDetector *pattern.Detector
|
||||
|
||||
// Domain services
|
||||
sessionManager *session.Manager
|
||||
@@ -75,6 +85,16 @@ type Service struct {
|
||||
vectorClient *sqlitevec.Client
|
||||
vectorSync *sqlitevec.Sync
|
||||
|
||||
// Cross-encoder reranking (for improved search relevance)
|
||||
reranker *reranking.Service
|
||||
|
||||
// Query expansion (for improved search recall)
|
||||
queryExpander *expansion.Expander
|
||||
|
||||
// Importance scoring
|
||||
scoreCalculator *scoring.Calculator
|
||||
recalculator *scoring.Recalculator
|
||||
|
||||
// HTTP server
|
||||
router *chi.Mux
|
||||
server *http.Server
|
||||
@@ -177,6 +197,15 @@ func (s *Service) initializeAsync() {
|
||||
observationStore := sqlite.NewObservationStore(store)
|
||||
summaryStore := sqlite.NewSummaryStore(store)
|
||||
promptStore := sqlite.NewPromptStore(store)
|
||||
conflictStore := sqlite.NewConflictStore(store)
|
||||
patternStore := sqlite.NewPatternStore(store)
|
||||
relationStore := sqlite.NewRelationStore(store)
|
||||
|
||||
// Enable conflict detection by linking stores
|
||||
observationStore.SetConflictStore(conflictStore)
|
||||
|
||||
// Enable relation detection by linking stores
|
||||
observationStore.SetRelationStore(relationStore)
|
||||
|
||||
// Create session manager
|
||||
sessionManager := session.NewManager(sessionStore)
|
||||
@@ -186,6 +215,8 @@ func (s *Service) initializeAsync() {
|
||||
var vectorClient *sqlitevec.Client
|
||||
var vectorSync *sqlitevec.Sync
|
||||
|
||||
var reranker *reranking.Service
|
||||
|
||||
emb, err := embedding.NewService()
|
||||
if err != nil {
|
||||
log.Warn().Err(err).Msg("Embedding service creation failed - vector search disabled")
|
||||
@@ -200,8 +231,32 @@ func (s *Service) initializeAsync() {
|
||||
} else {
|
||||
vectorClient = client
|
||||
vectorSync = sqlitevec.NewSync(client)
|
||||
log.Info().Msg("sqlite-vec vector search enabled")
|
||||
log.Info().
|
||||
Str("model", embedSvc.Version()).
|
||||
Msg("sqlite-vec vector search enabled")
|
||||
}
|
||||
|
||||
// Create cross-encoder reranking service if enabled
|
||||
if s.config.RerankingEnabled {
|
||||
rerankCfg := reranking.DefaultConfig()
|
||||
if s.config.RerankingAlpha > 0 && s.config.RerankingAlpha <= 1 {
|
||||
rerankCfg.Alpha = s.config.RerankingAlpha
|
||||
}
|
||||
|
||||
ranker, err := reranking.NewService(rerankCfg)
|
||||
if err != nil {
|
||||
log.Warn().Err(err).Msg("Cross-encoder reranking service creation failed - reranking disabled")
|
||||
} else {
|
||||
reranker = ranker
|
||||
log.Info().
|
||||
Float64("alpha", rerankCfg.Alpha).
|
||||
Msg("Cross-encoder reranking enabled")
|
||||
}
|
||||
}
|
||||
|
||||
// Create query expander for improved search recall
|
||||
s.queryExpander = expansion.NewExpander(embedSvc)
|
||||
log.Info().Msg("Query expansion enabled")
|
||||
}
|
||||
|
||||
// Create SDK processor (optional - will be nil if Claude CLI not available)
|
||||
@@ -225,11 +280,38 @@ func (s *Service) initializeAsync() {
|
||||
s.observationStore = observationStore
|
||||
s.summaryStore = summaryStore
|
||||
s.promptStore = promptStore
|
||||
s.conflictStore = conflictStore
|
||||
s.patternStore = patternStore
|
||||
s.relationStore = relationStore
|
||||
s.sessionManager = sessionManager
|
||||
s.processor = processor
|
||||
s.embedSvc = embedSvc
|
||||
s.vectorClient = vectorClient
|
||||
s.vectorSync = vectorSync
|
||||
s.reranker = reranker
|
||||
s.initMu.Unlock()
|
||||
|
||||
// Initialize pattern detector
|
||||
patternDetector := pattern.NewDetector(patternStore, observationStore, pattern.DefaultConfig())
|
||||
|
||||
// Set pattern sync callback if vector sync is available
|
||||
if vectorSync != nil {
|
||||
patternDetector.SetSyncFunc(func(p *models.Pattern) {
|
||||
if err := vectorSync.SyncPattern(s.ctx, p); err != nil {
|
||||
log.Warn().Err(err).Int64("id", p.ID).Msg("Failed to sync pattern to sqlite-vec")
|
||||
}
|
||||
})
|
||||
|
||||
// Set cleanup callback for pattern deletions
|
||||
patternStore.SetCleanupFunc(func(ctx context.Context, deletedIDs []int64) {
|
||||
if err := vectorSync.DeletePatterns(ctx, deletedIDs); err != nil {
|
||||
log.Warn().Err(err).Ints64("ids", deletedIDs).Msg("Failed to delete patterns from sqlite-vec")
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
s.initMu.Lock()
|
||||
s.patternDetector = patternDetector
|
||||
s.initMu.Unlock()
|
||||
|
||||
// Set vector sync callbacks on processor if both are available
|
||||
@@ -238,6 +320,22 @@ func (s *Service) initializeAsync() {
|
||||
if err := vectorSync.SyncObservation(s.ctx, obs); err != nil {
|
||||
log.Warn().Err(err).Int64("id", obs.ID).Msg("Failed to sync observation to sqlite-vec")
|
||||
}
|
||||
// Trigger pattern detection for the new observation
|
||||
if patternDetector != nil {
|
||||
go func(observation *models.Observation) {
|
||||
detectCtx, cancel := context.WithTimeout(context.Background(), 10*time.Second)
|
||||
defer cancel()
|
||||
if result, err := patternDetector.AnalyzeObservation(detectCtx, observation); err != nil {
|
||||
log.Warn().Err(err).Int64("obs_id", observation.ID).Msg("Pattern detection failed")
|
||||
} else if result.MatchedPattern != nil {
|
||||
log.Debug().
|
||||
Int64("pattern_id", result.MatchedPattern.ID).
|
||||
Str("pattern_name", result.MatchedPattern.Name).
|
||||
Bool("is_new", result.IsNewPattern).
|
||||
Msg("Pattern matched for observation")
|
||||
}
|
||||
}(obs)
|
||||
}
|
||||
})
|
||||
processor.SetSyncSummaryFunc(func(summary *models.SessionSummary) {
|
||||
if err := vectorSync.SyncSummary(s.ctx, summary); err != nil {
|
||||
@@ -282,6 +380,37 @@ func (s *Service) initializeAsync() {
|
||||
})
|
||||
})
|
||||
|
||||
// Initialize importance scoring system
|
||||
scoringConfig := models.DefaultScoringConfig()
|
||||
|
||||
// Load concept weights from database if available
|
||||
if weights, err := observationStore.GetConceptWeights(s.ctx); err == nil && len(weights) > 0 {
|
||||
scoringConfig.ConceptWeights = weights
|
||||
log.Info().Int("count", len(weights)).Msg("Loaded concept weights from database")
|
||||
}
|
||||
|
||||
scoreCalculator := scoring.NewCalculator(scoringConfig)
|
||||
recalculator := scoring.NewRecalculator(observationStore, scoreCalculator, log.Logger)
|
||||
|
||||
s.initMu.Lock()
|
||||
s.scoreCalculator = scoreCalculator
|
||||
s.recalculator = recalculator
|
||||
s.initMu.Unlock()
|
||||
|
||||
// Start background recalculator
|
||||
s.wg.Add(1)
|
||||
go func() {
|
||||
defer s.wg.Done()
|
||||
recalculator.Start(s.ctx)
|
||||
}()
|
||||
log.Info().Msg("Importance scoring system initialized")
|
||||
|
||||
// Start pattern detector background analysis
|
||||
if patternDetector != nil {
|
||||
patternDetector.Start()
|
||||
log.Info().Msg("Pattern recognition engine started")
|
||||
}
|
||||
|
||||
// Mark as ready
|
||||
s.ready.Store(true)
|
||||
log.Info().Msg("Async initialization complete - service ready")
|
||||
@@ -294,6 +423,27 @@ func (s *Service) initializeAsync() {
|
||||
|
||||
// Start file watchers for auto-recreation on deletion
|
||||
s.startWatchers()
|
||||
|
||||
// Check if vectors need rebuilding (empty or model version mismatch) and trigger background rebuild
|
||||
if vectorClient != nil && vectorSync != nil {
|
||||
needsRebuild, reason := vectorClient.NeedsRebuild(s.ctx)
|
||||
if needsRebuild {
|
||||
log.Info().
|
||||
Str("reason", reason).
|
||||
Str("model", vectorClient.ModelVersion()).
|
||||
Msg("Vector rebuild required")
|
||||
|
||||
if reason == "empty" {
|
||||
// Full rebuild - vectors table is empty
|
||||
s.wg.Add(1)
|
||||
go s.rebuildAllVectors(observationStore, summaryStore, promptStore, vectorSync)
|
||||
} else {
|
||||
// Granular rebuild - only rebuild vectors with mismatched model versions
|
||||
s.wg.Add(1)
|
||||
go s.rebuildStaleVectors(observationStore, summaryStore, promptStore, vectorClient, vectorSync)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// startWatchers initializes and starts file watchers for database and config.
|
||||
@@ -384,6 +534,15 @@ func (s *Service) reinitializeDatabase() {
|
||||
observationStore := sqlite.NewObservationStore(store)
|
||||
summaryStore := sqlite.NewSummaryStore(store)
|
||||
promptStore := sqlite.NewPromptStore(store)
|
||||
conflictStore := sqlite.NewConflictStore(store)
|
||||
patternStore := sqlite.NewPatternStore(store)
|
||||
relationStore := sqlite.NewRelationStore(store)
|
||||
|
||||
// Enable conflict detection by linking stores
|
||||
observationStore.SetConflictStore(conflictStore)
|
||||
|
||||
// Enable relation detection by linking stores
|
||||
observationStore.SetRelationStore(relationStore)
|
||||
|
||||
// Create new session manager
|
||||
sessionManager := session.NewManager(sessionStore)
|
||||
@@ -393,6 +552,8 @@ func (s *Service) reinitializeDatabase() {
|
||||
var vectorClient *sqlitevec.Client
|
||||
var vectorSync *sqlitevec.Sync
|
||||
|
||||
var reranker *reranking.Service
|
||||
|
||||
emb, err := embedding.NewService()
|
||||
if err != nil {
|
||||
log.Warn().Err(err).Msg("Embedding service creation failed after reinit")
|
||||
@@ -408,6 +569,34 @@ func (s *Service) reinitializeDatabase() {
|
||||
vectorSync = sqlitevec.NewSync(client)
|
||||
log.Info().Msg("sqlite-vec reconnected after reinit")
|
||||
}
|
||||
|
||||
// Recreate cross-encoder reranking service if enabled
|
||||
if s.config.RerankingEnabled {
|
||||
rerankCfg := reranking.DefaultConfig()
|
||||
if s.config.RerankingAlpha > 0 && s.config.RerankingAlpha <= 1 {
|
||||
rerankCfg.Alpha = s.config.RerankingAlpha
|
||||
}
|
||||
|
||||
ranker, err := reranking.NewService(rerankCfg)
|
||||
if err != nil {
|
||||
log.Warn().Err(err).Msg("Cross-encoder reranking service creation failed after reinit")
|
||||
} else {
|
||||
reranker = ranker
|
||||
log.Info().Msg("Cross-encoder reranking reconnected after reinit")
|
||||
}
|
||||
}
|
||||
|
||||
// Recreate query expander
|
||||
s.queryExpander = expansion.NewExpander(embedSvc)
|
||||
log.Info().Msg("Query expansion reconnected after reinit")
|
||||
}
|
||||
|
||||
// Close old reranker if exists
|
||||
s.initMu.RLock()
|
||||
oldReranker := s.reranker
|
||||
s.initMu.RUnlock()
|
||||
if oldReranker != nil {
|
||||
_ = oldReranker.Close()
|
||||
}
|
||||
|
||||
// Recreate SDK processor with new stores
|
||||
@@ -422,6 +611,30 @@ func (s *Service) reinitializeDatabase() {
|
||||
})
|
||||
}
|
||||
|
||||
// Stop old pattern detector if it exists
|
||||
if s.patternDetector != nil {
|
||||
s.patternDetector.Stop()
|
||||
}
|
||||
|
||||
// Create new pattern detector
|
||||
patternDetector := pattern.NewDetector(patternStore, observationStore, pattern.DefaultConfig())
|
||||
|
||||
// Set pattern sync callback if vector sync is available
|
||||
if vectorSync != nil {
|
||||
patternDetector.SetSyncFunc(func(p *models.Pattern) {
|
||||
if err := vectorSync.SyncPattern(s.ctx, p); err != nil {
|
||||
log.Warn().Err(err).Int64("id", p.ID).Msg("Failed to sync pattern to sqlite-vec")
|
||||
}
|
||||
})
|
||||
|
||||
// Set cleanup callback for pattern deletions
|
||||
patternStore.SetCleanupFunc(func(ctx context.Context, deletedIDs []int64) {
|
||||
if err := vectorSync.DeletePatterns(ctx, deletedIDs); err != nil {
|
||||
log.Warn().Err(err).Ints64("ids", deletedIDs).Msg("Failed to delete patterns from sqlite-vec")
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
// Atomically swap all components
|
||||
s.initMu.Lock()
|
||||
s.store = store
|
||||
@@ -429,20 +642,44 @@ func (s *Service) reinitializeDatabase() {
|
||||
s.observationStore = observationStore
|
||||
s.summaryStore = summaryStore
|
||||
s.promptStore = promptStore
|
||||
s.conflictStore = conflictStore
|
||||
s.patternStore = patternStore
|
||||
s.relationStore = relationStore
|
||||
s.patternDetector = patternDetector
|
||||
s.sessionManager = sessionManager
|
||||
s.processor = processor
|
||||
s.embedSvc = embedSvc
|
||||
s.vectorClient = vectorClient
|
||||
s.vectorSync = vectorSync
|
||||
s.reranker = reranker
|
||||
s.initError = nil
|
||||
s.initMu.Unlock()
|
||||
|
||||
// Start pattern detector
|
||||
patternDetector.Start()
|
||||
|
||||
// Set vector sync callbacks on processor if both are available
|
||||
if processor != nil && vectorSync != nil {
|
||||
processor.SetSyncObservationFunc(func(obs *models.Observation) {
|
||||
if err := vectorSync.SyncObservation(s.ctx, obs); err != nil {
|
||||
log.Warn().Err(err).Int64("id", obs.ID).Msg("Failed to sync observation to sqlite-vec")
|
||||
}
|
||||
// Trigger pattern detection for the new observation
|
||||
if patternDetector != nil {
|
||||
go func(observation *models.Observation) {
|
||||
detectCtx, cancel := context.WithTimeout(context.Background(), 10*time.Second)
|
||||
defer cancel()
|
||||
if result, err := patternDetector.AnalyzeObservation(detectCtx, observation); err != nil {
|
||||
log.Warn().Err(err).Int64("obs_id", observation.ID).Msg("Pattern detection failed")
|
||||
} else if result.MatchedPattern != nil {
|
||||
log.Debug().
|
||||
Int64("pattern_id", result.MatchedPattern.ID).
|
||||
Str("pattern_name", result.MatchedPattern.Name).
|
||||
Bool("is_new", result.IsNewPattern).
|
||||
Msg("Pattern matched for observation")
|
||||
}
|
||||
}(obs)
|
||||
}
|
||||
})
|
||||
processor.SetSyncSummaryFunc(func(summary *models.SessionSummary) {
|
||||
if err := vectorSync.SyncSummary(s.ctx, summary); err != nil {
|
||||
@@ -565,6 +802,210 @@ func (s *Service) processStaleQueue() {
|
||||
}
|
||||
}
|
||||
|
||||
// rebuildAllVectors rebuilds all vectors from observations, summaries, and prompts.
|
||||
// Called when the vectors table is empty (e.g., after migration 20 drops all vectors).
|
||||
func (s *Service) rebuildAllVectors(
|
||||
observationStore *sqlite.ObservationStore,
|
||||
summaryStore *sqlite.SummaryStore,
|
||||
promptStore *sqlite.PromptStore,
|
||||
vectorSync *sqlitevec.Sync,
|
||||
) {
|
||||
defer s.wg.Done()
|
||||
|
||||
log.Info().Msg("Starting full vector rebuild...")
|
||||
start := time.Now()
|
||||
|
||||
var totalSynced int
|
||||
var syncErrors int
|
||||
|
||||
// Rebuild observations
|
||||
observations, err := observationStore.GetAllObservations(s.ctx)
|
||||
if err != nil {
|
||||
log.Error().Err(err).Msg("Failed to fetch observations for vector rebuild")
|
||||
} else {
|
||||
for _, obs := range observations {
|
||||
if err := vectorSync.SyncObservation(s.ctx, obs); err != nil {
|
||||
log.Warn().Err(err).Int64("id", obs.ID).Msg("Failed to sync observation during rebuild")
|
||||
syncErrors++
|
||||
} else {
|
||||
totalSynced++
|
||||
}
|
||||
}
|
||||
log.Info().Int("count", len(observations)).Msg("Rebuilt observation vectors")
|
||||
}
|
||||
|
||||
// Rebuild summaries
|
||||
summaries, err := summaryStore.GetAllSummaries(s.ctx)
|
||||
if err != nil {
|
||||
log.Error().Err(err).Msg("Failed to fetch summaries for vector rebuild")
|
||||
} else {
|
||||
for _, summary := range summaries {
|
||||
if err := vectorSync.SyncSummary(s.ctx, summary); err != nil {
|
||||
log.Warn().Err(err).Int64("id", summary.ID).Msg("Failed to sync summary during rebuild")
|
||||
syncErrors++
|
||||
} else {
|
||||
totalSynced++
|
||||
}
|
||||
}
|
||||
log.Info().Int("count", len(summaries)).Msg("Rebuilt summary vectors")
|
||||
}
|
||||
|
||||
// Rebuild user prompts
|
||||
prompts, err := promptStore.GetAllPrompts(s.ctx)
|
||||
if err != nil {
|
||||
log.Error().Err(err).Msg("Failed to fetch prompts for vector rebuild")
|
||||
} else {
|
||||
for _, prompt := range prompts {
|
||||
if err := vectorSync.SyncUserPrompt(s.ctx, prompt); err != nil {
|
||||
log.Warn().Err(err).Int64("id", prompt.ID).Msg("Failed to sync prompt during rebuild")
|
||||
syncErrors++
|
||||
} else {
|
||||
totalSynced++
|
||||
}
|
||||
}
|
||||
log.Info().Int("count", len(prompts)).Msg("Rebuilt prompt vectors")
|
||||
}
|
||||
|
||||
elapsed := time.Since(start)
|
||||
log.Info().
|
||||
Int("total_synced", totalSynced).
|
||||
Int("errors", syncErrors).
|
||||
Dur("elapsed", elapsed).
|
||||
Msg("Full vector rebuild complete")
|
||||
}
|
||||
|
||||
// rebuildStaleVectors rebuilds only vectors with mismatched or unknown model versions.
|
||||
// This is more efficient than rebuilding all vectors when only some need updating.
|
||||
func (s *Service) rebuildStaleVectors(
|
||||
observationStore *sqlite.ObservationStore,
|
||||
summaryStore *sqlite.SummaryStore,
|
||||
promptStore *sqlite.PromptStore,
|
||||
vectorClient *sqlitevec.Client,
|
||||
vectorSync *sqlitevec.Sync,
|
||||
) {
|
||||
defer s.wg.Done()
|
||||
|
||||
log.Info().Msg("Starting granular vector rebuild for stale vectors...")
|
||||
start := time.Now()
|
||||
|
||||
// Get all stale vectors
|
||||
staleVectors, err := vectorClient.GetStaleVectors(s.ctx)
|
||||
if err != nil {
|
||||
log.Error().Err(err).Msg("Failed to get stale vectors")
|
||||
return
|
||||
}
|
||||
|
||||
if len(staleVectors) == 0 {
|
||||
log.Info().Msg("No stale vectors found")
|
||||
return
|
||||
}
|
||||
|
||||
log.Info().Int("stale_count", len(staleVectors)).Msg("Found stale vectors to rebuild")
|
||||
|
||||
// Group stale vectors by doc_type and sqlite_id for efficient lookup
|
||||
staleObsIDs := make(map[int64]bool)
|
||||
staleSummaryIDs := make(map[int64]bool)
|
||||
stalePromptIDs := make(map[int64]bool)
|
||||
staleDocIDs := make([]string, 0, len(staleVectors))
|
||||
|
||||
for _, sv := range staleVectors {
|
||||
staleDocIDs = append(staleDocIDs, sv.DocID)
|
||||
switch sv.DocType {
|
||||
case "observation":
|
||||
staleObsIDs[sv.SQLiteID] = true
|
||||
case "summary":
|
||||
staleSummaryIDs[sv.SQLiteID] = true
|
||||
case "prompt":
|
||||
stalePromptIDs[sv.SQLiteID] = true
|
||||
}
|
||||
}
|
||||
|
||||
// Delete stale vectors before re-syncing
|
||||
if err := vectorClient.DeleteVectorsByDocIDs(s.ctx, staleDocIDs); err != nil {
|
||||
log.Error().Err(err).Msg("Failed to delete stale vectors")
|
||||
return
|
||||
}
|
||||
|
||||
var totalSynced int
|
||||
var syncErrors int
|
||||
|
||||
// Rebuild stale observations
|
||||
if len(staleObsIDs) > 0 {
|
||||
ids := make([]int64, 0, len(staleObsIDs))
|
||||
for id := range staleObsIDs {
|
||||
ids = append(ids, id)
|
||||
}
|
||||
|
||||
observations, err := observationStore.GetObservationsByIDs(s.ctx, ids, "date_desc", 0)
|
||||
if err != nil {
|
||||
log.Error().Err(err).Msg("Failed to fetch observations for rebuild")
|
||||
} else {
|
||||
for _, obs := range observations {
|
||||
if err := vectorSync.SyncObservation(s.ctx, obs); err != nil {
|
||||
log.Warn().Err(err).Int64("id", obs.ID).Msg("Failed to sync observation during rebuild")
|
||||
syncErrors++
|
||||
} else {
|
||||
totalSynced++
|
||||
}
|
||||
}
|
||||
log.Info().Int("count", len(observations)).Msg("Rebuilt stale observation vectors")
|
||||
}
|
||||
}
|
||||
|
||||
// Rebuild stale summaries
|
||||
if len(staleSummaryIDs) > 0 {
|
||||
ids := make([]int64, 0, len(staleSummaryIDs))
|
||||
for id := range staleSummaryIDs {
|
||||
ids = append(ids, id)
|
||||
}
|
||||
|
||||
summaries, err := summaryStore.GetSummariesByIDs(s.ctx, ids, "date_desc", 0)
|
||||
if err != nil {
|
||||
log.Error().Err(err).Msg("Failed to fetch summaries for rebuild")
|
||||
} else {
|
||||
for _, summary := range summaries {
|
||||
if err := vectorSync.SyncSummary(s.ctx, summary); err != nil {
|
||||
log.Warn().Err(err).Int64("id", summary.ID).Msg("Failed to sync summary during rebuild")
|
||||
syncErrors++
|
||||
} else {
|
||||
totalSynced++
|
||||
}
|
||||
}
|
||||
log.Info().Int("count", len(summaries)).Msg("Rebuilt stale summary vectors")
|
||||
}
|
||||
}
|
||||
|
||||
// Rebuild stale prompts
|
||||
if len(stalePromptIDs) > 0 {
|
||||
ids := make([]int64, 0, len(stalePromptIDs))
|
||||
for id := range stalePromptIDs {
|
||||
ids = append(ids, id)
|
||||
}
|
||||
|
||||
prompts, err := promptStore.GetPromptsByIDs(s.ctx, ids, "date_desc", 0)
|
||||
if err != nil {
|
||||
log.Error().Err(err).Msg("Failed to fetch prompts for rebuild")
|
||||
} else {
|
||||
for _, prompt := range prompts {
|
||||
if err := vectorSync.SyncUserPrompt(s.ctx, prompt); err != nil {
|
||||
log.Warn().Err(err).Int64("id", prompt.ID).Msg("Failed to sync prompt during rebuild")
|
||||
syncErrors++
|
||||
} else {
|
||||
totalSynced++
|
||||
}
|
||||
}
|
||||
log.Info().Int("count", len(prompts)).Msg("Rebuilt stale prompt vectors")
|
||||
}
|
||||
}
|
||||
|
||||
elapsed := time.Since(start)
|
||||
log.Info().
|
||||
Int("total_synced", totalSynced).
|
||||
Int("errors", syncErrors).
|
||||
Dur("elapsed", elapsed).
|
||||
Msg("Granular vector rebuild complete")
|
||||
}
|
||||
|
||||
// verifyStaleObservation verifies a single stale observation in the background.
|
||||
func (s *Service) verifyStaleObservation(req staleVerifyRequest) {
|
||||
// Wait for service to be ready
|
||||
@@ -667,11 +1108,42 @@ func (s *Service) setupRoutes() {
|
||||
r.Get("/api/stats", s.handleGetStats)
|
||||
r.Get("/api/stats/retrieval", s.handleGetRetrievalStats)
|
||||
r.Get("/api/types", s.handleGetTypes)
|
||||
r.Get("/api/models", s.handleGetModels)
|
||||
|
||||
// Observation scoring and feedback routes
|
||||
r.Post("/api/observations/{id}/feedback", s.handleObservationFeedback)
|
||||
r.Get("/api/observations/{id}/score", s.handleExplainScore)
|
||||
r.Get("/api/observations/top", s.handleGetTopObservations)
|
||||
r.Get("/api/observations/most-retrieved", s.handleGetMostRetrieved)
|
||||
|
||||
// Scoring configuration routes
|
||||
r.Get("/api/scoring/stats", s.handleGetScoringStats)
|
||||
r.Get("/api/scoring/concepts", s.handleGetConceptWeights)
|
||||
r.Put("/api/scoring/concepts/{concept}", s.handleUpdateConceptWeight)
|
||||
r.Post("/api/scoring/recalculate", s.handleTriggerRecalculation)
|
||||
|
||||
// Context injection
|
||||
r.Get("/api/context/count", s.handleContextCount)
|
||||
r.Get("/api/context/inject", s.handleContextInject)
|
||||
r.Get("/api/context/search", s.handleSearchByPrompt)
|
||||
|
||||
// Pattern routes
|
||||
r.Get("/api/patterns", s.handleGetPatterns)
|
||||
r.Get("/api/patterns/stats", s.handleGetPatternStats)
|
||||
r.Get("/api/patterns/search", s.handleSearchPatterns)
|
||||
r.Get("/api/patterns/by-name", s.handleGetPatternByName)
|
||||
r.Get("/api/patterns/{id}", s.handleGetPatternByID)
|
||||
r.Get("/api/patterns/{id}/insight", s.handleGetPatternInsight)
|
||||
r.Delete("/api/patterns/{id}", s.handleDeletePattern)
|
||||
r.Post("/api/patterns/{id}/deprecate", s.handleDeprecatePattern)
|
||||
r.Post("/api/patterns/merge", s.handleMergePatterns)
|
||||
|
||||
// Relation routes (knowledge graph)
|
||||
r.Get("/api/relations/stats", s.handleGetRelationStats)
|
||||
r.Get("/api/relations/type/{type}", s.handleGetRelationsByType)
|
||||
r.Get("/api/observations/{id}/relations", s.handleGetRelations)
|
||||
r.Get("/api/observations/{id}/graph", s.handleGetRelationGraph)
|
||||
r.Get("/api/observations/{id}/related", s.handleGetRelatedObservations)
|
||||
})
|
||||
}
|
||||
|
||||
@@ -894,6 +1366,16 @@ func (s *Service) Shutdown(ctx context.Context) error {
|
||||
_ = s.configWatcher.Stop()
|
||||
}
|
||||
|
||||
// Stop background recalculator
|
||||
if s.recalculator != nil {
|
||||
s.recalculator.Stop()
|
||||
}
|
||||
|
||||
// Stop pattern detector
|
||||
if s.patternDetector != nil {
|
||||
s.patternDetector.Stop()
|
||||
}
|
||||
|
||||
// Shutdown all sessions
|
||||
s.sessionManager.ShutdownAll(ctx)
|
||||
|
||||
@@ -904,6 +1386,13 @@ func (s *Service) Shutdown(ctx context.Context) error {
|
||||
}
|
||||
}
|
||||
|
||||
// Close reranking service
|
||||
if s.reranker != nil {
|
||||
if err := s.reranker.Close(); err != nil {
|
||||
log.Error().Err(err).Msg("Reranking service close error")
|
||||
}
|
||||
}
|
||||
|
||||
// Close embedding service
|
||||
if s.embedSvc != nil {
|
||||
if err := s.embedSvc.Close(); err != nil {
|
||||
|
||||
@@ -0,0 +1,258 @@
|
||||
// Package models contains domain models for claude-mnemonic.
|
||||
package models
|
||||
|
||||
import (
|
||||
"regexp"
|
||||
"strings"
|
||||
"time"
|
||||
)
|
||||
|
||||
// ConflictType represents the type of conflict between observations.
|
||||
type ConflictType string
|
||||
|
||||
const (
|
||||
// ConflictSuperseded means newer observation supersedes older one (same topic, updated info).
|
||||
ConflictSuperseded ConflictType = "superseded"
|
||||
// ConflictContradicts means observations contain contradictory information.
|
||||
ConflictContradicts ConflictType = "contradicts"
|
||||
// ConflictOutdatedPattern means an outdated pattern/practice was identified.
|
||||
ConflictOutdatedPattern ConflictType = "outdated_pattern"
|
||||
)
|
||||
|
||||
// ConflictResolution indicates which observation to prefer.
|
||||
type ConflictResolution string
|
||||
|
||||
const (
|
||||
// ResolutionPreferNewer means prefer the newer observation.
|
||||
ResolutionPreferNewer ConflictResolution = "prefer_newer"
|
||||
// ResolutionPreferOlder means prefer the older observation (rare).
|
||||
ResolutionPreferOlder ConflictResolution = "prefer_older"
|
||||
// ResolutionManual means manual review is needed.
|
||||
ResolutionManual ConflictResolution = "manual"
|
||||
)
|
||||
|
||||
// ObservationConflict tracks conflicting observations.
|
||||
type ObservationConflict struct {
|
||||
ID int64 `db:"id" json:"id"`
|
||||
NewerObsID int64 `db:"newer_obs_id" json:"newer_obs_id"`
|
||||
OlderObsID int64 `db:"older_obs_id" json:"older_obs_id"`
|
||||
ConflictType ConflictType `db:"conflict_type" json:"conflict_type"`
|
||||
Resolution ConflictResolution `db:"resolution" json:"resolution"`
|
||||
Reason string `db:"reason" json:"reason"`
|
||||
DetectedAt string `db:"detected_at" json:"detected_at"`
|
||||
DetectedAtEpoch int64 `db:"detected_at_epoch" json:"detected_at_epoch"`
|
||||
Resolved bool `db:"resolved" json:"resolved"`
|
||||
ResolvedAt *string `db:"resolved_at" json:"resolved_at,omitempty"`
|
||||
}
|
||||
|
||||
// ConflictDetectionResult contains the result of conflict detection.
|
||||
type ConflictDetectionResult struct {
|
||||
HasConflict bool
|
||||
Type ConflictType
|
||||
Resolution ConflictResolution
|
||||
Reason string
|
||||
OlderObsIDs []int64 // IDs of observations that conflict with the new one
|
||||
}
|
||||
|
||||
// NewObservationConflict creates a new conflict record.
|
||||
func NewObservationConflict(newerID, olderID int64, conflictType ConflictType, resolution ConflictResolution, reason string) *ObservationConflict {
|
||||
now := time.Now()
|
||||
return &ObservationConflict{
|
||||
NewerObsID: newerID,
|
||||
OlderObsID: olderID,
|
||||
ConflictType: conflictType,
|
||||
Resolution: resolution,
|
||||
Reason: reason,
|
||||
DetectedAt: now.Format(time.RFC3339),
|
||||
DetectedAtEpoch: now.UnixMilli(),
|
||||
Resolved: false,
|
||||
}
|
||||
}
|
||||
|
||||
// CorrectionPatterns contains regex patterns that indicate explicit corrections.
|
||||
var CorrectionPatterns = []*regexp.Regexp{
|
||||
regexp.MustCompile(`(?i)\bactually[,\s]+that\s+was\s+wrong\b`),
|
||||
regexp.MustCompile(`(?i)\bactually[,\s]+that's\s+(wrong|incorrect|not\s+right)\b`),
|
||||
regexp.MustCompile(`(?i)\bpreviously\s+(said|mentioned|noted)\s+.*\s+but\b`),
|
||||
regexp.MustCompile(`(?i)\bcorrection:\s*`),
|
||||
regexp.MustCompile(`(?i)\bignore\s+(the\s+)?(previous|earlier)\b`),
|
||||
regexp.MustCompile(`(?i)\bdisregard\s+(the\s+)?(previous|earlier)\b`),
|
||||
regexp.MustCompile(`(?i)\bwas\s+(wrong|incorrect|mistaken)\b`),
|
||||
regexp.MustCompile(`(?i)\bturns\s+out\s+.*(wrong|incorrect|not\s+the\s+case)\b`),
|
||||
regexp.MustCompile(`(?i)\b(supersedes|replaces|overrides)\s+(the\s+)?(previous|earlier|old)\b`),
|
||||
regexp.MustCompile(`(?i)\b(don't|do\s+not)\s+use\s+.*\s+anymore\b`),
|
||||
regexp.MustCompile(`(?i)\bno\s+longer\s+(valid|applicable|correct|recommended)\b`),
|
||||
regexp.MustCompile(`(?i)\bdeprecated\s+(approach|method|pattern|way)\b`),
|
||||
regexp.MustCompile(`(?i)\bshould\s+have\s+(been|used)\b.*instead\b`),
|
||||
regexp.MustCompile(`(?i)\bbetter\s+(approach|way|method|solution)\s+is\b`),
|
||||
}
|
||||
|
||||
// OpposingChangePatterns detects add/remove conflicts.
|
||||
var OpposingChangePatterns = map[string]string{
|
||||
"add": "remove",
|
||||
"added": "removed",
|
||||
"create": "delete",
|
||||
"created": "deleted",
|
||||
"enable": "disable",
|
||||
"enabled": "disabled",
|
||||
"include": "exclude",
|
||||
"allow": "deny",
|
||||
"permit": "block",
|
||||
}
|
||||
|
||||
// DetectExplicitCorrection checks if text contains explicit correction language.
|
||||
func DetectExplicitCorrection(text string) (bool, string) {
|
||||
for _, pattern := range CorrectionPatterns {
|
||||
if match := pattern.FindString(text); match != "" {
|
||||
return true, "Explicit correction detected: " + match
|
||||
}
|
||||
}
|
||||
return false, ""
|
||||
}
|
||||
|
||||
// DetectOpposingFileChanges checks if two observations have opposing changes on the same file.
|
||||
func DetectOpposingFileChanges(newer, older *Observation) (bool, string) {
|
||||
// Check for overlapping modified files
|
||||
newerFiles := make(map[string]bool)
|
||||
for _, f := range newer.FilesModified {
|
||||
newerFiles[f] = true
|
||||
}
|
||||
|
||||
var overlappingFiles []string
|
||||
for _, f := range older.FilesModified {
|
||||
if newerFiles[f] {
|
||||
overlappingFiles = append(overlappingFiles, f)
|
||||
}
|
||||
}
|
||||
|
||||
if len(overlappingFiles) == 0 {
|
||||
return false, ""
|
||||
}
|
||||
|
||||
// Check for opposing action words in titles/narratives
|
||||
newerText := strings.ToLower(newer.Title.String + " " + newer.Narrative.String)
|
||||
olderText := strings.ToLower(older.Title.String + " " + older.Narrative.String)
|
||||
|
||||
for action, opposite := range OpposingChangePatterns {
|
||||
if (strings.Contains(newerText, action) && strings.Contains(olderText, opposite)) ||
|
||||
(strings.Contains(newerText, opposite) && strings.Contains(olderText, action)) {
|
||||
return true, "Opposing changes on files: " + strings.Join(overlappingFiles, ", ")
|
||||
}
|
||||
}
|
||||
|
||||
return false, ""
|
||||
}
|
||||
|
||||
// DetectConceptTagMismatch checks if observations have same concepts but different recommendations.
|
||||
func DetectConceptTagMismatch(newer, older *Observation) (bool, string) {
|
||||
// Find overlapping concepts
|
||||
newerConcepts := make(map[string]bool)
|
||||
for _, c := range newer.Concepts {
|
||||
newerConcepts[c] = true
|
||||
}
|
||||
|
||||
var overlapping []string
|
||||
for _, c := range older.Concepts {
|
||||
if newerConcepts[c] {
|
||||
overlapping = append(overlapping, c)
|
||||
}
|
||||
}
|
||||
|
||||
if len(overlapping) == 0 {
|
||||
return false, ""
|
||||
}
|
||||
|
||||
// Check if same file was modified and concepts overlap
|
||||
// This suggests the newer observation may update the approach
|
||||
newerFiles := make(map[string]bool)
|
||||
for _, f := range newer.FilesModified {
|
||||
newerFiles[f] = true
|
||||
}
|
||||
for _, f := range older.FilesModified {
|
||||
if newerFiles[f] {
|
||||
// Same file modified with same concepts - likely an update
|
||||
return true, "Same concepts (" + strings.Join(overlapping, ", ") + ") with overlapping file changes"
|
||||
}
|
||||
}
|
||||
|
||||
return false, ""
|
||||
}
|
||||
|
||||
// DetectConflict performs comprehensive conflict detection between a new observation
|
||||
// and an existing one. Returns detection result.
|
||||
func DetectConflict(newer, older *Observation) *ConflictDetectionResult {
|
||||
result := &ConflictDetectionResult{
|
||||
HasConflict: false,
|
||||
}
|
||||
|
||||
// 1. Check for explicit correction language in newer observation
|
||||
if newer.Narrative.Valid {
|
||||
if isCorrection, reason := DetectExplicitCorrection(newer.Narrative.String); isCorrection {
|
||||
result.HasConflict = true
|
||||
result.Type = ConflictContradicts
|
||||
result.Resolution = ResolutionPreferNewer
|
||||
result.Reason = reason
|
||||
result.OlderObsIDs = append(result.OlderObsIDs, older.ID)
|
||||
return result
|
||||
}
|
||||
}
|
||||
|
||||
// Check title as well
|
||||
if newer.Title.Valid {
|
||||
if isCorrection, reason := DetectExplicitCorrection(newer.Title.String); isCorrection {
|
||||
result.HasConflict = true
|
||||
result.Type = ConflictContradicts
|
||||
result.Resolution = ResolutionPreferNewer
|
||||
result.Reason = reason
|
||||
result.OlderObsIDs = append(result.OlderObsIDs, older.ID)
|
||||
return result
|
||||
}
|
||||
}
|
||||
|
||||
// 2. Check for opposing file changes
|
||||
if isOpposing, reason := DetectOpposingFileChanges(newer, older); isOpposing {
|
||||
result.HasConflict = true
|
||||
result.Type = ConflictSuperseded
|
||||
result.Resolution = ResolutionPreferNewer
|
||||
result.Reason = reason
|
||||
result.OlderObsIDs = append(result.OlderObsIDs, older.ID)
|
||||
return result
|
||||
}
|
||||
|
||||
// 3. Check for concept tag mismatches with same files
|
||||
if isMismatch, reason := DetectConceptTagMismatch(newer, older); isMismatch {
|
||||
result.HasConflict = true
|
||||
result.Type = ConflictSuperseded
|
||||
result.Resolution = ResolutionPreferNewer
|
||||
result.Reason = reason
|
||||
result.OlderObsIDs = append(result.OlderObsIDs, older.ID)
|
||||
return result
|
||||
}
|
||||
|
||||
return result
|
||||
}
|
||||
|
||||
// DetectConflictsWithExisting checks a new observation against a list of existing observations.
|
||||
// Returns all detected conflicts.
|
||||
func DetectConflictsWithExisting(newer *Observation, existing []*Observation) []*ConflictDetectionResult {
|
||||
var results []*ConflictDetectionResult
|
||||
|
||||
for _, older := range existing {
|
||||
// Skip self-comparison
|
||||
if older.ID == newer.ID {
|
||||
continue
|
||||
}
|
||||
|
||||
// Only compare within same project (or both global)
|
||||
if newer.Project != older.Project && newer.Scope != ScopeGlobal && older.Scope != ScopeGlobal {
|
||||
continue
|
||||
}
|
||||
|
||||
result := DetectConflict(newer, older)
|
||||
if result.HasConflict {
|
||||
results = append(results, result)
|
||||
}
|
||||
}
|
||||
|
||||
return results
|
||||
}
|
||||
@@ -0,0 +1,421 @@
|
||||
// Package models contains domain models for claude-mnemonic.
|
||||
package models
|
||||
|
||||
import (
|
||||
"database/sql"
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/suite"
|
||||
)
|
||||
|
||||
// ConflictSuite is a test suite for conflict detection operations.
|
||||
type ConflictSuite struct {
|
||||
suite.Suite
|
||||
}
|
||||
|
||||
func TestConflictSuite(t *testing.T) {
|
||||
suite.Run(t, new(ConflictSuite))
|
||||
}
|
||||
|
||||
// TestConflictTypeConstants tests conflict type constants.
|
||||
func (s *ConflictSuite) TestConflictTypeConstants() {
|
||||
s.Equal(ConflictType("superseded"), ConflictSuperseded)
|
||||
s.Equal(ConflictType("contradicts"), ConflictContradicts)
|
||||
s.Equal(ConflictType("outdated_pattern"), ConflictOutdatedPattern)
|
||||
}
|
||||
|
||||
// TestResolutionConstants tests resolution constants.
|
||||
func (s *ConflictSuite) TestResolutionConstants() {
|
||||
s.Equal(ConflictResolution("prefer_newer"), ResolutionPreferNewer)
|
||||
s.Equal(ConflictResolution("prefer_older"), ResolutionPreferOlder)
|
||||
s.Equal(ConflictResolution("manual"), ResolutionManual)
|
||||
}
|
||||
|
||||
// TestNewObservationConflict tests conflict creation.
|
||||
func (s *ConflictSuite) TestNewObservationConflict() {
|
||||
conflict := NewObservationConflict(2, 1, ConflictSuperseded, ResolutionPreferNewer, "Test reason")
|
||||
|
||||
s.Equal(int64(2), conflict.NewerObsID)
|
||||
s.Equal(int64(1), conflict.OlderObsID)
|
||||
s.Equal(ConflictSuperseded, conflict.ConflictType)
|
||||
s.Equal(ResolutionPreferNewer, conflict.Resolution)
|
||||
s.Equal("Test reason", conflict.Reason)
|
||||
s.False(conflict.Resolved)
|
||||
s.NotEmpty(conflict.DetectedAt)
|
||||
s.Greater(conflict.DetectedAtEpoch, int64(0))
|
||||
}
|
||||
|
||||
// TestDetectExplicitCorrection_TableDriven tests explicit correction detection.
|
||||
func (s *ConflictSuite) TestDetectExplicitCorrection_TableDriven() {
|
||||
tests := []struct {
|
||||
name string
|
||||
text string
|
||||
expectMatch bool
|
||||
expectPattern string
|
||||
}{
|
||||
{
|
||||
name: "actually that was wrong",
|
||||
text: "Actually, that was wrong - we should use a different approach",
|
||||
expectMatch: true,
|
||||
expectPattern: "actually, that was wrong",
|
||||
},
|
||||
{
|
||||
name: "correction prefix",
|
||||
text: "Correction: the previous implementation had a bug",
|
||||
expectMatch: true,
|
||||
expectPattern: "correction:",
|
||||
},
|
||||
{
|
||||
name: "ignore previous",
|
||||
text: "Please ignore the previous recommendation",
|
||||
expectMatch: true,
|
||||
expectPattern: "ignore",
|
||||
},
|
||||
{
|
||||
name: "disregard earlier",
|
||||
text: "Disregard the earlier suggestion, it was flawed",
|
||||
expectMatch: true,
|
||||
expectPattern: "disregard",
|
||||
},
|
||||
{
|
||||
name: "was wrong",
|
||||
text: "The original approach was wrong",
|
||||
expectMatch: true,
|
||||
expectPattern: "was wrong",
|
||||
},
|
||||
{
|
||||
name: "no longer valid",
|
||||
text: "This method is no longer valid after the refactor",
|
||||
expectMatch: true,
|
||||
expectPattern: "no longer valid",
|
||||
},
|
||||
{
|
||||
name: "deprecated approach",
|
||||
text: "This is a deprecated approach that should not be used",
|
||||
expectMatch: true,
|
||||
expectPattern: "deprecated approach",
|
||||
},
|
||||
{
|
||||
name: "better approach is",
|
||||
text: "A better approach is to use the new API",
|
||||
expectMatch: true,
|
||||
expectPattern: "better approach is",
|
||||
},
|
||||
{
|
||||
name: "normal text - no correction",
|
||||
text: "This is a normal observation about the code",
|
||||
expectMatch: false,
|
||||
},
|
||||
{
|
||||
name: "empty text",
|
||||
text: "",
|
||||
expectMatch: false,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
s.Run(tt.name, func() {
|
||||
found, reason := DetectExplicitCorrection(tt.text)
|
||||
s.Equal(tt.expectMatch, found)
|
||||
if tt.expectMatch {
|
||||
s.Contains(reason, "Explicit correction detected")
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// TestDetectOpposingFileChanges_TableDriven tests opposing file change detection.
|
||||
func (s *ConflictSuite) TestDetectOpposingFileChanges_TableDriven() {
|
||||
tests := []struct {
|
||||
name string
|
||||
newerObs *Observation
|
||||
olderObs *Observation
|
||||
expectConflict bool
|
||||
}{
|
||||
{
|
||||
name: "add then remove - conflict",
|
||||
newerObs: &Observation{
|
||||
Title: sql.NullString{String: "Remove authentication middleware", Valid: true},
|
||||
Narrative: sql.NullString{String: "Removed the auth middleware from handlers", Valid: true},
|
||||
FilesModified: []string{"middleware.go", "handler.go"},
|
||||
},
|
||||
olderObs: &Observation{
|
||||
Title: sql.NullString{String: "Add authentication middleware", Valid: true},
|
||||
Narrative: sql.NullString{String: "Added auth middleware to secure endpoints", Valid: true},
|
||||
FilesModified: []string{"middleware.go", "handler.go"},
|
||||
},
|
||||
expectConflict: true,
|
||||
},
|
||||
{
|
||||
name: "enable then disable - conflict",
|
||||
newerObs: &Observation{
|
||||
Title: sql.NullString{String: "Disable caching feature", Valid: true},
|
||||
Narrative: sql.NullString{String: "Disabled caching due to issues", Valid: true},
|
||||
FilesModified: []string{"cache.go"},
|
||||
},
|
||||
olderObs: &Observation{
|
||||
Title: sql.NullString{String: "Enable caching feature", Valid: true},
|
||||
Narrative: sql.NullString{String: "Enabled caching for performance", Valid: true},
|
||||
FilesModified: []string{"cache.go"},
|
||||
},
|
||||
expectConflict: true,
|
||||
},
|
||||
{
|
||||
name: "different files - no conflict",
|
||||
newerObs: &Observation{
|
||||
Title: sql.NullString{String: "Remove old code", Valid: true},
|
||||
Narrative: sql.NullString{String: "Removed deprecated functions", Valid: true},
|
||||
FilesModified: []string{"old.go"},
|
||||
},
|
||||
olderObs: &Observation{
|
||||
Title: sql.NullString{String: "Add new feature", Valid: true},
|
||||
Narrative: sql.NullString{String: "Added new functions", Valid: true},
|
||||
FilesModified: []string{"new.go"},
|
||||
},
|
||||
expectConflict: false,
|
||||
},
|
||||
{
|
||||
name: "same files but no opposing keywords - no conflict",
|
||||
newerObs: &Observation{
|
||||
Title: sql.NullString{String: "Update handler logic", Valid: true},
|
||||
Narrative: sql.NullString{String: "Updated the handler implementation", Valid: true},
|
||||
FilesModified: []string{"handler.go"},
|
||||
},
|
||||
olderObs: &Observation{
|
||||
Title: sql.NullString{String: "Fix handler bug", Valid: true},
|
||||
Narrative: sql.NullString{String: "Fixed a bug in handler", Valid: true},
|
||||
FilesModified: []string{"handler.go"},
|
||||
},
|
||||
expectConflict: false,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
s.Run(tt.name, func() {
|
||||
found, _ := DetectOpposingFileChanges(tt.newerObs, tt.olderObs)
|
||||
s.Equal(tt.expectConflict, found)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// TestDetectConceptTagMismatch_TableDriven tests concept tag mismatch detection.
|
||||
func (s *ConflictSuite) TestDetectConceptTagMismatch_TableDriven() {
|
||||
tests := []struct {
|
||||
name string
|
||||
newerObs *Observation
|
||||
olderObs *Observation
|
||||
expectConflict bool
|
||||
}{
|
||||
{
|
||||
name: "same concepts and files - conflict",
|
||||
newerObs: &Observation{
|
||||
Concepts: []string{"security", "authentication"},
|
||||
FilesModified: []string{"auth.go"},
|
||||
},
|
||||
olderObs: &Observation{
|
||||
Concepts: []string{"security", "authentication"},
|
||||
FilesModified: []string{"auth.go"},
|
||||
},
|
||||
expectConflict: true,
|
||||
},
|
||||
{
|
||||
name: "overlapping concepts different files - no conflict",
|
||||
newerObs: &Observation{
|
||||
Concepts: []string{"security"},
|
||||
FilesModified: []string{"new_auth.go"},
|
||||
},
|
||||
olderObs: &Observation{
|
||||
Concepts: []string{"security"},
|
||||
FilesModified: []string{"old_auth.go"},
|
||||
},
|
||||
expectConflict: false,
|
||||
},
|
||||
{
|
||||
name: "different concepts same files - no conflict",
|
||||
newerObs: &Observation{
|
||||
Concepts: []string{"performance"},
|
||||
FilesModified: []string{"handler.go"},
|
||||
},
|
||||
olderObs: &Observation{
|
||||
Concepts: []string{"testing"},
|
||||
FilesModified: []string{"handler.go"},
|
||||
},
|
||||
expectConflict: false,
|
||||
},
|
||||
{
|
||||
name: "no overlapping concepts or files - no conflict",
|
||||
newerObs: &Observation{
|
||||
Concepts: []string{"security"},
|
||||
FilesModified: []string{"auth.go"},
|
||||
},
|
||||
olderObs: &Observation{
|
||||
Concepts: []string{"testing"},
|
||||
FilesModified: []string{"test.go"},
|
||||
},
|
||||
expectConflict: false,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
s.Run(tt.name, func() {
|
||||
found, _ := DetectConceptTagMismatch(tt.newerObs, tt.olderObs)
|
||||
s.Equal(tt.expectConflict, found)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// TestDetectConflict tests comprehensive conflict detection.
|
||||
func (s *ConflictSuite) TestDetectConflict() {
|
||||
// Test explicit correction takes precedence
|
||||
newer := &Observation{
|
||||
ID: 2,
|
||||
Project: "test",
|
||||
Narrative: sql.NullString{String: "Actually, that was wrong. We should use a different approach.", Valid: true},
|
||||
}
|
||||
older := &Observation{
|
||||
ID: 1,
|
||||
Project: "test",
|
||||
}
|
||||
|
||||
result := DetectConflict(newer, older)
|
||||
s.True(result.HasConflict)
|
||||
s.Equal(ConflictContradicts, result.Type)
|
||||
s.Equal(ResolutionPreferNewer, result.Resolution)
|
||||
s.Contains(result.Reason, "Explicit correction")
|
||||
}
|
||||
|
||||
// TestDetectConflictsWithExisting tests conflict detection against multiple observations.
|
||||
func (s *ConflictSuite) TestDetectConflictsWithExisting() {
|
||||
newer := &Observation{
|
||||
ID: 3,
|
||||
Project: "test",
|
||||
Narrative: sql.NullString{String: "Actually, that was wrong", Valid: true},
|
||||
Concepts: []string{"security"},
|
||||
FilesModified: []string{"auth.go"},
|
||||
}
|
||||
|
||||
existing := []*Observation{
|
||||
{
|
||||
ID: 1,
|
||||
Project: "test",
|
||||
Concepts: []string{"security"},
|
||||
FilesModified: []string{"auth.go"},
|
||||
},
|
||||
{
|
||||
ID: 2,
|
||||
Project: "test",
|
||||
Concepts: []string{"testing"},
|
||||
FilesModified: []string{"test.go"},
|
||||
},
|
||||
{
|
||||
ID: 3, // Same as newer - should be skipped
|
||||
Project: "test",
|
||||
},
|
||||
}
|
||||
|
||||
results := DetectConflictsWithExisting(newer, existing)
|
||||
|
||||
// Should find conflicts with obs 1 (concepts + files overlap + correction language)
|
||||
// but not with obs 2 (different concepts/files) or obs 3 (same ID)
|
||||
s.GreaterOrEqual(len(results), 1)
|
||||
|
||||
// At least one result should reference obs 1
|
||||
foundObs1 := false
|
||||
for _, r := range results {
|
||||
for _, id := range r.OlderObsIDs {
|
||||
if id == 1 {
|
||||
foundObs1 = true
|
||||
break
|
||||
}
|
||||
}
|
||||
}
|
||||
s.True(foundObs1)
|
||||
}
|
||||
|
||||
// TestDetectConflictsWithExisting_DifferentProjects tests that different projects don't conflict.
|
||||
func (s *ConflictSuite) TestDetectConflictsWithExisting_DifferentProjects() {
|
||||
newer := &Observation{
|
||||
ID: 2,
|
||||
Project: "project-a",
|
||||
Scope: ScopeProject,
|
||||
Narrative: sql.NullString{String: "Actually, that was wrong", Valid: true},
|
||||
}
|
||||
|
||||
existing := []*Observation{
|
||||
{
|
||||
ID: 1,
|
||||
Project: "project-b",
|
||||
Scope: ScopeProject,
|
||||
},
|
||||
}
|
||||
|
||||
results := DetectConflictsWithExisting(newer, existing)
|
||||
s.Empty(results) // Different projects should not conflict
|
||||
}
|
||||
|
||||
// TestDetectConflictsWithExisting_GlobalScope tests that global observations can conflict across projects.
|
||||
func (s *ConflictSuite) TestDetectConflictsWithExisting_GlobalScope() {
|
||||
newer := &Observation{
|
||||
ID: 2,
|
||||
Project: "project-a",
|
||||
Scope: ScopeGlobal,
|
||||
Narrative: sql.NullString{String: "Actually, that was wrong", Valid: true},
|
||||
Concepts: []string{"security"},
|
||||
FilesModified: []string{"auth.go"},
|
||||
}
|
||||
|
||||
existing := []*Observation{
|
||||
{
|
||||
ID: 1,
|
||||
Project: "project-b",
|
||||
Scope: ScopeGlobal, // Global scope allows cross-project conflict detection
|
||||
Concepts: []string{"security"},
|
||||
FilesModified: []string{"auth.go"},
|
||||
},
|
||||
}
|
||||
|
||||
results := DetectConflictsWithExisting(newer, existing)
|
||||
s.GreaterOrEqual(len(results), 1) // Global scope allows conflict detection
|
||||
}
|
||||
|
||||
// TestCorrectionPatterns_Compiled ensures all patterns compile correctly.
|
||||
func TestCorrectionPatterns_Compiled(t *testing.T) {
|
||||
// This test verifies that all correction patterns are valid regexps
|
||||
// If any pattern fails to compile, the package won't load
|
||||
assert.NotEmpty(t, CorrectionPatterns)
|
||||
for i, pattern := range CorrectionPatterns {
|
||||
assert.NotNil(t, pattern, "Pattern %d should not be nil", i)
|
||||
}
|
||||
}
|
||||
|
||||
// TestOpposingChangePatterns tests the opposing change pattern map.
|
||||
func TestOpposingChangePatterns(t *testing.T) {
|
||||
assert.NotEmpty(t, OpposingChangePatterns)
|
||||
assert.Equal(t, "remove", OpposingChangePatterns["add"])
|
||||
assert.Equal(t, "removed", OpposingChangePatterns["added"])
|
||||
assert.Equal(t, "delete", OpposingChangePatterns["create"])
|
||||
assert.Equal(t, "disable", OpposingChangePatterns["enable"])
|
||||
}
|
||||
|
||||
// TestObservationConflict_Fields tests field access.
|
||||
func TestObservationConflict_Fields(t *testing.T) {
|
||||
conflict := &ObservationConflict{
|
||||
ID: 1,
|
||||
NewerObsID: 10,
|
||||
OlderObsID: 5,
|
||||
ConflictType: ConflictSuperseded,
|
||||
Resolution: ResolutionPreferNewer,
|
||||
Reason: "Test reason",
|
||||
DetectedAt: "2024-01-01T00:00:00Z",
|
||||
DetectedAtEpoch: 1704067200000,
|
||||
Resolved: false,
|
||||
}
|
||||
|
||||
assert.Equal(t, int64(1), conflict.ID)
|
||||
assert.Equal(t, int64(10), conflict.NewerObsID)
|
||||
assert.Equal(t, int64(5), conflict.OlderObsID)
|
||||
assert.Equal(t, ConflictSuperseded, conflict.ConflictType)
|
||||
assert.Equal(t, ResolutionPreferNewer, conflict.Resolution)
|
||||
assert.False(t, conflict.Resolved)
|
||||
}
|
||||
@@ -139,6 +139,16 @@ type Observation struct {
|
||||
CreatedAt string `db:"created_at" json:"created_at"`
|
||||
CreatedAtEpoch int64 `db:"created_at_epoch" json:"created_at_epoch"`
|
||||
IsStale bool `db:"-" json:"is_stale,omitempty"`
|
||||
|
||||
// Importance scoring fields
|
||||
ImportanceScore float64 `db:"importance_score" json:"importance_score"`
|
||||
UserFeedback int `db:"user_feedback" json:"user_feedback"`
|
||||
RetrievalCount int `db:"retrieval_count" json:"retrieval_count"`
|
||||
LastRetrievedAt sql.NullInt64 `db:"last_retrieved_at_epoch" json:"last_retrieved_at_epoch,omitempty"`
|
||||
ScoreUpdatedAt sql.NullInt64 `db:"score_updated_at_epoch" json:"score_updated_at_epoch,omitempty"`
|
||||
|
||||
// Conflict detection fields
|
||||
IsSuperseded bool `db:"is_superseded" json:"is_superseded,omitempty"`
|
||||
}
|
||||
|
||||
// ParsedObservation represents an observation parsed from SDK response XML.
|
||||
@@ -205,6 +215,16 @@ type ObservationJSON struct {
|
||||
CreatedAt string `json:"created_at"`
|
||||
CreatedAtEpoch int64 `json:"created_at_epoch"`
|
||||
IsStale bool `json:"is_stale,omitempty"`
|
||||
|
||||
// Importance scoring fields
|
||||
ImportanceScore float64 `json:"importance_score"`
|
||||
UserFeedback int `json:"user_feedback"`
|
||||
RetrievalCount int `json:"retrieval_count"`
|
||||
LastRetrievedAt int64 `json:"last_retrieved_at_epoch,omitempty"`
|
||||
ScoreUpdatedAt int64 `json:"score_updated_at_epoch,omitempty"`
|
||||
|
||||
// Conflict detection fields
|
||||
IsSuperseded bool `json:"is_superseded,omitempty"`
|
||||
}
|
||||
|
||||
// MarshalJSON implements json.Marshaler for Observation.
|
||||
@@ -225,6 +245,12 @@ func (o *Observation) MarshalJSON() ([]byte, error) {
|
||||
CreatedAt: o.CreatedAt,
|
||||
CreatedAtEpoch: o.CreatedAtEpoch,
|
||||
IsStale: o.IsStale,
|
||||
// Importance scoring fields
|
||||
ImportanceScore: o.ImportanceScore,
|
||||
UserFeedback: o.UserFeedback,
|
||||
RetrievalCount: o.RetrievalCount,
|
||||
// Conflict detection fields
|
||||
IsSuperseded: o.IsSuperseded,
|
||||
}
|
||||
if o.Title.Valid {
|
||||
j.Title = o.Title.String
|
||||
@@ -238,6 +264,12 @@ func (o *Observation) MarshalJSON() ([]byte, error) {
|
||||
if o.PromptNumber.Valid {
|
||||
j.PromptNumber = o.PromptNumber.Int64
|
||||
}
|
||||
if o.LastRetrievedAt.Valid {
|
||||
j.LastRetrievedAt = o.LastRetrievedAt.Int64
|
||||
}
|
||||
if o.ScoreUpdatedAt.Valid {
|
||||
j.ScoreUpdatedAt = o.ScoreUpdatedAt.Int64
|
||||
}
|
||||
return json.Marshal(j)
|
||||
}
|
||||
|
||||
@@ -268,9 +300,28 @@ func NewObservation(sdkSessionID, project string, parsed *ParsedObservation, pro
|
||||
DiscoveryTokens: discoveryTokens,
|
||||
CreatedAt: now.Format(time.RFC3339),
|
||||
CreatedAtEpoch: now.UnixMilli(),
|
||||
// Importance scoring: new observations start with score 1.0
|
||||
ImportanceScore: 1.0,
|
||||
UserFeedback: 0,
|
||||
RetrievalCount: 0,
|
||||
}
|
||||
}
|
||||
|
||||
// ToMap converts the observation to a map for JSON response building.
|
||||
// This allows adding extra fields like similarity scores.
|
||||
func (o *Observation) ToMap() map[string]interface{} {
|
||||
// Marshal to JSON then unmarshal to map (uses MarshalJSON for proper conversion)
|
||||
data, err := json.Marshal(o)
|
||||
if err != nil {
|
||||
return map[string]interface{}{"id": o.ID, "error": err.Error()}
|
||||
}
|
||||
var result map[string]interface{}
|
||||
if err := json.Unmarshal(data, &result); err != nil {
|
||||
return map[string]interface{}{"id": o.ID, "error": err.Error()}
|
||||
}
|
||||
return result
|
||||
}
|
||||
|
||||
// CheckStaleness checks if an observation is stale based on current file mtimes.
|
||||
// Returns true if any tracked file has been modified since the observation was created.
|
||||
func (o *Observation) CheckStaleness(currentMtimes map[string]int64) bool {
|
||||
|
||||
@@ -0,0 +1,398 @@
|
||||
// Package models contains domain models for claude-mnemonic.
|
||||
package models
|
||||
|
||||
import (
|
||||
"database/sql"
|
||||
"database/sql/driver"
|
||||
"encoding/json"
|
||||
"time"
|
||||
)
|
||||
|
||||
// PatternType represents the category of detected pattern.
|
||||
type PatternType string
|
||||
|
||||
const (
|
||||
// PatternTypeBug represents recurring bug patterns (e.g., "nil handling oversight").
|
||||
PatternTypeBug PatternType = "bug"
|
||||
// PatternTypeRefactor represents recurring refactoring approaches (e.g., "interface extraction").
|
||||
PatternTypeRefactor PatternType = "refactor"
|
||||
// PatternTypeArchitecture represents consistent architectural patterns.
|
||||
PatternTypeArchitecture PatternType = "architecture"
|
||||
// PatternTypeAntiPattern represents identified anti-patterns to avoid.
|
||||
PatternTypeAntiPattern PatternType = "anti-pattern"
|
||||
// PatternTypeBestPractice represents best practices that work consistently.
|
||||
PatternTypeBestPractice PatternType = "best-practice"
|
||||
)
|
||||
|
||||
// PatternStatus represents the lifecycle status of a pattern.
|
||||
type PatternStatus string
|
||||
|
||||
const (
|
||||
// PatternStatusActive means the pattern is actively being tracked and can be referenced.
|
||||
PatternStatusActive PatternStatus = "active"
|
||||
// PatternStatusDeprecated means the pattern has been superseded or is no longer relevant.
|
||||
PatternStatusDeprecated PatternStatus = "deprecated"
|
||||
// PatternStatusMerged means this pattern was merged into another pattern.
|
||||
PatternStatusMerged PatternStatus = "merged"
|
||||
)
|
||||
|
||||
// Pattern represents a recurring pattern detected across observations.
|
||||
// This enables Claude to reference historical insights: "I've encountered this pattern 12 times."
|
||||
type Pattern struct {
|
||||
ID int64 `db:"id" json:"id"`
|
||||
Name string `db:"name" json:"name"` // e.g., "State Management Anti-Pattern"
|
||||
Type PatternType `db:"type" json:"type"` // bug, refactor, architecture, etc.
|
||||
Description sql.NullString `db:"description" json:"description"` // Detailed description
|
||||
Signature JSONStringArray `db:"signature" json:"signature"` // Keyword clusters for detection
|
||||
Recommendation sql.NullString `db:"recommendation" json:"recommendation"` // What works for this pattern
|
||||
Frequency int `db:"frequency" json:"frequency"` // How many times encountered
|
||||
Projects JSONStringArray `db:"projects" json:"projects"` // Projects where this pattern was seen
|
||||
ObservationIDs JSONInt64Array `db:"observation_ids" json:"observation_ids"` // Source observation IDs
|
||||
Status PatternStatus `db:"status" json:"status"` // active, deprecated, merged
|
||||
MergedIntoID sql.NullInt64 `db:"merged_into_id" json:"merged_into_id,omitempty"`
|
||||
Confidence float64 `db:"confidence" json:"confidence"` // Detection confidence (0.0-1.0)
|
||||
LastSeenAt string `db:"last_seen_at" json:"last_seen_at"` // Last time pattern was detected
|
||||
LastSeenEpoch int64 `db:"last_seen_at_epoch" json:"last_seen_at_epoch"`
|
||||
CreatedAt string `db:"created_at" json:"created_at"`
|
||||
CreatedAtEpoch int64 `db:"created_at_epoch" json:"created_at_epoch"`
|
||||
}
|
||||
|
||||
// JSONInt64Array is a custom type for handling JSON int64 arrays in SQLite.
|
||||
type JSONInt64Array []int64
|
||||
|
||||
// Scan implements sql.Scanner for JSONInt64Array.
|
||||
func (j *JSONInt64Array) Scan(src interface{}) error {
|
||||
if src == nil {
|
||||
*j = nil
|
||||
return nil
|
||||
}
|
||||
|
||||
var data []byte
|
||||
switch v := src.(type) {
|
||||
case string:
|
||||
data = []byte(v)
|
||||
case []byte:
|
||||
data = v
|
||||
default:
|
||||
return nil
|
||||
}
|
||||
|
||||
if len(data) == 0 {
|
||||
*j = nil
|
||||
return nil
|
||||
}
|
||||
|
||||
return json.Unmarshal(data, j)
|
||||
}
|
||||
|
||||
// Value implements driver.Valuer for JSONInt64Array.
|
||||
func (j JSONInt64Array) Value() (driver.Value, error) {
|
||||
if j == nil {
|
||||
return nil, nil
|
||||
}
|
||||
return json.Marshal(j)
|
||||
}
|
||||
|
||||
// PatternJSON is a JSON-friendly representation of Pattern.
|
||||
type PatternJSON struct {
|
||||
ID int64 `json:"id"`
|
||||
Name string `json:"name"`
|
||||
Type PatternType `json:"type"`
|
||||
Description string `json:"description,omitempty"`
|
||||
Signature []string `json:"signature,omitempty"`
|
||||
Recommendation string `json:"recommendation,omitempty"`
|
||||
Frequency int `json:"frequency"`
|
||||
Projects []string `json:"projects,omitempty"`
|
||||
ObservationIDs []int64 `json:"observation_ids,omitempty"`
|
||||
Status PatternStatus `json:"status"`
|
||||
MergedIntoID int64 `json:"merged_into_id,omitempty"`
|
||||
Confidence float64 `json:"confidence"`
|
||||
LastSeenAt string `json:"last_seen_at"`
|
||||
LastSeenEpoch int64 `json:"last_seen_at_epoch"`
|
||||
CreatedAt string `json:"created_at"`
|
||||
CreatedAtEpoch int64 `json:"created_at_epoch"`
|
||||
}
|
||||
|
||||
// MarshalJSON implements json.Marshaler for Pattern.
|
||||
func (p *Pattern) MarshalJSON() ([]byte, error) {
|
||||
j := PatternJSON{
|
||||
ID: p.ID,
|
||||
Name: p.Name,
|
||||
Type: p.Type,
|
||||
Signature: p.Signature,
|
||||
Frequency: p.Frequency,
|
||||
Projects: p.Projects,
|
||||
ObservationIDs: p.ObservationIDs,
|
||||
Status: p.Status,
|
||||
Confidence: p.Confidence,
|
||||
LastSeenAt: p.LastSeenAt,
|
||||
LastSeenEpoch: p.LastSeenEpoch,
|
||||
CreatedAt: p.CreatedAt,
|
||||
CreatedAtEpoch: p.CreatedAtEpoch,
|
||||
}
|
||||
if p.Description.Valid {
|
||||
j.Description = p.Description.String
|
||||
}
|
||||
if p.Recommendation.Valid {
|
||||
j.Recommendation = p.Recommendation.String
|
||||
}
|
||||
if p.MergedIntoID.Valid {
|
||||
j.MergedIntoID = p.MergedIntoID.Int64
|
||||
}
|
||||
return json.Marshal(j)
|
||||
}
|
||||
|
||||
// NewPattern creates a new pattern from detected data.
|
||||
func NewPattern(name string, patternType PatternType, description string, signature []string, project string, observationID int64) *Pattern {
|
||||
now := time.Now()
|
||||
return &Pattern{
|
||||
Name: name,
|
||||
Type: patternType,
|
||||
Description: sql.NullString{String: description, Valid: description != ""},
|
||||
Signature: signature,
|
||||
Frequency: 1,
|
||||
Projects: []string{project},
|
||||
ObservationIDs: []int64{observationID},
|
||||
Status: PatternStatusActive,
|
||||
Confidence: 0.5, // Initial confidence
|
||||
LastSeenAt: now.Format(time.RFC3339),
|
||||
LastSeenEpoch: now.UnixMilli(),
|
||||
CreatedAt: now.Format(time.RFC3339),
|
||||
CreatedAtEpoch: now.UnixMilli(),
|
||||
}
|
||||
}
|
||||
|
||||
// AddOccurrence records a new occurrence of this pattern.
|
||||
func (p *Pattern) AddOccurrence(project string, observationID int64) {
|
||||
p.Frequency++
|
||||
|
||||
// Add project if not already tracked
|
||||
found := false
|
||||
for _, proj := range p.Projects {
|
||||
if proj == project {
|
||||
found = true
|
||||
break
|
||||
}
|
||||
}
|
||||
if !found {
|
||||
p.Projects = append(p.Projects, project)
|
||||
}
|
||||
|
||||
// Add observation ID if not already tracked
|
||||
for _, id := range p.ObservationIDs {
|
||||
if id == observationID {
|
||||
return
|
||||
}
|
||||
}
|
||||
p.ObservationIDs = append(p.ObservationIDs, observationID)
|
||||
|
||||
// Update confidence based on frequency and cross-project occurrence
|
||||
p.updateConfidence()
|
||||
|
||||
// Update last seen timestamp
|
||||
now := time.Now()
|
||||
p.LastSeenAt = now.Format(time.RFC3339)
|
||||
p.LastSeenEpoch = now.UnixMilli()
|
||||
}
|
||||
|
||||
// updateConfidence adjusts confidence based on frequency and cross-project validation.
|
||||
func (p *Pattern) updateConfidence() {
|
||||
// Base confidence from frequency (logarithmic scaling)
|
||||
freqConfidence := 0.3 + (0.4 * (float64(min(p.Frequency, 10)) / 10.0))
|
||||
|
||||
// Cross-project bonus: patterns seen across multiple projects are more reliable
|
||||
projectBonus := 0.0
|
||||
if len(p.Projects) >= 2 {
|
||||
projectBonus = 0.1
|
||||
}
|
||||
if len(p.Projects) >= 5 {
|
||||
projectBonus = 0.2
|
||||
}
|
||||
|
||||
p.Confidence = min(1.0, freqConfidence+projectBonus)
|
||||
}
|
||||
|
||||
// PatternMatch represents a match between an observation and a potential pattern.
|
||||
type PatternMatch struct {
|
||||
PatternID int64 `json:"pattern_id"`
|
||||
Score float64 `json:"score"` // Match score (0.0-1.0)
|
||||
MatchedOn string `json:"matched_on"` // What triggered the match (concept, keyword, type, etc.)
|
||||
IsNew bool `json:"is_new"` // Whether this would create a new pattern
|
||||
SuggestedName string `json:"suggested_name,omitempty"`
|
||||
}
|
||||
|
||||
// PatternSignatureKeywords are common keywords used in pattern detection.
|
||||
var PatternSignatureKeywords = map[PatternType][]string{
|
||||
PatternTypeBug: {
|
||||
"nil", "null", "undefined", "panic", "crash", "error handling",
|
||||
"race condition", "deadlock", "memory leak", "overflow",
|
||||
"off-by-one", "boundary", "timeout", "concurrency",
|
||||
},
|
||||
PatternTypeRefactor: {
|
||||
"extract", "inline", "rename", "move", "split", "merge",
|
||||
"interface", "abstraction", "decouple", "simplify",
|
||||
"consolidate", "modularize", "encapsulate",
|
||||
},
|
||||
PatternTypeArchitecture: {
|
||||
"layer", "service", "repository", "controller", "handler",
|
||||
"middleware", "dependency injection", "factory", "singleton",
|
||||
"observer", "strategy", "adapter", "facade", "builder",
|
||||
},
|
||||
PatternTypeAntiPattern: {
|
||||
"god class", "spaghetti", "copy paste", "magic number",
|
||||
"hardcoded", "circular dependency", "premature optimization",
|
||||
"over-engineering", "feature envy", "data clump",
|
||||
},
|
||||
PatternTypeBestPractice: {
|
||||
"test", "validation", "logging", "monitoring", "documentation",
|
||||
"error handling", "retry", "timeout", "circuit breaker",
|
||||
"graceful shutdown", "health check", "metrics",
|
||||
},
|
||||
}
|
||||
|
||||
// DetectPatternType analyzes concepts and content to determine pattern type.
|
||||
func DetectPatternType(concepts []string, title, narrative string) PatternType {
|
||||
// Check concepts first
|
||||
for _, concept := range concepts {
|
||||
switch concept {
|
||||
case "anti-pattern":
|
||||
return PatternTypeAntiPattern
|
||||
case "best-practice":
|
||||
return PatternTypeBestPractice
|
||||
case "architecture":
|
||||
return PatternTypeArchitecture
|
||||
case "refactor":
|
||||
return PatternTypeRefactor
|
||||
}
|
||||
}
|
||||
|
||||
// Check for bug-related patterns in content
|
||||
content := title + " " + narrative
|
||||
for _, keyword := range PatternSignatureKeywords[PatternTypeBug] {
|
||||
if containsIgnoreCase(content, keyword) {
|
||||
return PatternTypeBug
|
||||
}
|
||||
}
|
||||
|
||||
// Default to refactor for other patterns
|
||||
return PatternTypeRefactor
|
||||
}
|
||||
|
||||
// containsIgnoreCase checks if text contains substr (case-insensitive).
|
||||
func containsIgnoreCase(text, substr string) bool {
|
||||
textLower := toLower(text)
|
||||
substrLower := toLower(substr)
|
||||
return contains(textLower, substrLower)
|
||||
}
|
||||
|
||||
// Simple implementations to avoid strings package dependency in this file
|
||||
func toLower(s string) string {
|
||||
b := make([]byte, len(s))
|
||||
for i := 0; i < len(s); i++ {
|
||||
c := s[i]
|
||||
if 'A' <= c && c <= 'Z' {
|
||||
c += 'a' - 'A'
|
||||
}
|
||||
b[i] = c
|
||||
}
|
||||
return string(b)
|
||||
}
|
||||
|
||||
func contains(s, substr string) bool {
|
||||
if len(substr) > len(s) {
|
||||
return false
|
||||
}
|
||||
for i := 0; i <= len(s)-len(substr); i++ {
|
||||
if s[i:i+len(substr)] == substr {
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
// ExtractSignature creates a signature from observation content.
|
||||
func ExtractSignature(concepts []string, title, narrative string) []string {
|
||||
var signature []string
|
||||
|
||||
// Add all concepts
|
||||
signature = append(signature, concepts...)
|
||||
|
||||
// Extract key terms from title (simple word extraction)
|
||||
for _, word := range splitWords(title) {
|
||||
if len(word) > 3 && isSignificantWord(word) {
|
||||
signature = append(signature, toLower(word))
|
||||
}
|
||||
}
|
||||
|
||||
return uniqueStrings(signature)
|
||||
}
|
||||
|
||||
// splitWords is a simple word splitter.
|
||||
func splitWords(s string) []string {
|
||||
var words []string
|
||||
word := ""
|
||||
for _, r := range s {
|
||||
if r == ' ' || r == '-' || r == '_' || r == '.' || r == ',' {
|
||||
if word != "" {
|
||||
words = append(words, word)
|
||||
word = ""
|
||||
}
|
||||
} else {
|
||||
word += string(r)
|
||||
}
|
||||
}
|
||||
if word != "" {
|
||||
words = append(words, word)
|
||||
}
|
||||
return words
|
||||
}
|
||||
|
||||
// isSignificantWord filters out common stop words.
|
||||
func isSignificantWord(word string) bool {
|
||||
stopWords := map[string]bool{
|
||||
"the": true, "and": true, "for": true, "with": true, "that": true,
|
||||
"this": true, "from": true, "have": true, "not": true, "are": true,
|
||||
"was": true, "but": true, "all": true, "can": true, "had": true,
|
||||
"were": true, "been": true, "will": true, "when": true, "what": true,
|
||||
}
|
||||
return !stopWords[toLower(word)]
|
||||
}
|
||||
|
||||
// uniqueStrings returns a slice with duplicate strings removed.
|
||||
func uniqueStrings(s []string) []string {
|
||||
seen := make(map[string]bool)
|
||||
var result []string
|
||||
for _, v := range s {
|
||||
if !seen[v] {
|
||||
seen[v] = true
|
||||
result = append(result, v)
|
||||
}
|
||||
}
|
||||
return result
|
||||
}
|
||||
|
||||
// CalculateMatchScore computes similarity between two signatures.
|
||||
func CalculateMatchScore(sig1, sig2 []string) float64 {
|
||||
if len(sig1) == 0 || len(sig2) == 0 {
|
||||
return 0.0
|
||||
}
|
||||
|
||||
set1 := make(map[string]bool)
|
||||
for _, s := range sig1 {
|
||||
set1[toLower(s)] = true
|
||||
}
|
||||
|
||||
matches := 0
|
||||
for _, s := range sig2 {
|
||||
if set1[toLower(s)] {
|
||||
matches++
|
||||
}
|
||||
}
|
||||
|
||||
// Jaccard similarity
|
||||
unionSize := len(sig1) + len(sig2) - matches
|
||||
if unionSize == 0 {
|
||||
return 0.0
|
||||
}
|
||||
return float64(matches) / float64(unionSize)
|
||||
}
|
||||
@@ -0,0 +1,357 @@
|
||||
package models
|
||||
|
||||
import (
|
||||
"database/sql"
|
||||
"encoding/json"
|
||||
"testing"
|
||||
"time"
|
||||
)
|
||||
|
||||
func TestNewPattern(t *testing.T) {
|
||||
pattern := NewPattern(
|
||||
"Test Pattern",
|
||||
PatternTypeBug,
|
||||
"A test pattern description",
|
||||
[]string{"nil", "error", "handling"},
|
||||
"test-project",
|
||||
123,
|
||||
)
|
||||
|
||||
if pattern.Name != "Test Pattern" {
|
||||
t.Errorf("Expected name 'Test Pattern', got '%s'", pattern.Name)
|
||||
}
|
||||
if pattern.Type != PatternTypeBug {
|
||||
t.Errorf("Expected type PatternTypeBug, got '%s'", pattern.Type)
|
||||
}
|
||||
if !pattern.Description.Valid || pattern.Description.String != "A test pattern description" {
|
||||
t.Errorf("Description not set correctly")
|
||||
}
|
||||
if len(pattern.Signature) != 3 {
|
||||
t.Errorf("Expected 3 signature elements, got %d", len(pattern.Signature))
|
||||
}
|
||||
if pattern.Frequency != 1 {
|
||||
t.Errorf("Expected frequency 1, got %d", pattern.Frequency)
|
||||
}
|
||||
if len(pattern.Projects) != 1 || pattern.Projects[0] != "test-project" {
|
||||
t.Errorf("Projects not set correctly")
|
||||
}
|
||||
if len(pattern.ObservationIDs) != 1 || pattern.ObservationIDs[0] != 123 {
|
||||
t.Errorf("ObservationIDs not set correctly")
|
||||
}
|
||||
if pattern.Status != PatternStatusActive {
|
||||
t.Errorf("Expected status Active, got '%s'", pattern.Status)
|
||||
}
|
||||
if pattern.Confidence != 0.5 {
|
||||
t.Errorf("Expected initial confidence 0.5, got %f", pattern.Confidence)
|
||||
}
|
||||
}
|
||||
|
||||
func TestPattern_AddOccurrence(t *testing.T) {
|
||||
pattern := NewPattern("Test", PatternTypeBug, "desc", []string{"test"}, "project1", 1)
|
||||
|
||||
// Add same project occurrence
|
||||
pattern.AddOccurrence("project1", 2)
|
||||
if pattern.Frequency != 2 {
|
||||
t.Errorf("Expected frequency 2, got %d", pattern.Frequency)
|
||||
}
|
||||
if len(pattern.Projects) != 1 {
|
||||
t.Errorf("Expected 1 project (no duplicates), got %d", len(pattern.Projects))
|
||||
}
|
||||
|
||||
// Add different project occurrence
|
||||
pattern.AddOccurrence("project2", 3)
|
||||
if pattern.Frequency != 3 {
|
||||
t.Errorf("Expected frequency 3, got %d", pattern.Frequency)
|
||||
}
|
||||
if len(pattern.Projects) != 2 {
|
||||
t.Errorf("Expected 2 projects, got %d", len(pattern.Projects))
|
||||
}
|
||||
|
||||
// Add duplicate observation ID - should not duplicate
|
||||
pattern.AddOccurrence("project2", 3)
|
||||
if len(pattern.ObservationIDs) != 3 {
|
||||
t.Errorf("Expected 3 observation IDs (no duplicate), got %d", len(pattern.ObservationIDs))
|
||||
}
|
||||
|
||||
// Check confidence increased
|
||||
if pattern.Confidence <= 0.5 {
|
||||
t.Errorf("Expected confidence to increase above 0.5, got %f", pattern.Confidence)
|
||||
}
|
||||
}
|
||||
|
||||
func TestPattern_ConfidenceCalculation(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
frequency int
|
||||
projectCount int
|
||||
minConfidence float64
|
||||
maxConfidence float64
|
||||
}{
|
||||
{"low_frequency", 2, 1, 0.3, 0.5},
|
||||
{"high_frequency", 10, 1, 0.6, 0.8},
|
||||
{"multi_project", 3, 3, 0.4, 0.7},
|
||||
{"high_freq_multi_proj", 10, 5, 0.7, 1.0},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
pattern := NewPattern("Test", PatternTypeBug, "", []string{}, "proj1", 1)
|
||||
|
||||
// Simulate occurrences
|
||||
for i := 1; i < tt.frequency; i++ {
|
||||
projIdx := i % tt.projectCount
|
||||
if projIdx == 0 {
|
||||
projIdx = 1
|
||||
}
|
||||
pattern.AddOccurrence("proj"+string(rune('0'+projIdx)), int64(i+1))
|
||||
}
|
||||
|
||||
if pattern.Confidence < tt.minConfidence || pattern.Confidence > tt.maxConfidence {
|
||||
t.Errorf("Expected confidence between %f and %f, got %f",
|
||||
tt.minConfidence, tt.maxConfidence, pattern.Confidence)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestPatternType_Detection(t *testing.T) {
|
||||
tests := []struct {
|
||||
concepts []string
|
||||
title string
|
||||
narrative string
|
||||
expected PatternType
|
||||
}{
|
||||
{[]string{"anti-pattern"}, "", "", PatternTypeAntiPattern},
|
||||
{[]string{"best-practice"}, "", "", PatternTypeBestPractice},
|
||||
{[]string{"architecture"}, "", "", PatternTypeArchitecture},
|
||||
{[]string{"refactor"}, "", "", PatternTypeRefactor},
|
||||
{[]string{}, "nil pointer bug", "", PatternTypeBug},
|
||||
{[]string{}, "Deadlock in concurrent code", "", PatternTypeBug},
|
||||
{[]string{}, "Extract interface", "", PatternTypeRefactor},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.title+"_"+tt.expected.String(), func(t *testing.T) {
|
||||
result := DetectPatternType(tt.concepts, tt.title, tt.narrative)
|
||||
if result != tt.expected {
|
||||
t.Errorf("Expected %s, got %s", tt.expected, result)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func (pt PatternType) String() string {
|
||||
return string(pt)
|
||||
}
|
||||
|
||||
func TestExtractSignature(t *testing.T) {
|
||||
concepts := []string{"error-handling", "security"}
|
||||
title := "Nil Pointer Validation Pattern"
|
||||
narrative := "Always validate before dereferencing"
|
||||
|
||||
signature := ExtractSignature(concepts, title, narrative)
|
||||
|
||||
// Should contain concepts
|
||||
found := false
|
||||
for _, s := range signature {
|
||||
if s == "error-handling" {
|
||||
found = true
|
||||
break
|
||||
}
|
||||
}
|
||||
if !found {
|
||||
t.Errorf("Expected signature to contain concepts, got %v", signature)
|
||||
}
|
||||
|
||||
// Should contain significant words from title
|
||||
found = false
|
||||
for _, s := range signature {
|
||||
if s == "validation" || s == "pattern" || s == "pointer" {
|
||||
found = true
|
||||
break
|
||||
}
|
||||
}
|
||||
if !found {
|
||||
t.Errorf("Expected signature to contain title keywords, got %v", signature)
|
||||
}
|
||||
}
|
||||
|
||||
func TestCalculateMatchScore(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
sig1 []string
|
||||
sig2 []string
|
||||
minScore float64
|
||||
maxScore float64
|
||||
}{
|
||||
{"identical", []string{"a", "b", "c"}, []string{"a", "b", "c"}, 1.0, 1.0},
|
||||
{"partial", []string{"a", "b", "c"}, []string{"a", "b", "d"}, 0.4, 0.6},
|
||||
{"no_match", []string{"a", "b", "c"}, []string{"x", "y", "z"}, 0.0, 0.0},
|
||||
{"empty", []string{}, []string{"a", "b"}, 0.0, 0.0},
|
||||
{"subset", []string{"a", "b"}, []string{"a", "b", "c", "d"}, 0.4, 0.6},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
score := CalculateMatchScore(tt.sig1, tt.sig2)
|
||||
if score < tt.minScore || score > tt.maxScore {
|
||||
t.Errorf("Expected score between %f and %f, got %f",
|
||||
tt.minScore, tt.maxScore, score)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestPattern_MarshalJSON(t *testing.T) {
|
||||
pattern := &Pattern{
|
||||
ID: 1,
|
||||
Name: "Test Pattern",
|
||||
Type: PatternTypeBug,
|
||||
Description: sql.NullString{String: "A description", Valid: true},
|
||||
Signature: []string{"a", "b"},
|
||||
Recommendation: sql.NullString{String: "Do this", Valid: true},
|
||||
Frequency: 5,
|
||||
Projects: []string{"proj1", "proj2"},
|
||||
ObservationIDs: []int64{1, 2, 3},
|
||||
Status: PatternStatusActive,
|
||||
MergedIntoID: sql.NullInt64{Int64: 0, Valid: false},
|
||||
Confidence: 0.8,
|
||||
LastSeenAt: time.Now().Format(time.RFC3339),
|
||||
LastSeenEpoch: time.Now().UnixMilli(),
|
||||
CreatedAt: time.Now().Format(time.RFC3339),
|
||||
CreatedAtEpoch: time.Now().UnixMilli(),
|
||||
}
|
||||
|
||||
data, err := json.Marshal(pattern)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to marshal pattern: %v", err)
|
||||
}
|
||||
|
||||
var result PatternJSON
|
||||
if err := json.Unmarshal(data, &result); err != nil {
|
||||
t.Fatalf("Failed to unmarshal pattern: %v", err)
|
||||
}
|
||||
|
||||
if result.Name != pattern.Name {
|
||||
t.Errorf("Expected name %s, got %s", pattern.Name, result.Name)
|
||||
}
|
||||
if result.Description != pattern.Description.String {
|
||||
t.Errorf("Expected description %s, got %s", pattern.Description.String, result.Description)
|
||||
}
|
||||
if result.Frequency != pattern.Frequency {
|
||||
t.Errorf("Expected frequency %d, got %d", pattern.Frequency, result.Frequency)
|
||||
}
|
||||
if result.MergedIntoID != 0 {
|
||||
t.Errorf("Expected merged_into_id 0 for invalid NullInt64, got %d", result.MergedIntoID)
|
||||
}
|
||||
}
|
||||
|
||||
func TestJSONInt64Array_Scan(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
input interface{}
|
||||
expected JSONInt64Array
|
||||
wantErr bool
|
||||
}{
|
||||
{"string_array", "[1, 2, 3]", JSONInt64Array{1, 2, 3}, false},
|
||||
{"bytes_array", []byte("[4, 5, 6]"), JSONInt64Array{4, 5, 6}, false},
|
||||
{"nil", nil, nil, false},
|
||||
{"empty_string", "", nil, false},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
var arr JSONInt64Array
|
||||
err := arr.Scan(tt.input)
|
||||
if (err != nil) != tt.wantErr {
|
||||
t.Errorf("Scan() error = %v, wantErr %v", err, tt.wantErr)
|
||||
return
|
||||
}
|
||||
if len(arr) != len(tt.expected) {
|
||||
t.Errorf("Expected length %d, got %d", len(tt.expected), len(arr))
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestJSONInt64Array_Value(t *testing.T) {
|
||||
arr := JSONInt64Array{1, 2, 3}
|
||||
val, err := arr.Value()
|
||||
if err != nil {
|
||||
t.Fatalf("Value() error = %v", err)
|
||||
}
|
||||
|
||||
bytes, ok := val.([]byte)
|
||||
if !ok {
|
||||
t.Fatalf("Expected []byte, got %T", val)
|
||||
}
|
||||
|
||||
var result []int64
|
||||
if err := json.Unmarshal(bytes, &result); err != nil {
|
||||
t.Fatalf("Failed to unmarshal: %v", err)
|
||||
}
|
||||
|
||||
if len(result) != 3 || result[0] != 1 || result[1] != 2 || result[2] != 3 {
|
||||
t.Errorf("Expected [1, 2, 3], got %v", result)
|
||||
}
|
||||
}
|
||||
|
||||
func TestPatternSignatureKeywords(t *testing.T) {
|
||||
// Verify keywords exist for each type
|
||||
types := []PatternType{
|
||||
PatternTypeBug,
|
||||
PatternTypeRefactor,
|
||||
PatternTypeArchitecture,
|
||||
PatternTypeAntiPattern,
|
||||
PatternTypeBestPractice,
|
||||
}
|
||||
|
||||
for _, pt := range types {
|
||||
keywords := PatternSignatureKeywords[pt]
|
||||
if len(keywords) == 0 {
|
||||
t.Errorf("No keywords defined for pattern type %s", pt)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestUniqueStrings(t *testing.T) {
|
||||
tests := []struct {
|
||||
input []string
|
||||
expected int
|
||||
}{
|
||||
{[]string{"a", "b", "c"}, 3},
|
||||
{[]string{"a", "a", "b"}, 2},
|
||||
{[]string{"a", "a", "a"}, 1},
|
||||
{[]string{}, 0},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
result := uniqueStrings(tt.input)
|
||||
if len(result) != tt.expected {
|
||||
t.Errorf("uniqueStrings(%v) = %v (len=%d), expected len=%d",
|
||||
tt.input, result, len(result), tt.expected)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestContainsIgnoreCase(t *testing.T) {
|
||||
tests := []struct {
|
||||
text string
|
||||
substr string
|
||||
expected bool
|
||||
}{
|
||||
{"Hello World", "hello", true},
|
||||
{"Hello World", "WORLD", true},
|
||||
{"Hello World", "xyz", false},
|
||||
{"", "a", false},
|
||||
{"a", "", true},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
result := containsIgnoreCase(tt.text, tt.substr)
|
||||
if result != tt.expected {
|
||||
t.Errorf("containsIgnoreCase(%q, %q) = %v, expected %v",
|
||||
tt.text, tt.substr, result, tt.expected)
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,489 @@
|
||||
// Package models contains domain models for claude-mnemonic.
|
||||
package models
|
||||
|
||||
import (
|
||||
"strings"
|
||||
"time"
|
||||
)
|
||||
|
||||
// RelationType represents the type of relationship between observations.
|
||||
type RelationType string
|
||||
|
||||
const (
|
||||
// RelationCauses means source observation caused target observation.
|
||||
// Example: "This architectural decision caused this bug"
|
||||
RelationCauses RelationType = "causes"
|
||||
// RelationFixes means source observation fixes target observation.
|
||||
// Example: "This bugfix addresses that discovered issue"
|
||||
RelationFixes RelationType = "fixes"
|
||||
// RelationSupersedes means source observation supersedes target observation.
|
||||
// Example: "This new approach replaces the old workaround"
|
||||
RelationSupersedes RelationType = "supersedes"
|
||||
// RelationDependsOn means source observation depends on target observation.
|
||||
// Example: "This feature relies on that architectural decision"
|
||||
RelationDependsOn RelationType = "depends_on"
|
||||
// RelationRelatesTo means observations are related but no causal relationship.
|
||||
// Example: "Both deal with authentication"
|
||||
RelationRelatesTo RelationType = "relates_to"
|
||||
// RelationEvolvesFrom means source observation evolved from target observation.
|
||||
// Example: "This refined pattern evolved from that initial discovery"
|
||||
RelationEvolvesFrom RelationType = "evolves_from"
|
||||
)
|
||||
|
||||
// AllRelationTypes is the list of all valid relation types.
|
||||
var AllRelationTypes = []RelationType{
|
||||
RelationCauses,
|
||||
RelationFixes,
|
||||
RelationSupersedes,
|
||||
RelationDependsOn,
|
||||
RelationRelatesTo,
|
||||
RelationEvolvesFrom,
|
||||
}
|
||||
|
||||
// RelationDetectionSource indicates how a relationship was detected.
|
||||
type RelationDetectionSource string
|
||||
|
||||
const (
|
||||
// DetectionSourceFileOverlap means relationship was detected via shared file references.
|
||||
DetectionSourceFileOverlap RelationDetectionSource = "file_overlap"
|
||||
// DetectionSourceEmbeddingSimilarity means relationship was detected via vector similarity.
|
||||
DetectionSourceEmbeddingSimilarity RelationDetectionSource = "embedding_similarity"
|
||||
// DetectionSourceTemporalProximity means relationship was detected via close timestamps.
|
||||
DetectionSourceTemporalProximity RelationDetectionSource = "temporal_proximity"
|
||||
// DetectionSourceNarrativeMention means relationship was detected via explicit mentions.
|
||||
DetectionSourceNarrativeMention RelationDetectionSource = "narrative_mention"
|
||||
// DetectionSourceConceptOverlap means relationship was detected via shared concepts.
|
||||
DetectionSourceConceptOverlap RelationDetectionSource = "concept_overlap"
|
||||
// DetectionSourceTypeProgression means relationship was detected via type progression pattern.
|
||||
DetectionSourceTypeProgression RelationDetectionSource = "type_progression"
|
||||
)
|
||||
|
||||
// ObservationRelation represents a directed relationship between two observations.
|
||||
type ObservationRelation struct {
|
||||
ID int64 `db:"id" json:"id"`
|
||||
SourceID int64 `db:"source_id" json:"source_id"`
|
||||
TargetID int64 `db:"target_id" json:"target_id"`
|
||||
RelationType RelationType `db:"relation_type" json:"relation_type"`
|
||||
Confidence float64 `db:"confidence" json:"confidence"`
|
||||
DetectionSource RelationDetectionSource `db:"detection_source" json:"detection_source"`
|
||||
Reason string `db:"reason" json:"reason,omitempty"`
|
||||
CreatedAt string `db:"created_at" json:"created_at"`
|
||||
CreatedAtEpoch int64 `db:"created_at_epoch" json:"created_at_epoch"`
|
||||
}
|
||||
|
||||
// NewObservationRelation creates a new observation relation.
|
||||
func NewObservationRelation(sourceID, targetID int64, relType RelationType, confidence float64, source RelationDetectionSource, reason string) *ObservationRelation {
|
||||
now := time.Now()
|
||||
return &ObservationRelation{
|
||||
SourceID: sourceID,
|
||||
TargetID: targetID,
|
||||
RelationType: relType,
|
||||
Confidence: confidence,
|
||||
DetectionSource: source,
|
||||
Reason: reason,
|
||||
CreatedAt: now.Format(time.RFC3339),
|
||||
CreatedAtEpoch: now.UnixMilli(),
|
||||
}
|
||||
}
|
||||
|
||||
// RelationDetectionResult contains the result of relation detection.
|
||||
type RelationDetectionResult struct {
|
||||
SourceID int64
|
||||
TargetID int64
|
||||
RelationType RelationType
|
||||
Confidence float64
|
||||
DetectionSource RelationDetectionSource
|
||||
Reason string
|
||||
}
|
||||
|
||||
// DetectFileOverlapRelation checks if observations share file references and determines relationship type.
|
||||
func DetectFileOverlapRelation(newer, older *Observation) *RelationDetectionResult {
|
||||
// Check for overlapping modified files
|
||||
newerModified := make(map[string]bool)
|
||||
for _, f := range newer.FilesModified {
|
||||
newerModified[f] = true
|
||||
}
|
||||
|
||||
olderModified := make(map[string]bool)
|
||||
for _, f := range older.FilesModified {
|
||||
olderModified[f] = true
|
||||
}
|
||||
|
||||
// Files modified by both
|
||||
var sharedModified []string
|
||||
for f := range newerModified {
|
||||
if olderModified[f] {
|
||||
sharedModified = append(sharedModified, f)
|
||||
}
|
||||
}
|
||||
|
||||
// Files that newer reads which older modified
|
||||
var newerReadsOlderModified []string
|
||||
for _, f := range newer.FilesRead {
|
||||
if olderModified[f] {
|
||||
newerReadsOlderModified = append(newerReadsOlderModified, f)
|
||||
}
|
||||
}
|
||||
|
||||
// Calculate overlap score
|
||||
overlap := len(sharedModified) + len(newerReadsOlderModified)
|
||||
if overlap == 0 {
|
||||
return nil
|
||||
}
|
||||
|
||||
// Determine relationship type based on observation types and file overlap
|
||||
relType := RelationRelatesTo
|
||||
confidence := 0.5 + float64(overlap)*0.1 // Base 0.5, +0.1 per overlapping file
|
||||
|
||||
// Type-based relationship inference
|
||||
switch {
|
||||
case newer.Type == ObsTypeBugfix && (older.Type == ObsTypeDecision || older.Type == ObsTypeFeature):
|
||||
relType = RelationFixes
|
||||
confidence += 0.2
|
||||
case newer.Type == ObsTypeRefactor && older.Type == ObsTypeDiscovery:
|
||||
relType = RelationEvolvesFrom
|
||||
confidence += 0.15
|
||||
case newer.Type == older.Type && len(sharedModified) > 0:
|
||||
relType = RelationSupersedes
|
||||
confidence += 0.1
|
||||
case newer.Type == ObsTypeFeature && older.Type == ObsTypeDecision:
|
||||
relType = RelationDependsOn
|
||||
confidence += 0.15
|
||||
}
|
||||
|
||||
if confidence > 1.0 {
|
||||
confidence = 1.0
|
||||
}
|
||||
|
||||
reason := buildFileOverlapReason(sharedModified, newerReadsOlderModified)
|
||||
|
||||
return &RelationDetectionResult{
|
||||
SourceID: newer.ID,
|
||||
TargetID: older.ID,
|
||||
RelationType: relType,
|
||||
Confidence: confidence,
|
||||
DetectionSource: DetectionSourceFileOverlap,
|
||||
Reason: reason,
|
||||
}
|
||||
}
|
||||
|
||||
// buildFileOverlapReason creates a human-readable reason for file overlap relation.
|
||||
func buildFileOverlapReason(shared, readsModified []string) string {
|
||||
parts := []string{}
|
||||
if len(shared) > 0 {
|
||||
parts = append(parts, "both modified: "+strings.Join(truncateList(shared, 3), ", "))
|
||||
}
|
||||
if len(readsModified) > 0 {
|
||||
parts = append(parts, "reads files modified by older: "+strings.Join(truncateList(readsModified, 3), ", "))
|
||||
}
|
||||
return strings.Join(parts, "; ")
|
||||
}
|
||||
|
||||
// DetectConceptOverlapRelation checks if observations share concepts.
|
||||
func DetectConceptOverlapRelation(newer, older *Observation) *RelationDetectionResult {
|
||||
newerConcepts := make(map[string]bool)
|
||||
for _, c := range newer.Concepts {
|
||||
newerConcepts[c] = true
|
||||
}
|
||||
|
||||
var shared []string
|
||||
for _, c := range older.Concepts {
|
||||
if newerConcepts[c] {
|
||||
shared = append(shared, c)
|
||||
}
|
||||
}
|
||||
|
||||
if len(shared) == 0 {
|
||||
return nil
|
||||
}
|
||||
|
||||
// Calculate confidence based on overlap ratio
|
||||
totalUniqueConcepts := len(newerConcepts)
|
||||
for _, c := range older.Concepts {
|
||||
if !newerConcepts[c] {
|
||||
totalUniqueConcepts++
|
||||
}
|
||||
}
|
||||
|
||||
overlapRatio := float64(len(shared)) / float64(totalUniqueConcepts)
|
||||
confidence := 0.3 + overlapRatio*0.5 // Base 0.3, scale with overlap
|
||||
|
||||
// Boost for important concepts
|
||||
for _, c := range shared {
|
||||
if isHighValueConcept(c) {
|
||||
confidence += 0.1
|
||||
}
|
||||
}
|
||||
if confidence > 1.0 {
|
||||
confidence = 1.0
|
||||
}
|
||||
|
||||
return &RelationDetectionResult{
|
||||
SourceID: newer.ID,
|
||||
TargetID: older.ID,
|
||||
RelationType: RelationRelatesTo,
|
||||
Confidence: confidence,
|
||||
DetectionSource: DetectionSourceConceptOverlap,
|
||||
Reason: "shared concepts: " + strings.Join(truncateList(shared, 5), ", "),
|
||||
}
|
||||
}
|
||||
|
||||
// isHighValueConcept returns true for concepts that strongly indicate relationships.
|
||||
func isHighValueConcept(concept string) bool {
|
||||
highValue := map[string]bool{
|
||||
"security": true,
|
||||
"architecture": true,
|
||||
"gotcha": true,
|
||||
"anti-pattern": true,
|
||||
"best-practice": true,
|
||||
"error-handling": true,
|
||||
}
|
||||
return highValue[concept]
|
||||
}
|
||||
|
||||
// DetectTypeProgressionRelation checks for natural type progressions.
|
||||
// Example: discovery -> decision -> feature -> bugfix
|
||||
func DetectTypeProgressionRelation(newer, older *Observation) *RelationDetectionResult {
|
||||
// Define natural type progressions
|
||||
progressions := map[ObservationType][]ObservationType{
|
||||
ObsTypeBugfix: {ObsTypeDiscovery, ObsTypeFeature, ObsTypeDecision},
|
||||
ObsTypeFeature: {ObsTypeDiscovery, ObsTypeDecision},
|
||||
ObsTypeRefactor: {ObsTypeDiscovery, ObsTypeFeature, ObsTypeBugfix},
|
||||
ObsTypeDecision: {ObsTypeDiscovery},
|
||||
ObsTypeChange: {ObsTypeDiscovery, ObsTypeDecision},
|
||||
}
|
||||
|
||||
validPredecessors, ok := progressions[newer.Type]
|
||||
if !ok {
|
||||
return nil
|
||||
}
|
||||
|
||||
isValidProgression := false
|
||||
for _, pred := range validPredecessors {
|
||||
if older.Type == pred {
|
||||
isValidProgression = true
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
if !isValidProgression {
|
||||
return nil
|
||||
}
|
||||
|
||||
// Determine relationship type based on progression
|
||||
var relType RelationType
|
||||
var confidence float64 = 0.4
|
||||
|
||||
switch {
|
||||
case newer.Type == ObsTypeBugfix && older.Type == ObsTypeDiscovery:
|
||||
relType = RelationFixes
|
||||
confidence = 0.6
|
||||
case newer.Type == ObsTypeBugfix && older.Type == ObsTypeFeature:
|
||||
relType = RelationFixes
|
||||
confidence = 0.5
|
||||
case newer.Type == ObsTypeFeature && older.Type == ObsTypeDecision:
|
||||
relType = RelationDependsOn
|
||||
confidence = 0.6
|
||||
case newer.Type == ObsTypeRefactor:
|
||||
relType = RelationEvolvesFrom
|
||||
confidence = 0.5
|
||||
default:
|
||||
relType = RelationRelatesTo
|
||||
}
|
||||
|
||||
return &RelationDetectionResult{
|
||||
SourceID: newer.ID,
|
||||
TargetID: older.ID,
|
||||
RelationType: relType,
|
||||
Confidence: confidence,
|
||||
DetectionSource: DetectionSourceTypeProgression,
|
||||
Reason: string(older.Type) + " -> " + string(newer.Type) + " progression",
|
||||
}
|
||||
}
|
||||
|
||||
// DetectTemporalProximityRelation checks if observations are temporally close (same session).
|
||||
func DetectTemporalProximityRelation(newer, older *Observation) *RelationDetectionResult {
|
||||
// Only relate observations from the same session
|
||||
if newer.SDKSessionID != older.SDKSessionID {
|
||||
return nil
|
||||
}
|
||||
|
||||
// Check temporal proximity (within 5 minutes)
|
||||
timeDiffMs := newer.CreatedAtEpoch - older.CreatedAtEpoch
|
||||
if timeDiffMs < 0 {
|
||||
timeDiffMs = -timeDiffMs
|
||||
}
|
||||
|
||||
fiveMinutesMs := int64(5 * 60 * 1000)
|
||||
if timeDiffMs > fiveMinutesMs {
|
||||
return nil
|
||||
}
|
||||
|
||||
// Calculate confidence based on temporal proximity
|
||||
// Closer = higher confidence
|
||||
proximityRatio := 1.0 - (float64(timeDiffMs) / float64(fiveMinutesMs))
|
||||
confidence := 0.3 + proximityRatio*0.4
|
||||
|
||||
return &RelationDetectionResult{
|
||||
SourceID: newer.ID,
|
||||
TargetID: older.ID,
|
||||
RelationType: RelationRelatesTo,
|
||||
Confidence: confidence,
|
||||
DetectionSource: DetectionSourceTemporalProximity,
|
||||
Reason: "same session, close timestamps",
|
||||
}
|
||||
}
|
||||
|
||||
// NarrativeMentionPatterns are patterns that indicate explicit relationships in narratives.
|
||||
var NarrativeMentionPatterns = []struct {
|
||||
Pattern string
|
||||
RelationType RelationType
|
||||
ConfBoost float64
|
||||
}{
|
||||
{" caused ", RelationCauses, 0.3},
|
||||
{" causes ", RelationCauses, 0.3},
|
||||
{" because of ", RelationCauses, 0.25},
|
||||
{" due to ", RelationCauses, 0.2},
|
||||
{" fixes ", RelationFixes, 0.3},
|
||||
{" fixed ", RelationFixes, 0.3},
|
||||
{" resolves ", RelationFixes, 0.3},
|
||||
{" addresses ", RelationFixes, 0.25},
|
||||
{" replaces ", RelationSupersedes, 0.3},
|
||||
{" supersedes ", RelationSupersedes, 0.35},
|
||||
{" instead of ", RelationSupersedes, 0.25},
|
||||
{" depends on ", RelationDependsOn, 0.3},
|
||||
{" requires ", RelationDependsOn, 0.25},
|
||||
{" builds on ", RelationDependsOn, 0.25},
|
||||
{" based on ", RelationDependsOn, 0.2},
|
||||
{" related to ", RelationRelatesTo, 0.2},
|
||||
{" similar to ", RelationRelatesTo, 0.2},
|
||||
{" evolved from ", RelationEvolvesFrom, 0.3},
|
||||
{" improved from ", RelationEvolvesFrom, 0.25},
|
||||
{" refined from ", RelationEvolvesFrom, 0.25},
|
||||
}
|
||||
|
||||
// DetectNarrativeMentionRelation checks if newer observation's narrative mentions relationship.
|
||||
func DetectNarrativeMentionRelation(newer, older *Observation) *RelationDetectionResult {
|
||||
if !newer.Narrative.Valid || newer.Narrative.String == "" {
|
||||
return nil
|
||||
}
|
||||
|
||||
narrative := strings.ToLower(newer.Narrative.String)
|
||||
|
||||
// Check for patterns
|
||||
for _, p := range NarrativeMentionPatterns {
|
||||
if strings.Contains(narrative, p.Pattern) {
|
||||
// Found a pattern - this is a potential relationship
|
||||
confidence := 0.4 + p.ConfBoost
|
||||
if confidence > 1.0 {
|
||||
confidence = 1.0
|
||||
}
|
||||
|
||||
return &RelationDetectionResult{
|
||||
SourceID: newer.ID,
|
||||
TargetID: older.ID,
|
||||
RelationType: p.RelationType,
|
||||
Confidence: confidence,
|
||||
DetectionSource: DetectionSourceNarrativeMention,
|
||||
Reason: "narrative contains '" + strings.TrimSpace(p.Pattern) + "' language",
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// DetectRelationsWithExisting checks a new observation against existing ones and returns detected relations.
|
||||
// This is the main entry point for relation detection.
|
||||
func DetectRelationsWithExisting(newer *Observation, existing []*Observation, minConfidence float64) []*RelationDetectionResult {
|
||||
var results []*RelationDetectionResult
|
||||
seen := make(map[int64]bool)
|
||||
|
||||
for _, older := range existing {
|
||||
// Skip self
|
||||
if older.ID == newer.ID {
|
||||
continue
|
||||
}
|
||||
|
||||
// Skip if already superseded
|
||||
if older.IsSuperseded {
|
||||
continue
|
||||
}
|
||||
|
||||
// Only compare within same project (or both global)
|
||||
if newer.Project != older.Project && newer.Scope != ScopeGlobal && older.Scope != ScopeGlobal {
|
||||
continue
|
||||
}
|
||||
|
||||
// Run all detection methods and keep highest confidence result per target
|
||||
var bestResult *RelationDetectionResult
|
||||
|
||||
// 1. File overlap detection
|
||||
if result := DetectFileOverlapRelation(newer, older); result != nil && result.Confidence >= minConfidence {
|
||||
if bestResult == nil || result.Confidence > bestResult.Confidence {
|
||||
bestResult = result
|
||||
}
|
||||
}
|
||||
|
||||
// 2. Concept overlap detection
|
||||
if result := DetectConceptOverlapRelation(newer, older); result != nil && result.Confidence >= minConfidence {
|
||||
if bestResult == nil || result.Confidence > bestResult.Confidence {
|
||||
bestResult = result
|
||||
}
|
||||
}
|
||||
|
||||
// 3. Type progression detection
|
||||
if result := DetectTypeProgressionRelation(newer, older); result != nil && result.Confidence >= minConfidence {
|
||||
if bestResult == nil || result.Confidence > bestResult.Confidence {
|
||||
bestResult = result
|
||||
}
|
||||
}
|
||||
|
||||
// 4. Temporal proximity detection
|
||||
if result := DetectTemporalProximityRelation(newer, older); result != nil && result.Confidence >= minConfidence {
|
||||
// Only use temporal proximity if no better detection found
|
||||
if bestResult == nil {
|
||||
bestResult = result
|
||||
}
|
||||
}
|
||||
|
||||
// 5. Narrative mention detection (can upgrade relation type)
|
||||
if result := DetectNarrativeMentionRelation(newer, older); result != nil && result.Confidence >= minConfidence {
|
||||
if bestResult == nil || result.Confidence > bestResult.Confidence {
|
||||
bestResult = result
|
||||
}
|
||||
}
|
||||
|
||||
// Add best result if found and not already seen
|
||||
if bestResult != nil && !seen[older.ID] {
|
||||
results = append(results, bestResult)
|
||||
seen[older.ID] = true
|
||||
}
|
||||
}
|
||||
|
||||
return results
|
||||
}
|
||||
|
||||
// truncateList truncates a list to maxLen items.
|
||||
func truncateList(items []string, maxLen int) []string {
|
||||
if len(items) <= maxLen {
|
||||
return items
|
||||
}
|
||||
result := items[:maxLen]
|
||||
return append(result, "...")
|
||||
}
|
||||
|
||||
// RelationWithDetails contains a relation with its observation details.
|
||||
type RelationWithDetails struct {
|
||||
Relation *ObservationRelation `json:"relation"`
|
||||
SourceTitle string `json:"source_title"`
|
||||
TargetTitle string `json:"target_title"`
|
||||
SourceType ObservationType `json:"source_type"`
|
||||
TargetType ObservationType `json:"target_type"`
|
||||
}
|
||||
|
||||
// RelationGraph represents a graph of related observations.
|
||||
type RelationGraph struct {
|
||||
CenterID int64 `json:"center_id"`
|
||||
Relations []*RelationWithDetails `json:"relations"`
|
||||
}
|
||||
@@ -0,0 +1,473 @@
|
||||
// Package models contains domain models for claude-mnemonic.
|
||||
package models
|
||||
|
||||
import (
|
||||
"database/sql"
|
||||
"testing"
|
||||
)
|
||||
|
||||
func TestDetectFileOverlapRelation(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
newer *Observation
|
||||
older *Observation
|
||||
wantRelation bool
|
||||
wantRelType RelationType
|
||||
wantMinConfid float64
|
||||
}{
|
||||
{
|
||||
name: "no file overlap",
|
||||
newer: &Observation{
|
||||
ID: 1,
|
||||
FilesModified: []string{"file1.go", "file2.go"},
|
||||
},
|
||||
older: &Observation{
|
||||
ID: 2,
|
||||
FilesModified: []string{"file3.go", "file4.go"},
|
||||
},
|
||||
wantRelation: false,
|
||||
},
|
||||
{
|
||||
name: "shared modified files",
|
||||
newer: &Observation{
|
||||
ID: 1,
|
||||
Type: ObsTypeRefactor,
|
||||
FilesModified: []string{"shared.go", "file2.go"},
|
||||
},
|
||||
older: &Observation{
|
||||
ID: 2,
|
||||
Type: ObsTypeRefactor,
|
||||
FilesModified: []string{"shared.go", "file4.go"},
|
||||
},
|
||||
wantRelation: true,
|
||||
wantRelType: RelationSupersedes,
|
||||
wantMinConfid: 0.5,
|
||||
},
|
||||
{
|
||||
name: "bugfix on feature file",
|
||||
newer: &Observation{
|
||||
ID: 1,
|
||||
Type: ObsTypeBugfix,
|
||||
FilesModified: []string{"feature.go"},
|
||||
},
|
||||
older: &Observation{
|
||||
ID: 2,
|
||||
Type: ObsTypeFeature,
|
||||
FilesModified: []string{"feature.go"},
|
||||
},
|
||||
wantRelation: true,
|
||||
wantRelType: RelationFixes,
|
||||
wantMinConfid: 0.6,
|
||||
},
|
||||
{
|
||||
name: "newer reads older modified",
|
||||
newer: &Observation{
|
||||
ID: 1,
|
||||
Type: ObsTypeChange,
|
||||
FilesRead: []string{"dep.go"},
|
||||
FilesModified: []string{"caller.go"},
|
||||
},
|
||||
older: &Observation{
|
||||
ID: 2,
|
||||
Type: ObsTypeDecision,
|
||||
FilesModified: []string{"dep.go"},
|
||||
},
|
||||
wantRelation: true,
|
||||
wantMinConfid: 0.5,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
result := DetectFileOverlapRelation(tt.newer, tt.older)
|
||||
|
||||
if tt.wantRelation {
|
||||
if result == nil {
|
||||
t.Fatal("expected relation, got nil")
|
||||
}
|
||||
if tt.wantRelType != "" && result.RelationType != tt.wantRelType {
|
||||
t.Errorf("relation type = %v, want %v", result.RelationType, tt.wantRelType)
|
||||
}
|
||||
if result.Confidence < tt.wantMinConfid {
|
||||
t.Errorf("confidence = %v, want at least %v", result.Confidence, tt.wantMinConfid)
|
||||
}
|
||||
if result.DetectionSource != DetectionSourceFileOverlap {
|
||||
t.Errorf("source = %v, want %v", result.DetectionSource, DetectionSourceFileOverlap)
|
||||
}
|
||||
} else {
|
||||
if result != nil {
|
||||
t.Errorf("expected no relation, got %+v", result)
|
||||
}
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestDetectConceptOverlapRelation(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
newer *Observation
|
||||
older *Observation
|
||||
wantRelation bool
|
||||
wantMinConfid float64
|
||||
}{
|
||||
{
|
||||
name: "no concept overlap",
|
||||
newer: &Observation{
|
||||
ID: 1,
|
||||
Concepts: []string{"auth", "api"},
|
||||
},
|
||||
older: &Observation{
|
||||
ID: 2,
|
||||
Concepts: []string{"database", "caching"},
|
||||
},
|
||||
wantRelation: false,
|
||||
},
|
||||
{
|
||||
name: "shared concepts",
|
||||
newer: &Observation{
|
||||
ID: 1,
|
||||
Concepts: []string{"security", "auth"},
|
||||
},
|
||||
older: &Observation{
|
||||
ID: 2,
|
||||
Concepts: []string{"security", "validation"},
|
||||
},
|
||||
wantRelation: true,
|
||||
wantMinConfid: 0.4, // security is a high-value concept
|
||||
},
|
||||
{
|
||||
name: "multiple shared concepts",
|
||||
newer: &Observation{
|
||||
ID: 1,
|
||||
Concepts: []string{"auth", "api", "validation"},
|
||||
},
|
||||
older: &Observation{
|
||||
ID: 2,
|
||||
Concepts: []string{"auth", "api", "database"},
|
||||
},
|
||||
wantRelation: true,
|
||||
wantMinConfid: 0.5,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
result := DetectConceptOverlapRelation(tt.newer, tt.older)
|
||||
|
||||
if tt.wantRelation {
|
||||
if result == nil {
|
||||
t.Fatal("expected relation, got nil")
|
||||
}
|
||||
if result.Confidence < tt.wantMinConfid {
|
||||
t.Errorf("confidence = %v, want at least %v", result.Confidence, tt.wantMinConfid)
|
||||
}
|
||||
if result.DetectionSource != DetectionSourceConceptOverlap {
|
||||
t.Errorf("source = %v, want %v", result.DetectionSource, DetectionSourceConceptOverlap)
|
||||
}
|
||||
} else {
|
||||
if result != nil {
|
||||
t.Errorf("expected no relation, got %+v", result)
|
||||
}
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestDetectTypeProgressionRelation(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
newerType ObservationType
|
||||
olderType ObservationType
|
||||
wantRelation bool
|
||||
wantRelType RelationType
|
||||
}{
|
||||
{
|
||||
name: "bugfix fixes discovery",
|
||||
newerType: ObsTypeBugfix,
|
||||
olderType: ObsTypeDiscovery,
|
||||
wantRelation: true,
|
||||
wantRelType: RelationFixes,
|
||||
},
|
||||
{
|
||||
name: "bugfix fixes feature",
|
||||
newerType: ObsTypeBugfix,
|
||||
olderType: ObsTypeFeature,
|
||||
wantRelation: true,
|
||||
wantRelType: RelationFixes,
|
||||
},
|
||||
{
|
||||
name: "feature depends on decision",
|
||||
newerType: ObsTypeFeature,
|
||||
olderType: ObsTypeDecision,
|
||||
wantRelation: true,
|
||||
wantRelType: RelationDependsOn,
|
||||
},
|
||||
{
|
||||
name: "refactor evolves from discovery",
|
||||
newerType: ObsTypeRefactor,
|
||||
olderType: ObsTypeDiscovery,
|
||||
wantRelation: true,
|
||||
wantRelType: RelationEvolvesFrom,
|
||||
},
|
||||
{
|
||||
name: "no progression discovery to bugfix",
|
||||
newerType: ObsTypeDiscovery,
|
||||
olderType: ObsTypeBugfix,
|
||||
wantRelation: false,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
newer := &Observation{ID: 1, Type: tt.newerType}
|
||||
older := &Observation{ID: 2, Type: tt.olderType}
|
||||
result := DetectTypeProgressionRelation(newer, older)
|
||||
|
||||
if tt.wantRelation {
|
||||
if result == nil {
|
||||
t.Fatal("expected relation, got nil")
|
||||
}
|
||||
if result.RelationType != tt.wantRelType {
|
||||
t.Errorf("relation type = %v, want %v", result.RelationType, tt.wantRelType)
|
||||
}
|
||||
if result.DetectionSource != DetectionSourceTypeProgression {
|
||||
t.Errorf("source = %v, want %v", result.DetectionSource, DetectionSourceTypeProgression)
|
||||
}
|
||||
} else {
|
||||
if result != nil {
|
||||
t.Errorf("expected no relation, got %+v", result)
|
||||
}
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestDetectTemporalProximityRelation(t *testing.T) {
|
||||
baseTime := int64(1700000000000) // some base epoch ms
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
newerSession string
|
||||
olderSession string
|
||||
newerTime int64
|
||||
olderTime int64
|
||||
wantRelation bool
|
||||
}{
|
||||
{
|
||||
name: "same session close time",
|
||||
newerSession: "session-1",
|
||||
olderSession: "session-1",
|
||||
newerTime: baseTime + 60000, // 1 minute later
|
||||
olderTime: baseTime,
|
||||
wantRelation: true,
|
||||
},
|
||||
{
|
||||
name: "same session far apart",
|
||||
newerSession: "session-1",
|
||||
olderSession: "session-1",
|
||||
newerTime: baseTime + 600000, // 10 minutes later
|
||||
olderTime: baseTime,
|
||||
wantRelation: false, // > 5 minutes
|
||||
},
|
||||
{
|
||||
name: "different sessions close time",
|
||||
newerSession: "session-1",
|
||||
olderSession: "session-2",
|
||||
newerTime: baseTime + 30000,
|
||||
olderTime: baseTime,
|
||||
wantRelation: false, // different sessions
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
newer := &Observation{
|
||||
ID: 1,
|
||||
SDKSessionID: tt.newerSession,
|
||||
CreatedAtEpoch: tt.newerTime,
|
||||
}
|
||||
older := &Observation{
|
||||
ID: 2,
|
||||
SDKSessionID: tt.olderSession,
|
||||
CreatedAtEpoch: tt.olderTime,
|
||||
}
|
||||
result := DetectTemporalProximityRelation(newer, older)
|
||||
|
||||
if tt.wantRelation {
|
||||
if result == nil {
|
||||
t.Fatal("expected relation, got nil")
|
||||
}
|
||||
if result.DetectionSource != DetectionSourceTemporalProximity {
|
||||
t.Errorf("source = %v, want %v", result.DetectionSource, DetectionSourceTemporalProximity)
|
||||
}
|
||||
} else {
|
||||
if result != nil {
|
||||
t.Errorf("expected no relation, got %+v", result)
|
||||
}
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestDetectNarrativeMentionRelation(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
narrative string
|
||||
wantRelation bool
|
||||
wantRelType RelationType
|
||||
}{
|
||||
{
|
||||
name: "fixes language",
|
||||
narrative: "This change fixes the issue with authentication",
|
||||
wantRelation: true,
|
||||
wantRelType: RelationFixes,
|
||||
},
|
||||
{
|
||||
name: "causes language",
|
||||
narrative: "This decision caused unexpected side effects",
|
||||
wantRelation: true,
|
||||
wantRelType: RelationCauses,
|
||||
},
|
||||
{
|
||||
name: "supersedes language",
|
||||
narrative: "This approach supersedes the previous workaround",
|
||||
wantRelation: true,
|
||||
wantRelType: RelationSupersedes,
|
||||
},
|
||||
{
|
||||
name: "depends on language",
|
||||
narrative: "This feature depends on the authentication module",
|
||||
wantRelation: true,
|
||||
wantRelType: RelationDependsOn,
|
||||
},
|
||||
{
|
||||
name: "no relationship language",
|
||||
narrative: "Added new feature for user management",
|
||||
wantRelation: false,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
newer := &Observation{
|
||||
ID: 1,
|
||||
Narrative: sql.NullString{String: tt.narrative, Valid: true},
|
||||
}
|
||||
older := &Observation{ID: 2}
|
||||
result := DetectNarrativeMentionRelation(newer, older)
|
||||
|
||||
if tt.wantRelation {
|
||||
if result == nil {
|
||||
t.Fatal("expected relation, got nil")
|
||||
}
|
||||
if result.RelationType != tt.wantRelType {
|
||||
t.Errorf("relation type = %v, want %v", result.RelationType, tt.wantRelType)
|
||||
}
|
||||
if result.DetectionSource != DetectionSourceNarrativeMention {
|
||||
t.Errorf("source = %v, want %v", result.DetectionSource, DetectionSourceNarrativeMention)
|
||||
}
|
||||
} else {
|
||||
if result != nil {
|
||||
t.Errorf("expected no relation, got %+v", result)
|
||||
}
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestDetectRelationsWithExisting(t *testing.T) {
|
||||
newer := &Observation{
|
||||
ID: 1,
|
||||
SDKSessionID: "session-1",
|
||||
Project: "test-project",
|
||||
Type: ObsTypeBugfix,
|
||||
FilesModified: []string{"auth.go"},
|
||||
Concepts: []string{"security", "auth"},
|
||||
Narrative: sql.NullString{String: "Fixed security issue in auth module", Valid: true},
|
||||
}
|
||||
|
||||
existing := []*Observation{
|
||||
{
|
||||
ID: 2,
|
||||
SDKSessionID: "session-1",
|
||||
Project: "test-project",
|
||||
Type: ObsTypeDiscovery,
|
||||
FilesModified: []string{"auth.go"},
|
||||
Concepts: []string{"security"},
|
||||
},
|
||||
{
|
||||
ID: 3,
|
||||
SDKSessionID: "session-2",
|
||||
Project: "test-project",
|
||||
Type: ObsTypeFeature,
|
||||
FilesModified: []string{"other.go"},
|
||||
Concepts: []string{"api"},
|
||||
},
|
||||
{
|
||||
ID: 4,
|
||||
SDKSessionID: "session-1",
|
||||
Project: "other-project", // different project
|
||||
Type: ObsTypeDiscovery,
|
||||
},
|
||||
}
|
||||
|
||||
results := DetectRelationsWithExisting(newer, existing, 0.4)
|
||||
|
||||
// Should find relation with observation 2 (file overlap + concept overlap + type progression)
|
||||
// Should not find relation with observation 3 (no overlap)
|
||||
// Should not find relation with observation 4 (different project)
|
||||
|
||||
if len(results) == 0 {
|
||||
t.Fatal("expected at least one relation")
|
||||
}
|
||||
|
||||
// Check that we found relation with observation 2
|
||||
foundObs2 := false
|
||||
for _, r := range results {
|
||||
if r.TargetID == 2 {
|
||||
foundObs2 = true
|
||||
// Should be high confidence due to multiple signals
|
||||
if r.Confidence < 0.5 {
|
||||
t.Errorf("expected higher confidence for obs 2, got %v", r.Confidence)
|
||||
}
|
||||
}
|
||||
// Should not find relation with obs 4
|
||||
if r.TargetID == 4 {
|
||||
t.Error("should not find relation with different project")
|
||||
}
|
||||
}
|
||||
|
||||
if !foundObs2 {
|
||||
t.Error("expected to find relation with observation 2")
|
||||
}
|
||||
}
|
||||
|
||||
func TestNewObservationRelation(t *testing.T) {
|
||||
rel := NewObservationRelation(1, 2, RelationFixes, 0.8, DetectionSourceFileOverlap, "test reason")
|
||||
|
||||
if rel.SourceID != 1 {
|
||||
t.Errorf("SourceID = %v, want 1", rel.SourceID)
|
||||
}
|
||||
if rel.TargetID != 2 {
|
||||
t.Errorf("TargetID = %v, want 2", rel.TargetID)
|
||||
}
|
||||
if rel.RelationType != RelationFixes {
|
||||
t.Errorf("RelationType = %v, want %v", rel.RelationType, RelationFixes)
|
||||
}
|
||||
if rel.Confidence != 0.8 {
|
||||
t.Errorf("Confidence = %v, want 0.8", rel.Confidence)
|
||||
}
|
||||
if rel.DetectionSource != DetectionSourceFileOverlap {
|
||||
t.Errorf("DetectionSource = %v, want %v", rel.DetectionSource, DetectionSourceFileOverlap)
|
||||
}
|
||||
if rel.Reason != "test reason" {
|
||||
t.Errorf("Reason = %v, want 'test reason'", rel.Reason)
|
||||
}
|
||||
if rel.CreatedAt == "" {
|
||||
t.Error("CreatedAt should be set")
|
||||
}
|
||||
if rel.CreatedAtEpoch == 0 {
|
||||
t.Error("CreatedAtEpoch should be set")
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,112 @@
|
||||
// Package models contains domain models for claude-mnemonic.
|
||||
package models
|
||||
|
||||
// ConceptWeight represents a configurable weight for a concept.
|
||||
type ConceptWeight struct {
|
||||
Concept string `db:"concept" json:"concept"`
|
||||
Weight float64 `db:"weight" json:"weight"`
|
||||
UpdatedAt string `db:"updated_at" json:"updated_at"`
|
||||
}
|
||||
|
||||
// UserFeedbackType represents the type of user feedback.
|
||||
type UserFeedbackType int
|
||||
|
||||
const (
|
||||
// FeedbackNegative represents a thumbs down.
|
||||
FeedbackNegative UserFeedbackType = -1
|
||||
// FeedbackNeutral represents no feedback.
|
||||
FeedbackNeutral UserFeedbackType = 0
|
||||
// FeedbackPositive represents a thumbs up.
|
||||
FeedbackPositive UserFeedbackType = 1
|
||||
)
|
||||
|
||||
// DefaultConceptWeights contains the default weights for concepts.
|
||||
// Higher weights indicate more important concepts.
|
||||
var DefaultConceptWeights = map[string]float64{
|
||||
// Critical concepts (0.25-0.30)
|
||||
"security": 0.30, // Security issues are most critical
|
||||
|
||||
// High importance (0.20-0.25)
|
||||
"gotcha": 0.25, // Gotchas prevent future mistakes
|
||||
"best-practice": 0.20, // Best practices guide development
|
||||
"anti-pattern": 0.20, // Anti-patterns prevent bad code
|
||||
|
||||
// Medium importance (0.10-0.15)
|
||||
"architecture": 0.15, // Architectural decisions have lasting impact
|
||||
"performance": 0.15, // Performance patterns matter for scale
|
||||
"error-handling": 0.15, // Error handling prevents failures
|
||||
"pattern": 0.10, // General patterns are useful
|
||||
"testing": 0.10, // Testing knowledge helps quality
|
||||
"debugging": 0.10, // Debugging tips save time
|
||||
"problem-solution": 0.10, // Problem-solution pairs are actionable
|
||||
"trade-off": 0.10, // Trade-offs inform decisions
|
||||
|
||||
// Lower importance (0.05)
|
||||
"workflow": 0.05, // Workflow optimizations are nice-to-have
|
||||
"tooling": 0.05, // Tooling preferences are subjective
|
||||
"how-it-works": 0.05, // Understanding is foundational but less urgent
|
||||
"why-it-exists": 0.05, // Context is helpful but less actionable
|
||||
"what-changed": 0.05, // Changes are informational
|
||||
}
|
||||
|
||||
// TypeBaseScores contains the base importance multipliers for each observation type.
|
||||
// These are multiplied with the core score to weight different observation types.
|
||||
var TypeBaseScores = map[ObservationType]float64{
|
||||
ObsTypeBugfix: 1.3, // Bugfixes are valuable - prevent regressions
|
||||
ObsTypeFeature: 1.2, // New features expand capabilities
|
||||
ObsTypeDiscovery: 1.1, // Discoveries inform future work
|
||||
ObsTypeDecision: 1.1, // Architectural decisions guide development
|
||||
ObsTypeRefactor: 1.0, // Refactoring is neutral
|
||||
ObsTypeChange: 0.9, // Minor changes are slightly less important
|
||||
}
|
||||
|
||||
// ScoringConfig contains all scoring weights and parameters.
|
||||
type ScoringConfig struct {
|
||||
// RecencyHalfLifeDays is the number of days for the importance score to halve.
|
||||
// With 7 days, a 7-day old observation has 50% of a new observation's recency score.
|
||||
RecencyHalfLifeDays float64 `json:"recency_half_life_days"`
|
||||
|
||||
// FeedbackWeight scales the user feedback contribution to final score.
|
||||
// With 0.30, a thumbs up adds 0.30 to the score, thumbs down subtracts 0.30.
|
||||
FeedbackWeight float64 `json:"feedback_weight"`
|
||||
|
||||
// ConceptWeight scales the concept boost contribution.
|
||||
// The sum of matching concept weights is multiplied by this.
|
||||
ConceptWeight float64 `json:"concept_weight"`
|
||||
|
||||
// RetrievalWeight scales the retrieval boost contribution.
|
||||
// Popular observations get a logarithmic bonus.
|
||||
RetrievalWeight float64 `json:"retrieval_weight"`
|
||||
|
||||
// ConceptWeights maps concept names to their importance weights.
|
||||
ConceptWeights map[string]float64 `json:"concept_weights"`
|
||||
|
||||
// MinScore is the minimum allowed importance score.
|
||||
// Prevents observations from completely disappearing.
|
||||
MinScore float64 `json:"min_score"`
|
||||
}
|
||||
|
||||
// DefaultScoringConfig returns the default scoring configuration.
|
||||
func DefaultScoringConfig() *ScoringConfig {
|
||||
conceptWeights := make(map[string]float64, len(DefaultConceptWeights))
|
||||
for k, v := range DefaultConceptWeights {
|
||||
conceptWeights[k] = v
|
||||
}
|
||||
|
||||
return &ScoringConfig{
|
||||
RecencyHalfLifeDays: 7.0, // Score halves every 7 days
|
||||
FeedbackWeight: 0.30, // Feedback has moderate impact
|
||||
ConceptWeight: 0.20, // Concept weights have smaller impact
|
||||
RetrievalWeight: 0.15, // Retrieval has smallest impact
|
||||
ConceptWeights: conceptWeights,
|
||||
MinScore: 0.01, // Never completely disappear
|
||||
}
|
||||
}
|
||||
|
||||
// TypeBaseScore returns the base weight for an observation type.
|
||||
func TypeBaseScore(t ObservationType) float64 {
|
||||
if score, ok := TypeBaseScores[t]; ok {
|
||||
return score
|
||||
}
|
||||
return 1.0 // Default for unknown types
|
||||
}
|
||||
Executable
+121
@@ -0,0 +1,121 @@
|
||||
#!/bin/bash
|
||||
# Download BGE-small-en-v1.5 model for embedding
|
||||
# Usage: ./download-bge-model.sh [--force]
|
||||
# Use --force to re-download even if files exist
|
||||
|
||||
set -e
|
||||
|
||||
MODEL_NAME="bge-small-en-v1.5"
|
||||
MODEL_REPO="BAAI/bge-small-en-v1.5"
|
||||
ASSETS_DIR="internal/embedding/assets"
|
||||
VERSION_FILE="${ASSETS_DIR}/.model_version"
|
||||
FORCE_DOWNLOAD=false
|
||||
|
||||
# Check for --force flag
|
||||
for arg in "$@"; do
|
||||
if [ "$arg" = "--force" ]; then
|
||||
FORCE_DOWNLOAD=true
|
||||
fi
|
||||
done
|
||||
|
||||
# Temporary directory for downloads
|
||||
TEMP_DIR=$(mktemp -d)
|
||||
trap "rm -rf ${TEMP_DIR}" EXIT
|
||||
|
||||
# Check if model already exists
|
||||
model_exists() {
|
||||
[ -f "${ASSETS_DIR}/model.onnx" ] && [ -f "${ASSETS_DIR}/tokenizer.json" ]
|
||||
}
|
||||
|
||||
# Get installed version
|
||||
get_installed_version() {
|
||||
if [ -f "$VERSION_FILE" ]; then
|
||||
cat "$VERSION_FILE"
|
||||
else
|
||||
echo ""
|
||||
fi
|
||||
}
|
||||
|
||||
# Write version file
|
||||
write_version_file() {
|
||||
echo "${MODEL_NAME}" > "$VERSION_FILE"
|
||||
}
|
||||
|
||||
download_model() {
|
||||
echo "Downloading ${MODEL_NAME} from Hugging Face..."
|
||||
|
||||
# Create assets directory
|
||||
mkdir -p "${ASSETS_DIR}"
|
||||
|
||||
# Download ONNX model
|
||||
# BGE models have ONNX exports available in the repo
|
||||
echo "Downloading ONNX model..."
|
||||
curl -fsSL \
|
||||
"https://huggingface.co/${MODEL_REPO}/resolve/main/onnx/model.onnx" \
|
||||
-o "${TEMP_DIR}/model.onnx"
|
||||
|
||||
# Download tokenizer.json
|
||||
echo "Downloading tokenizer..."
|
||||
curl -fsSL \
|
||||
"https://huggingface.co/${MODEL_REPO}/resolve/main/tokenizer.json" \
|
||||
-o "${TEMP_DIR}/tokenizer.json"
|
||||
|
||||
# Verify files exist and have content
|
||||
if [ ! -s "${TEMP_DIR}/model.onnx" ]; then
|
||||
echo "Error: Failed to download model.onnx or file is empty"
|
||||
exit 1
|
||||
fi
|
||||
|
||||
if [ ! -s "${TEMP_DIR}/tokenizer.json" ]; then
|
||||
echo "Error: Failed to download tokenizer.json or file is empty"
|
||||
exit 1
|
||||
fi
|
||||
|
||||
# Move to assets directory (backup old files first)
|
||||
if [ -f "${ASSETS_DIR}/model.onnx" ]; then
|
||||
mv "${ASSETS_DIR}/model.onnx" "${ASSETS_DIR}/model.onnx.bak"
|
||||
fi
|
||||
if [ -f "${ASSETS_DIR}/tokenizer.json" ]; then
|
||||
mv "${ASSETS_DIR}/tokenizer.json" "${ASSETS_DIR}/tokenizer.json.bak"
|
||||
fi
|
||||
|
||||
mv "${TEMP_DIR}/model.onnx" "${ASSETS_DIR}/model.onnx"
|
||||
mv "${TEMP_DIR}/tokenizer.json" "${ASSETS_DIR}/tokenizer.json"
|
||||
|
||||
# Remove backups on success
|
||||
rm -f "${ASSETS_DIR}/model.onnx.bak" "${ASSETS_DIR}/tokenizer.json.bak"
|
||||
|
||||
# Write version file
|
||||
write_version_file
|
||||
|
||||
echo "Model size: $(du -h "${ASSETS_DIR}/model.onnx" | cut -f1)"
|
||||
echo "Tokenizer size: $(du -h "${ASSETS_DIR}/tokenizer.json" | cut -f1)"
|
||||
}
|
||||
|
||||
echo "BGE Model Downloader - ${MODEL_NAME}"
|
||||
echo "=================================="
|
||||
|
||||
need_download=false
|
||||
reason=""
|
||||
|
||||
if [ "$FORCE_DOWNLOAD" = true ]; then
|
||||
need_download=true
|
||||
reason="forced"
|
||||
elif ! model_exists; then
|
||||
need_download=true
|
||||
reason="not found"
|
||||
elif [ "$(get_installed_version)" != "${MODEL_NAME}" ]; then
|
||||
need_download=true
|
||||
reason="version mismatch (installed: $(get_installed_version), required: ${MODEL_NAME})"
|
||||
fi
|
||||
|
||||
if [ "$need_download" = true ]; then
|
||||
if [ -n "$reason" ] && [ "$reason" != "not found" ]; then
|
||||
echo "Re-downloading: ${reason}"
|
||||
fi
|
||||
download_model
|
||||
echo "Done! ${MODEL_NAME} installed successfully."
|
||||
else
|
||||
echo "Model ${MODEL_NAME} already exists, skipping download."
|
||||
echo "Use --force to re-download."
|
||||
fi
|
||||
Generated
+108
-2
@@ -1,13 +1,15 @@
|
||||
{
|
||||
"name": "claude-mnemonic-dashboard",
|
||||
"version": "v0.6.33-3-gf38ce5c-dirty",
|
||||
"version": "0ddacaa-dirty",
|
||||
"lockfileVersion": 3,
|
||||
"requires": true,
|
||||
"packages": {
|
||||
"": {
|
||||
"name": "claude-mnemonic-dashboard",
|
||||
"version": "v0.6.33-3-gf38ce5c-dirty",
|
||||
"version": "0ddacaa-dirty",
|
||||
"dependencies": {
|
||||
"vis-data": "^7.1.9",
|
||||
"vis-network": "^9.1.9",
|
||||
"vue": "^3.5.13"
|
||||
},
|
||||
"devDependencies": {
|
||||
@@ -81,6 +83,19 @@
|
||||
"node": ">=6.9.0"
|
||||
}
|
||||
},
|
||||
"node_modules/@egjs/hammerjs": {
|
||||
"version": "2.0.17",
|
||||
"resolved": "https://registry.npmjs.org/@egjs/hammerjs/-/hammerjs-2.0.17.tgz",
|
||||
"integrity": "sha512-XQsZgjm2EcVUiZQf11UBJQfmZeEmOW8DpI1gsFeln6w0ae0ii4dMQEQ0kjl6DspdWX1aGY1/loyXnP0JS06e/A==",
|
||||
"license": "MIT",
|
||||
"peer": true,
|
||||
"dependencies": {
|
||||
"@types/hammerjs": "^2.0.36"
|
||||
},
|
||||
"engines": {
|
||||
"node": ">=0.8.0"
|
||||
}
|
||||
},
|
||||
"node_modules/@esbuild/aix-ppc64": {
|
||||
"version": "0.25.12",
|
||||
"resolved": "https://registry.npmjs.org/@esbuild/aix-ppc64/-/aix-ppc64-0.25.12.tgz",
|
||||
@@ -924,6 +939,13 @@
|
||||
"dev": true,
|
||||
"license": "MIT"
|
||||
},
|
||||
"node_modules/@types/hammerjs": {
|
||||
"version": "2.0.46",
|
||||
"resolved": "https://registry.npmjs.org/@types/hammerjs/-/hammerjs-2.0.46.tgz",
|
||||
"integrity": "sha512-ynRvcq6wvqexJ9brDMS4BnBLzmr0e14d6ZJTEShTBWKymQiHwlAyGu0ZPEFI2Fh1U53F7tN9ufClWM5KvqkKOw==",
|
||||
"license": "MIT",
|
||||
"peer": true
|
||||
},
|
||||
"node_modules/@types/node": {
|
||||
"version": "22.19.2",
|
||||
"resolved": "https://registry.npmjs.org/@types/node/-/node-22.19.2.tgz",
|
||||
@@ -1352,6 +1374,19 @@
|
||||
"node": ">= 6"
|
||||
}
|
||||
},
|
||||
"node_modules/component-emitter": {
|
||||
"version": "2.0.0",
|
||||
"resolved": "https://registry.npmjs.org/component-emitter/-/component-emitter-2.0.0.tgz",
|
||||
"integrity": "sha512-4m5s3Me2xxlVKG9PkZpQqHQR7bgpnN7joDMJ4yvVkVXngjoITG76IaZmzmywSeRTeTpc6N6r3H3+KyUurV8OYw==",
|
||||
"license": "MIT",
|
||||
"peer": true,
|
||||
"engines": {
|
||||
"node": ">=18"
|
||||
},
|
||||
"funding": {
|
||||
"url": "https://github.com/sponsors/sindresorhus"
|
||||
}
|
||||
},
|
||||
"node_modules/cssesc": {
|
||||
"version": "3.0.0",
|
||||
"resolved": "https://registry.npmjs.org/cssesc/-/cssesc-3.0.0.tgz",
|
||||
@@ -1669,6 +1704,13 @@
|
||||
"jiti": "bin/jiti.js"
|
||||
}
|
||||
},
|
||||
"node_modules/keycharm": {
|
||||
"version": "0.4.0",
|
||||
"resolved": "https://registry.npmjs.org/keycharm/-/keycharm-0.4.0.tgz",
|
||||
"integrity": "sha512-TyQTtsabOVv3MeOpR92sIKk/br9wxS+zGj4BG7CR8YbK4jM3tyIBaF0zhzeBUMx36/Q/iQLOKKOT+3jOQtemRQ==",
|
||||
"license": "(Apache-2.0 OR MIT)",
|
||||
"peer": true
|
||||
},
|
||||
"node_modules/lilconfig": {
|
||||
"version": "3.1.3",
|
||||
"resolved": "https://registry.npmjs.org/lilconfig/-/lilconfig-3.1.3.tgz",
|
||||
@@ -2412,6 +2454,70 @@
|
||||
"dev": true,
|
||||
"license": "MIT"
|
||||
},
|
||||
"node_modules/uuid": {
|
||||
"version": "11.1.0",
|
||||
"resolved": "https://registry.npmjs.org/uuid/-/uuid-11.1.0.tgz",
|
||||
"integrity": "sha512-0/A9rDy9P7cJ+8w1c9WD9V//9Wj15Ce2MPz8Ri6032usz+NfePxx5AcN3bN+r6ZL6jEo066/yNYB3tn4pQEx+A==",
|
||||
"funding": [
|
||||
"https://github.com/sponsors/broofa",
|
||||
"https://github.com/sponsors/ctavan"
|
||||
],
|
||||
"license": "MIT",
|
||||
"peer": true,
|
||||
"bin": {
|
||||
"uuid": "dist/esm/bin/uuid"
|
||||
}
|
||||
},
|
||||
"node_modules/vis-data": {
|
||||
"version": "7.1.10",
|
||||
"resolved": "https://registry.npmjs.org/vis-data/-/vis-data-7.1.10.tgz",
|
||||
"integrity": "sha512-23juM9tdCaHTX5vyIQ7XBzsfZU0Hny+gSTwniLrfFcmw9DOm7pi3+h9iEBsoZMp5rX6KNqWwc1MF0fkAmWVuoQ==",
|
||||
"license": "(Apache-2.0 OR MIT)",
|
||||
"funding": {
|
||||
"type": "opencollective",
|
||||
"url": "https://opencollective.com/visjs"
|
||||
},
|
||||
"peerDependencies": {
|
||||
"uuid": "^3.4.0 || ^7.0.0 || ^8.0.0 || ^9.0.0 || ^10.0.0 || ^11.0.0",
|
||||
"vis-util": "^5.0.1"
|
||||
}
|
||||
},
|
||||
"node_modules/vis-network": {
|
||||
"version": "9.1.13",
|
||||
"resolved": "https://registry.npmjs.org/vis-network/-/vis-network-9.1.13.tgz",
|
||||
"integrity": "sha512-HLeHd5KZS92qzO1kC59qMh1/FWAZxMUEwUWBwDMoj6RKj/Ajkrgy/heEYo0Zc8SZNQ2J+u6omvK2+a28GX1QuQ==",
|
||||
"license": "(Apache-2.0 OR MIT)",
|
||||
"funding": {
|
||||
"type": "opencollective",
|
||||
"url": "https://opencollective.com/visjs"
|
||||
},
|
||||
"peerDependencies": {
|
||||
"@egjs/hammerjs": "^2.0.0",
|
||||
"component-emitter": "^1.3.0 || ^2.0.0",
|
||||
"keycharm": "^0.2.0 || ^0.3.0 || ^0.4.0",
|
||||
"uuid": "^3.4.0 || ^7.0.0 || ^8.0.0 || ^9.0.0 || ^10.0.0 || ^11.0.0",
|
||||
"vis-data": "^6.3.0 || ^7.0.0",
|
||||
"vis-util": "^5.0.1"
|
||||
}
|
||||
},
|
||||
"node_modules/vis-util": {
|
||||
"version": "5.0.7",
|
||||
"resolved": "https://registry.npmjs.org/vis-util/-/vis-util-5.0.7.tgz",
|
||||
"integrity": "sha512-E3L03G3+trvc/X4LXvBfih3YIHcKS2WrP0XTdZefr6W6Qi/2nNCqZfe4JFfJU6DcQLm6Gxqj2Pfl+02859oL5A==",
|
||||
"license": "(Apache-2.0 OR MIT)",
|
||||
"peer": true,
|
||||
"engines": {
|
||||
"node": ">=8"
|
||||
},
|
||||
"funding": {
|
||||
"type": "opencollective",
|
||||
"url": "https://opencollective.com/visjs"
|
||||
},
|
||||
"peerDependencies": {
|
||||
"@egjs/hammerjs": "^2.0.0",
|
||||
"component-emitter": "^1.3.0 || ^2.0.0"
|
||||
}
|
||||
},
|
||||
"node_modules/vite": {
|
||||
"version": "6.4.1",
|
||||
"resolved": "https://registry.npmjs.org/vite/-/vite-6.4.1.tgz",
|
||||
|
||||
+3
-1
@@ -1,6 +1,6 @@
|
||||
{
|
||||
"name": "claude-mnemonic-dashboard",
|
||||
"version": "v0.6.33-3-gf38ce5c-dirty",
|
||||
"version": "0ddacaa-dirty",
|
||||
"private": true,
|
||||
"type": "module",
|
||||
"scripts": {
|
||||
@@ -10,6 +10,8 @@
|
||||
"type-check": "vue-tsc --noEmit"
|
||||
},
|
||||
"dependencies": {
|
||||
"vis-data": "^7.1.9",
|
||||
"vis-network": "^9.1.9",
|
||||
"vue": "^3.5.13"
|
||||
},
|
||||
"devDependencies": {
|
||||
|
||||
@@ -10,6 +10,8 @@
|
||||
"type-check": "vue-tsc --noEmit"
|
||||
},
|
||||
"dependencies": {
|
||||
"vis-data": "^7.1.9",
|
||||
"vis-network": "^9.1.9",
|
||||
"vue": "^3.5.13"
|
||||
},
|
||||
"devDependencies": {
|
||||
|
||||
+1
-1
@@ -30,7 +30,7 @@ const {
|
||||
// Pass currentProject ref to useStats for project-specific retrieval stats
|
||||
const { stats } = useStats(currentProject)
|
||||
|
||||
// Note: Timeline refresh is handled by useTimeline's SSE watcher
|
||||
// Note: Feedback is handled directly in ObservationCard component
|
||||
</script>
|
||||
|
||||
<template>
|
||||
|
||||
@@ -1,17 +1,66 @@
|
||||
<script setup lang="ts">
|
||||
import type { ObservationFeedItem } from '@/types'
|
||||
import type { ObservationFeedItem, RelationWithDetails } from '@/types'
|
||||
import { TYPE_CONFIG, CONCEPT_CONFIG } from '@/types/observation'
|
||||
import { RELATION_TYPE_CONFIG, DETECTION_SOURCE_CONFIG } from '@/types/relation'
|
||||
import { formatRelativeTime } from '@/utils/formatters'
|
||||
import { fetchObservationRelations } from '@/utils/api'
|
||||
import Card from './Card.vue'
|
||||
import IconBox from './IconBox.vue'
|
||||
import Badge from './Badge.vue'
|
||||
import { computed } from 'vue'
|
||||
import RelationGraph from './RelationGraph.vue'
|
||||
import { computed, ref, onMounted } from 'vue'
|
||||
|
||||
const props = defineProps<{
|
||||
observation: ObservationFeedItem
|
||||
highlight?: boolean
|
||||
showFeedback?: boolean
|
||||
}>()
|
||||
|
||||
const emit = defineEmits<{
|
||||
navigateToObservation: [id: number]
|
||||
}>()
|
||||
|
||||
// Local feedback and score state (optimistic updates)
|
||||
const localFeedback = ref<number | null>(null)
|
||||
const localScore = ref<number | null>(null)
|
||||
const isSubmitting = ref(false)
|
||||
|
||||
const currentFeedback = computed(() =>
|
||||
localFeedback.value !== null ? localFeedback.value : (props.observation.user_feedback || 0)
|
||||
)
|
||||
|
||||
const currentScore = computed(() =>
|
||||
localScore.value !== null ? localScore.value : (props.observation.importance_score || 1)
|
||||
)
|
||||
|
||||
const submitFeedback = async (value: number) => {
|
||||
if (isSubmitting.value) return
|
||||
|
||||
// Toggle off if clicking same button
|
||||
const newValue = currentFeedback.value === value ? 0 : value
|
||||
|
||||
localFeedback.value = newValue
|
||||
isSubmitting.value = true
|
||||
|
||||
try {
|
||||
const response = await fetch(`/api/observations/${props.observation.id}/feedback`, {
|
||||
method: 'POST',
|
||||
headers: { 'Content-Type': 'application/json' },
|
||||
body: JSON.stringify({ feedback: newValue })
|
||||
})
|
||||
if (response.ok) {
|
||||
const data = await response.json()
|
||||
if (data.score !== undefined) {
|
||||
localScore.value = data.score
|
||||
}
|
||||
}
|
||||
} catch (error) {
|
||||
console.error('Error submitting feedback:', error)
|
||||
} finally {
|
||||
isSubmitting.value = false
|
||||
}
|
||||
}
|
||||
|
||||
const config = computed(() => TYPE_CONFIG[props.observation.type] || TYPE_CONFIG.change)
|
||||
|
||||
const concepts = computed(() => {
|
||||
@@ -40,6 +89,60 @@ const filesModified = computed(() => {
|
||||
|
||||
const hasFiles = computed(() => filesRead.value.length > 0 || filesModified.value.length > 0)
|
||||
|
||||
// Relations state
|
||||
const relations = ref<RelationWithDetails[]>([])
|
||||
const relationsLoading = ref(false)
|
||||
const relationsExpanded = ref(false)
|
||||
const showGraph = ref(false)
|
||||
|
||||
const hasRelations = computed(() => relations.value.length > 0)
|
||||
const relationCount = computed(() => relations.value.length)
|
||||
|
||||
// Load relations on mount
|
||||
const loadRelations = async () => {
|
||||
relationsLoading.value = true
|
||||
try {
|
||||
relations.value = await fetchObservationRelations(props.observation.id)
|
||||
} catch (err) {
|
||||
console.error('Failed to load relations:', err)
|
||||
} finally {
|
||||
relationsLoading.value = false
|
||||
}
|
||||
}
|
||||
|
||||
onMounted(() => {
|
||||
loadRelations()
|
||||
})
|
||||
|
||||
// Toggle relations expansion
|
||||
const toggleRelations = () => {
|
||||
relationsExpanded.value = !relationsExpanded.value
|
||||
}
|
||||
|
||||
// Open graph modal
|
||||
const openGraph = (e: Event) => {
|
||||
e.stopPropagation()
|
||||
showGraph.value = true
|
||||
}
|
||||
|
||||
// Handle navigation from graph
|
||||
const handleNavigateTo = (id: number) => {
|
||||
showGraph.value = false
|
||||
emit('navigateToObservation', id)
|
||||
}
|
||||
|
||||
// Get relation display info (whether we're source or target)
|
||||
const getRelationDisplay = (rel: RelationWithDetails) => {
|
||||
const isSource = rel.relation.source_id === props.observation.id
|
||||
return {
|
||||
type: rel.relation.relation_type,
|
||||
otherTitle: isSource ? rel.target_title : rel.source_title,
|
||||
otherId: isSource ? rel.relation.target_id : rel.relation.source_id,
|
||||
direction: isSource ? 'outgoing' : 'incoming',
|
||||
confidence: rel.relation.confidence
|
||||
}
|
||||
}
|
||||
|
||||
// Split path into project root and relative path for styling
|
||||
// e.g., /Users/foo/project/src/file.go → { root: 'project', path: 'src/file.go' }
|
||||
const splitPath = (path: string, components = 3) => {
|
||||
@@ -140,7 +243,145 @@ const splitPath = (path: string, components = 3) => {
|
||||
</div>
|
||||
</div>
|
||||
</div>
|
||||
|
||||
<!-- Relations -->
|
||||
<div v-if="hasRelations || relationsLoading" class="mt-3 pt-3 border-t border-slate-700/50">
|
||||
<!-- Header with count and graph button -->
|
||||
<div class="flex items-center justify-between">
|
||||
<button
|
||||
@click="toggleRelations"
|
||||
class="flex items-center gap-2 text-xs text-slate-400 hover:text-slate-300 transition-colors"
|
||||
:disabled="relationsLoading"
|
||||
>
|
||||
<i class="fas fa-diagram-project text-cyan-500/70" />
|
||||
<span v-if="relationsLoading" class="text-slate-500">
|
||||
<i class="fas fa-circle-notch fa-spin mr-1" />
|
||||
Loading relations...
|
||||
</span>
|
||||
<span v-else>
|
||||
{{ relationCount }} related observation{{ relationCount !== 1 ? 's' : '' }}
|
||||
</span>
|
||||
<i
|
||||
v-if="!relationsLoading && hasRelations"
|
||||
class="fas text-[10px] transition-transform"
|
||||
:class="relationsExpanded ? 'fa-chevron-up' : 'fa-chevron-down'"
|
||||
/>
|
||||
</button>
|
||||
|
||||
<!-- View Graph button -->
|
||||
<button
|
||||
v-if="hasRelations"
|
||||
@click="openGraph"
|
||||
class="flex items-center gap-1.5 px-2 py-1 text-xs text-cyan-400 hover:text-cyan-300 bg-cyan-500/10 hover:bg-cyan-500/20 rounded transition-colors"
|
||||
title="View knowledge graph"
|
||||
>
|
||||
<i class="fas fa-project-diagram" />
|
||||
<span>View Graph</span>
|
||||
</button>
|
||||
</div>
|
||||
|
||||
<!-- Expanded relations list -->
|
||||
<div
|
||||
v-if="relationsExpanded && hasRelations"
|
||||
class="mt-2 space-y-1.5"
|
||||
>
|
||||
<div
|
||||
v-for="rel in relations"
|
||||
:key="rel.relation.id"
|
||||
class="flex items-center gap-2 text-xs p-1.5 rounded bg-slate-800/30 hover:bg-slate-800/50 transition-colors group"
|
||||
>
|
||||
<!-- Relation type icon -->
|
||||
<i
|
||||
class="fas w-4 text-center"
|
||||
:class="[
|
||||
RELATION_TYPE_CONFIG[getRelationDisplay(rel).type]?.icon || 'fa-link',
|
||||
RELATION_TYPE_CONFIG[getRelationDisplay(rel).type]?.colorClass || 'text-slate-400'
|
||||
]"
|
||||
:title="RELATION_TYPE_CONFIG[getRelationDisplay(rel).type]?.label"
|
||||
/>
|
||||
|
||||
<!-- Direction arrow -->
|
||||
<i
|
||||
class="fas text-[10px] text-slate-600"
|
||||
:class="getRelationDisplay(rel).direction === 'outgoing' ? 'fa-arrow-right' : 'fa-arrow-left'"
|
||||
/>
|
||||
|
||||
<!-- Related observation title -->
|
||||
<span
|
||||
class="flex-1 truncate text-slate-300 cursor-pointer hover:text-amber-300 transition-colors"
|
||||
:title="getRelationDisplay(rel).otherTitle"
|
||||
@click="emit('navigateToObservation', getRelationDisplay(rel).otherId)"
|
||||
>
|
||||
{{ getRelationDisplay(rel).otherTitle || 'Untitled' }}
|
||||
</span>
|
||||
|
||||
<!-- Confidence -->
|
||||
<span
|
||||
class="text-[10px] text-slate-500 font-mono"
|
||||
:title="`${Math.round(getRelationDisplay(rel).confidence * 100)}% confidence`"
|
||||
>
|
||||
{{ Math.round(getRelationDisplay(rel).confidence * 100) }}%
|
||||
</span>
|
||||
|
||||
<!-- Detection source icon -->
|
||||
<i
|
||||
class="fas text-[10px] text-slate-600"
|
||||
:class="DETECTION_SOURCE_CONFIG[rel.relation.detection_source]?.icon || 'fa-question'"
|
||||
:title="DETECTION_SOURCE_CONFIG[rel.relation.detection_source]?.label"
|
||||
/>
|
||||
</div>
|
||||
</div>
|
||||
</div>
|
||||
</div>
|
||||
|
||||
<!-- Feedback buttons (right side) -->
|
||||
<div v-if="showFeedback" class="flex flex-col items-center gap-1 ml-2 flex-shrink-0">
|
||||
<button
|
||||
@click="submitFeedback(1)"
|
||||
:disabled="isSubmitting"
|
||||
:class="[
|
||||
'p-1.5 rounded-lg transition-all duration-200',
|
||||
currentFeedback === 1
|
||||
? 'bg-green-500/30 text-green-300 shadow-green-500/20 shadow-sm'
|
||||
: 'text-slate-500 hover:text-green-400 hover:bg-green-500/10'
|
||||
]"
|
||||
title="Helpful"
|
||||
>
|
||||
<i class="fas fa-thumbs-up text-sm" />
|
||||
</button>
|
||||
|
||||
<span
|
||||
class="text-[10px] font-mono px-1.5 py-0.5 rounded bg-slate-800/50 text-slate-400 flex items-center gap-1 transition-all duration-300"
|
||||
:class="{ 'text-green-400': localScore !== null && localScore > (observation.importance_score || 1), 'text-red-400': localScore !== null && localScore < (observation.importance_score || 1) }"
|
||||
:title="`Importance Score: ${currentScore.toFixed(3)}\nRetrieval Count: ${observation.retrieval_count || 0}`"
|
||||
>
|
||||
<i class="fas fa-scale-balanced text-amber-500/60" />
|
||||
{{ currentScore.toFixed(2) }}
|
||||
</span>
|
||||
|
||||
<button
|
||||
@click="submitFeedback(-1)"
|
||||
:disabled="isSubmitting"
|
||||
:class="[
|
||||
'p-1.5 rounded-lg transition-all duration-200',
|
||||
currentFeedback === -1
|
||||
? 'bg-red-500/30 text-red-300 shadow-red-500/20 shadow-sm'
|
||||
: 'text-slate-500 hover:text-red-400 hover:bg-red-500/10'
|
||||
]"
|
||||
title="Not helpful"
|
||||
>
|
||||
<i class="fas fa-thumbs-down text-sm" />
|
||||
</button>
|
||||
</div>
|
||||
</div>
|
||||
|
||||
<!-- Relation Graph Modal -->
|
||||
<RelationGraph
|
||||
:observation-id="observation.id"
|
||||
:observation-title="observation.title || 'Untitled'"
|
||||
:show="showGraph"
|
||||
@close="showGraph = false"
|
||||
@navigate-to="handleNavigateTo"
|
||||
/>
|
||||
</Card>
|
||||
</template>
|
||||
|
||||
@@ -0,0 +1,409 @@
|
||||
<script setup lang="ts">
|
||||
import { ref, onMounted, onUnmounted, watch, computed } from 'vue'
|
||||
import { Network, type Data, type Options } from 'vis-network'
|
||||
import type { RelationGraph, RelationWithDetails } from '@/types'
|
||||
import { RELATION_TYPE_CONFIG, DETECTION_SOURCE_CONFIG } from '@/types/relation'
|
||||
import { fetchObservationGraph } from '@/utils/api'
|
||||
|
||||
const props = defineProps<{
|
||||
observationId: number
|
||||
observationTitle: string
|
||||
show: boolean
|
||||
}>()
|
||||
|
||||
const emit = defineEmits<{
|
||||
close: []
|
||||
navigateTo: [id: number]
|
||||
}>()
|
||||
|
||||
const graphContainer = ref<HTMLElement | null>(null)
|
||||
const loading = ref(true)
|
||||
const error = ref<string | null>(null)
|
||||
const graphData = ref<RelationGraph | null>(null)
|
||||
const selectedRelation = ref<RelationWithDetails | null>(null)
|
||||
const depth = ref(2)
|
||||
|
||||
let network: Network | null = null
|
||||
|
||||
// Node colors based on observation type
|
||||
const getNodeColor = (type: string) => {
|
||||
const colors: Record<string, { background: string; border: string; highlight: { background: string; border: string } }> = {
|
||||
bugfix: { background: '#ef4444', border: '#dc2626', highlight: { background: '#f87171', border: '#ef4444' } },
|
||||
feature: { background: '#a855f7', border: '#9333ea', highlight: { background: '#c084fc', border: '#a855f7' } },
|
||||
refactor: { background: '#3b82f6', border: '#2563eb', highlight: { background: '#60a5fa', border: '#3b82f6' } },
|
||||
discovery: { background: '#06b6d4', border: '#0891b2', highlight: { background: '#22d3ee', border: '#06b6d4' } },
|
||||
decision: { background: '#eab308', border: '#ca8a04', highlight: { background: '#facc15', border: '#eab308' } },
|
||||
change: { background: '#64748b', border: '#475569', highlight: { background: '#94a3b8', border: '#64748b' } },
|
||||
}
|
||||
return colors[type] || colors.change
|
||||
}
|
||||
|
||||
// Edge colors based on relation type
|
||||
const getEdgeColor = (type: string) => {
|
||||
const colors: Record<string, string> = {
|
||||
causes: '#f97316',
|
||||
fixes: '#22c55e',
|
||||
supersedes: '#a855f7',
|
||||
depends_on: '#3b82f6',
|
||||
relates_to: '#64748b',
|
||||
evolves_from: '#06b6d4',
|
||||
}
|
||||
return colors[type] || '#64748b'
|
||||
}
|
||||
|
||||
const loadGraph = async () => {
|
||||
loading.value = true
|
||||
error.value = null
|
||||
|
||||
try {
|
||||
graphData.value = await fetchObservationGraph(props.observationId, depth.value)
|
||||
renderGraph()
|
||||
} catch (err) {
|
||||
error.value = err instanceof Error ? err.message : 'Failed to load graph'
|
||||
} finally {
|
||||
loading.value = false
|
||||
}
|
||||
}
|
||||
|
||||
const renderGraph = () => {
|
||||
if (!graphContainer.value || !graphData.value) return
|
||||
|
||||
// Build nodes and edges from relations
|
||||
const nodeMap = new Map<number, { id: number; label: string; type: string }>()
|
||||
const edgesList: { from: number; to: number; label: string; color: string; arrows: string; relation: RelationWithDetails }[] = []
|
||||
|
||||
// Add center node
|
||||
nodeMap.set(props.observationId, {
|
||||
id: props.observationId,
|
||||
label: truncateLabel(props.observationTitle),
|
||||
type: 'center'
|
||||
})
|
||||
|
||||
// Process relations
|
||||
for (const rel of graphData.value.relations) {
|
||||
// Add source node
|
||||
if (!nodeMap.has(rel.relation.source_id)) {
|
||||
nodeMap.set(rel.relation.source_id, {
|
||||
id: rel.relation.source_id,
|
||||
label: truncateLabel(rel.source_title),
|
||||
type: rel.source_type
|
||||
})
|
||||
}
|
||||
|
||||
// Add target node
|
||||
if (!nodeMap.has(rel.relation.target_id)) {
|
||||
nodeMap.set(rel.relation.target_id, {
|
||||
id: rel.relation.target_id,
|
||||
label: truncateLabel(rel.target_title),
|
||||
type: rel.target_type
|
||||
})
|
||||
}
|
||||
|
||||
// Add edge
|
||||
edgesList.push({
|
||||
from: rel.relation.source_id,
|
||||
to: rel.relation.target_id,
|
||||
label: rel.relation.relation_type.replace('_', ' '),
|
||||
color: getEdgeColor(rel.relation.relation_type),
|
||||
arrows: 'to',
|
||||
relation: rel
|
||||
})
|
||||
}
|
||||
|
||||
// Create vis-network data using plain arrays (simpler type compatibility)
|
||||
const nodes = Array.from(nodeMap.values()).map(node => ({
|
||||
id: node.id,
|
||||
label: node.label,
|
||||
color: node.id === props.observationId
|
||||
? { background: '#f59e0b', border: '#d97706', highlight: { background: '#fbbf24', border: '#f59e0b' } }
|
||||
: getNodeColor(node.type),
|
||||
font: { color: '#fff', size: 12 },
|
||||
shape: 'box' as const,
|
||||
borderWidth: node.id === props.observationId ? 3 : 2,
|
||||
margin: { top: 10, right: 10, bottom: 10, left: 10 },
|
||||
shadow: true
|
||||
}))
|
||||
|
||||
const edges = edgesList.map((edge, index) => ({
|
||||
id: index,
|
||||
from: edge.from,
|
||||
to: edge.to,
|
||||
label: edge.label,
|
||||
color: { color: edge.color, highlight: edge.color },
|
||||
font: { color: '#94a3b8', size: 10, strokeWidth: 0 },
|
||||
arrows: edge.arrows,
|
||||
width: 2,
|
||||
smooth: { enabled: true, type: 'curvedCW' as const, roundness: 0.2 }
|
||||
}))
|
||||
|
||||
// Cleanup existing network
|
||||
if (network) {
|
||||
network.destroy()
|
||||
}
|
||||
|
||||
// Create network data
|
||||
const data: Data = { nodes, edges }
|
||||
|
||||
const options: Options = {
|
||||
physics: {
|
||||
enabled: true,
|
||||
solver: 'forceAtlas2Based',
|
||||
forceAtlas2Based: {
|
||||
gravitationalConstant: -50,
|
||||
centralGravity: 0.01,
|
||||
springLength: 150,
|
||||
springConstant: 0.08
|
||||
},
|
||||
stabilization: { iterations: 100 }
|
||||
},
|
||||
interaction: {
|
||||
hover: true,
|
||||
tooltipDelay: 200,
|
||||
zoomView: true,
|
||||
dragView: true
|
||||
},
|
||||
layout: {
|
||||
improvedLayout: true
|
||||
}
|
||||
}
|
||||
|
||||
// Create network
|
||||
network = new Network(graphContainer.value, data, options)
|
||||
|
||||
// Handle edge click to show details
|
||||
network.on('selectEdge', (params: { edges: (string | number)[] }) => {
|
||||
if (params.edges.length > 0) {
|
||||
const edgeId = params.edges[0] as number
|
||||
selectedRelation.value = edgesList[edgeId]?.relation || null
|
||||
}
|
||||
})
|
||||
|
||||
// Handle node double-click to navigate
|
||||
network.on('doubleClick', (params: { nodes: (string | number)[] }) => {
|
||||
if (params.nodes.length > 0) {
|
||||
const nodeId = params.nodes[0] as number
|
||||
if (nodeId !== props.observationId) {
|
||||
emit('navigateTo', nodeId)
|
||||
}
|
||||
}
|
||||
})
|
||||
|
||||
// Clear selection when clicking background
|
||||
network.on('click', (params: { nodes: (string | number)[]; edges: (string | number)[] }) => {
|
||||
if (params.nodes.length === 0 && params.edges.length === 0) {
|
||||
selectedRelation.value = null
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
const truncateLabel = (label: string, maxLen = 30) => {
|
||||
if (label.length <= maxLen) return label
|
||||
return label.substring(0, maxLen - 3) + '...'
|
||||
}
|
||||
|
||||
const relationCount = computed(() => graphData.value?.relations.length || 0)
|
||||
|
||||
const closeModal = () => {
|
||||
emit('close')
|
||||
}
|
||||
|
||||
// Watch for show prop changes
|
||||
watch(() => props.show, (newVal) => {
|
||||
if (newVal) {
|
||||
loadGraph()
|
||||
} else {
|
||||
selectedRelation.value = null
|
||||
if (network) {
|
||||
network.destroy()
|
||||
network = null
|
||||
}
|
||||
}
|
||||
})
|
||||
|
||||
// Watch for depth changes
|
||||
watch(depth, () => {
|
||||
if (props.show) {
|
||||
loadGraph()
|
||||
}
|
||||
})
|
||||
|
||||
onMounted(() => {
|
||||
if (props.show) {
|
||||
loadGraph()
|
||||
}
|
||||
})
|
||||
|
||||
onUnmounted(() => {
|
||||
if (network) {
|
||||
network.destroy()
|
||||
network = null
|
||||
}
|
||||
})
|
||||
</script>
|
||||
|
||||
<template>
|
||||
<!-- Modal backdrop -->
|
||||
<Teleport to="body">
|
||||
<div
|
||||
v-if="show"
|
||||
class="fixed inset-0 z-50 flex items-center justify-center p-4"
|
||||
@click.self="closeModal"
|
||||
>
|
||||
<!-- Backdrop -->
|
||||
<div class="absolute inset-0 bg-black/70 backdrop-blur-sm" @click="closeModal" />
|
||||
|
||||
<!-- Modal content -->
|
||||
<div class="relative bg-slate-900 border border-slate-700 rounded-xl shadow-2xl w-full max-w-5xl max-h-[90vh] flex flex-col overflow-hidden">
|
||||
<!-- Header -->
|
||||
<div class="flex items-center justify-between p-4 border-b border-slate-700">
|
||||
<div class="flex items-center gap-3">
|
||||
<div class="p-2 rounded-lg bg-amber-500/20">
|
||||
<i class="fas fa-diagram-project text-amber-400" />
|
||||
</div>
|
||||
<div>
|
||||
<h2 class="text-lg font-semibold text-amber-100">Knowledge Graph</h2>
|
||||
<p class="text-sm text-slate-400">
|
||||
{{ relationCount }} relation{{ relationCount !== 1 ? 's' : '' }} for "{{ truncateLabel(observationTitle, 50) }}"
|
||||
</p>
|
||||
</div>
|
||||
</div>
|
||||
|
||||
<!-- Controls -->
|
||||
<div class="flex items-center gap-4">
|
||||
<!-- Depth selector -->
|
||||
<div class="flex items-center gap-2">
|
||||
<label class="text-xs text-slate-400">Depth:</label>
|
||||
<select
|
||||
v-model="depth"
|
||||
class="bg-slate-800 border border-slate-600 rounded px-2 py-1 text-sm text-slate-200 focus:outline-none focus:border-amber-500"
|
||||
>
|
||||
<option :value="1">1</option>
|
||||
<option :value="2">2</option>
|
||||
<option :value="3">3</option>
|
||||
</select>
|
||||
</div>
|
||||
|
||||
<!-- Close button -->
|
||||
<button
|
||||
@click="closeModal"
|
||||
class="p-2 rounded-lg text-slate-400 hover:text-slate-200 hover:bg-slate-800 transition-colors"
|
||||
>
|
||||
<i class="fas fa-times" />
|
||||
</button>
|
||||
</div>
|
||||
</div>
|
||||
|
||||
<!-- Graph container -->
|
||||
<div class="relative" style="height: 60vh; min-height: 500px;">
|
||||
<!-- Loading state -->
|
||||
<div v-if="loading" class="absolute inset-0 flex items-center justify-center bg-slate-900/50">
|
||||
<div class="flex items-center gap-3 text-amber-400">
|
||||
<i class="fas fa-circle-notch fa-spin text-xl" />
|
||||
<span>Loading graph...</span>
|
||||
</div>
|
||||
</div>
|
||||
|
||||
<!-- Error state -->
|
||||
<div v-else-if="error" class="absolute inset-0 flex items-center justify-center">
|
||||
<div class="text-center">
|
||||
<i class="fas fa-exclamation-triangle text-3xl text-red-400 mb-2" />
|
||||
<p class="text-red-300">{{ error }}</p>
|
||||
<button
|
||||
@click="loadGraph"
|
||||
class="mt-3 px-4 py-2 bg-slate-800 hover:bg-slate-700 rounded-lg text-sm text-slate-200 transition-colors"
|
||||
>
|
||||
Retry
|
||||
</button>
|
||||
</div>
|
||||
</div>
|
||||
|
||||
<!-- Empty state -->
|
||||
<div v-else-if="graphData && graphData.relations.length === 0" class="absolute inset-0 flex items-center justify-center">
|
||||
<div class="text-center">
|
||||
<i class="fas fa-diagram-project text-4xl text-slate-600 mb-3" />
|
||||
<p class="text-slate-400">No relations found for this observation</p>
|
||||
<p class="text-sm text-slate-500 mt-1">Relations are detected automatically when observations share files, concepts, or patterns</p>
|
||||
</div>
|
||||
</div>
|
||||
|
||||
<!-- Graph -->
|
||||
<div ref="graphContainer" class="absolute inset-0" />
|
||||
</div>
|
||||
|
||||
<!-- Relation details panel -->
|
||||
<div v-if="selectedRelation" class="border-t border-slate-700 p-4 bg-slate-800/50">
|
||||
<div class="flex items-start gap-4">
|
||||
<!-- Relation type icon -->
|
||||
<div
|
||||
class="p-3 rounded-lg"
|
||||
:class="RELATION_TYPE_CONFIG[selectedRelation.relation.relation_type]?.bgClass"
|
||||
>
|
||||
<i
|
||||
class="fas"
|
||||
:class="[
|
||||
RELATION_TYPE_CONFIG[selectedRelation.relation.relation_type]?.icon,
|
||||
RELATION_TYPE_CONFIG[selectedRelation.relation.relation_type]?.colorClass
|
||||
]"
|
||||
/>
|
||||
</div>
|
||||
|
||||
<!-- Details -->
|
||||
<div class="flex-1 min-w-0">
|
||||
<div class="flex items-center gap-2 mb-1">
|
||||
<span class="font-medium" :class="RELATION_TYPE_CONFIG[selectedRelation.relation.relation_type]?.colorClass">
|
||||
{{ RELATION_TYPE_CONFIG[selectedRelation.relation.relation_type]?.label }}
|
||||
</span>
|
||||
<span class="text-xs text-slate-500">
|
||||
({{ Math.round(selectedRelation.relation.confidence * 100) }}% confidence)
|
||||
</span>
|
||||
</div>
|
||||
|
||||
<div class="flex items-center gap-2 text-sm text-slate-300 mb-2">
|
||||
<span class="font-mono text-amber-400">{{ truncateLabel(selectedRelation.source_title, 40) }}</span>
|
||||
<i class="fas fa-arrow-right text-slate-500 text-xs" />
|
||||
<span class="font-mono text-amber-400">{{ truncateLabel(selectedRelation.target_title, 40) }}</span>
|
||||
</div>
|
||||
|
||||
<div class="flex items-center gap-4 text-xs text-slate-500">
|
||||
<span class="flex items-center gap-1">
|
||||
<i :class="['fas', DETECTION_SOURCE_CONFIG[selectedRelation.relation.detection_source]?.icon]" />
|
||||
{{ DETECTION_SOURCE_CONFIG[selectedRelation.relation.detection_source]?.label }}
|
||||
</span>
|
||||
<span v-if="selectedRelation.relation.reason">
|
||||
{{ selectedRelation.relation.reason }}
|
||||
</span>
|
||||
</div>
|
||||
</div>
|
||||
</div>
|
||||
</div>
|
||||
|
||||
<!-- Legend -->
|
||||
<div class="border-t border-slate-700 p-3 bg-slate-800/30">
|
||||
<div class="flex flex-wrap items-center justify-center gap-4 text-xs text-slate-400">
|
||||
<span class="font-medium text-slate-300">Legend:</span>
|
||||
<span class="flex items-center gap-1">
|
||||
<span class="w-3 h-3 rounded bg-amber-500" /> Center
|
||||
</span>
|
||||
<span class="flex items-center gap-1">
|
||||
<span class="w-3 h-3 rounded bg-red-500" /> Bugfix
|
||||
</span>
|
||||
<span class="flex items-center gap-1">
|
||||
<span class="w-3 h-3 rounded bg-purple-500" /> Feature
|
||||
</span>
|
||||
<span class="flex items-center gap-1">
|
||||
<span class="w-3 h-3 rounded bg-blue-500" /> Refactor
|
||||
</span>
|
||||
<span class="flex items-center gap-1">
|
||||
<span class="w-3 h-3 rounded bg-cyan-500" /> Discovery
|
||||
</span>
|
||||
<span class="flex items-center gap-1">
|
||||
<span class="w-3 h-3 rounded bg-yellow-500" /> Decision
|
||||
</span>
|
||||
<span class="text-slate-500">|</span>
|
||||
<span>Double-click node to navigate</span>
|
||||
</div>
|
||||
</div>
|
||||
</div>
|
||||
</div>
|
||||
</Teleport>
|
||||
</template>
|
||||
@@ -31,6 +31,7 @@ defineProps<{
|
||||
v-if="item.itemType === 'observation'"
|
||||
:observation="item"
|
||||
:highlight="index === 0"
|
||||
:show-feedback="true"
|
||||
/>
|
||||
<PromptCard
|
||||
v-else-if="item.itemType === 'prompt'"
|
||||
|
||||
@@ -2,3 +2,4 @@ export * from './observation'
|
||||
export * from './prompt'
|
||||
export * from './summary'
|
||||
export * from './api'
|
||||
export * from './relation'
|
||||
|
||||
@@ -30,6 +30,12 @@ export interface Observation {
|
||||
created_at: string
|
||||
created_at_epoch: number
|
||||
is_stale?: boolean
|
||||
// Importance scoring fields
|
||||
importance_score: number
|
||||
user_feedback: number // -1 (thumbs down), 0 (neutral), 1 (thumbs up)
|
||||
retrieval_count: number
|
||||
last_retrieved_at_epoch?: number
|
||||
score_updated_at_epoch?: number
|
||||
}
|
||||
|
||||
export const OBSERVATION_TYPES: ObservationType[] = ['bugfix', 'feature', 'refactor', 'discovery', 'decision', 'change']
|
||||
|
||||
@@ -0,0 +1,90 @@
|
||||
export type RelationType = 'causes' | 'fixes' | 'supersedes' | 'depends_on' | 'relates_to' | 'evolves_from'
|
||||
export type DetectionSource = 'file_overlap' | 'embedding_similarity' | 'temporal_proximity' | 'narrative_mention' | 'concept_overlap' | 'type_progression'
|
||||
|
||||
export interface ObservationRelation {
|
||||
id: number
|
||||
source_id: number
|
||||
target_id: number
|
||||
relation_type: RelationType
|
||||
confidence: number
|
||||
detection_source: DetectionSource
|
||||
reason: string
|
||||
created_at: string
|
||||
created_at_epoch: number
|
||||
}
|
||||
|
||||
export interface RelationWithDetails {
|
||||
relation: ObservationRelation
|
||||
source_title: string
|
||||
target_title: string
|
||||
source_type: string
|
||||
target_type: string
|
||||
}
|
||||
|
||||
export interface RelationGraph {
|
||||
center_id: number
|
||||
relations: RelationWithDetails[]
|
||||
}
|
||||
|
||||
export interface RelationStats {
|
||||
total_count: number
|
||||
high_confidence: number
|
||||
by_type: Record<RelationType, number>
|
||||
min_confidence_used: number
|
||||
}
|
||||
|
||||
// Configuration for relation type display
|
||||
export const RELATION_TYPE_CONFIG: Record<RelationType, { icon: string; label: string; colorClass: string; bgClass: string; description: string }> = {
|
||||
causes: {
|
||||
icon: 'fa-arrow-right',
|
||||
label: 'Causes',
|
||||
colorClass: 'text-orange-300',
|
||||
bgClass: 'bg-orange-500/20',
|
||||
description: 'This observation caused the related issue'
|
||||
},
|
||||
fixes: {
|
||||
icon: 'fa-wrench',
|
||||
label: 'Fixes',
|
||||
colorClass: 'text-green-300',
|
||||
bgClass: 'bg-green-500/20',
|
||||
description: 'This observation fixes the related issue'
|
||||
},
|
||||
supersedes: {
|
||||
icon: 'fa-layer-group',
|
||||
label: 'Supersedes',
|
||||
colorClass: 'text-purple-300',
|
||||
bgClass: 'bg-purple-500/20',
|
||||
description: 'This observation replaces the older one'
|
||||
},
|
||||
depends_on: {
|
||||
icon: 'fa-link',
|
||||
label: 'Depends On',
|
||||
colorClass: 'text-blue-300',
|
||||
bgClass: 'bg-blue-500/20',
|
||||
description: 'This observation depends on the related one'
|
||||
},
|
||||
relates_to: {
|
||||
icon: 'fa-arrows-left-right',
|
||||
label: 'Related',
|
||||
colorClass: 'text-slate-300',
|
||||
bgClass: 'bg-slate-500/20',
|
||||
description: 'These observations are related'
|
||||
},
|
||||
evolves_from: {
|
||||
icon: 'fa-code-branch',
|
||||
label: 'Evolves From',
|
||||
colorClass: 'text-cyan-300',
|
||||
bgClass: 'bg-cyan-500/20',
|
||||
description: 'This observation evolved from the related one'
|
||||
}
|
||||
}
|
||||
|
||||
// Configuration for detection source display
|
||||
export const DETECTION_SOURCE_CONFIG: Record<DetectionSource, { icon: string; label: string }> = {
|
||||
file_overlap: { icon: 'fa-file-code', label: 'Shared files' },
|
||||
embedding_similarity: { icon: 'fa-brain', label: 'Semantic similarity' },
|
||||
temporal_proximity: { icon: 'fa-clock', label: 'Close in time' },
|
||||
narrative_mention: { icon: 'fa-quote-left', label: 'Mentioned in text' },
|
||||
concept_overlap: { icon: 'fa-tags', label: 'Shared concepts' },
|
||||
type_progression: { icon: 'fa-diagram-next', label: 'Natural progression' }
|
||||
}
|
||||
+18
-1
@@ -1,4 +1,4 @@
|
||||
import type { Observation, UserPrompt, SessionSummary, Stats, FeedItem, ObservationFeedItem, PromptFeedItem, SummaryFeedItem } from '@/types'
|
||||
import type { Observation, UserPrompt, SessionSummary, Stats, FeedItem, ObservationFeedItem, PromptFeedItem, SummaryFeedItem, RelationWithDetails, RelationGraph, RelationStats } from '@/types'
|
||||
|
||||
const API_BASE = '/api'
|
||||
const DEFAULT_TIMEOUT = 10000 // 10 seconds
|
||||
@@ -147,3 +147,20 @@ export function combineTimeline(
|
||||
return [...obsItems, ...promptItems, ...summaryItems]
|
||||
.sort((a, b) => b.timestamp.getTime() - a.timestamp.getTime())
|
||||
}
|
||||
|
||||
// Relation API functions
|
||||
export async function fetchObservationRelations(observationId: number, signal?: AbortSignal): Promise<RelationWithDetails[]> {
|
||||
return fetchWithRetry<RelationWithDetails[]>(`${API_BASE}/observations/${observationId}/relations`, { signal })
|
||||
}
|
||||
|
||||
export async function fetchObservationGraph(observationId: number, depth: number = 2, signal?: AbortSignal): Promise<RelationGraph> {
|
||||
return fetchWithRetry<RelationGraph>(`${API_BASE}/observations/${observationId}/graph?depth=${depth}`, { signal })
|
||||
}
|
||||
|
||||
export async function fetchRelatedObservations(observationId: number, minConfidence: number = 0.4, signal?: AbortSignal): Promise<Observation[]> {
|
||||
return fetchWithRetry<Observation[]>(`${API_BASE}/observations/${observationId}/related?min_confidence=${minConfidence}`, { signal })
|
||||
}
|
||||
|
||||
export async function fetchRelationStats(signal?: AbortSignal): Promise<RelationStats> {
|
||||
return fetchWithRetry<RelationStats>(`${API_BASE}/relations/stats`, { signal })
|
||||
}
|
||||
|
||||
@@ -1 +1 @@
|
||||
{"root":["./src/main.ts","./src/vite-env.d.ts","./src/components/index.ts","./src/composables/index.ts","./src/composables/usehealth.ts","./src/composables/usesse.ts","./src/composables/usestats.ts","./src/composables/usetimeline.ts","./src/composables/usetypes.ts","./src/composables/useupdate.ts","./src/types/api.ts","./src/types/index.ts","./src/types/observation.ts","./src/types/prompt.ts","./src/types/summary.ts","./src/utils/api.ts","./src/utils/formatters.ts","./src/app.vue","./src/components/badge.vue","./src/components/card.vue","./src/components/filtertabs.vue","./src/components/header.vue","./src/components/iconbox.vue","./src/components/observationcard.vue","./src/components/projectfilter.vue","./src/components/promptcard.vue","./src/components/sidebar.vue","./src/components/statscards.vue","./src/components/summarycard.vue","./src/components/timeline.vue"],"version":"5.7.3"}
|
||||
{"root":["./src/main.ts","./src/vite-env.d.ts","./src/components/index.ts","./src/composables/index.ts","./src/composables/usehealth.ts","./src/composables/usesse.ts","./src/composables/usestats.ts","./src/composables/usetimeline.ts","./src/composables/usetypes.ts","./src/composables/useupdate.ts","./src/types/api.ts","./src/types/index.ts","./src/types/observation.ts","./src/types/prompt.ts","./src/types/relation.ts","./src/types/summary.ts","./src/utils/api.ts","./src/utils/formatters.ts","./src/app.vue","./src/components/badge.vue","./src/components/card.vue","./src/components/filtertabs.vue","./src/components/header.vue","./src/components/iconbox.vue","./src/components/observationcard.vue","./src/components/projectfilter.vue","./src/components/promptcard.vue","./src/components/relationgraph.vue","./src/components/sidebar.vue","./src/components/statscards.vue","./src/components/summarycard.vue","./src/components/timeline.vue"],"version":"5.7.3"}
|
||||
Reference in New Issue
Block a user