mirror of
https://github.com/lukaszraczylo/claude-mnemonic.git
synced 2026-06-05 23:03:55 +00:00
Release dec 2025 (#15)
* Resolves issue #13 - Switched model to bge-small-en-v1.5 - Added lazy re-embedding - Added model version tracking per vector - Added conversion of vectors to the new model * Add lfs support to the workflow. * Implements importance scoring with decay + voting #6 * Resolves issue #5 by marking observations as superseeded and scheduled for deletion * Implement pattern detection #7 * Improve injections and observations accuracy - Session start: Recent observations for project context (recency-based) - User prompt: Semantically relevant observations (similarity-based with threshold) * Added two stage retrieval with bi and cross encoder #8 * Implement query expansion and reformulation #9 * Knowledge graph and relationships ( resolves #4 ) - File Overlap Detection: Detects relationships when observations modify/read the same files - Concept Overlap Detection: Detects relationships based on shared semantic concepts - Type Progression Detection: Infers relationships from natural observation type progressions (e.g., discovery → bugfix = "fixes") - Temporal Proximity Detection: Detects relationships between observations in the same session within 5 minutes - Narrative Mention Detection: Detects explicit relationship language in narratives (e.g., "fixes", "depends on", "supersedes") * Add visualisation of the relations to the dashboard. * fixup! Add visualisation of the relations to the dashboard. * Update documentation with new settings and screenshots.
This commit is contained in:
@@ -20,4 +20,5 @@ jobs:
|
|||||||
uses: lukaszraczylo/shared-actions/.github/workflows/go-pr.yaml@main
|
uses: lukaszraczylo/shared-actions/.github/workflows/go-pr.yaml@main
|
||||||
with:
|
with:
|
||||||
go-version: ">=1.24"
|
go-version: ">=1.24"
|
||||||
|
lfs: true
|
||||||
secrets: inherit
|
secrets: inherit
|
||||||
|
|||||||
@@ -12,7 +12,21 @@ Claude Code forgets everything when your session ends. Claude Mnemonic fixes tha
|
|||||||
|
|
||||||
It captures what Claude learns during your coding sessions - bug fixes, architecture decisions, patterns that work - and brings that knowledge back in future conversations. No more re-explaining your codebase.
|
It captures what Claude learns during your coding sessions - bug fixes, architecture decisions, patterns that work - and brings that knowledge back in future conversations. No more re-explaining your codebase.
|
||||||
|
|
||||||
## What's New in v0.6
|

|
||||||
|
|
||||||
|
## What's New in v0.7
|
||||||
|
|
||||||
|
- **Two-Stage Retrieval** - Cross-encoder reranking for dramatically improved search relevance
|
||||||
|
- **Knowledge Graph** - Automatic relationship detection between observations with visual graph in dashboard
|
||||||
|

|
||||||
|
- **Pattern Detection** - Identifies recurring patterns across sessions and projects
|
||||||
|
- **Importance Scoring** - Time decay and voting system to surface the most valuable memories
|
||||||
|
- **Query Expansion** - Reformulates searches to find semantically related content
|
||||||
|
- **Conflict Detection** - Identifies and resolves contradictory observations
|
||||||
|
- **Observation Lifecycle** - Memories can be superseded when better information arrives
|
||||||
|
|
||||||
|
<details>
|
||||||
|
<summary>Previous: v0.6</summary>
|
||||||
|
|
||||||
- **Auto-Updates** - Automatically stays up-to-date with the latest version
|
- **Auto-Updates** - Automatically stays up-to-date with the latest version
|
||||||
- **Slash Command: `/restart`** - Restart the worker directly from Claude Code
|
- **Slash Command: `/restart`** - Restart the worker directly from Claude Code
|
||||||
@@ -20,6 +34,7 @@ It captures what Claude learns during your coding sessions - bug fixes, architec
|
|||||||
- **Async Queue Processing** - Non-blocking observation capture for faster sessions
|
- **Async Queue Processing** - Non-blocking observation capture for faster sessions
|
||||||
- **Smarter Storage** - Filters out system/agent summaries to keep knowledge relevant
|
- **Smarter Storage** - Filters out system/agent summaries to keep knowledge relevant
|
||||||
- **Improved Reliability** - Better handling of connectivity issues and dead connections
|
- **Improved Reliability** - Better handling of connectivity issues and dead connections
|
||||||
|
</details>
|
||||||
|
|
||||||
## Requirements
|
## Requirements
|
||||||
|
|
||||||
@@ -110,16 +125,39 @@ Config file: `~/.claude-mnemonic/settings.json`
|
|||||||
{
|
{
|
||||||
"CLAUDE_MNEMONIC_WORKER_PORT": 37777,
|
"CLAUDE_MNEMONIC_WORKER_PORT": 37777,
|
||||||
"CLAUDE_MNEMONIC_CONTEXT_OBSERVATIONS": 100,
|
"CLAUDE_MNEMONIC_CONTEXT_OBSERVATIONS": 100,
|
||||||
"CLAUDE_MNEMONIC_CONTEXT_FULL_COUNT": 25
|
"CLAUDE_MNEMONIC_CONTEXT_FULL_COUNT": 25,
|
||||||
|
"CLAUDE_MNEMONIC_RERANKING_ENABLED": true
|
||||||
}
|
}
|
||||||
```
|
```
|
||||||
|
|
||||||
|
### Core Settings
|
||||||
|
|
||||||
| Variable | Default | What it does |
|
| Variable | Default | What it does |
|
||||||
|----------|---------|--------------|
|
|----------|---------|--------------|
|
||||||
| `WORKER_PORT` | `37777` | Dashboard & API port |
|
| `WORKER_PORT` | `37777` | Dashboard & API port |
|
||||||
| `CONTEXT_OBSERVATIONS` | `100` | Max memories per session |
|
| `CONTEXT_OBSERVATIONS` | `100` | Max memories per session |
|
||||||
| `CONTEXT_FULL_COUNT` | `25` | Full detail memories (rest are condensed to title only) |
|
| `CONTEXT_FULL_COUNT` | `25` | Full detail memories (rest are condensed) |
|
||||||
| `CONTEXT_SESSION_COUNT` | `10` | Recent sessions to reference |
|
| `CONTEXT_SESSION_COUNT` | `10` | Recent sessions to reference |
|
||||||
|
| `CONTEXT_RELEVANCE_THRESHOLD` | `0.3` | Minimum similarity score (0.0-1.0) for inclusion |
|
||||||
|
| `CONTEXT_MAX_PROMPT_RESULTS` | `10` | Max results per prompt search |
|
||||||
|
|
||||||
|
### Reranking Settings (Two-Stage Retrieval)
|
||||||
|
|
||||||
|
| Variable | Default | What it does |
|
||||||
|
|----------|---------|--------------|
|
||||||
|
| `RERANKING_ENABLED` | `true` | Enable cross-encoder reranking |
|
||||||
|
| `RERANKING_CANDIDATES` | `100` | Candidates to retrieve before reranking |
|
||||||
|
| `RERANKING_RESULTS` | `10` | Final results after reranking |
|
||||||
|
| `RERANKING_ALPHA` | `0.7` | Score blend: alpha×rerank + (1-alpha)×original |
|
||||||
|
| `RERANKING_PURE_MODE` | `false` | Use pure cross-encoder scores only |
|
||||||
|
|
||||||
|
### Embedding Settings
|
||||||
|
|
||||||
|
| Variable | Default | What it does |
|
||||||
|
|----------|---------|--------------|
|
||||||
|
| `EMBEDDING_MODEL` | `bge-v1.5` | Embedding model for semantic search |
|
||||||
|
|
||||||
|
All variables are prefixed with `CLAUDE_MNEMONIC_` in the config file.
|
||||||
|
|
||||||
## Project vs Global scope
|
## Project vs Global scope
|
||||||
|
|
||||||
@@ -202,10 +240,11 @@ curl -sSL https://raw.githubusercontent.com/lukaszraczylo/claude-mnemonic/main/s
|
|||||||
|
|
||||||
- **SQLite + FTS5** - Full-text search for exact matches
|
- **SQLite + FTS5** - Full-text search for exact matches
|
||||||
- **sqlite-vec** - Vector database embedded in SQLite
|
- **sqlite-vec** - Vector database embedded in SQLite
|
||||||
- **all-MiniLM-L6-v2** - Local embedding model (384 dimensions) via ONNX Runtime
|
- **Two-Stage Retrieval** - Bi-encoder (embedding) + cross-encoder (reranking) for high accuracy
|
||||||
|
- **Local Models** - all-MiniLM-L6-v2 for embeddings, BGE reranker for relevance scoring
|
||||||
- **Go** - Single binary, no external dependencies
|
- **Go** - Single binary, no external dependencies
|
||||||
|
|
||||||
Everything runs locally. No Python. No external vector database. No API calls for embeddings.
|
Everything runs locally. No Python. No external vector database. No API calls.
|
||||||
|
|
||||||
## Platform support
|
## Platform support
|
||||||
|
|
||||||
|
|||||||
@@ -61,12 +61,8 @@ func main() {
|
|||||||
|
|
||||||
searchResult, _ := hooks.GET(port, searchURL)
|
searchResult, _ := hooks.GET(port, searchURL)
|
||||||
if observations, ok := searchResult["observations"].([]interface{}); ok && len(observations) > 0 {
|
if observations, ok := searchResult["observations"].([]interface{}); ok && len(observations) > 0 {
|
||||||
// Limit to top 5 most relevant observations
|
// Results are already filtered by relevance threshold and capped by max_results
|
||||||
maxObs := 5
|
// from the server-side config (ContextRelevanceThreshold, ContextMaxPromptResults)
|
||||||
if len(observations) < maxObs {
|
|
||||||
maxObs = len(observations)
|
|
||||||
}
|
|
||||||
observations = observations[:maxObs]
|
|
||||||
observationCount = len(observations)
|
observationCount = len(observations)
|
||||||
|
|
||||||
// Build context from search results
|
// Build context from search results
|
||||||
|
|||||||
Binary file not shown.
|
After Width: | Height: | Size: 252 KiB |
Binary file not shown.
|
After Width: | Height: | Size: 123 KiB |
+37
-7
@@ -29,6 +29,21 @@
|
|||||||
subtitle="Stop re-explaining your codebase. Claude Mnemonic captures bug fixes, architecture decisions, and coding patterns - then brings them back exactly when you need them."
|
subtitle="Stop re-explaining your codebase. Claude Mnemonic captures bug fixes, architecture decisions, and coding patterns - then brings them back exactly when you need them."
|
||||||
/>
|
/>
|
||||||
|
|
||||||
|
<!-- Dashboard Preview -->
|
||||||
|
<section class="py-12 lg:py-16 px-4 sm:px-6 relative">
|
||||||
|
<div class="max-w-6xl mx-auto">
|
||||||
|
<div class="relative rounded-xl overflow-hidden border border-slate-700/50 shadow-2xl shadow-amber-500/5">
|
||||||
|
<div class="absolute inset-0 bg-gradient-to-t from-slate-950 via-transparent to-transparent pointer-events-none z-10"></div>
|
||||||
|
<img
|
||||||
|
src="/claude-mnemonic.jpg"
|
||||||
|
alt="Claude Mnemonic Dashboard"
|
||||||
|
class="w-full h-auto"
|
||||||
|
/>
|
||||||
|
</div>
|
||||||
|
<p class="text-center text-slate-500 text-sm mt-4">The dashboard at localhost:37777 - browse, search, and manage your memories</p>
|
||||||
|
</div>
|
||||||
|
</section>
|
||||||
|
|
||||||
<!-- Problem Section -->
|
<!-- Problem Section -->
|
||||||
<section class="py-20 lg:py-28 px-4 sm:px-6 relative">
|
<section class="py-20 lg:py-28 px-4 sm:px-6 relative">
|
||||||
<div class="max-w-6xl mx-auto grid lg:grid-cols-2 gap-8 lg:gap-16 items-center">
|
<div class="max-w-6xl mx-auto grid lg:grid-cols-2 gap-8 lg:gap-16 items-center">
|
||||||
@@ -85,6 +100,21 @@
|
|||||||
</div>
|
</div>
|
||||||
</section>
|
</section>
|
||||||
|
|
||||||
|
<!-- Knowledge Graph Preview -->
|
||||||
|
<section class="py-16 lg:py-20 px-4 sm:px-6 relative">
|
||||||
|
<div class="max-w-5xl mx-auto">
|
||||||
|
<SectionHeader title="See how knowledge connects" subtitle="The knowledge graph reveals relationships between your memories automatically" />
|
||||||
|
<div class="relative rounded-xl overflow-hidden border border-slate-700/50 shadow-2xl shadow-purple-500/5">
|
||||||
|
<img
|
||||||
|
src="/observation-relation-graph.jpg"
|
||||||
|
alt="Knowledge Graph Visualization"
|
||||||
|
class="w-full h-auto"
|
||||||
|
/>
|
||||||
|
</div>
|
||||||
|
<p class="text-center text-slate-500 text-sm mt-4">Click any observation to explore its relationships - see what causes what, what fixes what, and how concepts evolve</p>
|
||||||
|
</div>
|
||||||
|
</section>
|
||||||
|
|
||||||
<!-- Before/After Section -->
|
<!-- Before/After Section -->
|
||||||
<section class="py-20 lg:py-28 px-4 sm:px-6">
|
<section class="py-20 lg:py-28 px-4 sm:px-6">
|
||||||
<div class="max-w-6xl mx-auto">
|
<div class="max-w-6xl mx-auto">
|
||||||
@@ -288,8 +318,8 @@
|
|||||||
<p class="text-slate-400 text-xs sm:text-sm">Embedded vector database. No external services required.</p>
|
<p class="text-slate-400 text-xs sm:text-sm">Embedded vector database. No external services required.</p>
|
||||||
</div>
|
</div>
|
||||||
<div class="glass rounded-2xl p-6 sm:p-8 hover:border-amber-500/30 transition-colors">
|
<div class="glass rounded-2xl p-6 sm:p-8 hover:border-amber-500/30 transition-colors">
|
||||||
<div class="text-3xl sm:text-4xl font-bold text-amber-500 mb-2">MiniLM</div>
|
<div class="text-3xl sm:text-4xl font-bold text-amber-500 mb-2">BGE</div>
|
||||||
<p class="text-slate-400 text-xs sm:text-sm">Local embeddings via ONNX. "Fix auth" finds "JWT issue" automatically.</p>
|
<p class="text-slate-400 text-xs sm:text-sm">Two-stage retrieval: bi-encoder embeddings + cross-encoder reranking for high accuracy.</p>
|
||||||
</div>
|
</div>
|
||||||
</div>
|
</div>
|
||||||
</div>
|
</div>
|
||||||
@@ -386,10 +416,10 @@ const activeTab = ref('macos')
|
|||||||
|
|
||||||
const features = [
|
const features = [
|
||||||
{ icon: 'fas fa-brain', title: 'Learns as you work', description: 'Every bug fix, every architecture decision, every "aha moment" - captured automatically without breaking your flow.' },
|
{ icon: 'fas fa-brain', title: 'Learns as you work', description: 'Every bug fix, every architecture decision, every "aha moment" - captured automatically without breaking your flow.' },
|
||||||
|
{ icon: 'fas fa-search', title: 'Two-stage retrieval', description: 'Cross-encoder reranking delivers highly relevant results. Finds what you need even with vague queries like "that auth thing".' },
|
||||||
|
{ icon: 'fas fa-project-diagram', title: 'Knowledge graph', description: 'Automatically discovers relationships between memories. See how concepts connect in the visual graph dashboard.' },
|
||||||
{ icon: 'fas fa-folder-tree', title: 'Project-aware context', description: 'Your React knowledge stays in React projects. Your Go patterns stay in Go projects. No context pollution.' },
|
{ icon: 'fas fa-folder-tree', title: 'Project-aware context', description: 'Your React knowledge stays in React projects. Your Go patterns stay in Go projects. No context pollution.' },
|
||||||
{ icon: 'fas fa-globe', title: 'Shared best practices', description: 'Security patterns, performance tips, and universal learnings automatically available across all your projects.' },
|
{ icon: 'fas fa-chart-line', title: 'Smart scoring', description: 'Importance decay, pattern detection, and conflict resolution ensure the most valuable memories surface first.' },
|
||||||
{ icon: 'fas fa-search', title: 'Finds what matters', description: 'Semantic search finds relevant memories even when you don\'t remember the exact words. "That auth thing" just works.' },
|
|
||||||
{ icon: 'fas fa-chart-line', title: 'Live statusline', description: 'Real-time metrics right in Claude Code: memories served, searches performed, and project memory count at a glance.' },
|
|
||||||
{ icon: 'fas fa-lock', title: '100% private', description: 'Your code context never leaves your machine. No telemetry. No cloud sync. Your memories are yours.' },
|
{ icon: 'fas fa-lock', title: '100% private', description: 'Your code context never leaves your machine. No telemetry. No cloud sync. Your memories are yours.' },
|
||||||
]
|
]
|
||||||
|
|
||||||
@@ -415,8 +445,8 @@ const installCommands = {
|
|||||||
const configOptions = [
|
const configOptions = [
|
||||||
{ name: 'CLAUDE_MNEMONIC_WORKER_PORT', description: 'HTTP port for the worker service (default: 37777)', icon: 'fas fa-network-wired' },
|
{ name: 'CLAUDE_MNEMONIC_WORKER_PORT', description: 'HTTP port for the worker service (default: 37777)', icon: 'fas fa-network-wired' },
|
||||||
{ name: 'CLAUDE_MNEMONIC_CONTEXT_OBSERVATIONS', description: 'Maximum observations injected per session (default: 100)', icon: 'fas fa-layer-group' },
|
{ name: 'CLAUDE_MNEMONIC_CONTEXT_OBSERVATIONS', description: 'Maximum observations injected per session (default: 100)', icon: 'fas fa-layer-group' },
|
||||||
{ name: 'CLAUDE_MNEMONIC_CONTEXT_FULL_COUNT', description: 'Observations with full narrative detail, rest are condensed (default: 25)', icon: 'fas fa-expand' },
|
{ name: 'CLAUDE_MNEMONIC_RERANKING_ENABLED', description: 'Enable cross-encoder reranking for improved search relevance (default: true)', icon: 'fas fa-sort-amount-down' },
|
||||||
{ name: 'CLAUDE_MNEMONIC_MODEL', description: 'Model for processing observations (default: haiku)', icon: 'fas fa-microchip' },
|
{ name: 'CLAUDE_MNEMONIC_CONTEXT_RELEVANCE_THRESHOLD', description: 'Minimum similarity score for inclusion, 0.0-1.0 (default: 0.3)', icon: 'fas fa-filter' },
|
||||||
]
|
]
|
||||||
|
|
||||||
const requiredDeps = [
|
const requiredDeps = [
|
||||||
|
|||||||
+74
-22
@@ -48,16 +48,29 @@ type Config struct {
|
|||||||
Model string `json:"model"`
|
Model string `json:"model"`
|
||||||
ClaudeCodePath string `json:"claude_code_path"`
|
ClaudeCodePath string `json:"claude_code_path"`
|
||||||
|
|
||||||
|
// Embedding settings
|
||||||
|
EmbeddingModel string `json:"embedding_model"` // e.g., "bge-v1.5"
|
||||||
|
|
||||||
|
// Reranking settings (cross-encoder)
|
||||||
|
RerankingEnabled bool `json:"reranking_enabled"` // Enable cross-encoder reranking
|
||||||
|
RerankingCandidates int `json:"reranking_candidates"` // Number of candidates to retrieve before reranking (default 100)
|
||||||
|
RerankingResults int `json:"reranking_results"` // Number of results to return after reranking (default 10)
|
||||||
|
RerankingAlpha float64 `json:"reranking_alpha"` // Weight for combining scores: alpha*rerank + (1-alpha)*original (default 0.7)
|
||||||
|
RerankingMinImprovement float64 `json:"reranking_min_improvement"` // Minimum rank improvement to trigger reranking (default 0, always rerank)
|
||||||
|
RerankingPureMode bool `json:"reranking_pure_mode"` // Use pure cross-encoder scores without combining with bi-encoder (default false)
|
||||||
|
|
||||||
// Context injection settings
|
// Context injection settings
|
||||||
ContextObservations int `json:"context_observations"`
|
ContextObservations int `json:"context_observations"`
|
||||||
ContextFullCount int `json:"context_full_count"`
|
ContextFullCount int `json:"context_full_count"`
|
||||||
ContextSessionCount int `json:"context_session_count"`
|
ContextSessionCount int `json:"context_session_count"`
|
||||||
ContextShowReadTokens bool `json:"context_show_read_tokens"`
|
ContextShowReadTokens bool `json:"context_show_read_tokens"`
|
||||||
ContextShowWorkTokens bool `json:"context_show_work_tokens"`
|
ContextShowWorkTokens bool `json:"context_show_work_tokens"`
|
||||||
ContextFullField string `json:"context_full_field"`
|
ContextFullField string `json:"context_full_field"`
|
||||||
ContextShowLastSummary bool `json:"context_show_last_summary"`
|
ContextShowLastSummary bool `json:"context_show_last_summary"`
|
||||||
ContextObsTypes []string `json:"context_obs_types"`
|
ContextObsTypes []string `json:"context_obs_types"`
|
||||||
ContextObsConcepts []string `json:"context_obs_concepts"`
|
ContextObsConcepts []string `json:"context_obs_concepts"`
|
||||||
|
ContextRelevanceThreshold float64 `json:"context_relevance_threshold"` // 0.0-1.0, minimum similarity for inclusion
|
||||||
|
ContextMaxPromptResults int `json:"context_max_prompt_results"` // Max results per prompt (0 = threshold only)
|
||||||
}
|
}
|
||||||
|
|
||||||
var (
|
var (
|
||||||
@@ -119,22 +132,33 @@ func EnsureAll() error {
|
|||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// DefaultEmbeddingModel is the default embedding model to use.
|
||||||
|
const DefaultEmbeddingModel = "bge-v1.5"
|
||||||
|
|
||||||
// Default returns a Config with default values.
|
// Default returns a Config with default values.
|
||||||
func Default() *Config {
|
func Default() *Config {
|
||||||
return &Config{
|
return &Config{
|
||||||
WorkerPort: DefaultWorkerPort,
|
WorkerPort: DefaultWorkerPort,
|
||||||
DBPath: DBPath(),
|
DBPath: DBPath(),
|
||||||
MaxConns: 4,
|
MaxConns: 4,
|
||||||
Model: DefaultModel,
|
Model: DefaultModel,
|
||||||
ContextObservations: 100,
|
EmbeddingModel: DefaultEmbeddingModel,
|
||||||
ContextFullCount: 25,
|
RerankingEnabled: true, // Enable by default for improved relevance
|
||||||
ContextSessionCount: 10,
|
RerankingCandidates: 100, // Retrieve top 100 candidates
|
||||||
ContextShowReadTokens: true,
|
RerankingResults: 10, // Return top 10 after reranking
|
||||||
ContextShowWorkTokens: true,
|
RerankingAlpha: 0.7, // Favor cross-encoder score
|
||||||
ContextFullField: "narrative",
|
RerankingMinImprovement: 0, // Always apply reranking
|
||||||
ContextShowLastSummary: true,
|
ContextObservations: 100,
|
||||||
ContextObsTypes: DefaultObservationTypes,
|
ContextFullCount: 25,
|
||||||
ContextObsConcepts: DefaultObservationConcepts,
|
ContextSessionCount: 10,
|
||||||
|
ContextShowReadTokens: true,
|
||||||
|
ContextShowWorkTokens: true,
|
||||||
|
ContextFullField: "narrative",
|
||||||
|
ContextShowLastSummary: true,
|
||||||
|
ContextObsTypes: DefaultObservationTypes,
|
||||||
|
ContextObsConcepts: DefaultObservationConcepts,
|
||||||
|
ContextRelevanceThreshold: 0.3, // Minimum 30% similarity to include
|
||||||
|
ContextMaxPromptResults: 10, // Cap at 10 results max (0 = no cap, threshold only)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -166,6 +190,28 @@ func Load() (*Config, error) {
|
|||||||
if v, ok := settings["CLAUDE_CODE_PATH"].(string); ok {
|
if v, ok := settings["CLAUDE_CODE_PATH"].(string); ok {
|
||||||
cfg.ClaudeCodePath = v
|
cfg.ClaudeCodePath = v
|
||||||
}
|
}
|
||||||
|
if v, ok := settings["CLAUDE_MNEMONIC_EMBEDDING_MODEL"].(string); ok && v != "" {
|
||||||
|
cfg.EmbeddingModel = v
|
||||||
|
}
|
||||||
|
// Reranking settings
|
||||||
|
if v, ok := settings["CLAUDE_MNEMONIC_RERANKING_ENABLED"].(bool); ok {
|
||||||
|
cfg.RerankingEnabled = v
|
||||||
|
}
|
||||||
|
if v, ok := settings["CLAUDE_MNEMONIC_RERANKING_CANDIDATES"].(float64); ok && v > 0 {
|
||||||
|
cfg.RerankingCandidates = int(v)
|
||||||
|
}
|
||||||
|
if v, ok := settings["CLAUDE_MNEMONIC_RERANKING_RESULTS"].(float64); ok && v > 0 {
|
||||||
|
cfg.RerankingResults = int(v)
|
||||||
|
}
|
||||||
|
if v, ok := settings["CLAUDE_MNEMONIC_RERANKING_ALPHA"].(float64); ok && v >= 0 && v <= 1 {
|
||||||
|
cfg.RerankingAlpha = v
|
||||||
|
}
|
||||||
|
if v, ok := settings["CLAUDE_MNEMONIC_RERANKING_MIN_IMPROVEMENT"].(float64); ok && v >= 0 {
|
||||||
|
cfg.RerankingMinImprovement = v
|
||||||
|
}
|
||||||
|
if v, ok := settings["CLAUDE_MNEMONIC_RERANKING_PURE_MODE"].(bool); ok {
|
||||||
|
cfg.RerankingPureMode = v
|
||||||
|
}
|
||||||
if v, ok := settings["CLAUDE_MNEMONIC_CONTEXT_OBSERVATIONS"].(float64); ok {
|
if v, ok := settings["CLAUDE_MNEMONIC_CONTEXT_OBSERVATIONS"].(float64); ok {
|
||||||
cfg.ContextObservations = int(v)
|
cfg.ContextObservations = int(v)
|
||||||
}
|
}
|
||||||
@@ -181,6 +227,12 @@ func Load() (*Config, error) {
|
|||||||
if v, ok := settings["CLAUDE_MNEMONIC_CONTEXT_OBS_CONCEPTS"].(string); ok && v != "" {
|
if v, ok := settings["CLAUDE_MNEMONIC_CONTEXT_OBS_CONCEPTS"].(string); ok && v != "" {
|
||||||
cfg.ContextObsConcepts = splitTrim(v)
|
cfg.ContextObsConcepts = splitTrim(v)
|
||||||
}
|
}
|
||||||
|
if v, ok := settings["CLAUDE_MNEMONIC_CONTEXT_RELEVANCE_THRESHOLD"].(float64); ok && v >= 0 && v <= 1 {
|
||||||
|
cfg.ContextRelevanceThreshold = v
|
||||||
|
}
|
||||||
|
if v, ok := settings["CLAUDE_MNEMONIC_CONTEXT_MAX_PROMPT_RESULTS"].(float64); ok && v >= 0 {
|
||||||
|
cfg.ContextMaxPromptResults = int(v)
|
||||||
|
}
|
||||||
|
|
||||||
return cfg, nil
|
return cfg, nil
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -0,0 +1,276 @@
|
|||||||
|
// Package sqlite provides SQLite database operations for claude-mnemonic.
|
||||||
|
package sqlite
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"github.com/lukaszraczylo/claude-mnemonic/pkg/models"
|
||||||
|
)
|
||||||
|
|
||||||
|
// SupersededRetentionDays is the number of days to keep superseded observations before deletion.
|
||||||
|
const SupersededRetentionDays = 3
|
||||||
|
|
||||||
|
// ConflictStore provides conflict-related database operations.
|
||||||
|
type ConflictStore struct {
|
||||||
|
store *Store
|
||||||
|
}
|
||||||
|
|
||||||
|
// NewConflictStore creates a new conflict store.
|
||||||
|
func NewConflictStore(store *Store) *ConflictStore {
|
||||||
|
return &ConflictStore{store: store}
|
||||||
|
}
|
||||||
|
|
||||||
|
// StoreConflict stores a new observation conflict.
|
||||||
|
func (s *ConflictStore) StoreConflict(ctx context.Context, conflict *models.ObservationConflict) (int64, error) {
|
||||||
|
const query = `
|
||||||
|
INSERT INTO observation_conflicts
|
||||||
|
(newer_obs_id, older_obs_id, conflict_type, resolution, reason, detected_at, detected_at_epoch, resolved, resolved_at)
|
||||||
|
VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?)
|
||||||
|
`
|
||||||
|
|
||||||
|
result, err := s.store.ExecContext(ctx, query,
|
||||||
|
conflict.NewerObsID, conflict.OlderObsID,
|
||||||
|
string(conflict.ConflictType), string(conflict.Resolution),
|
||||||
|
conflict.Reason, conflict.DetectedAt, conflict.DetectedAtEpoch,
|
||||||
|
conflict.Resolved, conflict.ResolvedAt,
|
||||||
|
)
|
||||||
|
if err != nil {
|
||||||
|
return 0, err
|
||||||
|
}
|
||||||
|
|
||||||
|
return result.LastInsertId()
|
||||||
|
}
|
||||||
|
|
||||||
|
// MarkObservationSuperseded marks an observation as superseded.
|
||||||
|
func (s *ConflictStore) MarkObservationSuperseded(ctx context.Context, obsID int64) error {
|
||||||
|
const query = `UPDATE observations SET is_superseded = 1 WHERE id = ?`
|
||||||
|
_, err := s.store.ExecContext(ctx, query, obsID)
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
// MarkObservationsSuperseded marks multiple observations as superseded.
|
||||||
|
func (s *ConflictStore) MarkObservationsSuperseded(ctx context.Context, obsIDs []int64) error {
|
||||||
|
if len(obsIDs) == 0 {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
query := `UPDATE observations SET is_superseded = 1 WHERE id IN (?` + repeatPlaceholders(len(obsIDs)-1) + `)` // #nosec G202 -- uses parameterized placeholders
|
||||||
|
args := int64SliceToInterface(obsIDs)
|
||||||
|
_, err := s.store.db.ExecContext(ctx, query, args...)
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
// GetConflictsByObservationID retrieves all conflicts involving an observation.
|
||||||
|
func (s *ConflictStore) GetConflictsByObservationID(ctx context.Context, obsID int64) ([]*models.ObservationConflict, error) {
|
||||||
|
const query = `
|
||||||
|
SELECT id, newer_obs_id, older_obs_id, conflict_type, resolution, reason,
|
||||||
|
detected_at, detected_at_epoch, resolved, resolved_at
|
||||||
|
FROM observation_conflicts
|
||||||
|
WHERE newer_obs_id = ? OR older_obs_id = ?
|
||||||
|
ORDER BY detected_at_epoch DESC
|
||||||
|
`
|
||||||
|
|
||||||
|
rows, err := s.store.QueryContext(ctx, query, obsID, obsID)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
defer rows.Close()
|
||||||
|
|
||||||
|
return s.scanConflictRows(rows)
|
||||||
|
}
|
||||||
|
|
||||||
|
// GetUnresolvedConflicts retrieves all unresolved conflicts.
|
||||||
|
func (s *ConflictStore) GetUnresolvedConflicts(ctx context.Context, limit int) ([]*models.ObservationConflict, error) {
|
||||||
|
const query = `
|
||||||
|
SELECT id, newer_obs_id, older_obs_id, conflict_type, resolution, reason,
|
||||||
|
detected_at, detected_at_epoch, resolved, resolved_at
|
||||||
|
FROM observation_conflicts
|
||||||
|
WHERE resolved = 0
|
||||||
|
ORDER BY detected_at_epoch DESC
|
||||||
|
LIMIT ?
|
||||||
|
`
|
||||||
|
|
||||||
|
rows, err := s.store.QueryContext(ctx, query, limit)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
defer rows.Close()
|
||||||
|
|
||||||
|
return s.scanConflictRows(rows)
|
||||||
|
}
|
||||||
|
|
||||||
|
// GetSupersededObservationIDs returns IDs of all observations that have been superseded.
|
||||||
|
func (s *ConflictStore) GetSupersededObservationIDs(ctx context.Context, project string) ([]int64, error) {
|
||||||
|
const query = `
|
||||||
|
SELECT DISTINCT older_obs_id
|
||||||
|
FROM observation_conflicts oc
|
||||||
|
JOIN observations o ON o.id = oc.older_obs_id
|
||||||
|
WHERE oc.resolution = 'prefer_newer'
|
||||||
|
AND (o.project = ? OR o.scope = 'global')
|
||||||
|
`
|
||||||
|
|
||||||
|
rows, err := s.store.QueryContext(ctx, query, project)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
defer rows.Close()
|
||||||
|
|
||||||
|
var ids []int64
|
||||||
|
for rows.Next() {
|
||||||
|
var id int64
|
||||||
|
if err := rows.Scan(&id); err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
ids = append(ids, id)
|
||||||
|
}
|
||||||
|
return ids, rows.Err()
|
||||||
|
}
|
||||||
|
|
||||||
|
// ResolveConflict marks a conflict as resolved.
|
||||||
|
func (s *ConflictStore) ResolveConflict(ctx context.Context, conflictID int64, resolution models.ConflictResolution) error {
|
||||||
|
now := time.Now().Format(time.RFC3339)
|
||||||
|
const query = `
|
||||||
|
UPDATE observation_conflicts
|
||||||
|
SET resolved = 1, resolved_at = ?, resolution = ?
|
||||||
|
WHERE id = ?
|
||||||
|
`
|
||||||
|
_, err := s.store.ExecContext(ctx, query, now, string(resolution), conflictID)
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
// DeleteConflictsByObservationID deletes all conflicts involving an observation.
|
||||||
|
// Called when an observation is deleted.
|
||||||
|
func (s *ConflictStore) DeleteConflictsByObservationID(ctx context.Context, obsID int64) error {
|
||||||
|
const query = `DELETE FROM observation_conflicts WHERE newer_obs_id = ? OR older_obs_id = ?`
|
||||||
|
_, err := s.store.ExecContext(ctx, query, obsID, obsID)
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
// ConflictWithDetails contains a conflict with its observation details.
|
||||||
|
type ConflictWithDetails struct {
|
||||||
|
Conflict *models.ObservationConflict
|
||||||
|
NewerObsTitle string
|
||||||
|
OlderObsTitle string
|
||||||
|
}
|
||||||
|
|
||||||
|
// CleanupSupersededObservations deletes observations that have been superseded for longer than
|
||||||
|
// SupersededRetentionDays. Returns the IDs of deleted observations for downstream cleanup (e.g., vector DB).
|
||||||
|
func (s *ConflictStore) CleanupSupersededObservations(ctx context.Context, project string) ([]int64, error) {
|
||||||
|
// Calculate cutoff time (3 days ago in milliseconds)
|
||||||
|
cutoffEpoch := time.Now().AddDate(0, 0, -SupersededRetentionDays).UnixMilli()
|
||||||
|
|
||||||
|
// First, find the IDs that will be deleted
|
||||||
|
// We delete observations that:
|
||||||
|
// 1. Are marked as superseded
|
||||||
|
// 2. Have a conflict record where they are the older observation
|
||||||
|
// 3. The conflict was detected more than 3 days ago
|
||||||
|
const selectQuery = `
|
||||||
|
SELECT DISTINCT o.id FROM observations o
|
||||||
|
JOIN observation_conflicts oc ON o.id = oc.older_obs_id
|
||||||
|
WHERE o.is_superseded = 1
|
||||||
|
AND o.project = ?
|
||||||
|
AND oc.detected_at_epoch < ?
|
||||||
|
`
|
||||||
|
|
||||||
|
rows, err := s.store.QueryContext(ctx, selectQuery, project, cutoffEpoch)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
defer rows.Close()
|
||||||
|
|
||||||
|
var toDelete []int64
|
||||||
|
for rows.Next() {
|
||||||
|
var id int64
|
||||||
|
if err := rows.Scan(&id); err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
toDelete = append(toDelete, id)
|
||||||
|
}
|
||||||
|
if err := rows.Err(); err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
if len(toDelete) == 0 {
|
||||||
|
return nil, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// Delete the conflict records first (due to foreign key constraints)
|
||||||
|
for _, obsID := range toDelete {
|
||||||
|
if err := s.DeleteConflictsByObservationID(ctx, obsID); err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Delete the observations
|
||||||
|
deleteQuery := `DELETE FROM observations WHERE id IN (?` + repeatPlaceholders(len(toDelete)-1) + `)` // #nosec G202 -- uses parameterized placeholders
|
||||||
|
args := int64SliceToInterface(toDelete)
|
||||||
|
_, err = s.store.db.ExecContext(ctx, deleteQuery, args...)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
return toDelete, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// GetConflictsWithDetails retrieves all conflicts with observation titles for display.
|
||||||
|
func (s *ConflictStore) GetConflictsWithDetails(ctx context.Context, project string, limit int) ([]*ConflictWithDetails, error) {
|
||||||
|
const query = `
|
||||||
|
SELECT oc.id, oc.newer_obs_id, oc.older_obs_id, oc.conflict_type, oc.resolution, oc.reason,
|
||||||
|
oc.detected_at, oc.detected_at_epoch, oc.resolved, oc.resolved_at,
|
||||||
|
COALESCE(newer.title, '') as newer_title,
|
||||||
|
COALESCE(older.title, '') as older_title
|
||||||
|
FROM observation_conflicts oc
|
||||||
|
JOIN observations newer ON newer.id = oc.newer_obs_id
|
||||||
|
JOIN observations older ON older.id = oc.older_obs_id
|
||||||
|
WHERE newer.project = ? OR older.project = ?
|
||||||
|
ORDER BY oc.detected_at_epoch DESC
|
||||||
|
LIMIT ?
|
||||||
|
`
|
||||||
|
|
||||||
|
rows, err := s.store.QueryContext(ctx, query, project, project, limit)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
defer rows.Close()
|
||||||
|
|
||||||
|
var results []*ConflictWithDetails
|
||||||
|
for rows.Next() {
|
||||||
|
var c models.ObservationConflict
|
||||||
|
var cwd ConflictWithDetails
|
||||||
|
if err := rows.Scan(
|
||||||
|
&c.ID, &c.NewerObsID, &c.OlderObsID,
|
||||||
|
&c.ConflictType, &c.Resolution, &c.Reason,
|
||||||
|
&c.DetectedAt, &c.DetectedAtEpoch,
|
||||||
|
&c.Resolved, &c.ResolvedAt,
|
||||||
|
&cwd.NewerObsTitle, &cwd.OlderObsTitle,
|
||||||
|
); err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
cwd.Conflict = &c
|
||||||
|
results = append(results, &cwd)
|
||||||
|
}
|
||||||
|
return results, rows.Err()
|
||||||
|
}
|
||||||
|
|
||||||
|
// scanConflictRows scans multiple conflicts from rows.
|
||||||
|
func (s *ConflictStore) scanConflictRows(rows interface {
|
||||||
|
Next() bool
|
||||||
|
Scan(...interface{}) error
|
||||||
|
Err() error
|
||||||
|
}) ([]*models.ObservationConflict, error) {
|
||||||
|
var conflicts []*models.ObservationConflict
|
||||||
|
for rows.Next() {
|
||||||
|
var c models.ObservationConflict
|
||||||
|
if err := rows.Scan(
|
||||||
|
&c.ID, &c.NewerObsID, &c.OlderObsID,
|
||||||
|
&c.ConflictType, &c.Resolution, &c.Reason,
|
||||||
|
&c.DetectedAt, &c.DetectedAtEpoch,
|
||||||
|
&c.Resolved, &c.ResolvedAt,
|
||||||
|
); err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
conflicts = append(conflicts, &c)
|
||||||
|
}
|
||||||
|
return conflicts, rows.Err()
|
||||||
|
}
|
||||||
@@ -283,6 +283,213 @@ var Migrations = []Migration{
|
|||||||
ON user_prompts(claude_session_id, prompt_number);
|
ON user_prompts(claude_session_id, prompt_number);
|
||||||
`,
|
`,
|
||||||
},
|
},
|
||||||
|
{
|
||||||
|
Version: 19,
|
||||||
|
Name: "vectors_with_model_version",
|
||||||
|
SQL: `
|
||||||
|
-- Drop old vectors table (virtual tables cannot be altered)
|
||||||
|
DROP TABLE IF EXISTS vectors;
|
||||||
|
|
||||||
|
-- Recreate vectors table with model_version column
|
||||||
|
-- Uses bge-small-en-v1.5 embeddings (384 dimensions)
|
||||||
|
CREATE VIRTUAL TABLE IF NOT EXISTS vectors USING vec0(
|
||||||
|
doc_id TEXT PRIMARY KEY,
|
||||||
|
embedding float[384],
|
||||||
|
sqlite_id INTEGER,
|
||||||
|
doc_type TEXT,
|
||||||
|
field_type TEXT,
|
||||||
|
project TEXT,
|
||||||
|
scope TEXT,
|
||||||
|
model_version TEXT
|
||||||
|
);
|
||||||
|
`,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
Version: 20,
|
||||||
|
Name: "importance_scoring",
|
||||||
|
SQL: `
|
||||||
|
-- Importance scoring system for observations
|
||||||
|
-- Implements multi-factor scoring: type weight, recency decay, user feedback, concept weights, retrieval boost
|
||||||
|
|
||||||
|
-- Cached importance score (recalculated periodically)
|
||||||
|
ALTER TABLE observations ADD COLUMN importance_score REAL DEFAULT 1.0;
|
||||||
|
|
||||||
|
-- User feedback: -1 = thumbs down, 0 = neutral, 1 = thumbs up
|
||||||
|
ALTER TABLE observations ADD COLUMN user_feedback INTEGER DEFAULT 0;
|
||||||
|
|
||||||
|
-- Retrieval tracking: how many times this observation was returned in searches
|
||||||
|
ALTER TABLE observations ADD COLUMN retrieval_count INTEGER DEFAULT 0;
|
||||||
|
|
||||||
|
-- Last time this observation was retrieved (for analytics)
|
||||||
|
ALTER TABLE observations ADD COLUMN last_retrieved_at_epoch INTEGER;
|
||||||
|
|
||||||
|
-- Timestamp of last score recalculation
|
||||||
|
ALTER TABLE observations ADD COLUMN score_updated_at_epoch INTEGER;
|
||||||
|
|
||||||
|
-- Index for importance-based sorting (primary ordering strategy)
|
||||||
|
CREATE INDEX IF NOT EXISTS idx_observations_importance
|
||||||
|
ON observations(importance_score DESC, created_at_epoch DESC);
|
||||||
|
|
||||||
|
-- Index for finding observations needing score recalculation
|
||||||
|
CREATE INDEX IF NOT EXISTS idx_observations_score_updated
|
||||||
|
ON observations(score_updated_at_epoch);
|
||||||
|
|
||||||
|
-- Configurable concept weights table
|
||||||
|
-- Allows runtime tuning of how much each concept contributes to importance
|
||||||
|
CREATE TABLE IF NOT EXISTS concept_weights (
|
||||||
|
concept TEXT PRIMARY KEY,
|
||||||
|
weight REAL NOT NULL DEFAULT 0.1,
|
||||||
|
updated_at TEXT NOT NULL
|
||||||
|
);
|
||||||
|
|
||||||
|
-- Seed default concept weights (security highest, tooling lowest)
|
||||||
|
INSERT OR IGNORE INTO concept_weights (concept, weight, updated_at) VALUES
|
||||||
|
('security', 0.30, datetime('now')),
|
||||||
|
('gotcha', 0.25, datetime('now')),
|
||||||
|
('best-practice', 0.20, datetime('now')),
|
||||||
|
('anti-pattern', 0.20, datetime('now')),
|
||||||
|
('architecture', 0.15, datetime('now')),
|
||||||
|
('performance', 0.15, datetime('now')),
|
||||||
|
('error-handling', 0.15, datetime('now')),
|
||||||
|
('pattern', 0.10, datetime('now')),
|
||||||
|
('testing', 0.10, datetime('now')),
|
||||||
|
('debugging', 0.10, datetime('now')),
|
||||||
|
('workflow', 0.05, datetime('now')),
|
||||||
|
('tooling', 0.05, datetime('now'));
|
||||||
|
`,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
Version: 21,
|
||||||
|
Name: "observation_conflicts",
|
||||||
|
SQL: `
|
||||||
|
-- Observation conflicts table for tracking contradictions and superseded observations
|
||||||
|
-- Implements Issue #5: Contradiction & Obsolescence Detection
|
||||||
|
CREATE TABLE IF NOT EXISTS observation_conflicts (
|
||||||
|
id INTEGER PRIMARY KEY AUTOINCREMENT,
|
||||||
|
newer_obs_id INTEGER NOT NULL,
|
||||||
|
older_obs_id INTEGER NOT NULL,
|
||||||
|
conflict_type TEXT NOT NULL CHECK(conflict_type IN ('superseded', 'contradicts', 'outdated_pattern')),
|
||||||
|
resolution TEXT NOT NULL CHECK(resolution IN ('prefer_newer', 'prefer_older', 'manual')),
|
||||||
|
reason TEXT,
|
||||||
|
detected_at TEXT NOT NULL,
|
||||||
|
detected_at_epoch INTEGER NOT NULL,
|
||||||
|
resolved INTEGER DEFAULT 0,
|
||||||
|
resolved_at TEXT,
|
||||||
|
FOREIGN KEY(newer_obs_id) REFERENCES observations(id) ON DELETE CASCADE,
|
||||||
|
FOREIGN KEY(older_obs_id) REFERENCES observations(id) ON DELETE CASCADE
|
||||||
|
);
|
||||||
|
|
||||||
|
-- Index for looking up conflicts by observation ID
|
||||||
|
CREATE INDEX IF NOT EXISTS idx_conflicts_newer ON observation_conflicts(newer_obs_id);
|
||||||
|
CREATE INDEX IF NOT EXISTS idx_conflicts_older ON observation_conflicts(older_obs_id);
|
||||||
|
CREATE INDEX IF NOT EXISTS idx_conflicts_unresolved ON observation_conflicts(resolved, detected_at_epoch DESC);
|
||||||
|
|
||||||
|
-- Add is_superseded column to observations for quick filtering
|
||||||
|
-- Set to 1 when this observation has been superseded by a newer one
|
||||||
|
ALTER TABLE observations ADD COLUMN is_superseded INTEGER DEFAULT 0;
|
||||||
|
|
||||||
|
-- Index for filtering out superseded observations in queries
|
||||||
|
CREATE INDEX IF NOT EXISTS idx_observations_superseded ON observations(is_superseded, importance_score DESC);
|
||||||
|
`,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
Version: 22,
|
||||||
|
Name: "patterns_table",
|
||||||
|
SQL: `
|
||||||
|
-- Pattern Recognition Engine (Issue #7)
|
||||||
|
-- Tracks recurring patterns detected across observations
|
||||||
|
-- Enables Claude to reference historical insights: "I've encountered this pattern 12 times."
|
||||||
|
CREATE TABLE IF NOT EXISTS patterns (
|
||||||
|
id INTEGER PRIMARY KEY AUTOINCREMENT,
|
||||||
|
name TEXT NOT NULL,
|
||||||
|
type TEXT NOT NULL CHECK(type IN ('bug', 'refactor', 'architecture', 'anti-pattern', 'best-practice')),
|
||||||
|
description TEXT,
|
||||||
|
signature TEXT, -- JSON array of keywords/concepts for detection
|
||||||
|
recommendation TEXT, -- What works for this pattern
|
||||||
|
frequency INTEGER DEFAULT 1, -- How many times encountered
|
||||||
|
projects TEXT, -- JSON array of projects where seen
|
||||||
|
observation_ids TEXT, -- JSON array of source observation IDs
|
||||||
|
status TEXT DEFAULT 'active' CHECK(status IN ('active', 'deprecated', 'merged')),
|
||||||
|
merged_into_id INTEGER, -- If status is 'merged', which pattern it merged into
|
||||||
|
confidence REAL DEFAULT 0.5, -- Detection confidence (0.0-1.0)
|
||||||
|
last_seen_at TEXT NOT NULL,
|
||||||
|
last_seen_at_epoch INTEGER NOT NULL,
|
||||||
|
created_at TEXT NOT NULL,
|
||||||
|
created_at_epoch INTEGER NOT NULL,
|
||||||
|
FOREIGN KEY(merged_into_id) REFERENCES patterns(id) ON DELETE SET NULL
|
||||||
|
);
|
||||||
|
|
||||||
|
-- Indexes for efficient pattern queries
|
||||||
|
CREATE INDEX IF NOT EXISTS idx_patterns_type ON patterns(type);
|
||||||
|
CREATE INDEX IF NOT EXISTS idx_patterns_status ON patterns(status);
|
||||||
|
CREATE INDEX IF NOT EXISTS idx_patterns_frequency ON patterns(frequency DESC);
|
||||||
|
CREATE INDEX IF NOT EXISTS idx_patterns_confidence ON patterns(confidence DESC);
|
||||||
|
CREATE INDEX IF NOT EXISTS idx_patterns_last_seen ON patterns(last_seen_at_epoch DESC);
|
||||||
|
|
||||||
|
-- FTS5 virtual table for pattern search
|
||||||
|
CREATE VIRTUAL TABLE IF NOT EXISTS patterns_fts USING fts5(
|
||||||
|
name, description, recommendation,
|
||||||
|
content='patterns',
|
||||||
|
content_rowid='id'
|
||||||
|
);
|
||||||
|
|
||||||
|
-- Triggers for FTS5 sync
|
||||||
|
CREATE TRIGGER IF NOT EXISTS patterns_ai AFTER INSERT ON patterns BEGIN
|
||||||
|
INSERT INTO patterns_fts(rowid, name, description, recommendation)
|
||||||
|
VALUES (new.id, new.name, new.description, new.recommendation);
|
||||||
|
END;
|
||||||
|
|
||||||
|
CREATE TRIGGER IF NOT EXISTS patterns_ad AFTER DELETE ON patterns BEGIN
|
||||||
|
INSERT INTO patterns_fts(patterns_fts, rowid, name, description, recommendation)
|
||||||
|
VALUES('delete', old.id, old.name, old.description, old.recommendation);
|
||||||
|
END;
|
||||||
|
|
||||||
|
CREATE TRIGGER IF NOT EXISTS patterns_au AFTER UPDATE ON patterns BEGIN
|
||||||
|
INSERT INTO patterns_fts(patterns_fts, rowid, name, description, recommendation)
|
||||||
|
VALUES('delete', old.id, old.name, old.description, old.recommendation);
|
||||||
|
INSERT INTO patterns_fts(rowid, name, description, recommendation)
|
||||||
|
VALUES (new.id, new.name, new.description, new.recommendation);
|
||||||
|
END;
|
||||||
|
`,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
Version: 23,
|
||||||
|
Name: "observation_relations",
|
||||||
|
SQL: `
|
||||||
|
-- Knowledge Graph: Observation Relations (Issue #4)
|
||||||
|
-- Tracks explicit relationships between observations for knowledge graph navigation.
|
||||||
|
-- Enables queries like "What caused this bug?" or "What depends on this decision?"
|
||||||
|
CREATE TABLE IF NOT EXISTS observation_relations (
|
||||||
|
id INTEGER PRIMARY KEY AUTOINCREMENT,
|
||||||
|
source_id INTEGER NOT NULL,
|
||||||
|
target_id INTEGER NOT NULL,
|
||||||
|
relation_type TEXT NOT NULL CHECK(relation_type IN ('causes', 'fixes', 'supersedes', 'depends_on', 'relates_to', 'evolves_from')),
|
||||||
|
confidence REAL NOT NULL DEFAULT 0.5,
|
||||||
|
detection_source TEXT NOT NULL CHECK(detection_source IN ('file_overlap', 'embedding_similarity', 'temporal_proximity', 'narrative_mention', 'concept_overlap', 'type_progression')),
|
||||||
|
reason TEXT,
|
||||||
|
created_at TEXT NOT NULL,
|
||||||
|
created_at_epoch INTEGER NOT NULL,
|
||||||
|
FOREIGN KEY(source_id) REFERENCES observations(id) ON DELETE CASCADE,
|
||||||
|
FOREIGN KEY(target_id) REFERENCES observations(id) ON DELETE CASCADE,
|
||||||
|
UNIQUE(source_id, target_id, relation_type)
|
||||||
|
);
|
||||||
|
|
||||||
|
-- Index for finding relations by source observation
|
||||||
|
CREATE INDEX IF NOT EXISTS idx_relations_source ON observation_relations(source_id);
|
||||||
|
|
||||||
|
-- Index for finding relations by target observation
|
||||||
|
CREATE INDEX IF NOT EXISTS idx_relations_target ON observation_relations(target_id);
|
||||||
|
|
||||||
|
-- Index for relation type queries
|
||||||
|
CREATE INDEX IF NOT EXISTS idx_relations_type ON observation_relations(relation_type);
|
||||||
|
|
||||||
|
-- Index for confidence-based filtering
|
||||||
|
CREATE INDEX IF NOT EXISTS idx_relations_confidence ON observation_relations(confidence DESC);
|
||||||
|
|
||||||
|
-- Index for finding all relations involving an observation (either direction)
|
||||||
|
CREATE INDEX IF NOT EXISTS idx_relations_both ON observation_relations(source_id, target_id);
|
||||||
|
`,
|
||||||
|
},
|
||||||
}
|
}
|
||||||
|
|
||||||
// MigrationManager handles database schema migrations.
|
// MigrationManager handles database schema migrations.
|
||||||
|
|||||||
@@ -11,14 +11,27 @@ import (
|
|||||||
"github.com/lukaszraczylo/claude-mnemonic/pkg/models"
|
"github.com/lukaszraczylo/claude-mnemonic/pkg/models"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
// observationColumns is the standard list of columns to select for observations.
|
||||||
|
// This ensures consistency across all observation queries and includes importance scoring fields.
|
||||||
|
const observationColumns = `id, sdk_session_id, project, COALESCE(scope, 'project') as scope, type,
|
||||||
|
title, subtitle, facts, narrative, concepts, files_read, files_modified, file_mtimes,
|
||||||
|
prompt_number, discovery_tokens, created_at, created_at_epoch,
|
||||||
|
COALESCE(importance_score, 1.0) as importance_score,
|
||||||
|
COALESCE(user_feedback, 0) as user_feedback,
|
||||||
|
COALESCE(retrieval_count, 0) as retrieval_count,
|
||||||
|
last_retrieved_at_epoch, score_updated_at_epoch,
|
||||||
|
COALESCE(is_superseded, 0) as is_superseded`
|
||||||
|
|
||||||
// CleanupFunc is a callback for when observations are cleaned up.
|
// CleanupFunc is a callback for when observations are cleaned up.
|
||||||
// Receives the IDs of deleted observations for downstream cleanup (e.g., vector DB).
|
// Receives the IDs of deleted observations for downstream cleanup (e.g., vector DB).
|
||||||
type CleanupFunc func(ctx context.Context, deletedIDs []int64)
|
type CleanupFunc func(ctx context.Context, deletedIDs []int64)
|
||||||
|
|
||||||
// ObservationStore provides observation-related database operations.
|
// ObservationStore provides observation-related database operations.
|
||||||
type ObservationStore struct {
|
type ObservationStore struct {
|
||||||
store *Store
|
store *Store
|
||||||
cleanupFunc CleanupFunc
|
cleanupFunc CleanupFunc
|
||||||
|
conflictStore *ConflictStore
|
||||||
|
relationStore *RelationStore
|
||||||
}
|
}
|
||||||
|
|
||||||
// NewObservationStore creates a new observation store.
|
// NewObservationStore creates a new observation store.
|
||||||
@@ -31,6 +44,16 @@ func (s *ObservationStore) SetCleanupFunc(fn CleanupFunc) {
|
|||||||
s.cleanupFunc = fn
|
s.cleanupFunc = fn
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// SetConflictStore sets the conflict store for conflict detection.
|
||||||
|
func (s *ObservationStore) SetConflictStore(conflictStore *ConflictStore) {
|
||||||
|
s.conflictStore = conflictStore
|
||||||
|
}
|
||||||
|
|
||||||
|
// SetRelationStore sets the relation store for relationship detection.
|
||||||
|
func (s *ObservationStore) SetRelationStore(relationStore *RelationStore) {
|
||||||
|
s.relationStore = relationStore
|
||||||
|
}
|
||||||
|
|
||||||
// StoreObservation stores a new observation.
|
// StoreObservation stores a new observation.
|
||||||
func (s *ObservationStore) StoreObservation(ctx context.Context, sdkSessionID, project string, obs *models.ParsedObservation, promptNumber int, discoveryTokens int64) (int64, int64, error) {
|
func (s *ObservationStore) StoreObservation(ctx context.Context, sdkSessionID, project string, obs *models.ParsedObservation, promptNumber int, discoveryTokens int64) (int64, int64, error) {
|
||||||
now := time.Now()
|
now := time.Now()
|
||||||
@@ -86,9 +109,112 @@ func (s *ObservationStore) StoreObservation(ctx context.Context, sdkSessionID, p
|
|||||||
}(project)
|
}(project)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Detect conflicts with existing observations (async to not block handler)
|
||||||
|
if s.conflictStore != nil && project != "" {
|
||||||
|
go func(newObsID int64, proj string, parsedObs *models.ParsedObservation) {
|
||||||
|
conflictCtx, cancel := context.WithTimeout(context.Background(), 10*time.Second)
|
||||||
|
defer cancel()
|
||||||
|
s.detectAndStoreConflicts(conflictCtx, newObsID, proj, parsedObs)
|
||||||
|
}(id, project, obs)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Detect relationships with existing observations (async to not block handler)
|
||||||
|
if s.relationStore != nil && project != "" {
|
||||||
|
go func(newObsID int64, proj string) {
|
||||||
|
relationCtx, cancel := context.WithTimeout(context.Background(), 10*time.Second)
|
||||||
|
defer cancel()
|
||||||
|
s.detectAndStoreRelations(relationCtx, newObsID, proj)
|
||||||
|
}(id, project)
|
||||||
|
}
|
||||||
|
|
||||||
return id, nowEpoch, nil
|
return id, nowEpoch, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// detectAndStoreConflicts detects conflicts between a new observation and existing ones.
|
||||||
|
func (s *ObservationStore) detectAndStoreConflicts(ctx context.Context, newObsID int64, project string, parsedObs *models.ParsedObservation) {
|
||||||
|
// Fetch the newly stored observation
|
||||||
|
newObs, err := s.GetObservationByID(ctx, newObsID)
|
||||||
|
if err != nil || newObs == nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// Fetch recent observations from the same project to check for conflicts
|
||||||
|
existing, err := s.GetRecentObservations(ctx, project, 50)
|
||||||
|
if err != nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// Detect conflicts
|
||||||
|
conflicts := models.DetectConflictsWithExisting(newObs, existing)
|
||||||
|
|
||||||
|
// Store conflicts and mark superseded observations
|
||||||
|
var supersededIDs []int64
|
||||||
|
for _, result := range conflicts {
|
||||||
|
for _, olderID := range result.OlderObsIDs {
|
||||||
|
conflict := models.NewObservationConflict(
|
||||||
|
newObsID, olderID,
|
||||||
|
result.Type, result.Resolution, result.Reason,
|
||||||
|
)
|
||||||
|
if _, err := s.conflictStore.StoreConflict(ctx, conflict); err != nil {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
// If resolution is to prefer newer, mark older as superseded
|
||||||
|
if result.Resolution == models.ResolutionPreferNewer {
|
||||||
|
supersededIDs = append(supersededIDs, olderID)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Mark superseded observations
|
||||||
|
if len(supersededIDs) > 0 {
|
||||||
|
_ = s.conflictStore.MarkObservationsSuperseded(ctx, supersededIDs)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Cleanup old superseded observations (older than 3 days)
|
||||||
|
deletedIDs, _ := s.conflictStore.CleanupSupersededObservations(ctx, project)
|
||||||
|
if len(deletedIDs) > 0 && s.cleanupFunc != nil {
|
||||||
|
s.cleanupFunc(ctx, deletedIDs)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// MinRelationConfidence is the minimum confidence threshold for storing relations.
|
||||||
|
const MinRelationConfidence = 0.4
|
||||||
|
|
||||||
|
// detectAndStoreRelations detects relationships between a new observation and existing ones.
|
||||||
|
func (s *ObservationStore) detectAndStoreRelations(ctx context.Context, newObsID int64, project string) {
|
||||||
|
// Fetch the newly stored observation
|
||||||
|
newObs, err := s.GetObservationByID(ctx, newObsID)
|
||||||
|
if err != nil || newObs == nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// Fetch recent observations from the same project to check for relations
|
||||||
|
existing, err := s.GetRecentObservations(ctx, project, 50)
|
||||||
|
if err != nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// Detect relationships using the models package detection logic
|
||||||
|
results := models.DetectRelationsWithExisting(newObs, existing, MinRelationConfidence)
|
||||||
|
if len(results) == 0 {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// Convert detection results to relation objects
|
||||||
|
relations := make([]*models.ObservationRelation, len(results))
|
||||||
|
for i, r := range results {
|
||||||
|
relations[i] = models.NewObservationRelation(
|
||||||
|
r.SourceID, r.TargetID,
|
||||||
|
r.RelationType, r.Confidence,
|
||||||
|
r.DetectionSource, r.Reason,
|
||||||
|
)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Store all relations
|
||||||
|
_ = s.relationStore.StoreRelations(ctx, relations)
|
||||||
|
}
|
||||||
|
|
||||||
// ensureSessionExists creates a session if it doesn't exist.
|
// ensureSessionExists creates a session if it doesn't exist.
|
||||||
func (s *ObservationStore) ensureSessionExists(ctx context.Context, sdkSessionID, project string) error {
|
func (s *ObservationStore) ensureSessionExists(ctx context.Context, sdkSessionID, project string) error {
|
||||||
return EnsureSessionExists(ctx, s.store, sdkSessionID, project)
|
return EnsureSessionExists(ctx, s.store, sdkSessionID, project)
|
||||||
@@ -96,13 +222,7 @@ func (s *ObservationStore) ensureSessionExists(ctx context.Context, sdkSessionID
|
|||||||
|
|
||||||
// GetObservationByID retrieves an observation by ID.
|
// GetObservationByID retrieves an observation by ID.
|
||||||
func (s *ObservationStore) GetObservationByID(ctx context.Context, id int64) (*models.Observation, error) {
|
func (s *ObservationStore) GetObservationByID(ctx context.Context, id int64) (*models.Observation, error) {
|
||||||
const query = `
|
query := `SELECT ` + observationColumns + ` FROM observations WHERE id = ?`
|
||||||
SELECT id, sdk_session_id, project, COALESCE(scope, 'project') as scope, type, title, subtitle, facts, narrative,
|
|
||||||
concepts, files_read, files_modified, file_mtimes, prompt_number, discovery_tokens,
|
|
||||||
created_at, created_at_epoch
|
|
||||||
FROM observations
|
|
||||||
WHERE id = ?
|
|
||||||
`
|
|
||||||
|
|
||||||
obs, err := scanObservation(s.store.QueryRowContext(ctx, query, id))
|
obs, err := scanObservation(s.store.QueryRowContext(ctx, query, id))
|
||||||
if err == sql.ErrNoRows {
|
if err == sql.ErrNoRows {
|
||||||
@@ -112,6 +232,7 @@ func (s *ObservationStore) GetObservationByID(ctx context.Context, id int64) (*m
|
|||||||
}
|
}
|
||||||
|
|
||||||
// GetObservationsByIDs retrieves observations by a list of IDs.
|
// GetObservationsByIDs retrieves observations by a list of IDs.
|
||||||
|
// Results are ordered by importance_score DESC by default, with created_at_epoch as secondary sort.
|
||||||
func (s *ObservationStore) GetObservationsByIDs(ctx context.Context, ids []int64, orderBy string, limit int) ([]*models.Observation, error) {
|
func (s *ObservationStore) GetObservationsByIDs(ctx context.Context, ids []int64, orderBy string, limit int) ([]*models.Observation, error) {
|
||||||
if len(ids) == 0 {
|
if len(ids) == 0 {
|
||||||
return nil, nil
|
return nil, nil
|
||||||
@@ -119,18 +240,22 @@ func (s *ObservationStore) GetObservationsByIDs(ctx context.Context, ids []int64
|
|||||||
|
|
||||||
// Build query with placeholders
|
// Build query with placeholders
|
||||||
// #nosec G202 -- query uses parameterized placeholders, not user input
|
// #nosec G202 -- query uses parameterized placeholders, not user input
|
||||||
query := `
|
query := `SELECT ` + observationColumns + `
|
||||||
SELECT id, sdk_session_id, project, COALESCE(scope, 'project') as scope, type, title, subtitle, facts, narrative,
|
|
||||||
concepts, files_read, files_modified, file_mtimes, prompt_number, discovery_tokens,
|
|
||||||
created_at, created_at_epoch
|
|
||||||
FROM observations
|
FROM observations
|
||||||
WHERE id IN (?` + repeatPlaceholders(len(ids)-1) + `)
|
WHERE id IN (?` + repeatPlaceholders(len(ids)-1) + `)
|
||||||
ORDER BY created_at_epoch `
|
ORDER BY `
|
||||||
|
|
||||||
if orderBy == "date_asc" {
|
// Default to importance-based ordering
|
||||||
query += "ASC"
|
switch orderBy {
|
||||||
} else {
|
case "date_asc":
|
||||||
query += "DESC"
|
query += "created_at_epoch ASC"
|
||||||
|
case "date_desc":
|
||||||
|
query += "created_at_epoch DESC"
|
||||||
|
case "importance":
|
||||||
|
query += "importance_score DESC, created_at_epoch DESC"
|
||||||
|
default:
|
||||||
|
// Default: importance first, then recency
|
||||||
|
query += "COALESCE(importance_score, 1.0) DESC, created_at_epoch DESC"
|
||||||
}
|
}
|
||||||
|
|
||||||
if limit > 0 {
|
if limit > 0 {
|
||||||
@@ -154,14 +279,56 @@ func (s *ObservationStore) GetObservationsByIDs(ctx context.Context, ids []int64
|
|||||||
|
|
||||||
// GetRecentObservations retrieves recent observations for a project.
|
// GetRecentObservations retrieves recent observations for a project.
|
||||||
// This includes project-scoped observations for the specified project AND global observations.
|
// This includes project-scoped observations for the specified project AND global observations.
|
||||||
|
// Results are ordered by importance_score DESC, then created_at_epoch DESC.
|
||||||
func (s *ObservationStore) GetRecentObservations(ctx context.Context, project string, limit int) ([]*models.Observation, error) {
|
func (s *ObservationStore) GetRecentObservations(ctx context.Context, project string, limit int) ([]*models.Observation, error) {
|
||||||
const query = `
|
query := `SELECT ` + observationColumns + `
|
||||||
SELECT id, sdk_session_id, project, COALESCE(scope, 'project') as scope, type, title, subtitle, facts, narrative,
|
|
||||||
concepts, files_read, files_modified, file_mtimes, prompt_number, discovery_tokens,
|
|
||||||
created_at, created_at_epoch
|
|
||||||
FROM observations
|
FROM observations
|
||||||
WHERE (project = ? AND (scope IS NULL OR scope = 'project'))
|
WHERE (project = ? AND (scope IS NULL OR scope = 'project'))
|
||||||
OR scope = 'global'
|
OR scope = 'global'
|
||||||
|
ORDER BY COALESCE(importance_score, 1.0) DESC, created_at_epoch DESC
|
||||||
|
LIMIT ?
|
||||||
|
`
|
||||||
|
|
||||||
|
rows, err := s.store.QueryContext(ctx, query, project, limit)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
defer rows.Close()
|
||||||
|
|
||||||
|
return scanObservationRows(rows)
|
||||||
|
}
|
||||||
|
|
||||||
|
// GetActiveObservations retrieves recent non-superseded observations for a project.
|
||||||
|
// This excludes observations that have been marked as superseded by newer ones.
|
||||||
|
// Use this for context injection where you want to avoid outdated/contradicted advice.
|
||||||
|
// Results are ordered by importance_score DESC, then created_at_epoch DESC.
|
||||||
|
func (s *ObservationStore) GetActiveObservations(ctx context.Context, project string, limit int) ([]*models.Observation, error) {
|
||||||
|
query := `SELECT ` + observationColumns + `
|
||||||
|
FROM observations
|
||||||
|
WHERE ((project = ? AND (scope IS NULL OR scope = 'project'))
|
||||||
|
OR scope = 'global')
|
||||||
|
AND COALESCE(is_superseded, 0) = 0
|
||||||
|
ORDER BY COALESCE(importance_score, 1.0) DESC, created_at_epoch DESC
|
||||||
|
LIMIT ?
|
||||||
|
`
|
||||||
|
|
||||||
|
rows, err := s.store.QueryContext(ctx, query, project, limit)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
defer rows.Close()
|
||||||
|
|
||||||
|
return scanObservationRows(rows)
|
||||||
|
}
|
||||||
|
|
||||||
|
// GetSupersededObservations retrieves observations that have been superseded by newer ones.
|
||||||
|
// Use this for verification/debugging to see which observations were marked as outdated.
|
||||||
|
// Results are ordered by created_at_epoch DESC.
|
||||||
|
func (s *ObservationStore) GetSupersededObservations(ctx context.Context, project string, limit int) ([]*models.Observation, error) {
|
||||||
|
query := `SELECT ` + observationColumns + `
|
||||||
|
FROM observations
|
||||||
|
WHERE project = ?
|
||||||
|
AND COALESCE(is_superseded, 0) = 1
|
||||||
ORDER BY created_at_epoch DESC
|
ORDER BY created_at_epoch DESC
|
||||||
LIMIT ?
|
LIMIT ?
|
||||||
`
|
`
|
||||||
@@ -178,14 +345,12 @@ func (s *ObservationStore) GetRecentObservations(ctx context.Context, project st
|
|||||||
// GetObservationsByProjectStrict retrieves observations strictly for a specific project.
|
// GetObservationsByProjectStrict retrieves observations strictly for a specific project.
|
||||||
// Unlike GetRecentObservations, this does NOT include global observations from other projects.
|
// Unlike GetRecentObservations, this does NOT include global observations from other projects.
|
||||||
// Use this for dashboard filtering where the user expects to see only that project's data.
|
// Use this for dashboard filtering where the user expects to see only that project's data.
|
||||||
|
// Results are ordered by importance_score DESC, then created_at_epoch DESC.
|
||||||
func (s *ObservationStore) GetObservationsByProjectStrict(ctx context.Context, project string, limit int) ([]*models.Observation, error) {
|
func (s *ObservationStore) GetObservationsByProjectStrict(ctx context.Context, project string, limit int) ([]*models.Observation, error) {
|
||||||
const query = `
|
query := `SELECT ` + observationColumns + `
|
||||||
SELECT id, sdk_session_id, project, COALESCE(scope, 'project') as scope, type, title, subtitle, facts, narrative,
|
|
||||||
concepts, files_read, files_modified, file_mtimes, prompt_number, discovery_tokens,
|
|
||||||
created_at, created_at_epoch
|
|
||||||
FROM observations
|
FROM observations
|
||||||
WHERE project = ?
|
WHERE project = ?
|
||||||
ORDER BY created_at_epoch DESC
|
ORDER BY COALESCE(importance_score, 1.0) DESC, created_at_epoch DESC
|
||||||
LIMIT ?
|
LIMIT ?
|
||||||
`
|
`
|
||||||
|
|
||||||
@@ -210,13 +375,11 @@ func (s *ObservationStore) GetObservationCount(ctx context.Context, project stri
|
|||||||
}
|
}
|
||||||
|
|
||||||
// GetAllRecentObservations retrieves recent observations across all projects.
|
// GetAllRecentObservations retrieves recent observations across all projects.
|
||||||
|
// Results are ordered by importance_score DESC, then created_at_epoch DESC.
|
||||||
func (s *ObservationStore) GetAllRecentObservations(ctx context.Context, limit int) ([]*models.Observation, error) {
|
func (s *ObservationStore) GetAllRecentObservations(ctx context.Context, limit int) ([]*models.Observation, error) {
|
||||||
const query = `
|
query := `SELECT ` + observationColumns + `
|
||||||
SELECT id, sdk_session_id, project, COALESCE(scope, 'project') as scope, type, title, subtitle, facts, narrative,
|
|
||||||
concepts, files_read, files_modified, file_mtimes, prompt_number, discovery_tokens,
|
|
||||||
created_at, created_at_epoch
|
|
||||||
FROM observations
|
FROM observations
|
||||||
ORDER BY created_at_epoch DESC
|
ORDER BY COALESCE(importance_score, 1.0) DESC, created_at_epoch DESC
|
||||||
LIMIT ?
|
LIMIT ?
|
||||||
`
|
`
|
||||||
|
|
||||||
@@ -229,7 +392,24 @@ func (s *ObservationStore) GetAllRecentObservations(ctx context.Context, limit i
|
|||||||
return scanObservationRows(rows)
|
return scanObservationRows(rows)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// GetAllObservations retrieves all observations (for vector rebuild).
|
||||||
|
func (s *ObservationStore) GetAllObservations(ctx context.Context) ([]*models.Observation, error) {
|
||||||
|
query := `SELECT ` + observationColumns + `
|
||||||
|
FROM observations
|
||||||
|
ORDER BY id
|
||||||
|
`
|
||||||
|
|
||||||
|
rows, err := s.store.QueryContext(ctx, query)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
defer rows.Close()
|
||||||
|
|
||||||
|
return scanObservationRows(rows)
|
||||||
|
}
|
||||||
|
|
||||||
// SearchObservationsFTS performs full-text search on observations.
|
// SearchObservationsFTS performs full-text search on observations.
|
||||||
|
// Results are ordered by FTS rank (relevance), then by importance_score.
|
||||||
func (s *ObservationStore) SearchObservationsFTS(ctx context.Context, query, project string, limit int) ([]*models.Observation, error) {
|
func (s *ObservationStore) SearchObservationsFTS(ctx context.Context, query, project string, limit int) ([]*models.Observation, error) {
|
||||||
if limit <= 0 {
|
if limit <= 0 {
|
||||||
limit = 10
|
limit = 10
|
||||||
@@ -245,15 +425,21 @@ func (s *ObservationStore) SearchObservationsFTS(ctx context.Context, query, pro
|
|||||||
ftsTerms := strings.Join(keywords, " OR ")
|
ftsTerms := strings.Join(keywords, " OR ")
|
||||||
|
|
||||||
// Use FTS5 to search title, subtitle, and narrative
|
// Use FTS5 to search title, subtitle, and narrative
|
||||||
const ftsQuery = `
|
// Include importance scoring columns and order by rank then importance
|
||||||
|
ftsQuery := `
|
||||||
SELECT o.id, o.sdk_session_id, o.project, COALESCE(o.scope, 'project') as scope, o.type,
|
SELECT o.id, o.sdk_session_id, o.project, COALESCE(o.scope, 'project') as scope, o.type,
|
||||||
o.title, o.subtitle, o.facts, o.narrative, o.concepts, o.files_read, o.files_modified,
|
o.title, o.subtitle, o.facts, o.narrative, o.concepts, o.files_read, o.files_modified,
|
||||||
o.file_mtimes, o.prompt_number, o.discovery_tokens, o.created_at, o.created_at_epoch
|
o.file_mtimes, o.prompt_number, o.discovery_tokens, o.created_at, o.created_at_epoch,
|
||||||
|
COALESCE(o.importance_score, 1.0) as importance_score,
|
||||||
|
COALESCE(o.user_feedback, 0) as user_feedback,
|
||||||
|
COALESCE(o.retrieval_count, 0) as retrieval_count,
|
||||||
|
o.last_retrieved_at_epoch, o.score_updated_at_epoch,
|
||||||
|
COALESCE(o.is_superseded, 0) as is_superseded
|
||||||
FROM observations o
|
FROM observations o
|
||||||
JOIN observations_fts fts ON o.id = fts.rowid
|
JOIN observations_fts fts ON o.id = fts.rowid
|
||||||
WHERE observations_fts MATCH ?
|
WHERE observations_fts MATCH ?
|
||||||
AND (o.project = ? OR o.scope = 'global')
|
AND (o.project = ? OR o.scope = 'global')
|
||||||
ORDER BY rank
|
ORDER BY rank, COALESCE(o.importance_score, 1.0) DESC
|
||||||
LIMIT ?
|
LIMIT ?
|
||||||
`
|
`
|
||||||
|
|
||||||
@@ -278,6 +464,7 @@ func (s *ObservationStore) SearchObservationsFTS(ctx context.Context, query, pro
|
|||||||
}
|
}
|
||||||
|
|
||||||
// searchObservationsLike performs fallback LIKE search on observations.
|
// searchObservationsLike performs fallback LIKE search on observations.
|
||||||
|
// Results are ordered by importance_score DESC, then created_at_epoch DESC.
|
||||||
func (s *ObservationStore) searchObservationsLike(ctx context.Context, keywords []string, project string, limit int) ([]*models.Observation, error) {
|
func (s *ObservationStore) searchObservationsLike(ctx context.Context, keywords []string, project string, limit int) ([]*models.Observation, error) {
|
||||||
if len(keywords) == 0 {
|
if len(keywords) == 0 {
|
||||||
return nil, nil
|
return nil, nil
|
||||||
@@ -294,14 +481,11 @@ func (s *ObservationStore) searchObservationsLike(ctx context.Context, keywords
|
|||||||
}
|
}
|
||||||
|
|
||||||
// #nosec G202 -- query uses parameterized placeholders, not user input
|
// #nosec G202 -- query uses parameterized placeholders, not user input
|
||||||
query := `
|
query := `SELECT ` + observationColumns + `
|
||||||
SELECT id, sdk_session_id, project, COALESCE(scope, 'project') as scope, type,
|
|
||||||
title, subtitle, facts, narrative, concepts, files_read, files_modified,
|
|
||||||
file_mtimes, prompt_number, discovery_tokens, created_at, created_at_epoch
|
|
||||||
FROM observations
|
FROM observations
|
||||||
WHERE (` + strings.Join(conditions, " OR ") + `)
|
WHERE (` + strings.Join(conditions, " OR ") + `)
|
||||||
AND (project = ? OR scope = 'global')
|
AND (project = ? OR scope = 'global')
|
||||||
ORDER BY created_at_epoch DESC
|
ORDER BY COALESCE(importance_score, 1.0) DESC, created_at_epoch DESC
|
||||||
LIMIT ?
|
LIMIT ?
|
||||||
`
|
`
|
||||||
args = append(args, project, limit)
|
args = append(args, project, limit)
|
||||||
@@ -445,6 +629,11 @@ func scanObservation(scanner interface{ Scan(...interface{}) error }) (*models.O
|
|||||||
&obs.Concepts, &obs.FilesRead, &obs.FilesModified, &obs.FileMtimes,
|
&obs.Concepts, &obs.FilesRead, &obs.FilesModified, &obs.FileMtimes,
|
||||||
&obs.PromptNumber, &obs.DiscoveryTokens,
|
&obs.PromptNumber, &obs.DiscoveryTokens,
|
||||||
&obs.CreatedAt, &obs.CreatedAtEpoch,
|
&obs.CreatedAt, &obs.CreatedAtEpoch,
|
||||||
|
// Importance scoring fields
|
||||||
|
&obs.ImportanceScore, &obs.UserFeedback, &obs.RetrievalCount,
|
||||||
|
&obs.LastRetrievedAt, &obs.ScoreUpdatedAt,
|
||||||
|
// Conflict detection fields
|
||||||
|
&obs.IsSuperseded,
|
||||||
); err != nil {
|
); err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -0,0 +1,370 @@
|
|||||||
|
// Package sqlite provides SQLite database operations for claude-mnemonic.
|
||||||
|
package sqlite
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"database/sql"
|
||||||
|
"encoding/json"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"github.com/lukaszraczylo/claude-mnemonic/pkg/models"
|
||||||
|
)
|
||||||
|
|
||||||
|
// patternColumns is the standard list of columns to select for patterns.
|
||||||
|
const patternColumns = `id, name, type, description, signature, recommendation,
|
||||||
|
frequency, projects, observation_ids, status, merged_into_id, confidence,
|
||||||
|
last_seen_at, last_seen_at_epoch, created_at, created_at_epoch`
|
||||||
|
|
||||||
|
// PatternCleanupFunc is a callback for when patterns are deleted.
|
||||||
|
type PatternCleanupFunc func(ctx context.Context, deletedIDs []int64)
|
||||||
|
|
||||||
|
// PatternStore provides pattern-related database operations.
|
||||||
|
type PatternStore struct {
|
||||||
|
store *Store
|
||||||
|
cleanupFunc PatternCleanupFunc
|
||||||
|
}
|
||||||
|
|
||||||
|
// NewPatternStore creates a new pattern store.
|
||||||
|
func NewPatternStore(store *Store) *PatternStore {
|
||||||
|
return &PatternStore{store: store}
|
||||||
|
}
|
||||||
|
|
||||||
|
// SetCleanupFunc sets the callback for when patterns are deleted.
|
||||||
|
func (s *PatternStore) SetCleanupFunc(fn PatternCleanupFunc) {
|
||||||
|
s.cleanupFunc = fn
|
||||||
|
}
|
||||||
|
|
||||||
|
// StorePattern stores a new pattern.
|
||||||
|
func (s *PatternStore) StorePattern(ctx context.Context, pattern *models.Pattern) (int64, error) {
|
||||||
|
signatureJSON, _ := json.Marshal(pattern.Signature)
|
||||||
|
projectsJSON, _ := json.Marshal(pattern.Projects)
|
||||||
|
obsIDsJSON, _ := json.Marshal(pattern.ObservationIDs)
|
||||||
|
|
||||||
|
const query = `
|
||||||
|
INSERT INTO patterns
|
||||||
|
(name, type, description, signature, recommendation, frequency, projects,
|
||||||
|
observation_ids, status, merged_into_id, confidence,
|
||||||
|
last_seen_at, last_seen_at_epoch, created_at, created_at_epoch)
|
||||||
|
VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)
|
||||||
|
`
|
||||||
|
|
||||||
|
result, err := s.store.ExecContext(ctx, query,
|
||||||
|
pattern.Name, string(pattern.Type),
|
||||||
|
nullString(pattern.Description.String), string(signatureJSON),
|
||||||
|
nullString(pattern.Recommendation.String),
|
||||||
|
pattern.Frequency, string(projectsJSON), string(obsIDsJSON),
|
||||||
|
string(pattern.Status), nullInt64(pattern.MergedIntoID),
|
||||||
|
pattern.Confidence, pattern.LastSeenAt, pattern.LastSeenEpoch,
|
||||||
|
pattern.CreatedAt, pattern.CreatedAtEpoch,
|
||||||
|
)
|
||||||
|
if err != nil {
|
||||||
|
return 0, err
|
||||||
|
}
|
||||||
|
|
||||||
|
return result.LastInsertId()
|
||||||
|
}
|
||||||
|
|
||||||
|
// UpdatePattern updates an existing pattern.
|
||||||
|
func (s *PatternStore) UpdatePattern(ctx context.Context, pattern *models.Pattern) error {
|
||||||
|
signatureJSON, _ := json.Marshal(pattern.Signature)
|
||||||
|
projectsJSON, _ := json.Marshal(pattern.Projects)
|
||||||
|
obsIDsJSON, _ := json.Marshal(pattern.ObservationIDs)
|
||||||
|
|
||||||
|
const query = `
|
||||||
|
UPDATE patterns SET
|
||||||
|
name = ?, type = ?, description = ?, signature = ?, recommendation = ?,
|
||||||
|
frequency = ?, projects = ?, observation_ids = ?, status = ?,
|
||||||
|
merged_into_id = ?, confidence = ?, last_seen_at = ?, last_seen_at_epoch = ?
|
||||||
|
WHERE id = ?
|
||||||
|
`
|
||||||
|
|
||||||
|
_, err := s.store.ExecContext(ctx, query,
|
||||||
|
pattern.Name, string(pattern.Type),
|
||||||
|
nullString(pattern.Description.String), string(signatureJSON),
|
||||||
|
nullString(pattern.Recommendation.String),
|
||||||
|
pattern.Frequency, string(projectsJSON), string(obsIDsJSON),
|
||||||
|
string(pattern.Status), nullInt64(pattern.MergedIntoID),
|
||||||
|
pattern.Confidence, pattern.LastSeenAt, pattern.LastSeenEpoch,
|
||||||
|
pattern.ID,
|
||||||
|
)
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
// GetPatternByID retrieves a pattern by ID.
|
||||||
|
func (s *PatternStore) GetPatternByID(ctx context.Context, id int64) (*models.Pattern, error) {
|
||||||
|
query := `SELECT ` + patternColumns + ` FROM patterns WHERE id = ?`
|
||||||
|
|
||||||
|
row := s.store.QueryRowContext(ctx, query, id)
|
||||||
|
return scanPattern(row)
|
||||||
|
}
|
||||||
|
|
||||||
|
// GetPatternByName retrieves a pattern by name.
|
||||||
|
func (s *PatternStore) GetPatternByName(ctx context.Context, name string) (*models.Pattern, error) {
|
||||||
|
query := `SELECT ` + patternColumns + ` FROM patterns WHERE name = ? AND status = 'active'`
|
||||||
|
|
||||||
|
row := s.store.QueryRowContext(ctx, query, name)
|
||||||
|
pattern, err := scanPattern(row)
|
||||||
|
if err == sql.ErrNoRows {
|
||||||
|
return nil, nil
|
||||||
|
}
|
||||||
|
return pattern, err
|
||||||
|
}
|
||||||
|
|
||||||
|
// GetActivePatterns retrieves all active patterns.
|
||||||
|
func (s *PatternStore) GetActivePatterns(ctx context.Context, limit int) ([]*models.Pattern, error) {
|
||||||
|
query := `SELECT ` + patternColumns + `
|
||||||
|
FROM patterns
|
||||||
|
WHERE status = 'active'
|
||||||
|
ORDER BY frequency DESC, confidence DESC
|
||||||
|
LIMIT ?`
|
||||||
|
|
||||||
|
rows, err := s.store.QueryContext(ctx, query, limit)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
defer rows.Close()
|
||||||
|
|
||||||
|
return scanPatternRows(rows)
|
||||||
|
}
|
||||||
|
|
||||||
|
// GetPatternsByType retrieves patterns of a specific type.
|
||||||
|
func (s *PatternStore) GetPatternsByType(ctx context.Context, patternType models.PatternType, limit int) ([]*models.Pattern, error) {
|
||||||
|
query := `SELECT ` + patternColumns + `
|
||||||
|
FROM patterns
|
||||||
|
WHERE type = ? AND status = 'active'
|
||||||
|
ORDER BY frequency DESC, confidence DESC
|
||||||
|
LIMIT ?`
|
||||||
|
|
||||||
|
rows, err := s.store.QueryContext(ctx, query, string(patternType), limit)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
defer rows.Close()
|
||||||
|
|
||||||
|
return scanPatternRows(rows)
|
||||||
|
}
|
||||||
|
|
||||||
|
// GetPatternsByProject retrieves patterns that have been observed in a specific project.
|
||||||
|
func (s *PatternStore) GetPatternsByProject(ctx context.Context, project string, limit int) ([]*models.Pattern, error) {
|
||||||
|
// Use JSON path to search within the projects array
|
||||||
|
query := `SELECT ` + patternColumns + `
|
||||||
|
FROM patterns
|
||||||
|
WHERE status = 'active'
|
||||||
|
AND EXISTS (
|
||||||
|
SELECT 1 FROM json_each(projects)
|
||||||
|
WHERE json_each.value = ?
|
||||||
|
)
|
||||||
|
ORDER BY frequency DESC, confidence DESC
|
||||||
|
LIMIT ?`
|
||||||
|
|
||||||
|
rows, err := s.store.QueryContext(ctx, query, project, limit)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
defer rows.Close()
|
||||||
|
|
||||||
|
return scanPatternRows(rows)
|
||||||
|
}
|
||||||
|
|
||||||
|
// FindMatchingPatterns searches for patterns that match a given signature.
|
||||||
|
func (s *PatternStore) FindMatchingPatterns(ctx context.Context, signature []string, minScore float64) ([]*models.Pattern, error) {
|
||||||
|
// Get all active patterns and filter by signature match in Go
|
||||||
|
// This is simpler than complex SQL for JSON array matching
|
||||||
|
patterns, err := s.GetActivePatterns(ctx, 100)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
var matches []*models.Pattern
|
||||||
|
for _, pattern := range patterns {
|
||||||
|
score := models.CalculateMatchScore(signature, pattern.Signature)
|
||||||
|
if score >= minScore {
|
||||||
|
matches = append(matches, pattern)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return matches, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// MarkPatternDeprecated marks a pattern as deprecated.
|
||||||
|
func (s *PatternStore) MarkPatternDeprecated(ctx context.Context, id int64) error {
|
||||||
|
const query = `UPDATE patterns SET status = 'deprecated' WHERE id = ?`
|
||||||
|
_, err := s.store.ExecContext(ctx, query, id)
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
// MergePatterns merges a source pattern into a target pattern.
|
||||||
|
func (s *PatternStore) MergePatterns(ctx context.Context, sourceID, targetID int64) error {
|
||||||
|
// Get both patterns
|
||||||
|
source, err := s.GetPatternByID(ctx, sourceID)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
target, err := s.GetPatternByID(ctx, targetID)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
// Merge source into target
|
||||||
|
target.Frequency += source.Frequency
|
||||||
|
for _, proj := range source.Projects {
|
||||||
|
found := false
|
||||||
|
for _, existing := range target.Projects {
|
||||||
|
if existing == proj {
|
||||||
|
found = true
|
||||||
|
break
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if !found {
|
||||||
|
target.Projects = append(target.Projects, proj)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
for _, obsID := range source.ObservationIDs {
|
||||||
|
found := false
|
||||||
|
for _, existing := range target.ObservationIDs {
|
||||||
|
if existing == obsID {
|
||||||
|
found = true
|
||||||
|
break
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if !found {
|
||||||
|
target.ObservationIDs = append(target.ObservationIDs, obsID)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Update target
|
||||||
|
if err := s.UpdatePattern(ctx, target); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
// Mark source as merged
|
||||||
|
source.Status = models.PatternStatusMerged
|
||||||
|
source.MergedIntoID = sql.NullInt64{Int64: targetID, Valid: true}
|
||||||
|
return s.UpdatePattern(ctx, source)
|
||||||
|
}
|
||||||
|
|
||||||
|
// DeletePattern deletes a pattern by ID.
|
||||||
|
func (s *PatternStore) DeletePattern(ctx context.Context, id int64) error {
|
||||||
|
const query = `DELETE FROM patterns WHERE id = ?`
|
||||||
|
_, err := s.store.ExecContext(ctx, query, id)
|
||||||
|
if err == nil && s.cleanupFunc != nil {
|
||||||
|
s.cleanupFunc(ctx, []int64{id})
|
||||||
|
}
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
// SearchPatternsFTS performs full-text search on patterns.
|
||||||
|
func (s *PatternStore) SearchPatternsFTS(ctx context.Context, searchQuery string, limit int) ([]*models.Pattern, error) {
|
||||||
|
query := `SELECT p.` + patternColumns + `
|
||||||
|
FROM patterns p
|
||||||
|
JOIN patterns_fts fts ON p.id = fts.rowid
|
||||||
|
WHERE patterns_fts MATCH ?
|
||||||
|
AND p.status = 'active'
|
||||||
|
ORDER BY rank
|
||||||
|
LIMIT ?`
|
||||||
|
|
||||||
|
rows, err := s.store.QueryContext(ctx, query, searchQuery, limit)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
defer rows.Close()
|
||||||
|
|
||||||
|
return scanPatternRows(rows)
|
||||||
|
}
|
||||||
|
|
||||||
|
// GetPatternStats returns statistics about patterns.
|
||||||
|
func (s *PatternStore) GetPatternStats(ctx context.Context) (*PatternStats, error) {
|
||||||
|
const query = `
|
||||||
|
SELECT
|
||||||
|
COUNT(*) as total,
|
||||||
|
COUNT(CASE WHEN status = 'active' THEN 1 END) as active,
|
||||||
|
COUNT(CASE WHEN status = 'deprecated' THEN 1 END) as deprecated,
|
||||||
|
COUNT(CASE WHEN status = 'merged' THEN 1 END) as merged,
|
||||||
|
COALESCE(SUM(frequency), 0) as total_occurrences,
|
||||||
|
COALESCE(AVG(confidence), 0) as avg_confidence,
|
||||||
|
COUNT(CASE WHEN type = 'bug' THEN 1 END) as bugs,
|
||||||
|
COUNT(CASE WHEN type = 'refactor' THEN 1 END) as refactors,
|
||||||
|
COUNT(CASE WHEN type = 'architecture' THEN 1 END) as architectures,
|
||||||
|
COUNT(CASE WHEN type = 'anti-pattern' THEN 1 END) as anti_patterns,
|
||||||
|
COUNT(CASE WHEN type = 'best-practice' THEN 1 END) as best_practices
|
||||||
|
FROM patterns
|
||||||
|
`
|
||||||
|
|
||||||
|
var stats PatternStats
|
||||||
|
err := s.store.QueryRowContext(ctx, query).Scan(
|
||||||
|
&stats.Total, &stats.Active, &stats.Deprecated, &stats.Merged,
|
||||||
|
&stats.TotalOccurrences, &stats.AvgConfidence,
|
||||||
|
&stats.Bugs, &stats.Refactors, &stats.Architectures,
|
||||||
|
&stats.AntiPatterns, &stats.BestPractices,
|
||||||
|
)
|
||||||
|
return &stats, err
|
||||||
|
}
|
||||||
|
|
||||||
|
// PatternStats contains aggregate statistics about patterns.
|
||||||
|
type PatternStats struct {
|
||||||
|
Total int `json:"total"`
|
||||||
|
Active int `json:"active"`
|
||||||
|
Deprecated int `json:"deprecated"`
|
||||||
|
Merged int `json:"merged"`
|
||||||
|
TotalOccurrences int `json:"total_occurrences"`
|
||||||
|
AvgConfidence float64 `json:"avg_confidence"`
|
||||||
|
Bugs int `json:"bugs"`
|
||||||
|
Refactors int `json:"refactors"`
|
||||||
|
Architectures int `json:"architectures"`
|
||||||
|
AntiPatterns int `json:"anti_patterns"`
|
||||||
|
BestPractices int `json:"best_practices"`
|
||||||
|
}
|
||||||
|
|
||||||
|
// scanPattern scans a single pattern from a row scanner.
|
||||||
|
func scanPattern(scanner interface{ Scan(...interface{}) error }) (*models.Pattern, error) {
|
||||||
|
var pattern models.Pattern
|
||||||
|
if err := scanner.Scan(
|
||||||
|
&pattern.ID, &pattern.Name, &pattern.Type,
|
||||||
|
&pattern.Description, &pattern.Signature, &pattern.Recommendation,
|
||||||
|
&pattern.Frequency, &pattern.Projects, &pattern.ObservationIDs,
|
||||||
|
&pattern.Status, &pattern.MergedIntoID, &pattern.Confidence,
|
||||||
|
&pattern.LastSeenAt, &pattern.LastSeenEpoch,
|
||||||
|
&pattern.CreatedAt, &pattern.CreatedAtEpoch,
|
||||||
|
); err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
return &pattern, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// scanPatternRows scans multiple patterns from rows.
|
||||||
|
func scanPatternRows(rows *sql.Rows) ([]*models.Pattern, error) {
|
||||||
|
var patterns []*models.Pattern
|
||||||
|
for rows.Next() {
|
||||||
|
pattern, err := scanPattern(rows)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
patterns = append(patterns, pattern)
|
||||||
|
}
|
||||||
|
return patterns, rows.Err()
|
||||||
|
}
|
||||||
|
|
||||||
|
// nullInt64 converts sql.NullInt64 to the value needed for database insertion.
|
||||||
|
func nullInt64(n sql.NullInt64) interface{} {
|
||||||
|
if n.Valid {
|
||||||
|
return n.Int64
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// IncrementPatternFrequency atomically increments a pattern's frequency and updates last_seen.
|
||||||
|
func (s *PatternStore) IncrementPatternFrequency(ctx context.Context, id int64, project string, observationID int64) error {
|
||||||
|
now := time.Now()
|
||||||
|
|
||||||
|
// Get current pattern
|
||||||
|
pattern, err := s.GetPatternByID(ctx, id)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
// Add occurrence
|
||||||
|
pattern.AddOccurrence(project, observationID)
|
||||||
|
pattern.LastSeenAt = now.Format(time.RFC3339)
|
||||||
|
pattern.LastSeenEpoch = now.UnixMilli()
|
||||||
|
|
||||||
|
return s.UpdatePattern(ctx, pattern)
|
||||||
|
}
|
||||||
@@ -0,0 +1,507 @@
|
|||||||
|
package sqlite
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"database/sql"
|
||||||
|
"testing"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"github.com/lukaszraczylo/claude-mnemonic/pkg/models"
|
||||||
|
)
|
||||||
|
|
||||||
|
// setupPatternTestStore creates a test store with patterns table.
|
||||||
|
func setupPatternTestStore(t *testing.T) *Store {
|
||||||
|
t.Helper()
|
||||||
|
db, _, cleanup := testDB(t)
|
||||||
|
t.Cleanup(cleanup)
|
||||||
|
createBaseTables(t, db)
|
||||||
|
return newStoreFromDB(db)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestPatternStore_StoreAndGet(t *testing.T) {
|
||||||
|
store := setupPatternTestStore(t)
|
||||||
|
|
||||||
|
patternStore := NewPatternStore(store)
|
||||||
|
ctx := context.Background()
|
||||||
|
|
||||||
|
// Create test pattern
|
||||||
|
pattern := &models.Pattern{
|
||||||
|
Name: "Test Pattern",
|
||||||
|
Type: models.PatternTypeBug,
|
||||||
|
Description: sql.NullString{String: "A test pattern", Valid: true},
|
||||||
|
Signature: []string{"nil", "error"},
|
||||||
|
Recommendation: sql.NullString{String: "Always check for nil", Valid: true},
|
||||||
|
Frequency: 1,
|
||||||
|
Projects: []string{"project1"},
|
||||||
|
ObservationIDs: []int64{1, 2},
|
||||||
|
Status: models.PatternStatusActive,
|
||||||
|
Confidence: 0.5,
|
||||||
|
LastSeenAt: time.Now().Format(time.RFC3339),
|
||||||
|
LastSeenEpoch: time.Now().UnixMilli(),
|
||||||
|
CreatedAt: time.Now().Format(time.RFC3339),
|
||||||
|
CreatedAtEpoch: time.Now().UnixMilli(),
|
||||||
|
}
|
||||||
|
|
||||||
|
// Store pattern
|
||||||
|
id, err := patternStore.StorePattern(ctx, pattern)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("StorePattern() error = %v", err)
|
||||||
|
}
|
||||||
|
if id <= 0 {
|
||||||
|
t.Errorf("Expected positive ID, got %d", id)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Get pattern by ID
|
||||||
|
retrieved, err := patternStore.GetPatternByID(ctx, id)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("GetPatternByID() error = %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if retrieved.Name != pattern.Name {
|
||||||
|
t.Errorf("Expected name %s, got %s", pattern.Name, retrieved.Name)
|
||||||
|
}
|
||||||
|
if retrieved.Type != pattern.Type {
|
||||||
|
t.Errorf("Expected type %s, got %s", pattern.Type, retrieved.Type)
|
||||||
|
}
|
||||||
|
if len(retrieved.Signature) != len(pattern.Signature) {
|
||||||
|
t.Errorf("Expected %d signature elements, got %d",
|
||||||
|
len(pattern.Signature), len(retrieved.Signature))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestPatternStore_GetByName(t *testing.T) {
|
||||||
|
store := setupPatternTestStore(t)
|
||||||
|
|
||||||
|
patternStore := NewPatternStore(store)
|
||||||
|
ctx := context.Background()
|
||||||
|
|
||||||
|
pattern := createTestPattern("Unique Name Pattern")
|
||||||
|
_, err := patternStore.StorePattern(ctx, pattern)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("StorePattern() error = %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Get by name
|
||||||
|
retrieved, err := patternStore.GetPatternByName(ctx, "Unique Name Pattern")
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("GetPatternByName() error = %v", err)
|
||||||
|
}
|
||||||
|
if retrieved == nil {
|
||||||
|
t.Fatal("Expected pattern, got nil")
|
||||||
|
}
|
||||||
|
if retrieved.Name != "Unique Name Pattern" {
|
||||||
|
t.Errorf("Expected name 'Unique Name Pattern', got '%s'", retrieved.Name)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Get non-existent pattern
|
||||||
|
nonExistent, err := patternStore.GetPatternByName(ctx, "Non Existent")
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("GetPatternByName() error = %v", err)
|
||||||
|
}
|
||||||
|
if nonExistent != nil {
|
||||||
|
t.Errorf("Expected nil for non-existent pattern, got %v", nonExistent)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestPatternStore_GetActivePatterns(t *testing.T) {
|
||||||
|
store := setupPatternTestStore(t)
|
||||||
|
|
||||||
|
patternStore := NewPatternStore(store)
|
||||||
|
ctx := context.Background()
|
||||||
|
|
||||||
|
// Create multiple patterns with different statuses
|
||||||
|
active1 := createTestPattern("Active 1")
|
||||||
|
active1.Frequency = 5
|
||||||
|
active2 := createTestPattern("Active 2")
|
||||||
|
active2.Frequency = 3
|
||||||
|
deprecated := createTestPattern("Deprecated")
|
||||||
|
deprecated.Status = models.PatternStatusDeprecated
|
||||||
|
|
||||||
|
patternStore.StorePattern(ctx, active1)
|
||||||
|
patternStore.StorePattern(ctx, active2)
|
||||||
|
patternStore.StorePattern(ctx, deprecated)
|
||||||
|
|
||||||
|
// Get active patterns
|
||||||
|
patterns, err := patternStore.GetActivePatterns(ctx, 10)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("GetActivePatterns() error = %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if len(patterns) != 2 {
|
||||||
|
t.Errorf("Expected 2 active patterns, got %d", len(patterns))
|
||||||
|
}
|
||||||
|
|
||||||
|
// Check order (should be by frequency descending)
|
||||||
|
if len(patterns) >= 2 {
|
||||||
|
if patterns[0].Frequency < patterns[1].Frequency {
|
||||||
|
t.Errorf("Patterns not ordered by frequency descending")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestPatternStore_GetPatternsByType(t *testing.T) {
|
||||||
|
store := setupPatternTestStore(t)
|
||||||
|
|
||||||
|
patternStore := NewPatternStore(store)
|
||||||
|
ctx := context.Background()
|
||||||
|
|
||||||
|
// Create patterns of different types
|
||||||
|
bugPattern := createTestPattern("Bug Pattern")
|
||||||
|
bugPattern.Type = models.PatternTypeBug
|
||||||
|
|
||||||
|
refactorPattern := createTestPattern("Refactor Pattern")
|
||||||
|
refactorPattern.Type = models.PatternTypeRefactor
|
||||||
|
|
||||||
|
patternStore.StorePattern(ctx, bugPattern)
|
||||||
|
patternStore.StorePattern(ctx, refactorPattern)
|
||||||
|
|
||||||
|
// Get by type
|
||||||
|
bugs, err := patternStore.GetPatternsByType(ctx, models.PatternTypeBug, 10)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("GetPatternsByType() error = %v", err)
|
||||||
|
}
|
||||||
|
if len(bugs) != 1 {
|
||||||
|
t.Errorf("Expected 1 bug pattern, got %d", len(bugs))
|
||||||
|
}
|
||||||
|
|
||||||
|
refactors, err := patternStore.GetPatternsByType(ctx, models.PatternTypeRefactor, 10)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("GetPatternsByType() error = %v", err)
|
||||||
|
}
|
||||||
|
if len(refactors) != 1 {
|
||||||
|
t.Errorf("Expected 1 refactor pattern, got %d", len(refactors))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestPatternStore_GetPatternsByProject(t *testing.T) {
|
||||||
|
store := setupPatternTestStore(t)
|
||||||
|
|
||||||
|
patternStore := NewPatternStore(store)
|
||||||
|
ctx := context.Background()
|
||||||
|
|
||||||
|
// Create patterns with different projects
|
||||||
|
pattern1 := createTestPattern("Pattern 1")
|
||||||
|
pattern1.Projects = []string{"project-a", "project-b"}
|
||||||
|
|
||||||
|
pattern2 := createTestPattern("Pattern 2")
|
||||||
|
pattern2.Projects = []string{"project-b", "project-c"}
|
||||||
|
|
||||||
|
patternStore.StorePattern(ctx, pattern1)
|
||||||
|
patternStore.StorePattern(ctx, pattern2)
|
||||||
|
|
||||||
|
// Get by project
|
||||||
|
projectA, err := patternStore.GetPatternsByProject(ctx, "project-a", 10)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("GetPatternsByProject() error = %v", err)
|
||||||
|
}
|
||||||
|
if len(projectA) != 1 {
|
||||||
|
t.Errorf("Expected 1 pattern for project-a, got %d", len(projectA))
|
||||||
|
}
|
||||||
|
|
||||||
|
projectB, err := patternStore.GetPatternsByProject(ctx, "project-b", 10)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("GetPatternsByProject() error = %v", err)
|
||||||
|
}
|
||||||
|
if len(projectB) != 2 {
|
||||||
|
t.Errorf("Expected 2 patterns for project-b, got %d", len(projectB))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestPatternStore_UpdatePattern(t *testing.T) {
|
||||||
|
store := setupPatternTestStore(t)
|
||||||
|
|
||||||
|
patternStore := NewPatternStore(store)
|
||||||
|
ctx := context.Background()
|
||||||
|
|
||||||
|
// Create and store pattern
|
||||||
|
pattern := createTestPattern("Original Name")
|
||||||
|
id, _ := patternStore.StorePattern(ctx, pattern)
|
||||||
|
|
||||||
|
// Update pattern
|
||||||
|
pattern.ID = id
|
||||||
|
pattern.Name = "Updated Name"
|
||||||
|
pattern.Frequency = 10
|
||||||
|
pattern.Confidence = 0.9
|
||||||
|
|
||||||
|
err := patternStore.UpdatePattern(ctx, pattern)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("UpdatePattern() error = %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Verify update
|
||||||
|
updated, _ := patternStore.GetPatternByID(ctx, id)
|
||||||
|
if updated.Name != "Updated Name" {
|
||||||
|
t.Errorf("Expected name 'Updated Name', got '%s'", updated.Name)
|
||||||
|
}
|
||||||
|
if updated.Frequency != 10 {
|
||||||
|
t.Errorf("Expected frequency 10, got %d", updated.Frequency)
|
||||||
|
}
|
||||||
|
if updated.Confidence != 0.9 {
|
||||||
|
t.Errorf("Expected confidence 0.9, got %f", updated.Confidence)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestPatternStore_DeletePattern(t *testing.T) {
|
||||||
|
store := setupPatternTestStore(t)
|
||||||
|
|
||||||
|
patternStore := NewPatternStore(store)
|
||||||
|
ctx := context.Background()
|
||||||
|
|
||||||
|
// Create and store pattern
|
||||||
|
pattern := createTestPattern("To Delete")
|
||||||
|
id, _ := patternStore.StorePattern(ctx, pattern)
|
||||||
|
|
||||||
|
// Delete pattern
|
||||||
|
err := patternStore.DeletePattern(ctx, id)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("DeletePattern() error = %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Verify deletion
|
||||||
|
deleted, err := patternStore.GetPatternByID(ctx, id)
|
||||||
|
if err != sql.ErrNoRows {
|
||||||
|
t.Errorf("Expected ErrNoRows, got %v", err)
|
||||||
|
}
|
||||||
|
if deleted != nil {
|
||||||
|
t.Errorf("Expected nil for deleted pattern")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestPatternStore_MarkPatternDeprecated(t *testing.T) {
|
||||||
|
store := setupPatternTestStore(t)
|
||||||
|
|
||||||
|
patternStore := NewPatternStore(store)
|
||||||
|
ctx := context.Background()
|
||||||
|
|
||||||
|
// Create and store pattern
|
||||||
|
pattern := createTestPattern("To Deprecate")
|
||||||
|
id, _ := patternStore.StorePattern(ctx, pattern)
|
||||||
|
|
||||||
|
// Mark as deprecated
|
||||||
|
err := patternStore.MarkPatternDeprecated(ctx, id)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("MarkPatternDeprecated() error = %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Verify status
|
||||||
|
deprecated, _ := patternStore.GetPatternByID(ctx, id)
|
||||||
|
if deprecated.Status != models.PatternStatusDeprecated {
|
||||||
|
t.Errorf("Expected status 'deprecated', got '%s'", deprecated.Status)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestPatternStore_MergePatterns(t *testing.T) {
|
||||||
|
store := setupPatternTestStore(t)
|
||||||
|
|
||||||
|
patternStore := NewPatternStore(store)
|
||||||
|
ctx := context.Background()
|
||||||
|
|
||||||
|
// Create source and target patterns
|
||||||
|
source := createTestPattern("Source Pattern")
|
||||||
|
source.Frequency = 3
|
||||||
|
source.Projects = []string{"proj1", "proj2"}
|
||||||
|
source.ObservationIDs = []int64{1, 2, 3}
|
||||||
|
|
||||||
|
target := createTestPattern("Target Pattern")
|
||||||
|
target.Frequency = 2
|
||||||
|
target.Projects = []string{"proj2", "proj3"}
|
||||||
|
target.ObservationIDs = []int64{4, 5}
|
||||||
|
|
||||||
|
sourceID, _ := patternStore.StorePattern(ctx, source)
|
||||||
|
targetID, _ := patternStore.StorePattern(ctx, target)
|
||||||
|
|
||||||
|
// Merge
|
||||||
|
err := patternStore.MergePatterns(ctx, sourceID, targetID)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("MergePatterns() error = %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Verify source is marked as merged
|
||||||
|
mergedSource, _ := patternStore.GetPatternByID(ctx, sourceID)
|
||||||
|
if mergedSource.Status != models.PatternStatusMerged {
|
||||||
|
t.Errorf("Expected source status 'merged', got '%s'", mergedSource.Status)
|
||||||
|
}
|
||||||
|
if !mergedSource.MergedIntoID.Valid || mergedSource.MergedIntoID.Int64 != targetID {
|
||||||
|
t.Errorf("Expected source merged_into_id to be %d", targetID)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Verify target has combined data
|
||||||
|
mergedTarget, _ := patternStore.GetPatternByID(ctx, targetID)
|
||||||
|
expectedFrequency := 5 // 3 + 2
|
||||||
|
if mergedTarget.Frequency != expectedFrequency {
|
||||||
|
t.Errorf("Expected merged frequency %d, got %d", expectedFrequency, mergedTarget.Frequency)
|
||||||
|
}
|
||||||
|
// Should have 3 unique projects: proj1, proj2, proj3
|
||||||
|
if len(mergedTarget.Projects) != 3 {
|
||||||
|
t.Errorf("Expected 3 projects after merge, got %d", len(mergedTarget.Projects))
|
||||||
|
}
|
||||||
|
// Should have 5 observation IDs
|
||||||
|
if len(mergedTarget.ObservationIDs) != 5 {
|
||||||
|
t.Errorf("Expected 5 observation IDs after merge, got %d", len(mergedTarget.ObservationIDs))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestPatternStore_FindMatchingPatterns(t *testing.T) {
|
||||||
|
store := setupPatternTestStore(t)
|
||||||
|
|
||||||
|
patternStore := NewPatternStore(store)
|
||||||
|
ctx := context.Background()
|
||||||
|
|
||||||
|
// Create patterns with known signatures
|
||||||
|
pattern1 := createTestPattern("Pattern 1")
|
||||||
|
pattern1.Signature = []string{"nil", "error", "handling"}
|
||||||
|
|
||||||
|
pattern2 := createTestPattern("Pattern 2")
|
||||||
|
pattern2.Signature = []string{"nil", "pointer", "check"}
|
||||||
|
|
||||||
|
pattern3 := createTestPattern("Pattern 3")
|
||||||
|
pattern3.Signature = []string{"refactor", "extract", "method"}
|
||||||
|
|
||||||
|
patternStore.StorePattern(ctx, pattern1)
|
||||||
|
patternStore.StorePattern(ctx, pattern2)
|
||||||
|
patternStore.StorePattern(ctx, pattern3)
|
||||||
|
|
||||||
|
// Find patterns matching "nil" related signature
|
||||||
|
matches, err := patternStore.FindMatchingPatterns(ctx, []string{"nil", "error"}, 0.3)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("FindMatchingPatterns() error = %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if len(matches) < 1 {
|
||||||
|
t.Errorf("Expected at least 1 match, got %d", len(matches))
|
||||||
|
}
|
||||||
|
|
||||||
|
// Verify no match for unrelated signature
|
||||||
|
noMatches, err := patternStore.FindMatchingPatterns(ctx, []string{"completely", "different"}, 0.5)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("FindMatchingPatterns() error = %v", err)
|
||||||
|
}
|
||||||
|
if len(noMatches) != 0 {
|
||||||
|
t.Errorf("Expected 0 matches for unrelated signature, got %d", len(noMatches))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestPatternStore_IncrementPatternFrequency(t *testing.T) {
|
||||||
|
store := setupPatternTestStore(t)
|
||||||
|
|
||||||
|
patternStore := NewPatternStore(store)
|
||||||
|
ctx := context.Background()
|
||||||
|
|
||||||
|
// Create pattern
|
||||||
|
pattern := createTestPattern("Frequency Test")
|
||||||
|
pattern.Frequency = 1
|
||||||
|
pattern.Projects = []string{"proj1"}
|
||||||
|
pattern.ObservationIDs = []int64{1}
|
||||||
|
|
||||||
|
id, _ := patternStore.StorePattern(ctx, pattern)
|
||||||
|
|
||||||
|
// Increment frequency
|
||||||
|
err := patternStore.IncrementPatternFrequency(ctx, id, "proj2", 2)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("IncrementPatternFrequency() error = %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Verify
|
||||||
|
updated, _ := patternStore.GetPatternByID(ctx, id)
|
||||||
|
if updated.Frequency != 2 {
|
||||||
|
t.Errorf("Expected frequency 2, got %d", updated.Frequency)
|
||||||
|
}
|
||||||
|
if len(updated.Projects) != 2 {
|
||||||
|
t.Errorf("Expected 2 projects, got %d", len(updated.Projects))
|
||||||
|
}
|
||||||
|
if len(updated.ObservationIDs) != 2 {
|
||||||
|
t.Errorf("Expected 2 observation IDs, got %d", len(updated.ObservationIDs))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestPatternStore_GetPatternStats(t *testing.T) {
|
||||||
|
store := setupPatternTestStore(t)
|
||||||
|
|
||||||
|
patternStore := NewPatternStore(store)
|
||||||
|
ctx := context.Background()
|
||||||
|
|
||||||
|
// Create patterns with different types and statuses
|
||||||
|
bug := createTestPattern("Bug")
|
||||||
|
bug.Type = models.PatternTypeBug
|
||||||
|
bug.Frequency = 5
|
||||||
|
|
||||||
|
refactor := createTestPattern("Refactor")
|
||||||
|
refactor.Type = models.PatternTypeRefactor
|
||||||
|
refactor.Frequency = 3
|
||||||
|
|
||||||
|
deprecated := createTestPattern("Deprecated")
|
||||||
|
deprecated.Type = models.PatternTypeArchitecture
|
||||||
|
deprecated.Status = models.PatternStatusDeprecated
|
||||||
|
|
||||||
|
patternStore.StorePattern(ctx, bug)
|
||||||
|
patternStore.StorePattern(ctx, refactor)
|
||||||
|
patternStore.StorePattern(ctx, deprecated)
|
||||||
|
|
||||||
|
// Get stats
|
||||||
|
stats, err := patternStore.GetPatternStats(ctx)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("GetPatternStats() error = %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if stats.Total != 3 {
|
||||||
|
t.Errorf("Expected total 3, got %d", stats.Total)
|
||||||
|
}
|
||||||
|
if stats.Active != 2 {
|
||||||
|
t.Errorf("Expected 2 active, got %d", stats.Active)
|
||||||
|
}
|
||||||
|
if stats.Deprecated != 1 {
|
||||||
|
t.Errorf("Expected 1 deprecated, got %d", stats.Deprecated)
|
||||||
|
}
|
||||||
|
if stats.Bugs != 1 {
|
||||||
|
t.Errorf("Expected 1 bug, got %d", stats.Bugs)
|
||||||
|
}
|
||||||
|
if stats.Refactors != 1 {
|
||||||
|
t.Errorf("Expected 1 refactor, got %d", stats.Refactors)
|
||||||
|
}
|
||||||
|
if stats.TotalOccurrences != 9 { // 5 + 3 + 1
|
||||||
|
t.Errorf("Expected 9 total occurrences, got %d", stats.TotalOccurrences)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestPatternStore_CleanupCallback(t *testing.T) {
|
||||||
|
store := setupPatternTestStore(t)
|
||||||
|
|
||||||
|
patternStore := NewPatternStore(store)
|
||||||
|
ctx := context.Background()
|
||||||
|
|
||||||
|
var deletedIDs []int64
|
||||||
|
patternStore.SetCleanupFunc(func(ctx context.Context, ids []int64) {
|
||||||
|
deletedIDs = ids
|
||||||
|
})
|
||||||
|
|
||||||
|
// Create and delete pattern
|
||||||
|
pattern := createTestPattern("Cleanup Test")
|
||||||
|
id, _ := patternStore.StorePattern(ctx, pattern)
|
||||||
|
|
||||||
|
patternStore.DeletePattern(ctx, id)
|
||||||
|
|
||||||
|
if len(deletedIDs) != 1 || deletedIDs[0] != id {
|
||||||
|
t.Errorf("Expected cleanup callback with ID %d, got %v", id, deletedIDs)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Helper function to create a test pattern
|
||||||
|
func createTestPattern(name string) *models.Pattern {
|
||||||
|
now := time.Now()
|
||||||
|
return &models.Pattern{
|
||||||
|
Name: name,
|
||||||
|
Type: models.PatternTypeBug,
|
||||||
|
Description: sql.NullString{String: "Test description", Valid: true},
|
||||||
|
Signature: []string{"test", "pattern"},
|
||||||
|
Recommendation: sql.NullString{String: "Test recommendation", Valid: true},
|
||||||
|
Frequency: 1,
|
||||||
|
Projects: []string{"test-project"},
|
||||||
|
ObservationIDs: []int64{1},
|
||||||
|
Status: models.PatternStatusActive,
|
||||||
|
Confidence: 0.5,
|
||||||
|
LastSeenAt: now.Format(time.RFC3339),
|
||||||
|
LastSeenEpoch: now.UnixMilli(),
|
||||||
|
CreatedAt: now.Format(time.RFC3339),
|
||||||
|
CreatedAtEpoch: now.UnixMilli(),
|
||||||
|
}
|
||||||
|
}
|
||||||
@@ -199,6 +199,28 @@ func (s *PromptStore) GetAllRecentUserPrompts(ctx context.Context, limit int) ([
|
|||||||
return scanPromptWithSessionRows(rows)
|
return scanPromptWithSessionRows(rows)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// GetAllPrompts retrieves all user prompts (for vector rebuild).
|
||||||
|
func (s *PromptStore) GetAllPrompts(ctx context.Context) ([]*models.UserPromptWithSession, error) {
|
||||||
|
const query = `
|
||||||
|
SELECT up.id, up.claude_session_id, up.prompt_number, up.prompt_text,
|
||||||
|
COALESCE(up.matched_observations, 0) as matched_observations,
|
||||||
|
up.created_at, up.created_at_epoch,
|
||||||
|
COALESCE(s.project, '') as project,
|
||||||
|
COALESCE(s.sdk_session_id, '') as sdk_session_id
|
||||||
|
FROM user_prompts up
|
||||||
|
LEFT JOIN sdk_sessions s ON up.claude_session_id = s.claude_session_id
|
||||||
|
ORDER BY up.id
|
||||||
|
`
|
||||||
|
|
||||||
|
rows, err := s.store.QueryContext(ctx, query)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
defer rows.Close()
|
||||||
|
|
||||||
|
return scanPromptWithSessionRows(rows)
|
||||||
|
}
|
||||||
|
|
||||||
// FindRecentPromptByText finds a prompt with the same text for a session within the last few seconds.
|
// FindRecentPromptByText finds a prompt with the same text for a session within the last few seconds.
|
||||||
// This is used to detect duplicate hook invocations.
|
// This is used to detect duplicate hook invocations.
|
||||||
// Returns (promptID, promptNumber, found).
|
// Returns (promptID, promptNumber, found).
|
||||||
|
|||||||
@@ -0,0 +1,377 @@
|
|||||||
|
// Package sqlite provides SQLite database operations for claude-mnemonic.
|
||||||
|
package sqlite
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"database/sql"
|
||||||
|
|
||||||
|
"github.com/lukaszraczylo/claude-mnemonic/pkg/models"
|
||||||
|
)
|
||||||
|
|
||||||
|
// RelationStore provides relation-related database operations.
|
||||||
|
type RelationStore struct {
|
||||||
|
store *Store
|
||||||
|
}
|
||||||
|
|
||||||
|
// NewRelationStore creates a new relation store.
|
||||||
|
func NewRelationStore(store *Store) *RelationStore {
|
||||||
|
return &RelationStore{store: store}
|
||||||
|
}
|
||||||
|
|
||||||
|
// StoreRelation stores a new observation relation.
|
||||||
|
// Uses INSERT OR IGNORE to handle duplicate (source_id, target_id, relation_type) combinations.
|
||||||
|
func (s *RelationStore) StoreRelation(ctx context.Context, relation *models.ObservationRelation) (int64, error) {
|
||||||
|
const query = `
|
||||||
|
INSERT OR IGNORE INTO observation_relations
|
||||||
|
(source_id, target_id, relation_type, confidence, detection_source, reason, created_at, created_at_epoch)
|
||||||
|
VALUES (?, ?, ?, ?, ?, ?, ?, ?)
|
||||||
|
`
|
||||||
|
|
||||||
|
result, err := s.store.ExecContext(ctx, query,
|
||||||
|
relation.SourceID, relation.TargetID,
|
||||||
|
string(relation.RelationType), relation.Confidence,
|
||||||
|
string(relation.DetectionSource), relation.Reason,
|
||||||
|
relation.CreatedAt, relation.CreatedAtEpoch,
|
||||||
|
)
|
||||||
|
if err != nil {
|
||||||
|
return 0, err
|
||||||
|
}
|
||||||
|
|
||||||
|
return result.LastInsertId()
|
||||||
|
}
|
||||||
|
|
||||||
|
// StoreRelations stores multiple relations in a single transaction.
|
||||||
|
func (s *RelationStore) StoreRelations(ctx context.Context, relations []*models.ObservationRelation) error {
|
||||||
|
if len(relations) == 0 {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
tx, err := s.store.db.BeginTx(ctx, nil)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
defer func() {
|
||||||
|
if err != nil {
|
||||||
|
_ = tx.Rollback()
|
||||||
|
}
|
||||||
|
}()
|
||||||
|
|
||||||
|
const query = `
|
||||||
|
INSERT OR IGNORE INTO observation_relations
|
||||||
|
(source_id, target_id, relation_type, confidence, detection_source, reason, created_at, created_at_epoch)
|
||||||
|
VALUES (?, ?, ?, ?, ?, ?, ?, ?)
|
||||||
|
`
|
||||||
|
|
||||||
|
stmt, err := tx.PrepareContext(ctx, query)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
defer stmt.Close()
|
||||||
|
|
||||||
|
for _, rel := range relations {
|
||||||
|
_, err = stmt.ExecContext(ctx,
|
||||||
|
rel.SourceID, rel.TargetID,
|
||||||
|
string(rel.RelationType), rel.Confidence,
|
||||||
|
string(rel.DetectionSource), rel.Reason,
|
||||||
|
rel.CreatedAt, rel.CreatedAtEpoch,
|
||||||
|
)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return tx.Commit()
|
||||||
|
}
|
||||||
|
|
||||||
|
// GetRelationsByObservationID retrieves all relations involving an observation (as source or target).
|
||||||
|
func (s *RelationStore) GetRelationsByObservationID(ctx context.Context, obsID int64) ([]*models.ObservationRelation, error) {
|
||||||
|
const query = `
|
||||||
|
SELECT id, source_id, target_id, relation_type, confidence, detection_source, reason,
|
||||||
|
created_at, created_at_epoch
|
||||||
|
FROM observation_relations
|
||||||
|
WHERE source_id = ? OR target_id = ?
|
||||||
|
ORDER BY confidence DESC, created_at_epoch DESC
|
||||||
|
`
|
||||||
|
|
||||||
|
rows, err := s.store.QueryContext(ctx, query, obsID, obsID)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
defer rows.Close()
|
||||||
|
|
||||||
|
return s.scanRelationRows(rows)
|
||||||
|
}
|
||||||
|
|
||||||
|
// GetOutgoingRelations retrieves relations where the observation is the source.
|
||||||
|
func (s *RelationStore) GetOutgoingRelations(ctx context.Context, obsID int64) ([]*models.ObservationRelation, error) {
|
||||||
|
const query = `
|
||||||
|
SELECT id, source_id, target_id, relation_type, confidence, detection_source, reason,
|
||||||
|
created_at, created_at_epoch
|
||||||
|
FROM observation_relations
|
||||||
|
WHERE source_id = ?
|
||||||
|
ORDER BY confidence DESC, created_at_epoch DESC
|
||||||
|
`
|
||||||
|
|
||||||
|
rows, err := s.store.QueryContext(ctx, query, obsID)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
defer rows.Close()
|
||||||
|
|
||||||
|
return s.scanRelationRows(rows)
|
||||||
|
}
|
||||||
|
|
||||||
|
// GetIncomingRelations retrieves relations where the observation is the target.
|
||||||
|
func (s *RelationStore) GetIncomingRelations(ctx context.Context, obsID int64) ([]*models.ObservationRelation, error) {
|
||||||
|
const query = `
|
||||||
|
SELECT id, source_id, target_id, relation_type, confidence, detection_source, reason,
|
||||||
|
created_at, created_at_epoch
|
||||||
|
FROM observation_relations
|
||||||
|
WHERE target_id = ?
|
||||||
|
ORDER BY confidence DESC, created_at_epoch DESC
|
||||||
|
`
|
||||||
|
|
||||||
|
rows, err := s.store.QueryContext(ctx, query, obsID)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
defer rows.Close()
|
||||||
|
|
||||||
|
return s.scanRelationRows(rows)
|
||||||
|
}
|
||||||
|
|
||||||
|
// GetRelationsByType retrieves all relations of a specific type.
|
||||||
|
func (s *RelationStore) GetRelationsByType(ctx context.Context, relationType models.RelationType, limit int) ([]*models.ObservationRelation, error) {
|
||||||
|
const query = `
|
||||||
|
SELECT id, source_id, target_id, relation_type, confidence, detection_source, reason,
|
||||||
|
created_at, created_at_epoch
|
||||||
|
FROM observation_relations
|
||||||
|
WHERE relation_type = ?
|
||||||
|
ORDER BY confidence DESC, created_at_epoch DESC
|
||||||
|
LIMIT ?
|
||||||
|
`
|
||||||
|
|
||||||
|
rows, err := s.store.QueryContext(ctx, query, string(relationType), limit)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
defer rows.Close()
|
||||||
|
|
||||||
|
return s.scanRelationRows(rows)
|
||||||
|
}
|
||||||
|
|
||||||
|
// GetRelationsWithDetails retrieves relations with observation titles for display.
|
||||||
|
func (s *RelationStore) GetRelationsWithDetails(ctx context.Context, obsID int64) ([]*models.RelationWithDetails, error) {
|
||||||
|
const query = `
|
||||||
|
SELECT r.id, r.source_id, r.target_id, r.relation_type, r.confidence, r.detection_source, r.reason,
|
||||||
|
r.created_at, r.created_at_epoch,
|
||||||
|
COALESCE(src.title, '') as source_title,
|
||||||
|
COALESCE(tgt.title, '') as target_title,
|
||||||
|
src.type as source_type,
|
||||||
|
tgt.type as target_type
|
||||||
|
FROM observation_relations r
|
||||||
|
JOIN observations src ON src.id = r.source_id
|
||||||
|
JOIN observations tgt ON tgt.id = r.target_id
|
||||||
|
WHERE r.source_id = ? OR r.target_id = ?
|
||||||
|
ORDER BY r.confidence DESC, r.created_at_epoch DESC
|
||||||
|
`
|
||||||
|
|
||||||
|
rows, err := s.store.QueryContext(ctx, query, obsID, obsID)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
defer rows.Close()
|
||||||
|
|
||||||
|
var results []*models.RelationWithDetails
|
||||||
|
for rows.Next() {
|
||||||
|
var r models.ObservationRelation
|
||||||
|
var rwd models.RelationWithDetails
|
||||||
|
var reason sql.NullString
|
||||||
|
if err := rows.Scan(
|
||||||
|
&r.ID, &r.SourceID, &r.TargetID,
|
||||||
|
&r.RelationType, &r.Confidence, &r.DetectionSource, &reason,
|
||||||
|
&r.CreatedAt, &r.CreatedAtEpoch,
|
||||||
|
&rwd.SourceTitle, &rwd.TargetTitle,
|
||||||
|
&rwd.SourceType, &rwd.TargetType,
|
||||||
|
); err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
if reason.Valid {
|
||||||
|
r.Reason = reason.String
|
||||||
|
}
|
||||||
|
rwd.Relation = &r
|
||||||
|
results = append(results, &rwd)
|
||||||
|
}
|
||||||
|
return results, rows.Err()
|
||||||
|
}
|
||||||
|
|
||||||
|
// GetRelationGraph retrieves a relation graph centered on an observation.
|
||||||
|
// This returns all observations within N hops from the center.
|
||||||
|
func (s *RelationStore) GetRelationGraph(ctx context.Context, centerID int64, maxDepth int) (*models.RelationGraph, error) {
|
||||||
|
// Get all relations involving the center observation
|
||||||
|
relations, err := s.GetRelationsWithDetails(ctx, centerID)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
graph := &models.RelationGraph{
|
||||||
|
CenterID: centerID,
|
||||||
|
Relations: relations,
|
||||||
|
}
|
||||||
|
|
||||||
|
// If depth > 1, recursively get relations for connected observations
|
||||||
|
if maxDepth > 1 {
|
||||||
|
visited := map[int64]bool{centerID: true}
|
||||||
|
toVisit := make([]int64, 0)
|
||||||
|
|
||||||
|
// Collect IDs of directly connected observations
|
||||||
|
for _, r := range relations {
|
||||||
|
if !visited[r.Relation.SourceID] {
|
||||||
|
toVisit = append(toVisit, r.Relation.SourceID)
|
||||||
|
visited[r.Relation.SourceID] = true
|
||||||
|
}
|
||||||
|
if !visited[r.Relation.TargetID] {
|
||||||
|
toVisit = append(toVisit, r.Relation.TargetID)
|
||||||
|
visited[r.Relation.TargetID] = true
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Get relations for connected observations (depth - 1)
|
||||||
|
for depth := 1; depth < maxDepth && len(toVisit) > 0; depth++ {
|
||||||
|
nextLevel := make([]int64, 0)
|
||||||
|
for _, obsID := range toVisit {
|
||||||
|
moreRelations, err := s.GetRelationsWithDetails(ctx, obsID)
|
||||||
|
if err != nil {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
for _, r := range moreRelations {
|
||||||
|
// Avoid duplicates
|
||||||
|
exists := false
|
||||||
|
for _, existing := range graph.Relations {
|
||||||
|
if existing.Relation.ID == r.Relation.ID {
|
||||||
|
exists = true
|
||||||
|
break
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if !exists {
|
||||||
|
graph.Relations = append(graph.Relations, r)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Queue next level
|
||||||
|
if !visited[r.Relation.SourceID] {
|
||||||
|
nextLevel = append(nextLevel, r.Relation.SourceID)
|
||||||
|
visited[r.Relation.SourceID] = true
|
||||||
|
}
|
||||||
|
if !visited[r.Relation.TargetID] {
|
||||||
|
nextLevel = append(nextLevel, r.Relation.TargetID)
|
||||||
|
visited[r.Relation.TargetID] = true
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
toVisit = nextLevel
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return graph, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// DeleteRelationsByObservationID deletes all relations involving an observation.
|
||||||
|
// Called when an observation is deleted.
|
||||||
|
func (s *RelationStore) DeleteRelationsByObservationID(ctx context.Context, obsID int64) error {
|
||||||
|
const query = `DELETE FROM observation_relations WHERE source_id = ? OR target_id = ?`
|
||||||
|
_, err := s.store.ExecContext(ctx, query, obsID, obsID)
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
// GetRelationCount returns the count of relations for an observation.
|
||||||
|
func (s *RelationStore) GetRelationCount(ctx context.Context, obsID int64) (int, error) {
|
||||||
|
const query = `
|
||||||
|
SELECT COUNT(*) FROM observation_relations
|
||||||
|
WHERE source_id = ? OR target_id = ?
|
||||||
|
`
|
||||||
|
var count int
|
||||||
|
err := s.store.QueryRowContext(ctx, query, obsID, obsID).Scan(&count)
|
||||||
|
return count, err
|
||||||
|
}
|
||||||
|
|
||||||
|
// GetTotalRelationCount returns the total count of all relations.
|
||||||
|
func (s *RelationStore) GetTotalRelationCount(ctx context.Context) (int, error) {
|
||||||
|
const query = `SELECT COUNT(*) FROM observation_relations`
|
||||||
|
var count int
|
||||||
|
err := s.store.QueryRowContext(ctx, query).Scan(&count)
|
||||||
|
return count, err
|
||||||
|
}
|
||||||
|
|
||||||
|
// GetHighConfidenceRelations retrieves relations with confidence above threshold.
|
||||||
|
func (s *RelationStore) GetHighConfidenceRelations(ctx context.Context, minConfidence float64, limit int) ([]*models.ObservationRelation, error) {
|
||||||
|
const query = `
|
||||||
|
SELECT id, source_id, target_id, relation_type, confidence, detection_source, reason,
|
||||||
|
created_at, created_at_epoch
|
||||||
|
FROM observation_relations
|
||||||
|
WHERE confidence >= ?
|
||||||
|
ORDER BY confidence DESC, created_at_epoch DESC
|
||||||
|
LIMIT ?
|
||||||
|
`
|
||||||
|
|
||||||
|
rows, err := s.store.QueryContext(ctx, query, minConfidence, limit)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
defer rows.Close()
|
||||||
|
|
||||||
|
return s.scanRelationRows(rows)
|
||||||
|
}
|
||||||
|
|
||||||
|
// UpdateRelationConfidence updates the confidence of a relation.
|
||||||
|
func (s *RelationStore) UpdateRelationConfidence(ctx context.Context, relationID int64, newConfidence float64) error {
|
||||||
|
const query = `UPDATE observation_relations SET confidence = ? WHERE id = ?`
|
||||||
|
_, err := s.store.ExecContext(ctx, query, newConfidence, relationID)
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
// GetRelatedObservationIDs returns IDs of observations related to the given one.
|
||||||
|
// This is useful for expanding search results.
|
||||||
|
func (s *RelationStore) GetRelatedObservationIDs(ctx context.Context, obsID int64, minConfidence float64) ([]int64, error) {
|
||||||
|
const query = `
|
||||||
|
SELECT DISTINCT CASE WHEN source_id = ? THEN target_id ELSE source_id END as related_id
|
||||||
|
FROM observation_relations
|
||||||
|
WHERE (source_id = ? OR target_id = ?) AND confidence >= ?
|
||||||
|
`
|
||||||
|
|
||||||
|
rows, err := s.store.QueryContext(ctx, query, obsID, obsID, obsID, minConfidence)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
defer rows.Close()
|
||||||
|
|
||||||
|
var ids []int64
|
||||||
|
for rows.Next() {
|
||||||
|
var id int64
|
||||||
|
if err := rows.Scan(&id); err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
ids = append(ids, id)
|
||||||
|
}
|
||||||
|
return ids, rows.Err()
|
||||||
|
}
|
||||||
|
|
||||||
|
// scanRelationRows scans multiple relations from rows.
|
||||||
|
func (s *RelationStore) scanRelationRows(rows *sql.Rows) ([]*models.ObservationRelation, error) {
|
||||||
|
var relations []*models.ObservationRelation
|
||||||
|
for rows.Next() {
|
||||||
|
var r models.ObservationRelation
|
||||||
|
var reason sql.NullString
|
||||||
|
if err := rows.Scan(
|
||||||
|
&r.ID, &r.SourceID, &r.TargetID,
|
||||||
|
&r.RelationType, &r.Confidence, &r.DetectionSource, &reason,
|
||||||
|
&r.CreatedAt, &r.CreatedAtEpoch,
|
||||||
|
); err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
if reason.Valid {
|
||||||
|
r.Reason = reason.String
|
||||||
|
}
|
||||||
|
relations = append(relations, &r)
|
||||||
|
}
|
||||||
|
return relations, rows.Err()
|
||||||
|
}
|
||||||
@@ -0,0 +1,324 @@
|
|||||||
|
// Package sqlite provides SQLite database operations for claude-mnemonic.
|
||||||
|
package sqlite
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"database/sql"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"github.com/lukaszraczylo/claude-mnemonic/pkg/models"
|
||||||
|
)
|
||||||
|
|
||||||
|
// UpdateObservationFeedback updates the user feedback for an observation.
|
||||||
|
// Feedback values: -1 (thumbs down), 0 (neutral), 1 (thumbs up).
|
||||||
|
func (s *ObservationStore) UpdateObservationFeedback(ctx context.Context, id int64, feedback int) error {
|
||||||
|
const query = `
|
||||||
|
UPDATE observations
|
||||||
|
SET user_feedback = ?, score_updated_at_epoch = ?
|
||||||
|
WHERE id = ?
|
||||||
|
`
|
||||||
|
_, err := s.store.ExecContext(ctx, query, feedback, time.Now().UnixMilli(), id)
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
// IncrementRetrievalCount increments the retrieval counter for the given observation IDs.
|
||||||
|
// This is called when observations are returned in search results.
|
||||||
|
func (s *ObservationStore) IncrementRetrievalCount(ctx context.Context, ids []int64) error {
|
||||||
|
if len(ids) == 0 {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
now := time.Now().UnixMilli()
|
||||||
|
|
||||||
|
// Build query with placeholders
|
||||||
|
// #nosec G202 -- query uses parameterized placeholders, not user input
|
||||||
|
query := `
|
||||||
|
UPDATE observations
|
||||||
|
SET retrieval_count = COALESCE(retrieval_count, 0) + 1,
|
||||||
|
last_retrieved_at_epoch = ?
|
||||||
|
WHERE id IN (?` + repeatPlaceholders(len(ids)-1) + `)
|
||||||
|
`
|
||||||
|
|
||||||
|
args := make([]interface{}, 0, len(ids)+1)
|
||||||
|
args = append(args, now)
|
||||||
|
for _, id := range ids {
|
||||||
|
args = append(args, id)
|
||||||
|
}
|
||||||
|
|
||||||
|
_, err := s.store.db.ExecContext(ctx, query, args...)
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
// UpdateImportanceScore updates the importance score for a single observation.
|
||||||
|
func (s *ObservationStore) UpdateImportanceScore(ctx context.Context, id int64, score float64) error {
|
||||||
|
const query = `
|
||||||
|
UPDATE observations
|
||||||
|
SET importance_score = ?, score_updated_at_epoch = ?
|
||||||
|
WHERE id = ?
|
||||||
|
`
|
||||||
|
_, err := s.store.ExecContext(ctx, query, score, time.Now().UnixMilli(), id)
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
// UpdateImportanceScores bulk updates importance scores for multiple observations.
|
||||||
|
// This is more efficient than individual updates for batch recalculation.
|
||||||
|
func (s *ObservationStore) UpdateImportanceScores(ctx context.Context, scores map[int64]float64) error {
|
||||||
|
if len(scores) == 0 {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
tx, err := s.store.db.BeginTx(ctx, nil)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
defer tx.Rollback()
|
||||||
|
|
||||||
|
now := time.Now().UnixMilli()
|
||||||
|
stmt, err := tx.PrepareContext(ctx, `
|
||||||
|
UPDATE observations
|
||||||
|
SET importance_score = ?, score_updated_at_epoch = ?
|
||||||
|
WHERE id = ?
|
||||||
|
`)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
defer stmt.Close()
|
||||||
|
|
||||||
|
for id, score := range scores {
|
||||||
|
if _, err := stmt.ExecContext(ctx, score, now, id); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return tx.Commit()
|
||||||
|
}
|
||||||
|
|
||||||
|
// GetObservationsNeedingScoreUpdate returns observations that need their importance score recalculated.
|
||||||
|
// Returns observations where score_updated_at_epoch is NULL or older than the threshold.
|
||||||
|
func (s *ObservationStore) GetObservationsNeedingScoreUpdate(ctx context.Context, threshold time.Duration, limit int) ([]*models.Observation, error) {
|
||||||
|
cutoff := time.Now().Add(-threshold).UnixMilli()
|
||||||
|
|
||||||
|
query := `SELECT ` + observationColumns + `
|
||||||
|
FROM observations
|
||||||
|
WHERE score_updated_at_epoch IS NULL OR score_updated_at_epoch < ?
|
||||||
|
ORDER BY created_at_epoch DESC
|
||||||
|
LIMIT ?
|
||||||
|
`
|
||||||
|
|
||||||
|
rows, err := s.store.QueryContext(ctx, query, cutoff, limit)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
defer rows.Close()
|
||||||
|
|
||||||
|
return scanObservationRows(rows)
|
||||||
|
}
|
||||||
|
|
||||||
|
// GetConceptWeights returns all concept weights from the database.
|
||||||
|
func (s *ObservationStore) GetConceptWeights(ctx context.Context) (map[string]float64, error) {
|
||||||
|
const query = `SELECT concept, weight FROM concept_weights`
|
||||||
|
|
||||||
|
rows, err := s.store.QueryContext(ctx, query)
|
||||||
|
if err != nil {
|
||||||
|
// Table might not exist in older databases
|
||||||
|
if err == sql.ErrNoRows {
|
||||||
|
return models.DefaultConceptWeights, nil
|
||||||
|
}
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
defer rows.Close()
|
||||||
|
|
||||||
|
weights := make(map[string]float64)
|
||||||
|
for rows.Next() {
|
||||||
|
var concept string
|
||||||
|
var weight float64
|
||||||
|
if err := rows.Scan(&concept, &weight); err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
weights[concept] = weight
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := rows.Err(); err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
// If no weights found, use defaults
|
||||||
|
if len(weights) == 0 {
|
||||||
|
return models.DefaultConceptWeights, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
return weights, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// UpdateConceptWeight updates a single concept weight.
|
||||||
|
func (s *ObservationStore) UpdateConceptWeight(ctx context.Context, concept string, weight float64) error {
|
||||||
|
const query = `
|
||||||
|
INSERT INTO concept_weights (concept, weight, updated_at)
|
||||||
|
VALUES (?, ?, datetime('now'))
|
||||||
|
ON CONFLICT(concept) DO UPDATE SET weight = excluded.weight, updated_at = excluded.updated_at
|
||||||
|
`
|
||||||
|
_, err := s.store.ExecContext(ctx, query, concept, weight)
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
// UpdateConceptWeights bulk updates multiple concept weights.
|
||||||
|
func (s *ObservationStore) UpdateConceptWeights(ctx context.Context, weights map[string]float64) error {
|
||||||
|
if len(weights) == 0 {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
tx, err := s.store.db.BeginTx(ctx, nil)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
defer tx.Rollback()
|
||||||
|
|
||||||
|
stmt, err := tx.PrepareContext(ctx, `
|
||||||
|
INSERT INTO concept_weights (concept, weight, updated_at)
|
||||||
|
VALUES (?, ?, datetime('now'))
|
||||||
|
ON CONFLICT(concept) DO UPDATE SET weight = excluded.weight, updated_at = excluded.updated_at
|
||||||
|
`)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
defer stmt.Close()
|
||||||
|
|
||||||
|
for concept, weight := range weights {
|
||||||
|
if _, err := stmt.ExecContext(ctx, concept, weight); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return tx.Commit()
|
||||||
|
}
|
||||||
|
|
||||||
|
// GetObservationFeedbackStats returns statistics about user feedback.
|
||||||
|
func (s *ObservationStore) GetObservationFeedbackStats(ctx context.Context, project string) (*FeedbackStats, error) {
|
||||||
|
var query string
|
||||||
|
var args []interface{}
|
||||||
|
|
||||||
|
if project == "" {
|
||||||
|
query = `
|
||||||
|
SELECT
|
||||||
|
COUNT(*) as total,
|
||||||
|
COALESCE(SUM(CASE WHEN user_feedback = 1 THEN 1 ELSE 0 END), 0) as positive,
|
||||||
|
COALESCE(SUM(CASE WHEN user_feedback = -1 THEN 1 ELSE 0 END), 0) as negative,
|
||||||
|
COALESCE(SUM(CASE WHEN user_feedback = 0 THEN 1 ELSE 0 END), 0) as neutral,
|
||||||
|
COALESCE(AVG(COALESCE(importance_score, 1.0)), 0) as avg_score,
|
||||||
|
COALESCE(AVG(COALESCE(retrieval_count, 0)), 0) as avg_retrieval
|
||||||
|
FROM observations
|
||||||
|
`
|
||||||
|
} else {
|
||||||
|
query = `
|
||||||
|
SELECT
|
||||||
|
COUNT(*) as total,
|
||||||
|
COALESCE(SUM(CASE WHEN user_feedback = 1 THEN 1 ELSE 0 END), 0) as positive,
|
||||||
|
COALESCE(SUM(CASE WHEN user_feedback = -1 THEN 1 ELSE 0 END), 0) as negative,
|
||||||
|
COALESCE(SUM(CASE WHEN user_feedback = 0 THEN 1 ELSE 0 END), 0) as neutral,
|
||||||
|
COALESCE(AVG(COALESCE(importance_score, 1.0)), 0) as avg_score,
|
||||||
|
COALESCE(AVG(COALESCE(retrieval_count, 0)), 0) as avg_retrieval
|
||||||
|
FROM observations
|
||||||
|
WHERE project = ? OR scope = 'global'
|
||||||
|
`
|
||||||
|
args = append(args, project)
|
||||||
|
}
|
||||||
|
|
||||||
|
var stats FeedbackStats
|
||||||
|
err := s.store.QueryRowContext(ctx, query, args...).Scan(
|
||||||
|
&stats.Total,
|
||||||
|
&stats.Positive,
|
||||||
|
&stats.Negative,
|
||||||
|
&stats.Neutral,
|
||||||
|
&stats.AvgScore,
|
||||||
|
&stats.AvgRetrieval,
|
||||||
|
)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
return &stats, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// FeedbackStats contains statistics about observation feedback and scoring.
|
||||||
|
type FeedbackStats struct {
|
||||||
|
Total int `json:"total"`
|
||||||
|
Positive int `json:"positive"`
|
||||||
|
Negative int `json:"negative"`
|
||||||
|
Neutral int `json:"neutral"`
|
||||||
|
AvgScore float64 `json:"avg_score"`
|
||||||
|
AvgRetrieval float64 `json:"avg_retrieval"`
|
||||||
|
}
|
||||||
|
|
||||||
|
// GetTopScoringObservations returns the highest-scoring observations.
|
||||||
|
func (s *ObservationStore) GetTopScoringObservations(ctx context.Context, project string, limit int) ([]*models.Observation, error) {
|
||||||
|
var query string
|
||||||
|
var args []interface{}
|
||||||
|
|
||||||
|
if project == "" {
|
||||||
|
query = `SELECT ` + observationColumns + `
|
||||||
|
FROM observations
|
||||||
|
ORDER BY COALESCE(importance_score, 1.0) DESC, created_at_epoch DESC
|
||||||
|
LIMIT ?
|
||||||
|
`
|
||||||
|
args = append(args, limit)
|
||||||
|
} else {
|
||||||
|
query = `SELECT ` + observationColumns + `
|
||||||
|
FROM observations
|
||||||
|
WHERE project = ? OR scope = 'global'
|
||||||
|
ORDER BY COALESCE(importance_score, 1.0) DESC, created_at_epoch DESC
|
||||||
|
LIMIT ?
|
||||||
|
`
|
||||||
|
args = append(args, project, limit)
|
||||||
|
}
|
||||||
|
|
||||||
|
rows, err := s.store.QueryContext(ctx, query, args...)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
defer rows.Close()
|
||||||
|
|
||||||
|
return scanObservationRows(rows)
|
||||||
|
}
|
||||||
|
|
||||||
|
// GetMostRetrievedObservations returns the most frequently retrieved observations.
|
||||||
|
func (s *ObservationStore) GetMostRetrievedObservations(ctx context.Context, project string, limit int) ([]*models.Observation, error) {
|
||||||
|
var query string
|
||||||
|
var args []interface{}
|
||||||
|
|
||||||
|
if project == "" {
|
||||||
|
query = `SELECT ` + observationColumns + `
|
||||||
|
FROM observations
|
||||||
|
WHERE retrieval_count > 0
|
||||||
|
ORDER BY retrieval_count DESC, created_at_epoch DESC
|
||||||
|
LIMIT ?
|
||||||
|
`
|
||||||
|
args = append(args, limit)
|
||||||
|
} else {
|
||||||
|
query = `SELECT ` + observationColumns + `
|
||||||
|
FROM observations
|
||||||
|
WHERE (project = ? OR scope = 'global') AND retrieval_count > 0
|
||||||
|
ORDER BY retrieval_count DESC, created_at_epoch DESC
|
||||||
|
LIMIT ?
|
||||||
|
`
|
||||||
|
args = append(args, project, limit)
|
||||||
|
}
|
||||||
|
|
||||||
|
rows, err := s.store.QueryContext(ctx, query, args...)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
defer rows.Close()
|
||||||
|
|
||||||
|
return scanObservationRows(rows)
|
||||||
|
}
|
||||||
|
|
||||||
|
// ResetObservationScores resets all observation scores to their default values.
|
||||||
|
// This is useful for testing or when changing the scoring algorithm.
|
||||||
|
func (s *ObservationStore) ResetObservationScores(ctx context.Context) error {
|
||||||
|
const query = `
|
||||||
|
UPDATE observations
|
||||||
|
SET importance_score = 1.0, score_updated_at_epoch = NULL
|
||||||
|
`
|
||||||
|
_, err := s.store.ExecContext(ctx, query)
|
||||||
|
return err
|
||||||
|
}
|
||||||
@@ -0,0 +1,698 @@
|
|||||||
|
// Package sqlite provides SQLite database operations for claude-mnemonic.
|
||||||
|
package sqlite
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"database/sql"
|
||||||
|
"testing"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"github.com/lukaszraczylo/claude-mnemonic/pkg/models"
|
||||||
|
"github.com/stretchr/testify/assert"
|
||||||
|
"github.com/stretchr/testify/require"
|
||||||
|
"github.com/stretchr/testify/suite"
|
||||||
|
)
|
||||||
|
|
||||||
|
// testScoringObservationStore creates an ObservationStore with scoring columns for testing.
|
||||||
|
func testScoringObservationStore(t *testing.T) (*ObservationStore, *Store, func()) {
|
||||||
|
t.Helper()
|
||||||
|
|
||||||
|
db, _, cleanup := testDB(t)
|
||||||
|
createBaseTables(t, db)
|
||||||
|
createConceptWeightsTable(t, db)
|
||||||
|
|
||||||
|
// Add importance index if not exists (columns already in createBaseTables)
|
||||||
|
if _, err := db.Exec(`CREATE INDEX IF NOT EXISTS idx_observations_importance ON observations(importance_score DESC, created_at_epoch DESC)`); err != nil {
|
||||||
|
t.Fatalf("create importance index: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
store := newStoreFromDB(db)
|
||||||
|
obsStore := NewObservationStore(store)
|
||||||
|
|
||||||
|
return obsStore, store, cleanup
|
||||||
|
}
|
||||||
|
|
||||||
|
// createConceptWeightsTable creates the concept_weights table for testing.
|
||||||
|
func createConceptWeightsTable(t *testing.T, db *sql.DB) {
|
||||||
|
t.Helper()
|
||||||
|
|
||||||
|
_, err := db.Exec(`
|
||||||
|
CREATE TABLE IF NOT EXISTS concept_weights (
|
||||||
|
concept TEXT PRIMARY KEY,
|
||||||
|
weight REAL NOT NULL DEFAULT 0.1,
|
||||||
|
updated_at TEXT NOT NULL
|
||||||
|
)
|
||||||
|
`)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("create concept_weights: %v", err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// ScoringStoreSuite is a test suite for scoring-related database operations.
|
||||||
|
type ScoringStoreSuite struct {
|
||||||
|
suite.Suite
|
||||||
|
obsStore *ObservationStore
|
||||||
|
store *Store
|
||||||
|
cleanup func()
|
||||||
|
ctx context.Context
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *ScoringStoreSuite) SetupTest() {
|
||||||
|
s.obsStore, s.store, s.cleanup = testScoringObservationStore(s.T())
|
||||||
|
s.ctx = context.Background()
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *ScoringStoreSuite) TearDownTest() {
|
||||||
|
if s.cleanup != nil {
|
||||||
|
s.cleanup()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestScoringStoreSuite(t *testing.T) {
|
||||||
|
suite.Run(t, new(ScoringStoreSuite))
|
||||||
|
}
|
||||||
|
|
||||||
|
// =============================================================================
|
||||||
|
// FEEDBACK TESTS
|
||||||
|
// =============================================================================
|
||||||
|
|
||||||
|
func (s *ScoringStoreSuite) TestUpdateObservationFeedback_Positive() {
|
||||||
|
// Create observation
|
||||||
|
obs := &models.ParsedObservation{
|
||||||
|
Type: models.ObsTypeBugfix,
|
||||||
|
Title: "Test feedback",
|
||||||
|
}
|
||||||
|
id, _, err := s.obsStore.StoreObservation(s.ctx, "session-1", "project-a", obs, 1, 100)
|
||||||
|
s.NoError(err)
|
||||||
|
|
||||||
|
// Update feedback to positive
|
||||||
|
err = s.obsStore.UpdateObservationFeedback(s.ctx, id, 1)
|
||||||
|
s.NoError(err)
|
||||||
|
|
||||||
|
// Verify
|
||||||
|
retrieved, err := s.obsStore.GetObservationByID(s.ctx, id)
|
||||||
|
s.NoError(err)
|
||||||
|
s.Equal(1, retrieved.UserFeedback)
|
||||||
|
s.True(retrieved.ScoreUpdatedAt.Valid)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *ScoringStoreSuite) TestUpdateObservationFeedback_Negative() {
|
||||||
|
obs := &models.ParsedObservation{
|
||||||
|
Type: models.ObsTypeDiscovery,
|
||||||
|
}
|
||||||
|
id, _, err := s.obsStore.StoreObservation(s.ctx, "session-1", "project-a", obs, 1, 100)
|
||||||
|
s.NoError(err)
|
||||||
|
|
||||||
|
err = s.obsStore.UpdateObservationFeedback(s.ctx, id, -1)
|
||||||
|
s.NoError(err)
|
||||||
|
|
||||||
|
retrieved, err := s.obsStore.GetObservationByID(s.ctx, id)
|
||||||
|
s.NoError(err)
|
||||||
|
s.Equal(-1, retrieved.UserFeedback)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *ScoringStoreSuite) TestUpdateObservationFeedback_Neutral() {
|
||||||
|
obs := &models.ParsedObservation{
|
||||||
|
Type: models.ObsTypeChange,
|
||||||
|
}
|
||||||
|
id, _, err := s.obsStore.StoreObservation(s.ctx, "session-1", "project-a", obs, 1, 100)
|
||||||
|
s.NoError(err)
|
||||||
|
|
||||||
|
// First set to positive
|
||||||
|
err = s.obsStore.UpdateObservationFeedback(s.ctx, id, 1)
|
||||||
|
s.NoError(err)
|
||||||
|
|
||||||
|
// Then reset to neutral
|
||||||
|
err = s.obsStore.UpdateObservationFeedback(s.ctx, id, 0)
|
||||||
|
s.NoError(err)
|
||||||
|
|
||||||
|
retrieved, err := s.obsStore.GetObservationByID(s.ctx, id)
|
||||||
|
s.NoError(err)
|
||||||
|
s.Equal(0, retrieved.UserFeedback)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *ScoringStoreSuite) TestUpdateObservationFeedback_NonExistent() {
|
||||||
|
// Updating non-existent observation should not fail (just no rows affected)
|
||||||
|
err := s.obsStore.UpdateObservationFeedback(s.ctx, 99999, 1)
|
||||||
|
s.NoError(err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// =============================================================================
|
||||||
|
// RETRIEVAL COUNT TESTS
|
||||||
|
// =============================================================================
|
||||||
|
|
||||||
|
func (s *ScoringStoreSuite) TestIncrementRetrievalCount_Single() {
|
||||||
|
obs := &models.ParsedObservation{
|
||||||
|
Type: models.ObsTypeBugfix,
|
||||||
|
}
|
||||||
|
id, _, err := s.obsStore.StoreObservation(s.ctx, "session-1", "project-a", obs, 1, 100)
|
||||||
|
s.NoError(err)
|
||||||
|
|
||||||
|
err = s.obsStore.IncrementRetrievalCount(s.ctx, []int64{id})
|
||||||
|
s.NoError(err)
|
||||||
|
|
||||||
|
retrieved, err := s.obsStore.GetObservationByID(s.ctx, id)
|
||||||
|
s.NoError(err)
|
||||||
|
s.Equal(1, retrieved.RetrievalCount)
|
||||||
|
s.True(retrieved.LastRetrievedAt.Valid)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *ScoringStoreSuite) TestIncrementRetrievalCount_Multiple() {
|
||||||
|
var ids []int64
|
||||||
|
for i := 0; i < 3; i++ {
|
||||||
|
obs := &models.ParsedObservation{
|
||||||
|
Type: models.ObsTypeDiscovery,
|
||||||
|
}
|
||||||
|
id, _, err := s.obsStore.StoreObservation(s.ctx, "session-1", "project-a", obs, i, 100)
|
||||||
|
s.NoError(err)
|
||||||
|
ids = append(ids, id)
|
||||||
|
}
|
||||||
|
|
||||||
|
err := s.obsStore.IncrementRetrievalCount(s.ctx, ids)
|
||||||
|
s.NoError(err)
|
||||||
|
|
||||||
|
for _, id := range ids {
|
||||||
|
retrieved, err := s.obsStore.GetObservationByID(s.ctx, id)
|
||||||
|
s.NoError(err)
|
||||||
|
s.Equal(1, retrieved.RetrievalCount)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *ScoringStoreSuite) TestIncrementRetrievalCount_Cumulative() {
|
||||||
|
obs := &models.ParsedObservation{
|
||||||
|
Type: models.ObsTypeBugfix,
|
||||||
|
}
|
||||||
|
id, _, err := s.obsStore.StoreObservation(s.ctx, "session-1", "project-a", obs, 1, 100)
|
||||||
|
s.NoError(err)
|
||||||
|
|
||||||
|
// Increment multiple times
|
||||||
|
for i := 0; i < 5; i++ {
|
||||||
|
err = s.obsStore.IncrementRetrievalCount(s.ctx, []int64{id})
|
||||||
|
s.NoError(err)
|
||||||
|
}
|
||||||
|
|
||||||
|
retrieved, err := s.obsStore.GetObservationByID(s.ctx, id)
|
||||||
|
s.NoError(err)
|
||||||
|
s.Equal(5, retrieved.RetrievalCount)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *ScoringStoreSuite) TestIncrementRetrievalCount_Empty() {
|
||||||
|
err := s.obsStore.IncrementRetrievalCount(s.ctx, []int64{})
|
||||||
|
s.NoError(err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// =============================================================================
|
||||||
|
// IMPORTANCE SCORE TESTS
|
||||||
|
// =============================================================================
|
||||||
|
|
||||||
|
func (s *ScoringStoreSuite) TestUpdateImportanceScore_Single() {
|
||||||
|
obs := &models.ParsedObservation{
|
||||||
|
Type: models.ObsTypeBugfix,
|
||||||
|
}
|
||||||
|
id, _, err := s.obsStore.StoreObservation(s.ctx, "session-1", "project-a", obs, 1, 100)
|
||||||
|
s.NoError(err)
|
||||||
|
|
||||||
|
err = s.obsStore.UpdateImportanceScore(s.ctx, id, 1.5)
|
||||||
|
s.NoError(err)
|
||||||
|
|
||||||
|
retrieved, err := s.obsStore.GetObservationByID(s.ctx, id)
|
||||||
|
s.NoError(err)
|
||||||
|
s.InDelta(1.5, retrieved.ImportanceScore, 0.001)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *ScoringStoreSuite) TestUpdateImportanceScores_Batch() {
|
||||||
|
var ids []int64
|
||||||
|
for i := 0; i < 5; i++ {
|
||||||
|
obs := &models.ParsedObservation{
|
||||||
|
Type: models.ObsTypeDiscovery,
|
||||||
|
}
|
||||||
|
id, _, err := s.obsStore.StoreObservation(s.ctx, "session-1", "project-a", obs, i, 100)
|
||||||
|
s.NoError(err)
|
||||||
|
ids = append(ids, id)
|
||||||
|
}
|
||||||
|
|
||||||
|
scores := map[int64]float64{
|
||||||
|
ids[0]: 1.5,
|
||||||
|
ids[1]: 0.8,
|
||||||
|
ids[2]: 1.2,
|
||||||
|
ids[3]: 0.5,
|
||||||
|
ids[4]: 2.0,
|
||||||
|
}
|
||||||
|
|
||||||
|
err := s.obsStore.UpdateImportanceScores(s.ctx, scores)
|
||||||
|
s.NoError(err)
|
||||||
|
|
||||||
|
for id, expectedScore := range scores {
|
||||||
|
retrieved, err := s.obsStore.GetObservationByID(s.ctx, id)
|
||||||
|
s.NoError(err)
|
||||||
|
s.InDelta(expectedScore, retrieved.ImportanceScore, 0.001)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *ScoringStoreSuite) TestUpdateImportanceScores_Empty() {
|
||||||
|
err := s.obsStore.UpdateImportanceScores(s.ctx, map[int64]float64{})
|
||||||
|
s.NoError(err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// =============================================================================
|
||||||
|
// OBSERVATIONS NEEDING SCORE UPDATE TESTS
|
||||||
|
// =============================================================================
|
||||||
|
|
||||||
|
func (s *ScoringStoreSuite) TestGetObservationsNeedingScoreUpdate_NeverUpdated() {
|
||||||
|
// Observations without score_updated_at_epoch should need update
|
||||||
|
obs := &models.ParsedObservation{
|
||||||
|
Type: models.ObsTypeBugfix,
|
||||||
|
}
|
||||||
|
id, _, err := s.obsStore.StoreObservation(s.ctx, "session-1", "project-a", obs, 1, 100)
|
||||||
|
s.NoError(err)
|
||||||
|
|
||||||
|
observations, err := s.obsStore.GetObservationsNeedingScoreUpdate(s.ctx, 6*time.Hour, 100)
|
||||||
|
s.NoError(err)
|
||||||
|
s.Len(observations, 1)
|
||||||
|
s.Equal(id, observations[0].ID)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *ScoringStoreSuite) TestGetObservationsNeedingScoreUpdate_RecentlyUpdated() {
|
||||||
|
obs := &models.ParsedObservation{
|
||||||
|
Type: models.ObsTypeBugfix,
|
||||||
|
}
|
||||||
|
id, _, err := s.obsStore.StoreObservation(s.ctx, "session-1", "project-a", obs, 1, 100)
|
||||||
|
s.NoError(err)
|
||||||
|
|
||||||
|
// Update score (this sets score_updated_at_epoch)
|
||||||
|
err = s.obsStore.UpdateImportanceScore(s.ctx, id, 1.5)
|
||||||
|
s.NoError(err)
|
||||||
|
|
||||||
|
// Should not need update (just updated)
|
||||||
|
observations, err := s.obsStore.GetObservationsNeedingScoreUpdate(s.ctx, 6*time.Hour, 100)
|
||||||
|
s.NoError(err)
|
||||||
|
s.Empty(observations)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *ScoringStoreSuite) TestGetObservationsNeedingScoreUpdate_Limit() {
|
||||||
|
// Create 10 observations
|
||||||
|
for i := 0; i < 10; i++ {
|
||||||
|
obs := &models.ParsedObservation{
|
||||||
|
Type: models.ObsTypeDiscovery,
|
||||||
|
}
|
||||||
|
_, _, err := s.obsStore.StoreObservation(s.ctx, "session-1", "project-a", obs, i, 100)
|
||||||
|
s.NoError(err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Request only 5
|
||||||
|
observations, err := s.obsStore.GetObservationsNeedingScoreUpdate(s.ctx, 6*time.Hour, 5)
|
||||||
|
s.NoError(err)
|
||||||
|
s.Len(observations, 5)
|
||||||
|
}
|
||||||
|
|
||||||
|
// =============================================================================
|
||||||
|
// CONCEPT WEIGHTS TESTS
|
||||||
|
// =============================================================================
|
||||||
|
|
||||||
|
func (s *ScoringStoreSuite) TestGetConceptWeights_Empty() {
|
||||||
|
weights, err := s.obsStore.GetConceptWeights(s.ctx)
|
||||||
|
s.NoError(err)
|
||||||
|
s.Equal(models.DefaultConceptWeights, weights)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *ScoringStoreSuite) TestUpdateConceptWeight_NewConcept() {
|
||||||
|
err := s.obsStore.UpdateConceptWeight(s.ctx, "new-concept", 0.42)
|
||||||
|
s.NoError(err)
|
||||||
|
|
||||||
|
weights, err := s.obsStore.GetConceptWeights(s.ctx)
|
||||||
|
s.NoError(err)
|
||||||
|
s.Equal(0.42, weights["new-concept"])
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *ScoringStoreSuite) TestUpdateConceptWeight_UpdateExisting() {
|
||||||
|
// Insert first
|
||||||
|
err := s.obsStore.UpdateConceptWeight(s.ctx, "test-concept", 0.1)
|
||||||
|
s.NoError(err)
|
||||||
|
|
||||||
|
// Update
|
||||||
|
err = s.obsStore.UpdateConceptWeight(s.ctx, "test-concept", 0.9)
|
||||||
|
s.NoError(err)
|
||||||
|
|
||||||
|
weights, err := s.obsStore.GetConceptWeights(s.ctx)
|
||||||
|
s.NoError(err)
|
||||||
|
s.Equal(0.9, weights["test-concept"])
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *ScoringStoreSuite) TestUpdateConceptWeights_Batch() {
|
||||||
|
weightsToSet := map[string]float64{
|
||||||
|
"security": 0.5,
|
||||||
|
"performance": 0.3,
|
||||||
|
"testing": 0.2,
|
||||||
|
}
|
||||||
|
|
||||||
|
err := s.obsStore.UpdateConceptWeights(s.ctx, weightsToSet)
|
||||||
|
s.NoError(err)
|
||||||
|
|
||||||
|
retrieved, err := s.obsStore.GetConceptWeights(s.ctx)
|
||||||
|
s.NoError(err)
|
||||||
|
|
||||||
|
for concept, expected := range weightsToSet {
|
||||||
|
s.Equal(expected, retrieved[concept])
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *ScoringStoreSuite) TestUpdateConceptWeights_Empty() {
|
||||||
|
err := s.obsStore.UpdateConceptWeights(s.ctx, map[string]float64{})
|
||||||
|
s.NoError(err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// =============================================================================
|
||||||
|
// FEEDBACK STATS TESTS
|
||||||
|
// =============================================================================
|
||||||
|
|
||||||
|
func (s *ScoringStoreSuite) TestGetObservationFeedbackStats_Empty() {
|
||||||
|
stats, err := s.obsStore.GetObservationFeedbackStats(s.ctx, "")
|
||||||
|
s.NoError(err)
|
||||||
|
s.Equal(0, stats.Total)
|
||||||
|
s.Equal(0, stats.Positive)
|
||||||
|
s.Equal(0, stats.Negative)
|
||||||
|
s.Equal(0, stats.Neutral)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *ScoringStoreSuite) TestGetObservationFeedbackStats_WithData() {
|
||||||
|
// Create observations with different feedback
|
||||||
|
feedbacks := []int{1, 1, 1, -1, -1, 0, 0, 0, 0, 0}
|
||||||
|
for i, fb := range feedbacks {
|
||||||
|
obs := &models.ParsedObservation{
|
||||||
|
Type: models.ObsTypeDiscovery,
|
||||||
|
}
|
||||||
|
id, _, err := s.obsStore.StoreObservation(s.ctx, "session-1", "project-a", obs, i, 100)
|
||||||
|
s.NoError(err)
|
||||||
|
if fb != 0 {
|
||||||
|
err = s.obsStore.UpdateObservationFeedback(s.ctx, id, fb)
|
||||||
|
s.NoError(err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
stats, err := s.obsStore.GetObservationFeedbackStats(s.ctx, "")
|
||||||
|
s.NoError(err)
|
||||||
|
s.Equal(10, stats.Total)
|
||||||
|
s.Equal(3, stats.Positive)
|
||||||
|
s.Equal(2, stats.Negative)
|
||||||
|
s.Equal(5, stats.Neutral)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *ScoringStoreSuite) TestGetObservationFeedbackStats_ByProject() {
|
||||||
|
// Project A observations
|
||||||
|
for i := 0; i < 5; i++ {
|
||||||
|
obs := &models.ParsedObservation{
|
||||||
|
Type: models.ObsTypeBugfix,
|
||||||
|
}
|
||||||
|
id, _, err := s.obsStore.StoreObservation(s.ctx, "session-1", "project-a", obs, i, 100)
|
||||||
|
s.NoError(err)
|
||||||
|
_ = s.obsStore.UpdateObservationFeedback(s.ctx, id, 1)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Project B observations
|
||||||
|
for i := 0; i < 3; i++ {
|
||||||
|
obs := &models.ParsedObservation{
|
||||||
|
Type: models.ObsTypeFeature,
|
||||||
|
}
|
||||||
|
id, _, err := s.obsStore.StoreObservation(s.ctx, "session-1", "project-b", obs, i, 100)
|
||||||
|
s.NoError(err)
|
||||||
|
_ = s.obsStore.UpdateObservationFeedback(s.ctx, id, -1)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Check project A stats
|
||||||
|
statsA, err := s.obsStore.GetObservationFeedbackStats(s.ctx, "project-a")
|
||||||
|
s.NoError(err)
|
||||||
|
s.Equal(5, statsA.Total)
|
||||||
|
s.Equal(5, statsA.Positive)
|
||||||
|
|
||||||
|
// Check project B stats
|
||||||
|
statsB, err := s.obsStore.GetObservationFeedbackStats(s.ctx, "project-b")
|
||||||
|
s.NoError(err)
|
||||||
|
s.Equal(3, statsB.Total)
|
||||||
|
s.Equal(3, statsB.Negative)
|
||||||
|
}
|
||||||
|
|
||||||
|
// =============================================================================
|
||||||
|
// TOP SCORING OBSERVATIONS TESTS
|
||||||
|
// =============================================================================
|
||||||
|
|
||||||
|
func (s *ScoringStoreSuite) TestGetTopScoringObservations() {
|
||||||
|
// Create observations with different scores
|
||||||
|
for i := 0; i < 5; i++ {
|
||||||
|
obs := &models.ParsedObservation{
|
||||||
|
Type: models.ObsTypeDiscovery,
|
||||||
|
}
|
||||||
|
id, _, err := s.obsStore.StoreObservation(s.ctx, "session-1", "project-a", obs, i, 100)
|
||||||
|
s.NoError(err)
|
||||||
|
// Set different scores
|
||||||
|
err = s.obsStore.UpdateImportanceScore(s.ctx, id, float64(i+1)*0.5)
|
||||||
|
s.NoError(err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Get top 3
|
||||||
|
top, err := s.obsStore.GetTopScoringObservations(s.ctx, "", 3)
|
||||||
|
s.NoError(err)
|
||||||
|
s.Len(top, 3)
|
||||||
|
|
||||||
|
// Verify ordered by score descending
|
||||||
|
s.GreaterOrEqual(top[0].ImportanceScore, top[1].ImportanceScore)
|
||||||
|
s.GreaterOrEqual(top[1].ImportanceScore, top[2].ImportanceScore)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *ScoringStoreSuite) TestGetTopScoringObservations_ByProject() {
|
||||||
|
// Project A with high scores
|
||||||
|
for i := 0; i < 3; i++ {
|
||||||
|
obs := &models.ParsedObservation{
|
||||||
|
Type: models.ObsTypeBugfix,
|
||||||
|
}
|
||||||
|
id, _, err := s.obsStore.StoreObservation(s.ctx, "session-1", "project-a", obs, i, 100)
|
||||||
|
s.NoError(err)
|
||||||
|
_ = s.obsStore.UpdateImportanceScore(s.ctx, id, 2.0)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Project B with low scores
|
||||||
|
for i := 0; i < 3; i++ {
|
||||||
|
obs := &models.ParsedObservation{
|
||||||
|
Type: models.ObsTypeChange,
|
||||||
|
}
|
||||||
|
id, _, err := s.obsStore.StoreObservation(s.ctx, "session-1", "project-b", obs, i, 100)
|
||||||
|
s.NoError(err)
|
||||||
|
_ = s.obsStore.UpdateImportanceScore(s.ctx, id, 0.5)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Get top for project A
|
||||||
|
topA, err := s.obsStore.GetTopScoringObservations(s.ctx, "project-a", 10)
|
||||||
|
s.NoError(err)
|
||||||
|
s.Len(topA, 3)
|
||||||
|
for _, obs := range topA {
|
||||||
|
s.Equal("project-a", obs.Project)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// =============================================================================
|
||||||
|
// MOST RETRIEVED OBSERVATIONS TESTS
|
||||||
|
// =============================================================================
|
||||||
|
|
||||||
|
func (s *ScoringStoreSuite) TestGetMostRetrievedObservations() {
|
||||||
|
// Create observations with different retrieval counts
|
||||||
|
var ids []int64
|
||||||
|
for i := 0; i < 5; i++ {
|
||||||
|
obs := &models.ParsedObservation{
|
||||||
|
Type: models.ObsTypeDiscovery,
|
||||||
|
}
|
||||||
|
id, _, err := s.obsStore.StoreObservation(s.ctx, "session-1", "project-a", obs, i, 100)
|
||||||
|
s.NoError(err)
|
||||||
|
ids = append(ids, id)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Set different retrieval counts
|
||||||
|
for i := 0; i < 10; i++ {
|
||||||
|
_ = s.obsStore.IncrementRetrievalCount(s.ctx, []int64{ids[0]}) // 10 retrievals
|
||||||
|
}
|
||||||
|
for i := 0; i < 5; i++ {
|
||||||
|
_ = s.obsStore.IncrementRetrievalCount(s.ctx, []int64{ids[1]}) // 5 retrievals
|
||||||
|
}
|
||||||
|
for i := 0; i < 3; i++ {
|
||||||
|
_ = s.obsStore.IncrementRetrievalCount(s.ctx, []int64{ids[2]}) // 3 retrievals
|
||||||
|
}
|
||||||
|
// ids[3] and ids[4] have 0 retrievals
|
||||||
|
|
||||||
|
// Get top 3
|
||||||
|
most, err := s.obsStore.GetMostRetrievedObservations(s.ctx, "", 3)
|
||||||
|
s.NoError(err)
|
||||||
|
s.Len(most, 3)
|
||||||
|
|
||||||
|
// Verify ordered by retrieval count descending
|
||||||
|
s.Equal(10, most[0].RetrievalCount)
|
||||||
|
s.Equal(5, most[1].RetrievalCount)
|
||||||
|
s.Equal(3, most[2].RetrievalCount)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *ScoringStoreSuite) TestGetMostRetrievedObservations_NoRetrievals() {
|
||||||
|
// Create observations without any retrievals
|
||||||
|
obs := &models.ParsedObservation{
|
||||||
|
Type: models.ObsTypeDiscovery,
|
||||||
|
}
|
||||||
|
_, _, err := s.obsStore.StoreObservation(s.ctx, "session-1", "project-a", obs, 1, 100)
|
||||||
|
s.NoError(err)
|
||||||
|
|
||||||
|
most, err := s.obsStore.GetMostRetrievedObservations(s.ctx, "", 10)
|
||||||
|
s.NoError(err)
|
||||||
|
s.Empty(most) // No observations with retrieval_count > 0
|
||||||
|
}
|
||||||
|
|
||||||
|
// =============================================================================
|
||||||
|
// RESET OBSERVATION SCORES TESTS
|
||||||
|
// =============================================================================
|
||||||
|
|
||||||
|
func (s *ScoringStoreSuite) TestResetObservationScores() {
|
||||||
|
// Create observations with various scores
|
||||||
|
for i := 0; i < 5; i++ {
|
||||||
|
obs := &models.ParsedObservation{
|
||||||
|
Type: models.ObsTypeDiscovery,
|
||||||
|
}
|
||||||
|
id, _, err := s.obsStore.StoreObservation(s.ctx, "session-1", "project-a", obs, i, 100)
|
||||||
|
s.NoError(err)
|
||||||
|
_ = s.obsStore.UpdateImportanceScore(s.ctx, id, float64(i+1))
|
||||||
|
}
|
||||||
|
|
||||||
|
// Reset all scores
|
||||||
|
err := s.obsStore.ResetObservationScores(s.ctx)
|
||||||
|
s.NoError(err)
|
||||||
|
|
||||||
|
// Verify all scores are reset to 1.0
|
||||||
|
observations, err := s.obsStore.GetAllRecentObservations(s.ctx, 100)
|
||||||
|
s.NoError(err)
|
||||||
|
for _, obs := range observations {
|
||||||
|
s.InDelta(1.0, obs.ImportanceScore, 0.001)
|
||||||
|
s.False(obs.ScoreUpdatedAt.Valid)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// =============================================================================
|
||||||
|
// EDGE CASES
|
||||||
|
// =============================================================================
|
||||||
|
|
||||||
|
func (s *ScoringStoreSuite) TestScoring_ZeroScore() {
|
||||||
|
obs := &models.ParsedObservation{
|
||||||
|
Type: models.ObsTypeBugfix,
|
||||||
|
}
|
||||||
|
id, _, err := s.obsStore.StoreObservation(s.ctx, "session-1", "project-a", obs, 1, 100)
|
||||||
|
s.NoError(err)
|
||||||
|
|
||||||
|
// Set score to 0
|
||||||
|
err = s.obsStore.UpdateImportanceScore(s.ctx, id, 0.0)
|
||||||
|
s.NoError(err)
|
||||||
|
|
||||||
|
retrieved, err := s.obsStore.GetObservationByID(s.ctx, id)
|
||||||
|
s.NoError(err)
|
||||||
|
s.InDelta(0.0, retrieved.ImportanceScore, 0.001)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *ScoringStoreSuite) TestScoring_NegativeScore() {
|
||||||
|
obs := &models.ParsedObservation{
|
||||||
|
Type: models.ObsTypeBugfix,
|
||||||
|
}
|
||||||
|
id, _, err := s.obsStore.StoreObservation(s.ctx, "session-1", "project-a", obs, 1, 100)
|
||||||
|
s.NoError(err)
|
||||||
|
|
||||||
|
// Set negative score (calculator shouldn't produce this, but test DB handling)
|
||||||
|
err = s.obsStore.UpdateImportanceScore(s.ctx, id, -0.5)
|
||||||
|
s.NoError(err)
|
||||||
|
|
||||||
|
retrieved, err := s.obsStore.GetObservationByID(s.ctx, id)
|
||||||
|
s.NoError(err)
|
||||||
|
s.InDelta(-0.5, retrieved.ImportanceScore, 0.001)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *ScoringStoreSuite) TestScoring_LargeScore() {
|
||||||
|
obs := &models.ParsedObservation{
|
||||||
|
Type: models.ObsTypeBugfix,
|
||||||
|
}
|
||||||
|
id, _, err := s.obsStore.StoreObservation(s.ctx, "session-1", "project-a", obs, 1, 100)
|
||||||
|
s.NoError(err)
|
||||||
|
|
||||||
|
// Set very large score
|
||||||
|
err = s.obsStore.UpdateImportanceScore(s.ctx, id, 999.999)
|
||||||
|
s.NoError(err)
|
||||||
|
|
||||||
|
retrieved, err := s.obsStore.GetObservationByID(s.ctx, id)
|
||||||
|
s.NoError(err)
|
||||||
|
s.InDelta(999.999, retrieved.ImportanceScore, 0.001)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *ScoringStoreSuite) TestConceptWeight_ZeroWeight() {
|
||||||
|
err := s.obsStore.UpdateConceptWeight(s.ctx, "zero-concept", 0.0)
|
||||||
|
s.NoError(err)
|
||||||
|
|
||||||
|
weights, err := s.obsStore.GetConceptWeights(s.ctx)
|
||||||
|
s.NoError(err)
|
||||||
|
s.Equal(0.0, weights["zero-concept"])
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *ScoringStoreSuite) TestConceptWeight_ExactBoundary() {
|
||||||
|
err := s.obsStore.UpdateConceptWeight(s.ctx, "max-concept", 1.0)
|
||||||
|
s.NoError(err)
|
||||||
|
|
||||||
|
weights, err := s.obsStore.GetConceptWeights(s.ctx)
|
||||||
|
s.NoError(err)
|
||||||
|
s.Equal(1.0, weights["max-concept"])
|
||||||
|
}
|
||||||
|
|
||||||
|
// =============================================================================
|
||||||
|
// STANDALONE TESTS
|
||||||
|
// =============================================================================
|
||||||
|
|
||||||
|
func TestFeedbackStats_Structure(t *testing.T) {
|
||||||
|
stats := FeedbackStats{
|
||||||
|
Total: 100,
|
||||||
|
Positive: 30,
|
||||||
|
Negative: 10,
|
||||||
|
Neutral: 60,
|
||||||
|
AvgScore: 1.5,
|
||||||
|
AvgRetrieval: 5.0,
|
||||||
|
}
|
||||||
|
|
||||||
|
assert.Equal(t, 100, stats.Total)
|
||||||
|
assert.Equal(t, 30, stats.Positive)
|
||||||
|
assert.Equal(t, 10, stats.Negative)
|
||||||
|
assert.Equal(t, 60, stats.Neutral)
|
||||||
|
assert.Equal(t, 1.5, stats.AvgScore)
|
||||||
|
assert.Equal(t, 5.0, stats.AvgRetrieval)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestScoringStore_Integration(t *testing.T) {
|
||||||
|
obsStore, _, cleanup := testScoringObservationStore(t)
|
||||||
|
defer cleanup()
|
||||||
|
|
||||||
|
ctx := context.Background()
|
||||||
|
|
||||||
|
// Full integration test: store, feedback, retrieval, score update
|
||||||
|
obs := &models.ParsedObservation{
|
||||||
|
Type: models.ObsTypeBugfix,
|
||||||
|
Title: "Integration test observation",
|
||||||
|
Concepts: []string{"security"},
|
||||||
|
}
|
||||||
|
id, _, err := obsStore.StoreObservation(ctx, "session-int", "project-int", obs, 1, 100)
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
// Add feedback
|
||||||
|
err = obsStore.UpdateObservationFeedback(ctx, id, 1)
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
// Increment retrieval
|
||||||
|
err = obsStore.IncrementRetrievalCount(ctx, []int64{id})
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
// Update score
|
||||||
|
err = obsStore.UpdateImportanceScore(ctx, id, 1.75)
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
// Verify final state
|
||||||
|
retrieved, err := obsStore.GetObservationByID(ctx, id)
|
||||||
|
require.NoError(t, err)
|
||||||
|
assert.Equal(t, 1, retrieved.UserFeedback)
|
||||||
|
assert.Equal(t, 1, retrieved.RetrievalCount)
|
||||||
|
assert.InDelta(t, 1.75, retrieved.ImportanceScore, 0.001)
|
||||||
|
assert.True(t, retrieved.ScoreUpdatedAt.Valid)
|
||||||
|
assert.True(t, retrieved.LastRetrievedAt.Valid)
|
||||||
|
}
|
||||||
@@ -116,3 +116,21 @@ func (s *SummaryStore) GetAllRecentSummaries(ctx context.Context, limit int) ([]
|
|||||||
|
|
||||||
return scanSummaryRows(rows)
|
return scanSummaryRows(rows)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// GetAllSummaries retrieves all summaries (for vector rebuild).
|
||||||
|
func (s *SummaryStore) GetAllSummaries(ctx context.Context) ([]*models.SessionSummary, error) {
|
||||||
|
const query = `
|
||||||
|
SELECT id, sdk_session_id, project, request, investigated, learned, completed,
|
||||||
|
next_steps, notes, prompt_number, discovery_tokens, created_at, created_at_epoch
|
||||||
|
FROM session_summaries
|
||||||
|
ORDER BY id
|
||||||
|
`
|
||||||
|
|
||||||
|
rows, err := s.store.QueryContext(ctx, query)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
defer rows.Close()
|
||||||
|
|
||||||
|
return scanSummaryRows(rows)
|
||||||
|
}
|
||||||
|
|||||||
@@ -103,6 +103,12 @@ func createBaseTables(t *testing.T, db *sql.DB) {
|
|||||||
discovery_tokens INTEGER DEFAULT 0,
|
discovery_tokens INTEGER DEFAULT 0,
|
||||||
created_at TEXT NOT NULL,
|
created_at TEXT NOT NULL,
|
||||||
created_at_epoch INTEGER NOT NULL,
|
created_at_epoch INTEGER NOT NULL,
|
||||||
|
importance_score REAL DEFAULT 1.0,
|
||||||
|
user_feedback INTEGER DEFAULT 0,
|
||||||
|
retrieval_count INTEGER DEFAULT 0,
|
||||||
|
last_retrieved_at_epoch INTEGER,
|
||||||
|
score_updated_at_epoch INTEGER,
|
||||||
|
is_superseded INTEGER DEFAULT 0,
|
||||||
FOREIGN KEY(sdk_session_id) REFERENCES sdk_sessions(sdk_session_id) ON DELETE CASCADE
|
FOREIGN KEY(sdk_session_id) REFERENCES sdk_sessions(sdk_session_id) ON DELETE CASCADE
|
||||||
)
|
)
|
||||||
`)
|
`)
|
||||||
@@ -110,6 +116,27 @@ func createBaseTables(t *testing.T, db *sql.DB) {
|
|||||||
t.Fatalf("create observations: %v", err)
|
t.Fatalf("create observations: %v", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Create observation_conflicts table for conflict detection
|
||||||
|
_, err = db.Exec(`
|
||||||
|
CREATE TABLE IF NOT EXISTS observation_conflicts (
|
||||||
|
id INTEGER PRIMARY KEY AUTOINCREMENT,
|
||||||
|
newer_obs_id INTEGER NOT NULL,
|
||||||
|
older_obs_id INTEGER NOT NULL,
|
||||||
|
conflict_type TEXT NOT NULL CHECK(conflict_type IN ('superseded', 'contradicts', 'outdated_pattern')),
|
||||||
|
resolution TEXT NOT NULL CHECK(resolution IN ('prefer_newer', 'prefer_older', 'manual')),
|
||||||
|
reason TEXT,
|
||||||
|
detected_at TEXT NOT NULL,
|
||||||
|
detected_at_epoch INTEGER NOT NULL,
|
||||||
|
resolved INTEGER DEFAULT 0,
|
||||||
|
resolved_at TEXT,
|
||||||
|
FOREIGN KEY(newer_obs_id) REFERENCES observations(id) ON DELETE CASCADE,
|
||||||
|
FOREIGN KEY(older_obs_id) REFERENCES observations(id) ON DELETE CASCADE
|
||||||
|
)
|
||||||
|
`)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("create observation_conflicts: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
_, err = db.Exec(`
|
_, err = db.Exec(`
|
||||||
CREATE TABLE IF NOT EXISTS session_summaries (
|
CREATE TABLE IF NOT EXISTS session_summaries (
|
||||||
id INTEGER PRIMARY KEY AUTOINCREMENT,
|
id INTEGER PRIMARY KEY AUTOINCREMENT,
|
||||||
@@ -150,6 +177,31 @@ func createBaseTables(t *testing.T, db *sql.DB) {
|
|||||||
t.Fatalf("create user_prompts: %v", err)
|
t.Fatalf("create user_prompts: %v", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
_, err = db.Exec(`
|
||||||
|
CREATE TABLE IF NOT EXISTS patterns (
|
||||||
|
id INTEGER PRIMARY KEY AUTOINCREMENT,
|
||||||
|
name TEXT NOT NULL,
|
||||||
|
type TEXT NOT NULL CHECK(type IN ('bug', 'refactor', 'architecture', 'anti-pattern', 'best-practice')),
|
||||||
|
description TEXT,
|
||||||
|
signature TEXT,
|
||||||
|
recommendation TEXT,
|
||||||
|
frequency INTEGER DEFAULT 1,
|
||||||
|
projects TEXT,
|
||||||
|
observation_ids TEXT,
|
||||||
|
status TEXT DEFAULT 'active' CHECK(status IN ('active', 'deprecated', 'merged')),
|
||||||
|
merged_into_id INTEGER,
|
||||||
|
confidence REAL DEFAULT 0.5,
|
||||||
|
last_seen_at TEXT NOT NULL,
|
||||||
|
last_seen_at_epoch INTEGER NOT NULL,
|
||||||
|
created_at TEXT NOT NULL,
|
||||||
|
created_at_epoch INTEGER NOT NULL,
|
||||||
|
FOREIGN KEY(merged_into_id) REFERENCES patterns(id) ON DELETE SET NULL
|
||||||
|
)
|
||||||
|
`)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("create patterns: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
indexes := []string{
|
indexes := []string{
|
||||||
`CREATE INDEX IF NOT EXISTS idx_sdk_sessions_claude_id ON sdk_sessions(claude_session_id)`,
|
`CREATE INDEX IF NOT EXISTS idx_sdk_sessions_claude_id ON sdk_sessions(claude_session_id)`,
|
||||||
`CREATE INDEX IF NOT EXISTS idx_sdk_sessions_sdk_id ON sdk_sessions(sdk_session_id)`,
|
`CREATE INDEX IF NOT EXISTS idx_sdk_sessions_sdk_id ON sdk_sessions(sdk_session_id)`,
|
||||||
|
|||||||
@@ -0,0 +1 @@
|
|||||||
|
bge-small-en-v1.5
|
||||||
Binary file not shown.
@@ -1,21 +1,7 @@
|
|||||||
{
|
{
|
||||||
"version": "1.0",
|
"version": "1.0",
|
||||||
"truncation": {
|
"truncation": null,
|
||||||
"direction": "Right",
|
"padding": null,
|
||||||
"max_length": 128,
|
|
||||||
"strategy": "LongestFirst",
|
|
||||||
"stride": 0
|
|
||||||
},
|
|
||||||
"padding": {
|
|
||||||
"strategy": {
|
|
||||||
"Fixed": 128
|
|
||||||
},
|
|
||||||
"direction": "Right",
|
|
||||||
"pad_to_multiple_of": null,
|
|
||||||
"pad_id": 0,
|
|
||||||
"pad_type_id": 0,
|
|
||||||
"pad_token": "[PAD]"
|
|
||||||
},
|
|
||||||
"added_tokens": [
|
"added_tokens": [
|
||||||
{
|
{
|
||||||
"id": 0,
|
"id": 0,
|
||||||
|
|||||||
@@ -0,0 +1,157 @@
|
|||||||
|
// Package embedding provides text embedding generation with swappable models.
|
||||||
|
package embedding
|
||||||
|
|
||||||
|
import (
|
||||||
|
"fmt"
|
||||||
|
"sync"
|
||||||
|
)
|
||||||
|
|
||||||
|
// PoolingStrategy defines how to pool token embeddings into sentence embeddings.
|
||||||
|
type PoolingStrategy string
|
||||||
|
|
||||||
|
const (
|
||||||
|
// PoolingNone means the model already outputs sentence embeddings directly.
|
||||||
|
PoolingNone PoolingStrategy = "none"
|
||||||
|
// PoolingMean averages all token embeddings (weighted by attention mask).
|
||||||
|
PoolingMean PoolingStrategy = "mean"
|
||||||
|
// PoolingCLS uses only the [CLS] token embedding.
|
||||||
|
PoolingCLS PoolingStrategy = "cls"
|
||||||
|
)
|
||||||
|
|
||||||
|
// ONNXConfig describes ONNX-specific model configuration.
|
||||||
|
// This allows different models to specify their tensor names and pooling needs.
|
||||||
|
type ONNXConfig struct {
|
||||||
|
// InputNames are the ONNX input tensor names in order.
|
||||||
|
InputNames []string
|
||||||
|
// OutputNames are the ONNX output tensor names.
|
||||||
|
OutputNames []string
|
||||||
|
// Pooling specifies how to convert token embeddings to sentence embeddings.
|
||||||
|
// If PoolingNone, the model outputs sentence embeddings directly.
|
||||||
|
Pooling PoolingStrategy
|
||||||
|
// HiddenSize is the embedding dimension (used for pooling calculations).
|
||||||
|
HiddenSize int
|
||||||
|
}
|
||||||
|
|
||||||
|
// EmbeddingModel represents a text embedding model.
|
||||||
|
type EmbeddingModel interface {
|
||||||
|
// Name returns the human-readable model name (e.g., "bge-small-en-v1.5").
|
||||||
|
Name() string
|
||||||
|
|
||||||
|
// Version returns a short version string for storage (e.g., "bge-v1.5").
|
||||||
|
Version() string
|
||||||
|
|
||||||
|
// Dimensions returns the embedding vector size.
|
||||||
|
Dimensions() int
|
||||||
|
|
||||||
|
// Embed generates an embedding for a single text.
|
||||||
|
Embed(text string) ([]float32, error)
|
||||||
|
|
||||||
|
// EmbedBatch generates embeddings for multiple texts.
|
||||||
|
EmbedBatch(texts []string) ([][]float32, error)
|
||||||
|
|
||||||
|
// Close releases model resources.
|
||||||
|
Close() error
|
||||||
|
}
|
||||||
|
|
||||||
|
// ONNXConfigurer is an optional interface that models can implement
|
||||||
|
// to expose their ONNX configuration for introspection.
|
||||||
|
type ONNXConfigurer interface {
|
||||||
|
// ONNXConfig returns the model's ONNX configuration.
|
||||||
|
ONNXConfig() ONNXConfig
|
||||||
|
}
|
||||||
|
|
||||||
|
// ModelMetadata describes an embedding model for UI/config.
|
||||||
|
type ModelMetadata struct {
|
||||||
|
Name string `json:"name"` // Human-readable name
|
||||||
|
Version string `json:"version"` // Short ID for DB storage
|
||||||
|
Dimensions int `json:"dimensions"` // Vector size
|
||||||
|
Description string `json:"description"` // Brief description
|
||||||
|
Default bool `json:"default"` // Is this the default model?
|
||||||
|
}
|
||||||
|
|
||||||
|
// ModelFactory creates a new instance of an embedding model.
|
||||||
|
type ModelFactory func() (EmbeddingModel, error)
|
||||||
|
|
||||||
|
// ModelRegistry provides model lookup by version.
|
||||||
|
type ModelRegistry struct {
|
||||||
|
mu sync.RWMutex
|
||||||
|
models map[string]ModelFactory
|
||||||
|
metadata map[string]ModelMetadata
|
||||||
|
defaultModel string
|
||||||
|
}
|
||||||
|
|
||||||
|
// NewModelRegistry creates a new model registry.
|
||||||
|
func NewModelRegistry() *ModelRegistry {
|
||||||
|
return &ModelRegistry{
|
||||||
|
models: make(map[string]ModelFactory),
|
||||||
|
metadata: make(map[string]ModelMetadata),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Register adds a model factory to the registry.
|
||||||
|
func (r *ModelRegistry) Register(meta ModelMetadata, factory ModelFactory) {
|
||||||
|
r.mu.Lock()
|
||||||
|
defer r.mu.Unlock()
|
||||||
|
|
||||||
|
r.models[meta.Version] = factory
|
||||||
|
r.metadata[meta.Version] = meta
|
||||||
|
|
||||||
|
if meta.Default {
|
||||||
|
r.defaultModel = meta.Version
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Get creates a new instance of the model with the given version.
|
||||||
|
func (r *ModelRegistry) Get(version string) (EmbeddingModel, error) {
|
||||||
|
r.mu.RLock()
|
||||||
|
factory, ok := r.models[version]
|
||||||
|
r.mu.RUnlock()
|
||||||
|
|
||||||
|
if !ok {
|
||||||
|
return nil, fmt.Errorf("unknown model version: %s", version)
|
||||||
|
}
|
||||||
|
|
||||||
|
return factory()
|
||||||
|
}
|
||||||
|
|
||||||
|
// Default returns the default model version.
|
||||||
|
func (r *ModelRegistry) Default() string {
|
||||||
|
r.mu.RLock()
|
||||||
|
defer r.mu.RUnlock()
|
||||||
|
return r.defaultModel
|
||||||
|
}
|
||||||
|
|
||||||
|
// List returns metadata for all registered models.
|
||||||
|
func (r *ModelRegistry) List() []ModelMetadata {
|
||||||
|
r.mu.RLock()
|
||||||
|
defer r.mu.RUnlock()
|
||||||
|
|
||||||
|
result := make([]ModelMetadata, 0, len(r.metadata))
|
||||||
|
for _, meta := range r.metadata {
|
||||||
|
result = append(result, meta)
|
||||||
|
}
|
||||||
|
return result
|
||||||
|
}
|
||||||
|
|
||||||
|
// DefaultRegistry is the global model registry with all available models.
|
||||||
|
var DefaultRegistry = NewModelRegistry()
|
||||||
|
|
||||||
|
// RegisterModel adds a model to the default registry.
|
||||||
|
func RegisterModel(meta ModelMetadata, factory ModelFactory) {
|
||||||
|
DefaultRegistry.Register(meta, factory)
|
||||||
|
}
|
||||||
|
|
||||||
|
// GetModel creates a model instance from the default registry.
|
||||||
|
func GetModel(version string) (EmbeddingModel, error) {
|
||||||
|
return DefaultRegistry.Get(version)
|
||||||
|
}
|
||||||
|
|
||||||
|
// GetDefaultModel returns the default model version from the default registry.
|
||||||
|
func GetDefaultModel() string {
|
||||||
|
return DefaultRegistry.Default()
|
||||||
|
}
|
||||||
|
|
||||||
|
// ListModels returns metadata for all models in the default registry.
|
||||||
|
func ListModels() []ModelMetadata {
|
||||||
|
return DefaultRegistry.List()
|
||||||
|
}
|
||||||
+277
-56
@@ -1,4 +1,4 @@
|
|||||||
// Package embedding provides text embedding generation using all-MiniLM-L6-v2.
|
// Package embedding provides text embedding generation with swappable models.
|
||||||
package embedding
|
package embedding
|
||||||
|
|
||||||
import (
|
import (
|
||||||
@@ -15,19 +15,50 @@ import (
|
|||||||
ort "github.com/yalue/onnxruntime_go"
|
ort "github.com/yalue/onnxruntime_go"
|
||||||
)
|
)
|
||||||
|
|
||||||
// EmbeddingDim is the dimension of embeddings produced by all-MiniLM-L6-v2.
|
// EmbeddingDim is the dimension of embeddings produced by the current model.
|
||||||
|
// Both all-MiniLM-L6-v2 and bge-small-en-v1.5 produce 384-dimensional embeddings.
|
||||||
const EmbeddingDim = 384
|
const EmbeddingDim = 384
|
||||||
|
|
||||||
// Service provides thread-safe text embedding generation.
|
// Model version constants
|
||||||
type Service struct {
|
const (
|
||||||
|
// BGEModelVersion is the version string for bge-small-en-v1.5
|
||||||
|
BGEModelVersion = "bge-v1.5"
|
||||||
|
// BGEModelName is the human-readable name for bge-small-en-v1.5
|
||||||
|
BGEModelName = "bge-small-en-v1.5"
|
||||||
|
// DefaultModelVersion is the default model to use
|
||||||
|
DefaultModelVersion = BGEModelVersion
|
||||||
|
)
|
||||||
|
|
||||||
|
// MaxSequenceLength is the maximum token sequence length for the model.
|
||||||
|
const MaxSequenceLength = 512
|
||||||
|
|
||||||
|
// bgeONNXConfig defines the ONNX configuration for BGE models.
|
||||||
|
// BGE outputs last_hidden_state and requires mean pooling.
|
||||||
|
var bgeONNXConfig = ONNXConfig{
|
||||||
|
InputNames: []string{"input_ids", "attention_mask", "token_type_ids"},
|
||||||
|
OutputNames: []string{"last_hidden_state"},
|
||||||
|
Pooling: PoolingMean,
|
||||||
|
HiddenSize: EmbeddingDim,
|
||||||
|
}
|
||||||
|
|
||||||
|
// bgeModel is the ONNX-based embedding model implementation.
|
||||||
|
// Currently supports bge-small-en-v1.5 (previously all-MiniLM-L6-v2).
|
||||||
|
type bgeModel struct {
|
||||||
tk *tokenizer.Tokenizer
|
tk *tokenizer.Tokenizer
|
||||||
session *ort.DynamicAdvancedSession
|
session *ort.DynamicAdvancedSession
|
||||||
mu sync.Mutex
|
mu sync.Mutex
|
||||||
libDir string // temp directory containing extracted libraries
|
libDir string // temp directory containing extracted libraries
|
||||||
|
config ONNXConfig // ONNX configuration for this model
|
||||||
}
|
}
|
||||||
|
|
||||||
// NewService creates a new embedding service using bundled ONNX runtime and model.
|
// Compile-time check that bgeModel implements EmbeddingModel
|
||||||
func NewService() (*Service, error) {
|
var _ EmbeddingModel = (*bgeModel)(nil)
|
||||||
|
|
||||||
|
// Compile-time check that bgeModel implements ONNXConfigurer
|
||||||
|
var _ ONNXConfigurer = (*bgeModel)(nil)
|
||||||
|
|
||||||
|
// newBGEModel creates a new BGE embedding model using bundled ONNX runtime and model.
|
||||||
|
func newBGEModel() (EmbeddingModel, error) {
|
||||||
// Extract ONNX runtime library to temp directory
|
// Extract ONNX runtime library to temp directory
|
||||||
libDir, err := extractONNXLibrary()
|
libDir, err := extractONNXLibrary()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
@@ -49,22 +80,41 @@ func NewService() (*Service, error) {
|
|||||||
return nil, fmt.Errorf("load tokenizer: %w", err)
|
return nil, fmt.Errorf("load tokenizer: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
// Create ONNX session with embedded model
|
// Create ONNX session using model-specific configuration
|
||||||
inputNames := []string{"input_ids", "attention_mask", "token_type_ids"}
|
config := bgeONNXConfig
|
||||||
outputNames := []string{"sentence_embedding"}
|
session, err := ort.NewDynamicAdvancedSessionWithONNXData(modelData, config.InputNames, config.OutputNames, nil)
|
||||||
|
|
||||||
session, err := ort.NewDynamicAdvancedSessionWithONNXData(modelData, inputNames, outputNames, nil)
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, fmt.Errorf("create ONNX session: %w", err)
|
return nil, fmt.Errorf("create ONNX session: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
return &Service{
|
return &bgeModel{
|
||||||
tk: tk,
|
tk: tk,
|
||||||
session: session,
|
session: session,
|
||||||
libDir: libDir,
|
libDir: libDir,
|
||||||
|
config: config,
|
||||||
}, nil
|
}, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// ONNXConfig returns the model's ONNX configuration.
|
||||||
|
func (m *bgeModel) ONNXConfig() ONNXConfig {
|
||||||
|
return m.config
|
||||||
|
}
|
||||||
|
|
||||||
|
// Name returns the human-readable model name.
|
||||||
|
func (m *bgeModel) Name() string {
|
||||||
|
return BGEModelName
|
||||||
|
}
|
||||||
|
|
||||||
|
// Version returns the short version string for storage.
|
||||||
|
func (m *bgeModel) Version() string {
|
||||||
|
return BGEModelVersion
|
||||||
|
}
|
||||||
|
|
||||||
|
// Dimensions returns the embedding vector size.
|
||||||
|
func (m *bgeModel) Dimensions() int {
|
||||||
|
return EmbeddingDim
|
||||||
|
}
|
||||||
|
|
||||||
// extractONNXLibrary extracts the embedded ONNX runtime library to a temp directory.
|
// extractONNXLibrary extracts the embedded ONNX runtime library to a temp directory.
|
||||||
// Uses content hash to avoid re-extracting if already present.
|
// Uses content hash to avoid re-extracting if already present.
|
||||||
func extractONNXLibrary() (string, error) {
|
func extractONNXLibrary() (string, error) {
|
||||||
@@ -107,15 +157,15 @@ func extractONNXLibrary() (string, error) {
|
|||||||
|
|
||||||
// Embed generates an embedding for a single text.
|
// Embed generates an embedding for a single text.
|
||||||
// Returns a 384-dimensional float32 vector.
|
// Returns a 384-dimensional float32 vector.
|
||||||
func (s *Service) Embed(text string) ([]float32, error) {
|
func (m *bgeModel) Embed(text string) ([]float32, error) {
|
||||||
s.mu.Lock()
|
m.mu.Lock()
|
||||||
defer s.mu.Unlock()
|
defer m.mu.Unlock()
|
||||||
|
|
||||||
if text == "" {
|
if text == "" {
|
||||||
return make([]float32, EmbeddingDim), nil
|
return make([]float32, EmbeddingDim), nil
|
||||||
}
|
}
|
||||||
|
|
||||||
results, err := s.computeBatch([]string{text})
|
results, err := m.computeBatch([]string{text})
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
@@ -127,13 +177,13 @@ func (s *Service) Embed(text string) ([]float32, error) {
|
|||||||
|
|
||||||
// EmbedBatch generates embeddings for multiple texts.
|
// EmbedBatch generates embeddings for multiple texts.
|
||||||
// Returns slice of 384-dimensional float32 vectors.
|
// Returns slice of 384-dimensional float32 vectors.
|
||||||
func (s *Service) EmbedBatch(texts []string) ([][]float32, error) {
|
func (m *bgeModel) EmbedBatch(texts []string) ([][]float32, error) {
|
||||||
if len(texts) == 0 {
|
if len(texts) == 0 {
|
||||||
return nil, nil
|
return nil, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
s.mu.Lock()
|
m.mu.Lock()
|
||||||
defer s.mu.Unlock()
|
defer m.mu.Unlock()
|
||||||
|
|
||||||
// Filter out empty texts and track indices
|
// Filter out empty texts and track indices
|
||||||
nonEmpty := make([]string, 0, len(texts))
|
nonEmpty := make([]string, 0, len(texts))
|
||||||
@@ -155,7 +205,7 @@ func (s *Service) EmbedBatch(texts []string) ([][]float32, error) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// Compute embeddings for non-empty texts
|
// Compute embeddings for non-empty texts
|
||||||
embeddings, err := s.computeBatch(nonEmpty)
|
embeddings, err := m.computeBatch(nonEmpty)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, fmt.Errorf("compute batch embeddings: %w", err)
|
return nil, fmt.Errorf("compute batch embeddings: %w", err)
|
||||||
}
|
}
|
||||||
@@ -173,7 +223,7 @@ func (s *Service) EmbedBatch(texts []string) ([][]float32, error) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// computeBatch runs inference on a batch of texts. Must be called with lock held.
|
// computeBatch runs inference on a batch of texts. Must be called with lock held.
|
||||||
func (s *Service) computeBatch(sentences []string) ([][]float32, error) {
|
func (m *bgeModel) computeBatch(sentences []string) ([][]float32, error) {
|
||||||
if len(sentences) == 0 {
|
if len(sentences) == 0 {
|
||||||
return nil, nil
|
return nil, nil
|
||||||
}
|
}
|
||||||
@@ -184,31 +234,57 @@ func (s *Service) computeBatch(sentences []string) ([][]float32, error) {
|
|||||||
inputBatch[i] = tokenizer.NewSingleEncodeInput(tokenizer.NewRawInputSequence(sent))
|
inputBatch[i] = tokenizer.NewSingleEncodeInput(tokenizer.NewRawInputSequence(sent))
|
||||||
}
|
}
|
||||||
|
|
||||||
encodings, err := s.tk.EncodeBatch(inputBatch, true)
|
encodings, err := m.tk.EncodeBatch(inputBatch, true)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, fmt.Errorf("tokenize: %w", err)
|
return nil, fmt.Errorf("tokenize: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
batchSize := len(encodings)
|
batchSize := len(encodings)
|
||||||
seqLength := len(encodings[0].Ids)
|
hiddenSize := m.config.HiddenSize
|
||||||
hiddenSize := EmbeddingDim
|
|
||||||
|
// Find max sequence length across all encodings (tokenizer may not pad uniformly)
|
||||||
|
// Also enforce MaxSequenceLength to prevent model errors
|
||||||
|
seqLength := 0
|
||||||
|
for _, enc := range encodings {
|
||||||
|
if len(enc.Ids) > seqLength {
|
||||||
|
seqLength = len(enc.Ids)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
// Truncate to max model sequence length
|
||||||
|
if seqLength > MaxSequenceLength {
|
||||||
|
seqLength = MaxSequenceLength
|
||||||
|
}
|
||||||
|
|
||||||
inputShape := ort.NewShape(int64(batchSize), int64(seqLength))
|
inputShape := ort.NewShape(int64(batchSize), int64(seqLength))
|
||||||
|
|
||||||
// Create input tensors
|
// Create input tensors (pre-filled with zeros for padding)
|
||||||
inputIdsData := make([]int64, batchSize*seqLength)
|
inputIdsData := make([]int64, batchSize*seqLength)
|
||||||
attentionMaskData := make([]int64, batchSize*seqLength)
|
attentionMaskData := make([]int64, batchSize*seqLength)
|
||||||
tokenTypeIdsData := make([]int64, batchSize*seqLength)
|
tokenTypeIdsData := make([]int64, batchSize*seqLength)
|
||||||
|
|
||||||
for b := 0; b < batchSize; b++ {
|
for b := 0; b < batchSize; b++ {
|
||||||
for i, id := range encodings[b].Ids {
|
// Copy actual token data (rest remains 0 as padding)
|
||||||
inputIdsData[b*seqLength+i] = int64(id)
|
// Truncate to seqLength to handle long inputs
|
||||||
|
copyLen := len(encodings[b].Ids)
|
||||||
|
if copyLen > seqLength {
|
||||||
|
copyLen = seqLength
|
||||||
}
|
}
|
||||||
for i, mask := range encodings[b].AttentionMask {
|
for i := 0; i < copyLen; i++ {
|
||||||
attentionMaskData[b*seqLength+i] = int64(mask)
|
inputIdsData[b*seqLength+i] = int64(encodings[b].Ids[i])
|
||||||
}
|
}
|
||||||
for i, typeId := range encodings[b].TypeIds {
|
copyLen = len(encodings[b].AttentionMask)
|
||||||
tokenTypeIdsData[b*seqLength+i] = int64(typeId)
|
if copyLen > seqLength {
|
||||||
|
copyLen = seqLength
|
||||||
|
}
|
||||||
|
for i := 0; i < copyLen; i++ {
|
||||||
|
attentionMaskData[b*seqLength+i] = int64(encodings[b].AttentionMask[i])
|
||||||
|
}
|
||||||
|
copyLen = len(encodings[b].TypeIds)
|
||||||
|
if copyLen > seqLength {
|
||||||
|
copyLen = seqLength
|
||||||
|
}
|
||||||
|
for i := 0; i < copyLen; i++ {
|
||||||
|
tokenTypeIdsData[b*seqLength+i] = int64(encodings[b].TypeIds[i])
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -230,51 +306,131 @@ func (s *Service) computeBatch(sentences []string) ([][]float32, error) {
|
|||||||
}
|
}
|
||||||
defer tokenTypeIdsTensor.Destroy()
|
defer tokenTypeIdsTensor.Destroy()
|
||||||
|
|
||||||
sentenceOutputShape := ort.NewShape(int64(batchSize), int64(hiddenSize))
|
// Create output tensor based on pooling strategy
|
||||||
sentenceOutputTensor, err := ort.NewEmptyTensor[float32](sentenceOutputShape)
|
var outputShape ort.Shape
|
||||||
|
|
||||||
|
switch m.config.Pooling {
|
||||||
|
case PoolingNone:
|
||||||
|
// Direct sentence embedding output: [batch, hidden]
|
||||||
|
outputShape = ort.NewShape(int64(batchSize), int64(hiddenSize))
|
||||||
|
case PoolingMean, PoolingCLS:
|
||||||
|
// Token-level output: [batch, seq_len, hidden]
|
||||||
|
outputShape = ort.NewShape(int64(batchSize), int64(seqLength), int64(hiddenSize))
|
||||||
|
default:
|
||||||
|
outputShape = ort.NewShape(int64(batchSize), int64(hiddenSize))
|
||||||
|
}
|
||||||
|
|
||||||
|
outputTensor, err := ort.NewEmptyTensor[float32](outputShape)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, fmt.Errorf("create output tensor: %w", err)
|
return nil, fmt.Errorf("create output tensor: %w", err)
|
||||||
}
|
}
|
||||||
defer sentenceOutputTensor.Destroy()
|
defer outputTensor.Destroy()
|
||||||
|
|
||||||
// Run inference
|
// Run inference
|
||||||
inputTensors := []ort.Value{inputIdsTensor, attentionMaskTensor, tokenTypeIdsTensor}
|
inputTensors := []ort.Value{inputIdsTensor, attentionMaskTensor, tokenTypeIdsTensor}
|
||||||
outputTensors := []ort.Value{sentenceOutputTensor}
|
outputTensors := []ort.Value{outputTensor}
|
||||||
|
|
||||||
if err := s.session.Run(inputTensors, outputTensors); err != nil {
|
if err := m.session.Run(inputTensors, outputTensors); err != nil {
|
||||||
return nil, fmt.Errorf("run inference: %w", err)
|
return nil, fmt.Errorf("run inference: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
// Extract results
|
// Extract and pool results based on strategy
|
||||||
flatOutput := sentenceOutputTensor.GetData()
|
flatOutput := outputTensor.GetData()
|
||||||
expectedSize := batchSize * hiddenSize
|
|
||||||
if len(flatOutput) != expectedSize {
|
|
||||||
return nil, fmt.Errorf("unexpected output size: got %d, expected %d", len(flatOutput), expectedSize)
|
|
||||||
}
|
|
||||||
|
|
||||||
|
switch m.config.Pooling {
|
||||||
|
case PoolingNone:
|
||||||
|
// Direct output, no pooling needed
|
||||||
|
expectedSize := batchSize * hiddenSize
|
||||||
|
if len(flatOutput) != expectedSize {
|
||||||
|
return nil, fmt.Errorf("unexpected output size: got %d, expected %d", len(flatOutput), expectedSize)
|
||||||
|
}
|
||||||
|
results := make([][]float32, batchSize)
|
||||||
|
for i := 0; i < batchSize; i++ {
|
||||||
|
start := i * hiddenSize
|
||||||
|
end := start + hiddenSize
|
||||||
|
results[i] = make([]float32, hiddenSize)
|
||||||
|
copy(results[i], flatOutput[start:end])
|
||||||
|
}
|
||||||
|
return results, nil
|
||||||
|
|
||||||
|
case PoolingMean:
|
||||||
|
// Mean pooling over tokens (weighted by attention mask)
|
||||||
|
return meanPooling(flatOutput, attentionMaskData, batchSize, seqLength, hiddenSize), nil
|
||||||
|
|
||||||
|
case PoolingCLS:
|
||||||
|
// CLS token pooling (first token of each sequence)
|
||||||
|
return clsPooling(flatOutput, batchSize, seqLength, hiddenSize), nil
|
||||||
|
|
||||||
|
default:
|
||||||
|
return nil, fmt.Errorf("unknown pooling strategy: %s", m.config.Pooling)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// meanPooling applies mean pooling over token embeddings, weighted by attention mask.
|
||||||
|
// Input shape: [batch, seq_len, hidden], attention mask: [batch, seq_len]
|
||||||
|
// Output shape: [batch, hidden]
|
||||||
|
func meanPooling(embeddings []float32, attentionMask []int64, batchSize, seqLen, hiddenSize int) [][]float32 {
|
||||||
results := make([][]float32, batchSize)
|
results := make([][]float32, batchSize)
|
||||||
for i := 0; i < batchSize; i++ {
|
|
||||||
start := i * hiddenSize
|
for b := 0; b < batchSize; b++ {
|
||||||
end := start + hiddenSize
|
result := make([]float32, hiddenSize)
|
||||||
results[i] = make([]float32, hiddenSize)
|
var maskSum float32
|
||||||
copy(results[i], flatOutput[start:end])
|
|
||||||
|
// Sum embeddings weighted by attention mask
|
||||||
|
for s := 0; s < seqLen; s++ {
|
||||||
|
maskVal := float32(attentionMask[b*seqLen+s])
|
||||||
|
maskSum += maskVal
|
||||||
|
|
||||||
|
if maskVal > 0 {
|
||||||
|
embOffset := (b*seqLen + s) * hiddenSize
|
||||||
|
for h := 0; h < hiddenSize; h++ {
|
||||||
|
result[h] += embeddings[embOffset+h] * maskVal
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Normalize by mask sum (avoid division by zero)
|
||||||
|
if maskSum > 0 {
|
||||||
|
for h := 0; h < hiddenSize; h++ {
|
||||||
|
result[h] /= maskSum
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
results[b] = result
|
||||||
}
|
}
|
||||||
|
|
||||||
return results, nil
|
return results
|
||||||
|
}
|
||||||
|
|
||||||
|
// clsPooling extracts the [CLS] token embedding (first token).
|
||||||
|
// Input shape: [batch, seq_len, hidden]
|
||||||
|
// Output shape: [batch, hidden]
|
||||||
|
func clsPooling(embeddings []float32, batchSize, seqLen, hiddenSize int) [][]float32 {
|
||||||
|
results := make([][]float32, batchSize)
|
||||||
|
|
||||||
|
for b := 0; b < batchSize; b++ {
|
||||||
|
result := make([]float32, hiddenSize)
|
||||||
|
// CLS token is at position 0
|
||||||
|
embOffset := b * seqLen * hiddenSize
|
||||||
|
copy(result, embeddings[embOffset:embOffset+hiddenSize])
|
||||||
|
results[b] = result
|
||||||
|
}
|
||||||
|
|
||||||
|
return results
|
||||||
}
|
}
|
||||||
|
|
||||||
// Close releases model resources.
|
// Close releases model resources.
|
||||||
func (s *Service) Close() error {
|
func (m *bgeModel) Close() error {
|
||||||
s.mu.Lock()
|
m.mu.Lock()
|
||||||
defer s.mu.Unlock()
|
defer m.mu.Unlock()
|
||||||
|
|
||||||
var errs []error
|
var errs []error
|
||||||
|
|
||||||
if s.session != nil {
|
if m.session != nil {
|
||||||
if err := s.session.Destroy(); err != nil {
|
if err := m.session.Destroy(); err != nil {
|
||||||
errs = append(errs, fmt.Errorf("destroy session: %w", err))
|
errs = append(errs, fmt.Errorf("destroy session: %w", err))
|
||||||
}
|
}
|
||||||
s.session = nil
|
m.session = nil
|
||||||
}
|
}
|
||||||
|
|
||||||
if err := ort.DestroyEnvironment(); err != nil {
|
if err := ort.DestroyEnvironment(); err != nil {
|
||||||
@@ -282,10 +438,75 @@ func (s *Service) Close() error {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// Optionally clean up extracted library (leave for caching)
|
// Optionally clean up extracted library (leave for caching)
|
||||||
// os.RemoveAll(s.libDir)
|
// os.RemoveAll(m.libDir)
|
||||||
|
|
||||||
if len(errs) > 0 {
|
if len(errs) > 0 {
|
||||||
return errs[0]
|
return errs[0]
|
||||||
}
|
}
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Register the BGE model with the default registry at init time
|
||||||
|
func init() {
|
||||||
|
RegisterModel(ModelMetadata{
|
||||||
|
Name: BGEModelName,
|
||||||
|
Version: BGEModelVersion,
|
||||||
|
Dimensions: EmbeddingDim,
|
||||||
|
Description: "High-quality semantic search model",
|
||||||
|
Default: true,
|
||||||
|
}, newBGEModel)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Service provides thread-safe text embedding generation with model abstraction.
|
||||||
|
type Service struct {
|
||||||
|
model EmbeddingModel
|
||||||
|
}
|
||||||
|
|
||||||
|
// NewService creates a new embedding service using the default model.
|
||||||
|
func NewService() (*Service, error) {
|
||||||
|
return NewServiceWithModel(DefaultModelVersion)
|
||||||
|
}
|
||||||
|
|
||||||
|
// NewServiceWithModel creates a new embedding service using the specified model.
|
||||||
|
func NewServiceWithModel(version string) (*Service, error) {
|
||||||
|
if version == "" {
|
||||||
|
version = DefaultModelVersion
|
||||||
|
}
|
||||||
|
|
||||||
|
model, err := GetModel(version)
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("get model %s: %w", version, err)
|
||||||
|
}
|
||||||
|
|
||||||
|
return &Service{model: model}, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// Name returns the human-readable model name.
|
||||||
|
func (s *Service) Name() string {
|
||||||
|
return s.model.Name()
|
||||||
|
}
|
||||||
|
|
||||||
|
// Version returns the short version string for storage.
|
||||||
|
func (s *Service) Version() string {
|
||||||
|
return s.model.Version()
|
||||||
|
}
|
||||||
|
|
||||||
|
// Dimensions returns the embedding vector size.
|
||||||
|
func (s *Service) Dimensions() int {
|
||||||
|
return s.model.Dimensions()
|
||||||
|
}
|
||||||
|
|
||||||
|
// Embed generates an embedding for a single text.
|
||||||
|
func (s *Service) Embed(text string) ([]float32, error) {
|
||||||
|
return s.model.Embed(text)
|
||||||
|
}
|
||||||
|
|
||||||
|
// EmbedBatch generates embeddings for multiple texts.
|
||||||
|
func (s *Service) EmbedBatch(texts []string) ([][]float32, error) {
|
||||||
|
return s.model.EmbedBatch(texts)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Close releases model resources.
|
||||||
|
func (s *Service) Close() error {
|
||||||
|
return s.model.Close()
|
||||||
|
}
|
||||||
|
|||||||
@@ -22,8 +22,10 @@ func TestNewService(t *testing.T) {
|
|||||||
|
|
||||||
defer svc.Close()
|
defer svc.Close()
|
||||||
|
|
||||||
assert.NotNil(t, svc.tk)
|
// Verify the service is properly initialized via public methods
|
||||||
assert.NotNil(t, svc.session)
|
assert.NotEmpty(t, svc.Name())
|
||||||
|
assert.NotEmpty(t, svc.Version())
|
||||||
|
assert.Equal(t, EmbeddingDim, svc.Dimensions())
|
||||||
}
|
}
|
||||||
|
|
||||||
// TestEmbed_SingleText tests embedding a single text.
|
// TestEmbed_SingleText tests embedding a single text.
|
||||||
@@ -269,8 +271,8 @@ func TestClose(t *testing.T) {
|
|||||||
err = svc.Close()
|
err = svc.Close()
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
|
|
||||||
// Session should be nil after close
|
// After close, embedding should fail (model resources released)
|
||||||
assert.Nil(t, svc.session)
|
// Note: This behavior is model-specific; some models may still work after close
|
||||||
}
|
}
|
||||||
|
|
||||||
// TestEmbedBatch_SingleItem tests batch embedding with single item.
|
// TestEmbedBatch_SingleItem tests batch embedding with single item.
|
||||||
|
|||||||
@@ -0,0 +1,438 @@
|
|||||||
|
// Package pattern provides pattern detection and recognition functionality.
|
||||||
|
package pattern
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"sync"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"github.com/lukaszraczylo/claude-mnemonic/internal/db/sqlite"
|
||||||
|
"github.com/lukaszraczylo/claude-mnemonic/pkg/models"
|
||||||
|
"github.com/rs/zerolog/log"
|
||||||
|
)
|
||||||
|
|
||||||
|
// DetectorConfig contains configuration for the pattern detector.
|
||||||
|
type DetectorConfig struct {
|
||||||
|
// MinMatchScore is the minimum similarity score to consider a match (0.0-1.0).
|
||||||
|
MinMatchScore float64
|
||||||
|
// MinFrequencyForPattern is the minimum occurrences before creating a pattern.
|
||||||
|
MinFrequencyForPattern int
|
||||||
|
// AnalysisInterval is how often to run background pattern analysis.
|
||||||
|
AnalysisInterval time.Duration
|
||||||
|
// MaxPatternsToTrack is the maximum number of active patterns.
|
||||||
|
MaxPatternsToTrack int
|
||||||
|
}
|
||||||
|
|
||||||
|
// DefaultConfig returns the default detector configuration.
|
||||||
|
func DefaultConfig() DetectorConfig {
|
||||||
|
return DetectorConfig{
|
||||||
|
MinMatchScore: 0.3, // 30% similarity threshold
|
||||||
|
MinFrequencyForPattern: 2, // At least 2 occurrences to form a pattern
|
||||||
|
AnalysisInterval: 5 * time.Minute,
|
||||||
|
MaxPatternsToTrack: 1000,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// PatternSyncFunc is a callback for syncing patterns to vector store.
|
||||||
|
type PatternSyncFunc func(pattern *models.Pattern)
|
||||||
|
|
||||||
|
// Detector detects and tracks recurring patterns across observations.
|
||||||
|
type Detector struct {
|
||||||
|
config DetectorConfig
|
||||||
|
patternStore *sqlite.PatternStore
|
||||||
|
observationStore *sqlite.ObservationStore
|
||||||
|
|
||||||
|
// Vector sync callback
|
||||||
|
syncFunc PatternSyncFunc
|
||||||
|
|
||||||
|
// Candidate tracking (patterns not yet confirmed)
|
||||||
|
candidates map[string]*candidatePattern
|
||||||
|
candidatesMu sync.RWMutex
|
||||||
|
|
||||||
|
// Background analysis
|
||||||
|
ctx context.Context
|
||||||
|
cancel context.CancelFunc
|
||||||
|
wg sync.WaitGroup
|
||||||
|
}
|
||||||
|
|
||||||
|
// SetSyncFunc sets the callback for syncing patterns to vector store.
|
||||||
|
func (d *Detector) SetSyncFunc(fn PatternSyncFunc) {
|
||||||
|
d.syncFunc = fn
|
||||||
|
}
|
||||||
|
|
||||||
|
// candidatePattern tracks a potential pattern before it reaches frequency threshold.
|
||||||
|
type candidatePattern struct {
|
||||||
|
signature []string
|
||||||
|
observationIDs []int64
|
||||||
|
projects []string
|
||||||
|
patternType models.PatternType
|
||||||
|
title string
|
||||||
|
lastSeenEpoch int64
|
||||||
|
}
|
||||||
|
|
||||||
|
// NewDetector creates a new pattern detector.
|
||||||
|
func NewDetector(patternStore *sqlite.PatternStore, observationStore *sqlite.ObservationStore, config DetectorConfig) *Detector {
|
||||||
|
ctx, cancel := context.WithCancel(context.Background())
|
||||||
|
return &Detector{
|
||||||
|
config: config,
|
||||||
|
patternStore: patternStore,
|
||||||
|
observationStore: observationStore,
|
||||||
|
candidates: make(map[string]*candidatePattern),
|
||||||
|
ctx: ctx,
|
||||||
|
cancel: cancel,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Start begins background pattern analysis.
|
||||||
|
func (d *Detector) Start() {
|
||||||
|
d.wg.Add(1)
|
||||||
|
go d.backgroundAnalysis()
|
||||||
|
log.Info().Dur("interval", d.config.AnalysisInterval).Msg("Pattern detector started")
|
||||||
|
}
|
||||||
|
|
||||||
|
// Stop stops background pattern analysis.
|
||||||
|
func (d *Detector) Stop() {
|
||||||
|
d.cancel()
|
||||||
|
d.wg.Wait()
|
||||||
|
log.Info().Msg("Pattern detector stopped")
|
||||||
|
}
|
||||||
|
|
||||||
|
// backgroundAnalysis runs periodic pattern analysis.
|
||||||
|
func (d *Detector) backgroundAnalysis() {
|
||||||
|
defer d.wg.Done()
|
||||||
|
|
||||||
|
ticker := time.NewTicker(d.config.AnalysisInterval)
|
||||||
|
defer ticker.Stop()
|
||||||
|
|
||||||
|
for {
|
||||||
|
select {
|
||||||
|
case <-d.ctx.Done():
|
||||||
|
return
|
||||||
|
case <-ticker.C:
|
||||||
|
if err := d.AnalyzeRecentObservations(d.ctx); err != nil {
|
||||||
|
log.Warn().Err(err).Msg("Background pattern analysis failed")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// AnalyzeObservation processes a new observation for pattern detection.
|
||||||
|
// This is called synchronously when a new observation is stored.
|
||||||
|
func (d *Detector) AnalyzeObservation(ctx context.Context, obs *models.Observation) (*DetectionResult, error) {
|
||||||
|
result := &DetectionResult{}
|
||||||
|
|
||||||
|
// Extract signature from observation
|
||||||
|
signature := models.ExtractSignature(
|
||||||
|
obs.Concepts,
|
||||||
|
obs.Title.String,
|
||||||
|
obs.Narrative.String,
|
||||||
|
)
|
||||||
|
if len(signature) == 0 {
|
||||||
|
return result, nil // Nothing to detect
|
||||||
|
}
|
||||||
|
|
||||||
|
// Check against existing patterns
|
||||||
|
matches, err := d.patternStore.FindMatchingPatterns(ctx, signature, d.config.MinMatchScore)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
if len(matches) > 0 {
|
||||||
|
// Found existing pattern match
|
||||||
|
bestMatch := matches[0]
|
||||||
|
for _, m := range matches[1:] {
|
||||||
|
if models.CalculateMatchScore(signature, m.Signature) > models.CalculateMatchScore(signature, bestMatch.Signature) {
|
||||||
|
bestMatch = m
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Update the pattern with new occurrence
|
||||||
|
if err := d.patternStore.IncrementPatternFrequency(ctx, bestMatch.ID, obs.Project, obs.ID); err != nil {
|
||||||
|
log.Warn().Err(err).Int64("pattern_id", bestMatch.ID).Msg("Failed to update pattern frequency")
|
||||||
|
}
|
||||||
|
|
||||||
|
result.MatchedPattern = bestMatch
|
||||||
|
result.MatchScore = models.CalculateMatchScore(signature, bestMatch.Signature)
|
||||||
|
result.IsNewPattern = false
|
||||||
|
|
||||||
|
log.Debug().
|
||||||
|
Int64("pattern_id", bestMatch.ID).
|
||||||
|
Str("pattern_name", bestMatch.Name).
|
||||||
|
Float64("score", result.MatchScore).
|
||||||
|
Msg("Observation matched existing pattern")
|
||||||
|
|
||||||
|
return result, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// No existing pattern match - check candidates
|
||||||
|
candidateKey := generateCandidateKey(signature)
|
||||||
|
d.candidatesMu.Lock()
|
||||||
|
defer d.candidatesMu.Unlock()
|
||||||
|
|
||||||
|
if candidate, exists := d.candidates[candidateKey]; exists {
|
||||||
|
// Update existing candidate
|
||||||
|
candidate.observationIDs = append(candidate.observationIDs, obs.ID)
|
||||||
|
|
||||||
|
// Add project if not already present
|
||||||
|
found := false
|
||||||
|
for _, p := range candidate.projects {
|
||||||
|
if p == obs.Project {
|
||||||
|
found = true
|
||||||
|
break
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if !found {
|
||||||
|
candidate.projects = append(candidate.projects, obs.Project)
|
||||||
|
}
|
||||||
|
candidate.lastSeenEpoch = time.Now().UnixMilli()
|
||||||
|
|
||||||
|
// Check if candidate should be promoted to pattern
|
||||||
|
if len(candidate.observationIDs) >= d.config.MinFrequencyForPattern {
|
||||||
|
pattern, err := d.promoteCandidate(ctx, candidateKey, candidate)
|
||||||
|
if err != nil {
|
||||||
|
log.Warn().Err(err).Msg("Failed to promote candidate to pattern")
|
||||||
|
} else {
|
||||||
|
result.MatchedPattern = pattern
|
||||||
|
result.IsNewPattern = true
|
||||||
|
log.Info().
|
||||||
|
Str("pattern_name", pattern.Name).
|
||||||
|
Int("frequency", pattern.Frequency).
|
||||||
|
Msg("New pattern detected and stored")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
// Create new candidate
|
||||||
|
patternType := models.DetectPatternType(obs.Concepts, obs.Title.String, obs.Narrative.String)
|
||||||
|
d.candidates[candidateKey] = &candidatePattern{
|
||||||
|
signature: signature,
|
||||||
|
observationIDs: []int64{obs.ID},
|
||||||
|
projects: []string{obs.Project},
|
||||||
|
patternType: patternType,
|
||||||
|
title: obs.Title.String,
|
||||||
|
lastSeenEpoch: time.Now().UnixMilli(),
|
||||||
|
}
|
||||||
|
log.Debug().
|
||||||
|
Str("candidate_key", candidateKey).
|
||||||
|
Strs("signature", signature).
|
||||||
|
Msg("New pattern candidate created")
|
||||||
|
}
|
||||||
|
|
||||||
|
return result, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// promoteCandidate converts a candidate to a stored pattern.
|
||||||
|
func (d *Detector) promoteCandidate(ctx context.Context, key string, candidate *candidatePattern) (*models.Pattern, error) {
|
||||||
|
// Generate pattern name from signature
|
||||||
|
name := generatePatternName(candidate.patternType, candidate.signature, candidate.title)
|
||||||
|
|
||||||
|
// Create base pattern using NewPattern with first observation
|
||||||
|
firstProject := ""
|
||||||
|
if len(candidate.projects) > 0 {
|
||||||
|
firstProject = candidate.projects[0]
|
||||||
|
}
|
||||||
|
var firstObsID int64
|
||||||
|
if len(candidate.observationIDs) > 0 {
|
||||||
|
firstObsID = candidate.observationIDs[0]
|
||||||
|
}
|
||||||
|
pattern := models.NewPattern(
|
||||||
|
name,
|
||||||
|
candidate.patternType,
|
||||||
|
"Automatically detected pattern from recurring observations",
|
||||||
|
candidate.signature,
|
||||||
|
firstProject,
|
||||||
|
firstObsID,
|
||||||
|
)
|
||||||
|
|
||||||
|
// Add remaining projects and observations
|
||||||
|
for i := 1; i < len(candidate.projects); i++ {
|
||||||
|
pattern.Projects = append(pattern.Projects, candidate.projects[i])
|
||||||
|
}
|
||||||
|
for i := 1; i < len(candidate.observationIDs); i++ {
|
||||||
|
pattern.ObservationIDs = append(pattern.ObservationIDs, candidate.observationIDs[i])
|
||||||
|
}
|
||||||
|
pattern.Frequency = len(candidate.observationIDs)
|
||||||
|
|
||||||
|
id, err := d.patternStore.StorePattern(ctx, pattern)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
pattern.ID = id
|
||||||
|
|
||||||
|
// Sync to vector store if callback is set
|
||||||
|
if d.syncFunc != nil {
|
||||||
|
d.syncFunc(pattern)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Remove from candidates
|
||||||
|
delete(d.candidates, key)
|
||||||
|
|
||||||
|
return pattern, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// AnalyzeRecentObservations analyzes recent observations for pattern detection.
|
||||||
|
// This is used for background batch analysis.
|
||||||
|
func (d *Detector) AnalyzeRecentObservations(ctx context.Context) error {
|
||||||
|
// Get observations from the last 24 hours that haven't been analyzed
|
||||||
|
observations, err := d.observationStore.GetRecentObservations(ctx, "", 100)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
analyzed := 0
|
||||||
|
patternsFound := 0
|
||||||
|
for _, obs := range observations {
|
||||||
|
result, err := d.AnalyzeObservation(ctx, obs)
|
||||||
|
if err != nil {
|
||||||
|
log.Warn().Err(err).Int64("obs_id", obs.ID).Msg("Failed to analyze observation")
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
analyzed++
|
||||||
|
if result.MatchedPattern != nil {
|
||||||
|
patternsFound++
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if analyzed > 0 {
|
||||||
|
log.Info().
|
||||||
|
Int("analyzed", analyzed).
|
||||||
|
Int("patterns_found", patternsFound).
|
||||||
|
Msg("Background pattern analysis completed")
|
||||||
|
}
|
||||||
|
|
||||||
|
// Clean up old candidates (older than 7 days)
|
||||||
|
d.cleanupOldCandidates()
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// cleanupOldCandidates removes candidates that haven't been seen recently.
|
||||||
|
func (d *Detector) cleanupOldCandidates() {
|
||||||
|
d.candidatesMu.Lock()
|
||||||
|
defer d.candidatesMu.Unlock()
|
||||||
|
|
||||||
|
threshold := time.Now().Add(-7 * 24 * time.Hour).UnixMilli()
|
||||||
|
for key, candidate := range d.candidates {
|
||||||
|
if candidate.lastSeenEpoch < threshold {
|
||||||
|
delete(d.candidates, key)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// GetPatternInsight returns a formatted insight string for a pattern.
|
||||||
|
func (d *Detector) GetPatternInsight(ctx context.Context, patternID int64) (string, error) {
|
||||||
|
pattern, err := d.patternStore.GetPatternByID(ctx, patternID)
|
||||||
|
if err != nil {
|
||||||
|
return "", err
|
||||||
|
}
|
||||||
|
|
||||||
|
return formatPatternInsight(pattern), nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// DetectionResult contains the result of pattern detection.
|
||||||
|
type DetectionResult struct {
|
||||||
|
MatchedPattern *models.Pattern
|
||||||
|
MatchScore float64
|
||||||
|
IsNewPattern bool
|
||||||
|
}
|
||||||
|
|
||||||
|
// generateCandidateKey creates a unique key for a signature.
|
||||||
|
func generateCandidateKey(signature []string) string {
|
||||||
|
if len(signature) == 0 {
|
||||||
|
return ""
|
||||||
|
}
|
||||||
|
key := ""
|
||||||
|
for _, s := range signature {
|
||||||
|
key += s + "|"
|
||||||
|
}
|
||||||
|
return key
|
||||||
|
}
|
||||||
|
|
||||||
|
// generatePatternName creates a human-readable name for a pattern.
|
||||||
|
func generatePatternName(patternType models.PatternType, signature []string, title string) string {
|
||||||
|
// Use title if available and meaningful
|
||||||
|
if title != "" && len(title) < 60 {
|
||||||
|
return title
|
||||||
|
}
|
||||||
|
|
||||||
|
// Otherwise generate from type and signature
|
||||||
|
prefix := ""
|
||||||
|
switch patternType {
|
||||||
|
case models.PatternTypeBug:
|
||||||
|
prefix = "Bug Pattern: "
|
||||||
|
case models.PatternTypeRefactor:
|
||||||
|
prefix = "Refactor Pattern: "
|
||||||
|
case models.PatternTypeArchitecture:
|
||||||
|
prefix = "Architecture Pattern: "
|
||||||
|
case models.PatternTypeAntiPattern:
|
||||||
|
prefix = "Anti-Pattern: "
|
||||||
|
case models.PatternTypeBestPractice:
|
||||||
|
prefix = "Best Practice: "
|
||||||
|
}
|
||||||
|
|
||||||
|
// Use first few signature elements
|
||||||
|
if len(signature) > 0 {
|
||||||
|
name := prefix
|
||||||
|
for i, s := range signature {
|
||||||
|
if i >= 3 {
|
||||||
|
break
|
||||||
|
}
|
||||||
|
if i > 0 {
|
||||||
|
name += ", "
|
||||||
|
}
|
||||||
|
name += s
|
||||||
|
}
|
||||||
|
return name
|
||||||
|
}
|
||||||
|
|
||||||
|
return prefix + "Unnamed"
|
||||||
|
}
|
||||||
|
|
||||||
|
// formatPatternInsight creates a human-readable insight from a pattern.
|
||||||
|
func formatPatternInsight(pattern *models.Pattern) string {
|
||||||
|
insight := "I've encountered this pattern " +
|
||||||
|
itoa(pattern.Frequency) + " times"
|
||||||
|
|
||||||
|
if len(pattern.Projects) > 1 {
|
||||||
|
insight += " across " + itoa(len(pattern.Projects)) + " projects"
|
||||||
|
}
|
||||||
|
|
||||||
|
insight += ". "
|
||||||
|
|
||||||
|
if pattern.Recommendation.Valid && pattern.Recommendation.String != "" {
|
||||||
|
insight += "What works: " + pattern.Recommendation.String
|
||||||
|
} else {
|
||||||
|
switch pattern.Type {
|
||||||
|
case models.PatternTypeBug:
|
||||||
|
insight += "This appears to be a recurring bug pattern."
|
||||||
|
case models.PatternTypeAntiPattern:
|
||||||
|
insight += "This is an identified anti-pattern to avoid."
|
||||||
|
case models.PatternTypeBestPractice:
|
||||||
|
insight += "This is a validated best practice."
|
||||||
|
default:
|
||||||
|
insight += "This is a recognized pattern in the codebase."
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return insight
|
||||||
|
}
|
||||||
|
|
||||||
|
// itoa converts int to string without importing strconv.
|
||||||
|
func itoa(n int) string {
|
||||||
|
if n == 0 {
|
||||||
|
return "0"
|
||||||
|
}
|
||||||
|
negative := false
|
||||||
|
if n < 0 {
|
||||||
|
negative = true
|
||||||
|
n = -n
|
||||||
|
}
|
||||||
|
var digits []byte
|
||||||
|
for n > 0 {
|
||||||
|
digits = append([]byte{byte('0' + n%10)}, digits...)
|
||||||
|
n /= 10
|
||||||
|
}
|
||||||
|
if negative {
|
||||||
|
digits = append([]byte{'-'}, digits...)
|
||||||
|
}
|
||||||
|
return string(digits)
|
||||||
|
}
|
||||||
@@ -0,0 +1,450 @@
|
|||||||
|
package pattern
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"database/sql"
|
||||||
|
"os"
|
||||||
|
"testing"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"github.com/lukaszraczylo/claude-mnemonic/internal/db/sqlite"
|
||||||
|
"github.com/lukaszraczylo/claude-mnemonic/pkg/models"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestNewDetector(t *testing.T) {
|
||||||
|
store := setupTestStore(t)
|
||||||
|
defer store.Close()
|
||||||
|
|
||||||
|
patternStore := sqlite.NewPatternStore(store)
|
||||||
|
observationStore := sqlite.NewObservationStore(store)
|
||||||
|
config := DefaultConfig()
|
||||||
|
|
||||||
|
detector := NewDetector(patternStore, observationStore, config)
|
||||||
|
if detector == nil {
|
||||||
|
t.Fatal("Expected non-nil detector")
|
||||||
|
}
|
||||||
|
|
||||||
|
if detector.config.MinMatchScore != config.MinMatchScore {
|
||||||
|
t.Errorf("Expected MinMatchScore %f, got %f",
|
||||||
|
config.MinMatchScore, detector.config.MinMatchScore)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestDetector_StartStop(t *testing.T) {
|
||||||
|
store := setupTestStore(t)
|
||||||
|
defer store.Close()
|
||||||
|
|
||||||
|
patternStore := sqlite.NewPatternStore(store)
|
||||||
|
observationStore := sqlite.NewObservationStore(store)
|
||||||
|
config := DefaultConfig()
|
||||||
|
config.AnalysisInterval = 100 * time.Millisecond // Short interval for testing
|
||||||
|
|
||||||
|
detector := NewDetector(patternStore, observationStore, config)
|
||||||
|
|
||||||
|
// Start
|
||||||
|
detector.Start()
|
||||||
|
|
||||||
|
// Wait a bit to ensure background goroutine is running
|
||||||
|
time.Sleep(50 * time.Millisecond)
|
||||||
|
|
||||||
|
// Stop
|
||||||
|
detector.Stop()
|
||||||
|
|
||||||
|
// Verify we can stop without hanging
|
||||||
|
// (if this test hangs, the Stop logic is broken)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestDetector_AnalyzeObservation_NewCandidate(t *testing.T) {
|
||||||
|
store := setupTestStore(t)
|
||||||
|
defer store.Close()
|
||||||
|
|
||||||
|
patternStore := sqlite.NewPatternStore(store)
|
||||||
|
observationStore := sqlite.NewObservationStore(store)
|
||||||
|
config := DefaultConfig()
|
||||||
|
config.MinFrequencyForPattern = 2
|
||||||
|
|
||||||
|
detector := NewDetector(patternStore, observationStore, config)
|
||||||
|
ctx := context.Background()
|
||||||
|
|
||||||
|
// Create an observation
|
||||||
|
obs := createTestObservation(1, "nil-handling", []string{"nil", "error-handling"})
|
||||||
|
|
||||||
|
// Analyze first observation
|
||||||
|
result, err := detector.AnalyzeObservation(ctx, obs)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("AnalyzeObservation() error = %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// First observation should create a candidate, not a pattern
|
||||||
|
if result.MatchedPattern != nil {
|
||||||
|
t.Errorf("Expected no pattern match for first observation")
|
||||||
|
}
|
||||||
|
if result.IsNewPattern {
|
||||||
|
t.Errorf("Expected IsNewPattern to be false for first observation")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestDetector_AnalyzeObservation_PromoteToPattern(t *testing.T) {
|
||||||
|
store := setupTestStore(t)
|
||||||
|
defer store.Close()
|
||||||
|
|
||||||
|
patternStore := sqlite.NewPatternStore(store)
|
||||||
|
observationStore := sqlite.NewObservationStore(store)
|
||||||
|
config := DefaultConfig()
|
||||||
|
config.MinFrequencyForPattern = 2
|
||||||
|
|
||||||
|
detector := NewDetector(patternStore, observationStore, config)
|
||||||
|
ctx := context.Background()
|
||||||
|
|
||||||
|
// Create two similar observations
|
||||||
|
obs1 := createTestObservation(1, "Nil pointer handling", []string{"nil", "error-handling"})
|
||||||
|
obs2 := createTestObservation(2, "Nil pointer handling", []string{"nil", "error-handling"})
|
||||||
|
|
||||||
|
// Analyze first observation
|
||||||
|
_, err := detector.AnalyzeObservation(ctx, obs1)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("AnalyzeObservation(obs1) error = %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Analyze second observation - should promote to pattern
|
||||||
|
result, err := detector.AnalyzeObservation(ctx, obs2)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("AnalyzeObservation(obs2) error = %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if result.MatchedPattern == nil {
|
||||||
|
t.Fatal("Expected pattern to be created after second occurrence")
|
||||||
|
}
|
||||||
|
if !result.IsNewPattern {
|
||||||
|
t.Errorf("Expected IsNewPattern to be true for newly promoted pattern")
|
||||||
|
}
|
||||||
|
if result.MatchedPattern.Frequency != 2 {
|
||||||
|
t.Errorf("Expected frequency 2, got %d", result.MatchedPattern.Frequency)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestDetector_AnalyzeObservation_MatchExisting(t *testing.T) {
|
||||||
|
store := setupTestStore(t)
|
||||||
|
defer store.Close()
|
||||||
|
|
||||||
|
patternStore := sqlite.NewPatternStore(store)
|
||||||
|
observationStore := sqlite.NewObservationStore(store)
|
||||||
|
config := DefaultConfig()
|
||||||
|
|
||||||
|
detector := NewDetector(patternStore, observationStore, config)
|
||||||
|
ctx := context.Background()
|
||||||
|
|
||||||
|
// Create an existing pattern
|
||||||
|
pattern := &models.Pattern{
|
||||||
|
Name: "Existing Pattern",
|
||||||
|
Type: models.PatternTypeBug,
|
||||||
|
Signature: []string{"nil", "error-handling", "pointer"},
|
||||||
|
Frequency: 5,
|
||||||
|
Projects: []string{"proj1"},
|
||||||
|
ObservationIDs: []int64{1, 2, 3, 4, 5},
|
||||||
|
Status: models.PatternStatusActive,
|
||||||
|
Confidence: 0.7,
|
||||||
|
LastSeenAt: time.Now().Format(time.RFC3339),
|
||||||
|
LastSeenEpoch: time.Now().UnixMilli(),
|
||||||
|
CreatedAt: time.Now().Format(time.RFC3339),
|
||||||
|
CreatedAtEpoch: time.Now().UnixMilli(),
|
||||||
|
}
|
||||||
|
patternStore.StorePattern(ctx, pattern)
|
||||||
|
|
||||||
|
// Create observation with similar signature
|
||||||
|
obs := createTestObservation(10, "Nil check", []string{"nil", "error-handling"})
|
||||||
|
|
||||||
|
// Analyze - should match existing pattern
|
||||||
|
result, err := detector.AnalyzeObservation(ctx, obs)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("AnalyzeObservation() error = %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if result.MatchedPattern == nil {
|
||||||
|
t.Fatal("Expected to match existing pattern")
|
||||||
|
}
|
||||||
|
if result.IsNewPattern {
|
||||||
|
t.Errorf("Expected IsNewPattern to be false for existing pattern")
|
||||||
|
}
|
||||||
|
if result.MatchScore < 0.3 {
|
||||||
|
t.Errorf("Expected match score >= 0.3, got %f", result.MatchScore)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestDetector_AnalyzeObservation_NoMatch(t *testing.T) {
|
||||||
|
store := setupTestStore(t)
|
||||||
|
defer store.Close()
|
||||||
|
|
||||||
|
patternStore := sqlite.NewPatternStore(store)
|
||||||
|
observationStore := sqlite.NewObservationStore(store)
|
||||||
|
config := DefaultConfig()
|
||||||
|
config.MinMatchScore = 0.5 // Higher threshold
|
||||||
|
|
||||||
|
detector := NewDetector(patternStore, observationStore, config)
|
||||||
|
ctx := context.Background()
|
||||||
|
|
||||||
|
// Create an existing pattern with specific signature
|
||||||
|
pattern := &models.Pattern{
|
||||||
|
Name: "Specific Pattern",
|
||||||
|
Type: models.PatternTypeBug,
|
||||||
|
Signature: []string{"database", "connection", "pool"},
|
||||||
|
Frequency: 3,
|
||||||
|
Projects: []string{"proj1"},
|
||||||
|
ObservationIDs: []int64{1, 2, 3},
|
||||||
|
Status: models.PatternStatusActive,
|
||||||
|
Confidence: 0.6,
|
||||||
|
LastSeenAt: time.Now().Format(time.RFC3339),
|
||||||
|
LastSeenEpoch: time.Now().UnixMilli(),
|
||||||
|
CreatedAt: time.Now().Format(time.RFC3339),
|
||||||
|
CreatedAtEpoch: time.Now().UnixMilli(),
|
||||||
|
}
|
||||||
|
patternStore.StorePattern(ctx, pattern)
|
||||||
|
|
||||||
|
// Create observation with completely different signature
|
||||||
|
obs := createTestObservation(10, "UI Component", []string{"frontend", "react", "component"})
|
||||||
|
|
||||||
|
// Analyze - should not match
|
||||||
|
result, err := detector.AnalyzeObservation(ctx, obs)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("AnalyzeObservation() error = %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if result.MatchedPattern != nil {
|
||||||
|
t.Errorf("Expected no match for unrelated observation")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestDetector_CandidateCleanup(t *testing.T) {
|
||||||
|
store := setupTestStore(t)
|
||||||
|
defer store.Close()
|
||||||
|
|
||||||
|
patternStore := sqlite.NewPatternStore(store)
|
||||||
|
observationStore := sqlite.NewObservationStore(store)
|
||||||
|
config := DefaultConfig()
|
||||||
|
config.MinFrequencyForPattern = 3 // Higher threshold
|
||||||
|
|
||||||
|
detector := NewDetector(patternStore, observationStore, config)
|
||||||
|
|
||||||
|
// Add an old candidate manually
|
||||||
|
oldKey := "old|candidate|"
|
||||||
|
detector.candidates[oldKey] = &candidatePattern{
|
||||||
|
signature: []string{"old", "candidate"},
|
||||||
|
observationIDs: []int64{1},
|
||||||
|
projects: []string{"proj1"},
|
||||||
|
patternType: models.PatternTypeBug,
|
||||||
|
title: "Old Candidate",
|
||||||
|
lastSeenEpoch: time.Now().Add(-8 * 24 * time.Hour).UnixMilli(), // 8 days ago
|
||||||
|
}
|
||||||
|
|
||||||
|
// Add a recent candidate
|
||||||
|
recentKey := "recent|candidate|"
|
||||||
|
detector.candidates[recentKey] = &candidatePattern{
|
||||||
|
signature: []string{"recent", "candidate"},
|
||||||
|
observationIDs: []int64{2},
|
||||||
|
projects: []string{"proj1"},
|
||||||
|
patternType: models.PatternTypeBug,
|
||||||
|
title: "Recent Candidate",
|
||||||
|
lastSeenEpoch: time.Now().UnixMilli(),
|
||||||
|
}
|
||||||
|
|
||||||
|
// Run cleanup
|
||||||
|
detector.cleanupOldCandidates()
|
||||||
|
|
||||||
|
// Old candidate should be removed
|
||||||
|
if _, exists := detector.candidates[oldKey]; exists {
|
||||||
|
t.Errorf("Expected old candidate to be cleaned up")
|
||||||
|
}
|
||||||
|
|
||||||
|
// Recent candidate should remain
|
||||||
|
if _, exists := detector.candidates[recentKey]; !exists {
|
||||||
|
t.Errorf("Expected recent candidate to remain")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestDetector_GetPatternInsight(t *testing.T) {
|
||||||
|
store := setupTestStore(t)
|
||||||
|
defer store.Close()
|
||||||
|
|
||||||
|
patternStore := sqlite.NewPatternStore(store)
|
||||||
|
observationStore := sqlite.NewObservationStore(store)
|
||||||
|
config := DefaultConfig()
|
||||||
|
|
||||||
|
detector := NewDetector(patternStore, observationStore, config)
|
||||||
|
ctx := context.Background()
|
||||||
|
|
||||||
|
// Create pattern with recommendation
|
||||||
|
pattern := &models.Pattern{
|
||||||
|
Name: "Insight Test Pattern",
|
||||||
|
Type: models.PatternTypeBestPractice,
|
||||||
|
Signature: []string{"test"},
|
||||||
|
Recommendation: sql.NullString{String: "Always write tests first", Valid: true},
|
||||||
|
Frequency: 12,
|
||||||
|
Projects: []string{"proj1", "proj2", "proj3"},
|
||||||
|
ObservationIDs: []int64{1},
|
||||||
|
Status: models.PatternStatusActive,
|
||||||
|
Confidence: 0.8,
|
||||||
|
LastSeenAt: time.Now().Format(time.RFC3339),
|
||||||
|
LastSeenEpoch: time.Now().UnixMilli(),
|
||||||
|
CreatedAt: time.Now().Format(time.RFC3339),
|
||||||
|
CreatedAtEpoch: time.Now().UnixMilli(),
|
||||||
|
}
|
||||||
|
id, _ := patternStore.StorePattern(ctx, pattern)
|
||||||
|
|
||||||
|
// Get insight
|
||||||
|
insight, err := detector.GetPatternInsight(ctx, id)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("GetPatternInsight() error = %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Verify insight contains expected elements
|
||||||
|
if insight == "" {
|
||||||
|
t.Error("Expected non-empty insight")
|
||||||
|
}
|
||||||
|
if !containsString(insight, "12 times") {
|
||||||
|
t.Errorf("Expected insight to contain frequency, got: %s", insight)
|
||||||
|
}
|
||||||
|
if !containsString(insight, "3 projects") {
|
||||||
|
t.Errorf("Expected insight to contain project count, got: %s", insight)
|
||||||
|
}
|
||||||
|
if !containsString(insight, "Always write tests first") {
|
||||||
|
t.Errorf("Expected insight to contain recommendation, got: %s", insight)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestDefaultConfig(t *testing.T) {
|
||||||
|
config := DefaultConfig()
|
||||||
|
|
||||||
|
if config.MinMatchScore <= 0 || config.MinMatchScore > 1 {
|
||||||
|
t.Errorf("Invalid MinMatchScore: %f", config.MinMatchScore)
|
||||||
|
}
|
||||||
|
if config.MinFrequencyForPattern < 1 {
|
||||||
|
t.Errorf("Invalid MinFrequencyForPattern: %d", config.MinFrequencyForPattern)
|
||||||
|
}
|
||||||
|
if config.AnalysisInterval <= 0 {
|
||||||
|
t.Errorf("Invalid AnalysisInterval: %v", config.AnalysisInterval)
|
||||||
|
}
|
||||||
|
if config.MaxPatternsToTrack <= 0 {
|
||||||
|
t.Errorf("Invalid MaxPatternsToTrack: %d", config.MaxPatternsToTrack)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestGeneratePatternName(t *testing.T) {
|
||||||
|
tests := []struct {
|
||||||
|
patternType models.PatternType
|
||||||
|
signature []string
|
||||||
|
title string
|
||||||
|
wantPrefix string
|
||||||
|
}{
|
||||||
|
{models.PatternTypeBug, []string{"nil", "error"}, "", "Bug Pattern:"},
|
||||||
|
{models.PatternTypeRefactor, []string{"extract"}, "", "Refactor Pattern:"},
|
||||||
|
{models.PatternTypeArchitecture, []string{"service"}, "", "Architecture Pattern:"},
|
||||||
|
{models.PatternTypeAntiPattern, []string{"god-class"}, "", "Anti-Pattern:"},
|
||||||
|
{models.PatternTypeBestPractice, []string{"testing"}, "", "Best Practice:"},
|
||||||
|
{models.PatternTypeBug, []string{}, "Short Title", "Short Title"}, // Use title directly
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tt := range tests {
|
||||||
|
name := generatePatternName(tt.patternType, tt.signature, tt.title)
|
||||||
|
if !hasPrefix(name, tt.wantPrefix) {
|
||||||
|
t.Errorf("generatePatternName(%v, %v, %q) = %q, want prefix %q",
|
||||||
|
tt.patternType, tt.signature, tt.title, name, tt.wantPrefix)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestFormatPatternInsight(t *testing.T) {
|
||||||
|
// Pattern without recommendation
|
||||||
|
pattern1 := &models.Pattern{
|
||||||
|
Type: models.PatternTypeBug,
|
||||||
|
Frequency: 5,
|
||||||
|
Projects: []string{"proj1"},
|
||||||
|
}
|
||||||
|
insight1 := formatPatternInsight(pattern1)
|
||||||
|
if !containsString(insight1, "5 times") {
|
||||||
|
t.Errorf("Expected insight to contain frequency")
|
||||||
|
}
|
||||||
|
if !containsString(insight1, "recurring bug pattern") {
|
||||||
|
t.Errorf("Expected bug pattern description")
|
||||||
|
}
|
||||||
|
|
||||||
|
// Pattern with recommendation
|
||||||
|
pattern2 := &models.Pattern{
|
||||||
|
Type: models.PatternTypeBestPractice,
|
||||||
|
Frequency: 10,
|
||||||
|
Projects: []string{"proj1", "proj2"},
|
||||||
|
Recommendation: sql.NullString{String: "Do this", Valid: true},
|
||||||
|
}
|
||||||
|
insight2 := formatPatternInsight(pattern2)
|
||||||
|
if !containsString(insight2, "10 times") {
|
||||||
|
t.Errorf("Expected insight to contain frequency")
|
||||||
|
}
|
||||||
|
if !containsString(insight2, "2 projects") {
|
||||||
|
t.Errorf("Expected insight to contain project count")
|
||||||
|
}
|
||||||
|
if !containsString(insight2, "Do this") {
|
||||||
|
t.Errorf("Expected insight to contain recommendation")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Helper functions
|
||||||
|
|
||||||
|
func setupTestStore(t *testing.T) *sqlite.Store {
|
||||||
|
t.Helper()
|
||||||
|
|
||||||
|
// Create temp database file
|
||||||
|
tmpFile, err := os.CreateTemp("", "pattern_test_*.db")
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("Failed to create temp file: %v", err)
|
||||||
|
}
|
||||||
|
tmpFile.Close()
|
||||||
|
|
||||||
|
t.Cleanup(func() {
|
||||||
|
os.Remove(tmpFile.Name())
|
||||||
|
})
|
||||||
|
|
||||||
|
store, err := sqlite.NewStore(sqlite.StoreConfig{
|
||||||
|
Path: tmpFile.Name(),
|
||||||
|
MaxConns: 1,
|
||||||
|
WALMode: true,
|
||||||
|
})
|
||||||
|
if err != nil {
|
||||||
|
// Check if this is an FTS5 related error
|
||||||
|
if containsString(err.Error(), "fts5") || containsString(err.Error(), "no such module") {
|
||||||
|
t.Skip("FTS5 not available in this SQLite build")
|
||||||
|
}
|
||||||
|
t.Fatalf("Failed to create store: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
return store
|
||||||
|
}
|
||||||
|
|
||||||
|
func createTestObservation(id int64, title string, concepts []string) *models.Observation {
|
||||||
|
return &models.Observation{
|
||||||
|
ID: id,
|
||||||
|
SDKSessionID: "test-session",
|
||||||
|
Project: "test-project",
|
||||||
|
Scope: models.ScopeProject,
|
||||||
|
Type: models.ObsTypeBugfix,
|
||||||
|
Title: sql.NullString{String: title, Valid: true},
|
||||||
|
Narrative: sql.NullString{String: "Test narrative", Valid: true},
|
||||||
|
Concepts: concepts,
|
||||||
|
CreatedAt: time.Now().Format(time.RFC3339),
|
||||||
|
CreatedAtEpoch: time.Now().UnixMilli(),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func containsString(s, substr string) bool {
|
||||||
|
for i := 0; i <= len(s)-len(substr); i++ {
|
||||||
|
if s[i:i+len(substr)] == substr {
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|
||||||
|
func hasPrefix(s, prefix string) bool {
|
||||||
|
if len(s) < len(prefix) {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
return s[:len(prefix)] == prefix
|
||||||
|
}
|
||||||
@@ -0,0 +1,14 @@
|
|||||||
|
// Package reranking provides cross-encoder reranking for search results.
|
||||||
|
package reranking
|
||||||
|
|
||||||
|
import (
|
||||||
|
_ "embed"
|
||||||
|
)
|
||||||
|
|
||||||
|
// Cross-encoder model and tokenizer files - embedded for all platforms
|
||||||
|
//
|
||||||
|
//go:embed assets/model.onnx
|
||||||
|
var crossEncoderModelData []byte
|
||||||
|
|
||||||
|
//go:embed assets/tokenizer.json
|
||||||
|
var crossEncoderTokenizerData []byte
|
||||||
Binary file not shown.
File diff suppressed because it is too large
Load Diff
@@ -0,0 +1,57 @@
|
|||||||
|
{
|
||||||
|
"added_tokens_decoder": {
|
||||||
|
"0": {
|
||||||
|
"content": "[PAD]",
|
||||||
|
"lstrip": false,
|
||||||
|
"normalized": false,
|
||||||
|
"rstrip": false,
|
||||||
|
"single_word": false,
|
||||||
|
"special": true
|
||||||
|
},
|
||||||
|
"100": {
|
||||||
|
"content": "[UNK]",
|
||||||
|
"lstrip": false,
|
||||||
|
"normalized": false,
|
||||||
|
"rstrip": false,
|
||||||
|
"single_word": false,
|
||||||
|
"special": true
|
||||||
|
},
|
||||||
|
"101": {
|
||||||
|
"content": "[CLS]",
|
||||||
|
"lstrip": false,
|
||||||
|
"normalized": false,
|
||||||
|
"rstrip": false,
|
||||||
|
"single_word": false,
|
||||||
|
"special": true
|
||||||
|
},
|
||||||
|
"102": {
|
||||||
|
"content": "[SEP]",
|
||||||
|
"lstrip": false,
|
||||||
|
"normalized": false,
|
||||||
|
"rstrip": false,
|
||||||
|
"single_word": false,
|
||||||
|
"special": true
|
||||||
|
},
|
||||||
|
"103": {
|
||||||
|
"content": "[MASK]",
|
||||||
|
"lstrip": false,
|
||||||
|
"normalized": false,
|
||||||
|
"rstrip": false,
|
||||||
|
"single_word": false,
|
||||||
|
"special": true
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"clean_up_tokenization_spaces": true,
|
||||||
|
"cls_token": "[CLS]",
|
||||||
|
"do_basic_tokenize": true,
|
||||||
|
"do_lower_case": true,
|
||||||
|
"mask_token": "[MASK]",
|
||||||
|
"model_max_length": 512,
|
||||||
|
"never_split": null,
|
||||||
|
"pad_token": "[PAD]",
|
||||||
|
"sep_token": "[SEP]",
|
||||||
|
"strip_accents": null,
|
||||||
|
"tokenize_chinese_chars": true,
|
||||||
|
"tokenizer_class": "BertTokenizer",
|
||||||
|
"unk_token": "[UNK]"
|
||||||
|
}
|
||||||
@@ -0,0 +1,380 @@
|
|||||||
|
// Package reranking provides cross-encoder reranking for search results.
|
||||||
|
// Uses MS-MARCO MiniLM L6 v2 cross-encoder model for relevance scoring.
|
||||||
|
package reranking
|
||||||
|
|
||||||
|
import (
|
||||||
|
"bytes"
|
||||||
|
"fmt"
|
||||||
|
"math"
|
||||||
|
"sort"
|
||||||
|
"sync"
|
||||||
|
|
||||||
|
"github.com/rs/zerolog/log"
|
||||||
|
"github.com/sugarme/tokenizer"
|
||||||
|
"github.com/sugarme/tokenizer/pretrained"
|
||||||
|
ort "github.com/yalue/onnxruntime_go"
|
||||||
|
)
|
||||||
|
|
||||||
|
const (
|
||||||
|
// ModelName is the human-readable name for the cross-encoder model
|
||||||
|
ModelName = "ms-marco-MiniLM-L6-v2"
|
||||||
|
// ModelVersion is the short version string for identification
|
||||||
|
ModelVersion = "msmarco-v2"
|
||||||
|
// MaxSequenceLength is the maximum combined query+document token length
|
||||||
|
MaxSequenceLength = 512
|
||||||
|
// DefaultCandidateLimit is the default number of candidates to rerank
|
||||||
|
DefaultCandidateLimit = 100
|
||||||
|
// DefaultResultLimit is the default number of results to return after reranking
|
||||||
|
DefaultResultLimit = 10
|
||||||
|
)
|
||||||
|
|
||||||
|
// Candidate represents a search result candidate for reranking.
|
||||||
|
type Candidate struct {
|
||||||
|
ID string // Document ID
|
||||||
|
Content string // Document text content for scoring
|
||||||
|
Score float64 // Original bi-encoder similarity score
|
||||||
|
Metadata map[string]any // Preserved metadata
|
||||||
|
RerankInfo map[string]float64 // Reranking debug info (optional)
|
||||||
|
}
|
||||||
|
|
||||||
|
// RerankResult represents a reranked search result.
|
||||||
|
type RerankResult struct {
|
||||||
|
ID string // Document ID
|
||||||
|
Content string // Document text content
|
||||||
|
OriginalScore float64 // Original bi-encoder score
|
||||||
|
RerankScore float64 // Cross-encoder relevance score
|
||||||
|
CombinedScore float64 // Weighted combination of scores
|
||||||
|
Metadata map[string]any // Preserved metadata
|
||||||
|
OriginalRank int // Position before reranking (1-indexed)
|
||||||
|
RerankRank int // Position after reranking (1-indexed)
|
||||||
|
RankImprovement int // How much the rank improved (positive = moved up)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Service provides cross-encoder reranking functionality.
|
||||||
|
type Service struct {
|
||||||
|
tk *tokenizer.Tokenizer
|
||||||
|
session *ort.DynamicAdvancedSession
|
||||||
|
mu sync.Mutex
|
||||||
|
|
||||||
|
// Weight for combining scores: combined = alpha*rerank + (1-alpha)*original
|
||||||
|
// Default 0.7 favors cross-encoder score
|
||||||
|
Alpha float64
|
||||||
|
}
|
||||||
|
|
||||||
|
// Config holds configuration for the reranking service.
|
||||||
|
type Config struct {
|
||||||
|
// Alpha is the weight for combining scores (0.0-1.0)
|
||||||
|
// Higher values favor cross-encoder scores, lower values favor bi-encoder scores
|
||||||
|
Alpha float64
|
||||||
|
}
|
||||||
|
|
||||||
|
// DefaultConfig returns sensible defaults for reranking.
|
||||||
|
func DefaultConfig() Config {
|
||||||
|
return Config{
|
||||||
|
Alpha: 0.7, // Favor cross-encoder by default
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// NewService creates a new cross-encoder reranking service.
|
||||||
|
// Note: ONNX runtime must be initialized before calling this (via embedding.NewService).
|
||||||
|
func NewService(cfg Config) (*Service, error) {
|
||||||
|
// Load tokenizer from embedded data
|
||||||
|
tk, err := pretrained.FromReader(bytes.NewReader(crossEncoderTokenizerData))
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("load cross-encoder tokenizer: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Configure tokenizer for sequence classification (pairs)
|
||||||
|
tk.WithTruncation(&tokenizer.TruncationParams{
|
||||||
|
MaxLength: MaxSequenceLength,
|
||||||
|
Strategy: tokenizer.LongestFirst,
|
||||||
|
Stride: 0,
|
||||||
|
})
|
||||||
|
|
||||||
|
// Cross-encoder outputs a single logit for relevance scoring
|
||||||
|
inputNames := []string{"input_ids", "attention_mask", "token_type_ids"}
|
||||||
|
outputNames := []string{"logits"}
|
||||||
|
|
||||||
|
session, err := ort.NewDynamicAdvancedSessionWithONNXData(
|
||||||
|
crossEncoderModelData,
|
||||||
|
inputNames,
|
||||||
|
outputNames,
|
||||||
|
nil,
|
||||||
|
)
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("create cross-encoder ONNX session: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
alpha := cfg.Alpha
|
||||||
|
if alpha <= 0 || alpha > 1 {
|
||||||
|
alpha = 0.7
|
||||||
|
}
|
||||||
|
|
||||||
|
return &Service{
|
||||||
|
tk: tk,
|
||||||
|
session: session,
|
||||||
|
Alpha: alpha,
|
||||||
|
}, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// Rerank reranks candidates using the cross-encoder model.
|
||||||
|
// Takes a query and list of candidates, returns reranked results.
|
||||||
|
func (s *Service) Rerank(query string, candidates []Candidate, limit int) ([]RerankResult, error) {
|
||||||
|
if len(candidates) == 0 {
|
||||||
|
return nil, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
if limit <= 0 {
|
||||||
|
limit = DefaultResultLimit
|
||||||
|
}
|
||||||
|
|
||||||
|
s.mu.Lock()
|
||||||
|
defer s.mu.Unlock()
|
||||||
|
|
||||||
|
// Score all query-document pairs
|
||||||
|
scores, err := s.scoreAll(query, candidates)
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("score candidates: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Build results with combined scores
|
||||||
|
results := make([]RerankResult, len(candidates))
|
||||||
|
for i, c := range candidates {
|
||||||
|
// Normalize cross-encoder score to 0-1 range using sigmoid
|
||||||
|
normalizedRerank := sigmoid(scores[i])
|
||||||
|
|
||||||
|
results[i] = RerankResult{
|
||||||
|
ID: c.ID,
|
||||||
|
Content: c.Content,
|
||||||
|
OriginalScore: c.Score,
|
||||||
|
RerankScore: normalizedRerank,
|
||||||
|
CombinedScore: s.Alpha*normalizedRerank + (1-s.Alpha)*c.Score,
|
||||||
|
Metadata: c.Metadata,
|
||||||
|
OriginalRank: i + 1,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Sort by combined score (descending)
|
||||||
|
sort.Slice(results, func(i, j int) bool {
|
||||||
|
return results[i].CombinedScore > results[j].CombinedScore
|
||||||
|
})
|
||||||
|
|
||||||
|
// Assign rerank positions and calculate improvement
|
||||||
|
for i := range results {
|
||||||
|
results[i].RerankRank = i + 1
|
||||||
|
results[i].RankImprovement = results[i].OriginalRank - results[i].RerankRank
|
||||||
|
}
|
||||||
|
|
||||||
|
// Apply limit
|
||||||
|
if len(results) > limit {
|
||||||
|
results = results[:limit]
|
||||||
|
}
|
||||||
|
|
||||||
|
log.Debug().
|
||||||
|
Int("candidates", len(candidates)).
|
||||||
|
Int("returned", len(results)).
|
||||||
|
Float64("alpha", s.Alpha).
|
||||||
|
Msg("Cross-encoder reranking completed")
|
||||||
|
|
||||||
|
return results, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// RerankByScore reranks candidates and returns sorted by pure cross-encoder score.
|
||||||
|
// Useful when you want to completely replace bi-encoder ranking.
|
||||||
|
func (s *Service) RerankByScore(query string, candidates []Candidate, limit int) ([]RerankResult, error) {
|
||||||
|
if len(candidates) == 0 {
|
||||||
|
return nil, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
if limit <= 0 {
|
||||||
|
limit = DefaultResultLimit
|
||||||
|
}
|
||||||
|
|
||||||
|
s.mu.Lock()
|
||||||
|
defer s.mu.Unlock()
|
||||||
|
|
||||||
|
scores, err := s.scoreAll(query, candidates)
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("score candidates: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
results := make([]RerankResult, len(candidates))
|
||||||
|
for i, c := range candidates {
|
||||||
|
normalizedRerank := sigmoid(scores[i])
|
||||||
|
results[i] = RerankResult{
|
||||||
|
ID: c.ID,
|
||||||
|
Content: c.Content,
|
||||||
|
OriginalScore: c.Score,
|
||||||
|
RerankScore: normalizedRerank,
|
||||||
|
CombinedScore: normalizedRerank, // Use pure rerank score
|
||||||
|
Metadata: c.Metadata,
|
||||||
|
OriginalRank: i + 1,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Sort by rerank score only
|
||||||
|
sort.Slice(results, func(i, j int) bool {
|
||||||
|
return results[i].RerankScore > results[j].RerankScore
|
||||||
|
})
|
||||||
|
|
||||||
|
for i := range results {
|
||||||
|
results[i].RerankRank = i + 1
|
||||||
|
results[i].RankImprovement = results[i].OriginalRank - results[i].RerankRank
|
||||||
|
}
|
||||||
|
|
||||||
|
if len(results) > limit {
|
||||||
|
results = results[:limit]
|
||||||
|
}
|
||||||
|
|
||||||
|
return results, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// scoreAll scores all query-document pairs using the cross-encoder.
|
||||||
|
// Returns raw logits (before sigmoid normalization).
|
||||||
|
func (s *Service) scoreAll(query string, candidates []Candidate) ([]float64, error) {
|
||||||
|
batchSize := len(candidates)
|
||||||
|
|
||||||
|
// Tokenize all query-document pairs
|
||||||
|
pairs := make([]tokenizer.EncodeInput, batchSize)
|
||||||
|
for i, c := range candidates {
|
||||||
|
// Cross-encoder takes query and document as a pair
|
||||||
|
pairs[i] = tokenizer.NewDualEncodeInput(
|
||||||
|
tokenizer.NewRawInputSequence(query),
|
||||||
|
tokenizer.NewRawInputSequence(c.Content),
|
||||||
|
)
|
||||||
|
}
|
||||||
|
|
||||||
|
encodings, err := s.tk.EncodeBatch(pairs, true)
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("tokenize pairs: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Find max sequence length
|
||||||
|
seqLength := 0
|
||||||
|
for _, enc := range encodings {
|
||||||
|
if len(enc.Ids) > seqLength {
|
||||||
|
seqLength = len(enc.Ids)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if seqLength > MaxSequenceLength {
|
||||||
|
seqLength = MaxSequenceLength
|
||||||
|
}
|
||||||
|
|
||||||
|
inputShape := ort.NewShape(int64(batchSize), int64(seqLength))
|
||||||
|
|
||||||
|
// Create input tensors
|
||||||
|
inputIdsData := make([]int64, batchSize*seqLength)
|
||||||
|
attentionMaskData := make([]int64, batchSize*seqLength)
|
||||||
|
tokenTypeIdsData := make([]int64, batchSize*seqLength)
|
||||||
|
|
||||||
|
for b := 0; b < batchSize; b++ {
|
||||||
|
copyLen := len(encodings[b].Ids)
|
||||||
|
if copyLen > seqLength {
|
||||||
|
copyLen = seqLength
|
||||||
|
}
|
||||||
|
for i := 0; i < copyLen; i++ {
|
||||||
|
inputIdsData[b*seqLength+i] = int64(encodings[b].Ids[i])
|
||||||
|
}
|
||||||
|
|
||||||
|
copyLen = len(encodings[b].AttentionMask)
|
||||||
|
if copyLen > seqLength {
|
||||||
|
copyLen = seqLength
|
||||||
|
}
|
||||||
|
for i := 0; i < copyLen; i++ {
|
||||||
|
attentionMaskData[b*seqLength+i] = int64(encodings[b].AttentionMask[i])
|
||||||
|
}
|
||||||
|
|
||||||
|
copyLen = len(encodings[b].TypeIds)
|
||||||
|
if copyLen > seqLength {
|
||||||
|
copyLen = seqLength
|
||||||
|
}
|
||||||
|
for i := 0; i < copyLen; i++ {
|
||||||
|
tokenTypeIdsData[b*seqLength+i] = int64(encodings[b].TypeIds[i])
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
inputIdsTensor, err := ort.NewTensor(inputShape, inputIdsData)
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("create input_ids tensor: %w", err)
|
||||||
|
}
|
||||||
|
defer inputIdsTensor.Destroy()
|
||||||
|
|
||||||
|
attentionMaskTensor, err := ort.NewTensor(inputShape, attentionMaskData)
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("create attention_mask tensor: %w", err)
|
||||||
|
}
|
||||||
|
defer attentionMaskTensor.Destroy()
|
||||||
|
|
||||||
|
tokenTypeIdsTensor, err := ort.NewTensor(inputShape, tokenTypeIdsData)
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("create token_type_ids tensor: %w", err)
|
||||||
|
}
|
||||||
|
defer tokenTypeIdsTensor.Destroy()
|
||||||
|
|
||||||
|
// Cross-encoder outputs [batch, 1] logits
|
||||||
|
outputShape := ort.NewShape(int64(batchSize), 1)
|
||||||
|
outputTensor, err := ort.NewEmptyTensor[float32](outputShape)
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("create output tensor: %w", err)
|
||||||
|
}
|
||||||
|
defer outputTensor.Destroy()
|
||||||
|
|
||||||
|
// Run inference
|
||||||
|
inputTensors := []ort.Value{inputIdsTensor, attentionMaskTensor, tokenTypeIdsTensor}
|
||||||
|
outputTensors := []ort.Value{outputTensor}
|
||||||
|
|
||||||
|
if err := s.session.Run(inputTensors, outputTensors); err != nil {
|
||||||
|
return nil, fmt.Errorf("run cross-encoder inference: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Extract scores
|
||||||
|
flatOutput := outputTensor.GetData()
|
||||||
|
scores := make([]float64, batchSize)
|
||||||
|
for i := 0; i < batchSize; i++ {
|
||||||
|
scores[i] = float64(flatOutput[i])
|
||||||
|
}
|
||||||
|
|
||||||
|
return scores, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// Score scores a single query-document pair.
|
||||||
|
// Returns the raw cross-encoder logit and normalized score.
|
||||||
|
func (s *Service) Score(query, document string) (rawScore, normalizedScore float64, err error) {
|
||||||
|
s.mu.Lock()
|
||||||
|
defer s.mu.Unlock()
|
||||||
|
|
||||||
|
scores, err := s.scoreAll(query, []Candidate{{Content: document}})
|
||||||
|
if err != nil {
|
||||||
|
return 0, 0, err
|
||||||
|
}
|
||||||
|
|
||||||
|
rawScore = scores[0]
|
||||||
|
normalizedScore = sigmoid(rawScore)
|
||||||
|
return rawScore, normalizedScore, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// Close releases model resources.
|
||||||
|
func (s *Service) Close() error {
|
||||||
|
s.mu.Lock()
|
||||||
|
defer s.mu.Unlock()
|
||||||
|
|
||||||
|
if s.session != nil {
|
||||||
|
if err := s.session.Destroy(); err != nil {
|
||||||
|
return fmt.Errorf("destroy cross-encoder session: %w", err)
|
||||||
|
}
|
||||||
|
s.session = nil
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// sigmoid applies the sigmoid function to normalize scores to 0-1 range.
|
||||||
|
func sigmoid(x float64) float64 {
|
||||||
|
if x > 20 {
|
||||||
|
return 1.0
|
||||||
|
}
|
||||||
|
if x < -20 {
|
||||||
|
return 0.0
|
||||||
|
}
|
||||||
|
return 1.0 / (1.0 + math.Exp(-x))
|
||||||
|
}
|
||||||
@@ -0,0 +1,448 @@
|
|||||||
|
package reranking
|
||||||
|
|
||||||
|
import (
|
||||||
|
"sync"
|
||||||
|
"testing"
|
||||||
|
|
||||||
|
"github.com/lukaszraczylo/claude-mnemonic/internal/embedding"
|
||||||
|
)
|
||||||
|
|
||||||
|
// initONNX initializes the ONNX runtime via the embedding service.
|
||||||
|
// Must be called before creating reranking service.
|
||||||
|
func initONNX(t *testing.T) func() {
|
||||||
|
t.Helper()
|
||||||
|
embSvc, err := embedding.NewService()
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("Failed to initialize ONNX via embedding service: %v", err)
|
||||||
|
}
|
||||||
|
return func() {
|
||||||
|
embSvc.Close()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// TestSigmoid tests the sigmoid normalization function.
|
||||||
|
func TestSigmoid(t *testing.T) {
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
input float64
|
||||||
|
wantMin float64
|
||||||
|
wantMax float64
|
||||||
|
}{
|
||||||
|
{"positive large", 10, 0.9999, 1.0},
|
||||||
|
{"positive small", 1, 0.7, 0.8},
|
||||||
|
{"zero", 0, 0.4999, 0.5001},
|
||||||
|
{"negative small", -1, 0.2, 0.3},
|
||||||
|
{"negative large", -10, 0, 0.0001},
|
||||||
|
{"very positive", 25, 0.999999, 1.0},
|
||||||
|
{"very negative", -25, 0, 0.000001},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tt := range tests {
|
||||||
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
got := sigmoid(tt.input)
|
||||||
|
if got < tt.wantMin || got > tt.wantMax {
|
||||||
|
t.Errorf("sigmoid(%v) = %v, want in range [%v, %v]",
|
||||||
|
tt.input, got, tt.wantMin, tt.wantMax)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// TestDefaultConfig tests the default configuration values.
|
||||||
|
func TestDefaultConfig(t *testing.T) {
|
||||||
|
cfg := DefaultConfig()
|
||||||
|
|
||||||
|
if cfg.Alpha < 0 || cfg.Alpha > 1 {
|
||||||
|
t.Errorf("DefaultConfig().Alpha = %v, want in range [0, 1]", cfg.Alpha)
|
||||||
|
}
|
||||||
|
if cfg.Alpha != 0.7 {
|
||||||
|
t.Errorf("DefaultConfig().Alpha = %v, want 0.7", cfg.Alpha)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// TestNewService tests service creation.
|
||||||
|
func TestNewService(t *testing.T) {
|
||||||
|
cleanup := initONNX(t)
|
||||||
|
defer cleanup()
|
||||||
|
|
||||||
|
cfg := DefaultConfig()
|
||||||
|
svc, err := NewService(cfg)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("NewService() error = %v", err)
|
||||||
|
}
|
||||||
|
defer svc.Close()
|
||||||
|
|
||||||
|
if svc == nil {
|
||||||
|
t.Fatal("NewService() returned nil")
|
||||||
|
}
|
||||||
|
|
||||||
|
if svc.Alpha != cfg.Alpha {
|
||||||
|
t.Errorf("Service.Alpha = %v, want %v", svc.Alpha, cfg.Alpha)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// TestScore tests single pair scoring.
|
||||||
|
func TestScore(t *testing.T) {
|
||||||
|
cleanup := initONNX(t)
|
||||||
|
defer cleanup()
|
||||||
|
|
||||||
|
cfg := DefaultConfig()
|
||||||
|
svc, err := NewService(cfg)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("NewService() error = %v", err)
|
||||||
|
}
|
||||||
|
defer svc.Close()
|
||||||
|
|
||||||
|
query := "What is the capital of France?"
|
||||||
|
relevant := "Paris is the capital and largest city of France."
|
||||||
|
irrelevant := "Dogs are popular pets known for their loyalty."
|
||||||
|
|
||||||
|
// Score relevant document
|
||||||
|
_, relevantNorm, err := svc.Score(query, relevant)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("Score(relevant) error = %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Score irrelevant document
|
||||||
|
_, irrelevantNorm, err := svc.Score(query, irrelevant)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("Score(irrelevant) error = %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Relevant document should score higher
|
||||||
|
if relevantNorm <= irrelevantNorm {
|
||||||
|
t.Errorf("Expected relevant (%v) > irrelevant (%v)",
|
||||||
|
relevantNorm, irrelevantNorm)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// TestRerank tests the reranking functionality.
|
||||||
|
func TestRerank(t *testing.T) {
|
||||||
|
cleanup := initONNX(t)
|
||||||
|
defer cleanup()
|
||||||
|
|
||||||
|
cfg := DefaultConfig()
|
||||||
|
svc, err := NewService(cfg)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("NewService() error = %v", err)
|
||||||
|
}
|
||||||
|
defer svc.Close()
|
||||||
|
|
||||||
|
query := "How to handle errors in Go?"
|
||||||
|
candidates := []Candidate{
|
||||||
|
{
|
||||||
|
ID: "1",
|
||||||
|
Content: "Python exception handling with try/except blocks.",
|
||||||
|
Score: 0.8, // High bi-encoder score but irrelevant
|
||||||
|
},
|
||||||
|
{
|
||||||
|
ID: "2",
|
||||||
|
Content: "Go error handling uses explicit return values. Functions return error as the last value.",
|
||||||
|
Score: 0.6, // Lower bi-encoder score but relevant
|
||||||
|
},
|
||||||
|
{
|
||||||
|
ID: "3",
|
||||||
|
Content: "JavaScript uses Promise.catch for async error handling.",
|
||||||
|
Score: 0.7,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
results, err := svc.Rerank(query, candidates, 3)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("Rerank() error = %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if len(results) != 3 {
|
||||||
|
t.Fatalf("Rerank() returned %d results, want 3", len(results))
|
||||||
|
}
|
||||||
|
|
||||||
|
// The Go error handling document should rank higher after reranking
|
||||||
|
var goRank int
|
||||||
|
for i, r := range results {
|
||||||
|
if r.ID == "2" {
|
||||||
|
goRank = i + 1
|
||||||
|
break
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if goRank == 0 {
|
||||||
|
t.Error("Go document not found in results")
|
||||||
|
}
|
||||||
|
|
||||||
|
// Verify all results have required fields populated
|
||||||
|
for i, r := range results {
|
||||||
|
if r.ID == "" {
|
||||||
|
t.Errorf("Result %d has empty ID", i)
|
||||||
|
}
|
||||||
|
if r.Content == "" {
|
||||||
|
t.Errorf("Result %d has empty Content", i)
|
||||||
|
}
|
||||||
|
if r.RerankRank != i+1 {
|
||||||
|
t.Errorf("Result %d has RerankRank %d, want %d", i, r.RerankRank, i+1)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// TestRerankEmpty tests reranking with empty input.
|
||||||
|
func TestRerankEmpty(t *testing.T) {
|
||||||
|
cleanup := initONNX(t)
|
||||||
|
defer cleanup()
|
||||||
|
|
||||||
|
cfg := DefaultConfig()
|
||||||
|
svc, err := NewService(cfg)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("NewService() error = %v", err)
|
||||||
|
}
|
||||||
|
defer svc.Close()
|
||||||
|
|
||||||
|
results, err := svc.Rerank("test query", nil, 10)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("Rerank(nil) error = %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if results != nil {
|
||||||
|
t.Errorf("Rerank(nil) = %v, want nil", results)
|
||||||
|
}
|
||||||
|
|
||||||
|
results, err = svc.Rerank("test query", []Candidate{}, 10)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("Rerank([]) error = %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if results != nil {
|
||||||
|
t.Errorf("Rerank([]) = %v, want nil", results)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// TestRerankLimit tests that limit is respected.
|
||||||
|
func TestRerankLimit(t *testing.T) {
|
||||||
|
cleanup := initONNX(t)
|
||||||
|
defer cleanup()
|
||||||
|
|
||||||
|
cfg := DefaultConfig()
|
||||||
|
svc, err := NewService(cfg)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("NewService() error = %v", err)
|
||||||
|
}
|
||||||
|
defer svc.Close()
|
||||||
|
|
||||||
|
candidates := make([]Candidate, 20)
|
||||||
|
for i := range candidates {
|
||||||
|
candidates[i] = Candidate{
|
||||||
|
ID: string(rune('A' + i)),
|
||||||
|
Content: "Test document content for ranking.",
|
||||||
|
Score: 0.5,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
results, err := svc.Rerank("test query", candidates, 5)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("Rerank() error = %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if len(results) != 5 {
|
||||||
|
t.Errorf("Rerank() returned %d results, want 5", len(results))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// TestRerankByScore tests pure cross-encoder ranking.
|
||||||
|
func TestRerankByScore(t *testing.T) {
|
||||||
|
cleanup := initONNX(t)
|
||||||
|
defer cleanup()
|
||||||
|
|
||||||
|
cfg := DefaultConfig()
|
||||||
|
svc, err := NewService(cfg)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("NewService() error = %v", err)
|
||||||
|
}
|
||||||
|
defer svc.Close()
|
||||||
|
|
||||||
|
query := "machine learning algorithms"
|
||||||
|
candidates := []Candidate{
|
||||||
|
{
|
||||||
|
ID: "1",
|
||||||
|
Content: "Cooking recipes for Italian pasta dishes.",
|
||||||
|
Score: 0.9, // High original score
|
||||||
|
},
|
||||||
|
{
|
||||||
|
ID: "2",
|
||||||
|
Content: "Neural networks are a type of machine learning algorithm.",
|
||||||
|
Score: 0.3, // Low original score
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
results, err := svc.RerankByScore(query, candidates, 2)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("RerankByScore() error = %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Document 2 should rank first since it's about ML
|
||||||
|
if results[0].ID != "2" {
|
||||||
|
t.Errorf("Expected ML document to rank first, got %v", results[0].ID)
|
||||||
|
}
|
||||||
|
|
||||||
|
// CombinedScore should equal RerankScore when using RerankByScore
|
||||||
|
for _, r := range results {
|
||||||
|
if r.CombinedScore != r.RerankScore {
|
||||||
|
t.Errorf("RerankByScore: CombinedScore (%v) != RerankScore (%v)",
|
||||||
|
r.CombinedScore, r.RerankScore)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// TestRankImprovement tests that rank improvement is calculated correctly.
|
||||||
|
func TestRankImprovement(t *testing.T) {
|
||||||
|
cleanup := initONNX(t)
|
||||||
|
defer cleanup()
|
||||||
|
|
||||||
|
cfg := DefaultConfig()
|
||||||
|
svc, err := NewService(cfg)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("NewService() error = %v", err)
|
||||||
|
}
|
||||||
|
defer svc.Close()
|
||||||
|
|
||||||
|
// Create candidates where we know the expected reranking
|
||||||
|
candidates := []Candidate{
|
||||||
|
{ID: "A", Content: "Unrelated content about weather forecasting.", Score: 0.9},
|
||||||
|
{ID: "B", Content: "How to fix memory leaks in Go programs.", Score: 0.8},
|
||||||
|
{ID: "C", Content: "More unrelated content about gardening tips.", Score: 0.7},
|
||||||
|
}
|
||||||
|
|
||||||
|
results, err := svc.Rerank("debugging memory issues in Go", candidates, 3)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("Rerank() error = %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, r := range results {
|
||||||
|
// RankImprovement = OriginalRank - RerankRank
|
||||||
|
// Positive means moved up, negative means moved down
|
||||||
|
expectedImprovement := r.OriginalRank - r.RerankRank
|
||||||
|
if r.RankImprovement != expectedImprovement {
|
||||||
|
t.Errorf("ID %s: RankImprovement = %d, want %d (orig=%d, new=%d)",
|
||||||
|
r.ID, r.RankImprovement, expectedImprovement,
|
||||||
|
r.OriginalRank, r.RerankRank)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// TestConcurrentRerank tests concurrent reranking calls.
|
||||||
|
func TestConcurrentRerank(t *testing.T) {
|
||||||
|
cleanup := initONNX(t)
|
||||||
|
defer cleanup()
|
||||||
|
|
||||||
|
cfg := DefaultConfig()
|
||||||
|
svc, err := NewService(cfg)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("NewService() error = %v", err)
|
||||||
|
}
|
||||||
|
defer svc.Close()
|
||||||
|
|
||||||
|
candidates := []Candidate{
|
||||||
|
{ID: "1", Content: "Test document one.", Score: 0.5},
|
||||||
|
{ID: "2", Content: "Test document two.", Score: 0.5},
|
||||||
|
}
|
||||||
|
|
||||||
|
var wg sync.WaitGroup
|
||||||
|
errors := make(chan error, 10)
|
||||||
|
|
||||||
|
for i := 0; i < 10; i++ {
|
||||||
|
wg.Add(1)
|
||||||
|
go func(idx int) {
|
||||||
|
defer wg.Done()
|
||||||
|
_, err := svc.Rerank("concurrent test query", candidates, 2)
|
||||||
|
if err != nil {
|
||||||
|
errors <- err
|
||||||
|
}
|
||||||
|
}(i)
|
||||||
|
}
|
||||||
|
|
||||||
|
wg.Wait()
|
||||||
|
close(errors)
|
||||||
|
|
||||||
|
for err := range errors {
|
||||||
|
t.Errorf("Concurrent Rerank error: %v", err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// TestClose tests service cleanup.
|
||||||
|
func TestClose(t *testing.T) {
|
||||||
|
cleanup := initONNX(t)
|
||||||
|
defer cleanup()
|
||||||
|
|
||||||
|
cfg := DefaultConfig()
|
||||||
|
svc, err := NewService(cfg)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("NewService() error = %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
err = svc.Close()
|
||||||
|
if err != nil {
|
||||||
|
t.Errorf("Close() error = %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Double close should not panic
|
||||||
|
err = svc.Close()
|
||||||
|
if err != nil {
|
||||||
|
t.Errorf("Close() on closed service error = %v", err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// TestMetadataPreserved tests that metadata is preserved through reranking.
|
||||||
|
func TestMetadataPreserved(t *testing.T) {
|
||||||
|
cleanup := initONNX(t)
|
||||||
|
defer cleanup()
|
||||||
|
|
||||||
|
cfg := DefaultConfig()
|
||||||
|
svc, err := NewService(cfg)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("NewService() error = %v", err)
|
||||||
|
}
|
||||||
|
defer svc.Close()
|
||||||
|
|
||||||
|
candidates := []Candidate{
|
||||||
|
{
|
||||||
|
ID: "1",
|
||||||
|
Content: "Test content.",
|
||||||
|
Score: 0.5,
|
||||||
|
Metadata: map[string]any{"custom": "value1", "num": 42},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
ID: "2",
|
||||||
|
Content: "Another test.",
|
||||||
|
Score: 0.5,
|
||||||
|
Metadata: map[string]any{"custom": "value2"},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
results, err := svc.Rerank("query", candidates, 2)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("Rerank() error = %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, r := range results {
|
||||||
|
if r.Metadata == nil {
|
||||||
|
t.Errorf("Result %s has nil metadata", r.ID)
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
// Find original candidate
|
||||||
|
var original *Candidate
|
||||||
|
for i := range candidates {
|
||||||
|
if candidates[i].ID == r.ID {
|
||||||
|
original = &candidates[i]
|
||||||
|
break
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if original == nil {
|
||||||
|
t.Errorf("Could not find original for result %s", r.ID)
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
// Check metadata preserved
|
||||||
|
if original.Metadata["custom"] != r.Metadata["custom"] {
|
||||||
|
t.Errorf("Metadata not preserved for %s", r.ID)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
@@ -0,0 +1,168 @@
|
|||||||
|
// Package scoring provides importance score calculation for observations.
|
||||||
|
package scoring
|
||||||
|
|
||||||
|
import (
|
||||||
|
"math"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"github.com/lukaszraczylo/claude-mnemonic/pkg/models"
|
||||||
|
)
|
||||||
|
|
||||||
|
// Calculator computes importance scores for observations.
|
||||||
|
type Calculator struct {
|
||||||
|
config *models.ScoringConfig
|
||||||
|
}
|
||||||
|
|
||||||
|
// NewCalculator creates a new scoring calculator.
|
||||||
|
// If config is nil, uses the default configuration.
|
||||||
|
func NewCalculator(config *models.ScoringConfig) *Calculator {
|
||||||
|
if config == nil {
|
||||||
|
config = models.DefaultScoringConfig()
|
||||||
|
}
|
||||||
|
return &Calculator{config: config}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Calculate computes the importance score for an observation at the given time.
|
||||||
|
//
|
||||||
|
// The scoring formula:
|
||||||
|
//
|
||||||
|
// FinalScore = (BaseScore × TypeWeight × RecencyDecay) + FeedbackContrib + ConceptContrib + RetrievalContrib
|
||||||
|
//
|
||||||
|
// Where:
|
||||||
|
// - BaseScore = 1.0
|
||||||
|
// - TypeWeight = observation type multiplier (e.g., bugfix=1.3, change=0.9)
|
||||||
|
// - RecencyDecay = 0.5^(age_days / half_life_days) - halves every 7 days by default
|
||||||
|
// - FeedbackContrib = user_feedback × feedback_weight
|
||||||
|
// - ConceptContrib = sum(concept_weights) × concept_weight_factor
|
||||||
|
// - RetrievalContrib = log2(retrieval_count + 1) × 0.1 × retrieval_weight
|
||||||
|
func (c *Calculator) Calculate(obs *models.Observation, now time.Time) float64 {
|
||||||
|
// 1. Get base type weight
|
||||||
|
typeWeight := models.TypeBaseScore(obs.Type)
|
||||||
|
|
||||||
|
// 2. Calculate recency decay: 0.5^(age_days / half_life_days)
|
||||||
|
ageDays := now.Sub(time.UnixMilli(obs.CreatedAtEpoch)).Hours() / 24.0
|
||||||
|
if ageDays < 0 {
|
||||||
|
ageDays = 0 // Handle future timestamps gracefully
|
||||||
|
}
|
||||||
|
recencyDecay := math.Pow(0.5, ageDays/c.config.RecencyHalfLifeDays)
|
||||||
|
|
||||||
|
// Core score = 1.0 × type_weight × recency_decay
|
||||||
|
coreScore := 1.0 * typeWeight * recencyDecay
|
||||||
|
|
||||||
|
// 3. User feedback contribution: feedback × weight
|
||||||
|
feedbackContrib := float64(obs.UserFeedback) * c.config.FeedbackWeight
|
||||||
|
|
||||||
|
// 4. Concept boost contribution: sum of matching concept weights × factor
|
||||||
|
conceptBoost := 0.0
|
||||||
|
for _, concept := range obs.Concepts {
|
||||||
|
if weight, ok := c.config.ConceptWeights[concept]; ok {
|
||||||
|
conceptBoost += weight
|
||||||
|
}
|
||||||
|
}
|
||||||
|
conceptContrib := conceptBoost * c.config.ConceptWeight
|
||||||
|
|
||||||
|
// 5. Retrieval boost: log2(count + 1) × 0.1 × weight (diminishing returns)
|
||||||
|
retrievalContrib := 0.0
|
||||||
|
if obs.RetrievalCount > 0 {
|
||||||
|
// log2(count + 1) gives diminishing returns: 1→1, 3→2, 7→3, 15→4, etc.
|
||||||
|
retrievalBoost := math.Log2(float64(obs.RetrievalCount)+1) * 0.1
|
||||||
|
retrievalContrib = retrievalBoost * c.config.RetrievalWeight
|
||||||
|
}
|
||||||
|
|
||||||
|
// Final score with minimum threshold
|
||||||
|
finalScore := coreScore + feedbackContrib + conceptContrib + retrievalContrib
|
||||||
|
if finalScore < c.config.MinScore {
|
||||||
|
finalScore = c.config.MinScore
|
||||||
|
}
|
||||||
|
|
||||||
|
return finalScore
|
||||||
|
}
|
||||||
|
|
||||||
|
// CalculateComponents returns the individual components of the importance score.
|
||||||
|
// Useful for debugging and explaining scores to users.
|
||||||
|
func (c *Calculator) CalculateComponents(obs *models.Observation, now time.Time) ScoreComponents {
|
||||||
|
typeWeight := models.TypeBaseScore(obs.Type)
|
||||||
|
|
||||||
|
ageDays := now.Sub(time.UnixMilli(obs.CreatedAtEpoch)).Hours() / 24.0
|
||||||
|
if ageDays < 0 {
|
||||||
|
ageDays = 0
|
||||||
|
}
|
||||||
|
recencyDecay := math.Pow(0.5, ageDays/c.config.RecencyHalfLifeDays)
|
||||||
|
|
||||||
|
coreScore := 1.0 * typeWeight * recencyDecay
|
||||||
|
feedbackContrib := float64(obs.UserFeedback) * c.config.FeedbackWeight
|
||||||
|
|
||||||
|
conceptBoost := 0.0
|
||||||
|
for _, concept := range obs.Concepts {
|
||||||
|
if weight, ok := c.config.ConceptWeights[concept]; ok {
|
||||||
|
conceptBoost += weight
|
||||||
|
}
|
||||||
|
}
|
||||||
|
conceptContrib := conceptBoost * c.config.ConceptWeight
|
||||||
|
|
||||||
|
retrievalContrib := 0.0
|
||||||
|
if obs.RetrievalCount > 0 {
|
||||||
|
retrievalBoost := math.Log2(float64(obs.RetrievalCount)+1) * 0.1
|
||||||
|
retrievalContrib = retrievalBoost * c.config.RetrievalWeight
|
||||||
|
}
|
||||||
|
|
||||||
|
finalScore := coreScore + feedbackContrib + conceptContrib + retrievalContrib
|
||||||
|
if finalScore < c.config.MinScore {
|
||||||
|
finalScore = c.config.MinScore
|
||||||
|
}
|
||||||
|
|
||||||
|
return ScoreComponents{
|
||||||
|
TypeWeight: typeWeight,
|
||||||
|
RecencyDecay: recencyDecay,
|
||||||
|
CoreScore: coreScore,
|
||||||
|
FeedbackContrib: feedbackContrib,
|
||||||
|
ConceptContrib: conceptContrib,
|
||||||
|
RetrievalContrib: retrievalContrib,
|
||||||
|
FinalScore: finalScore,
|
||||||
|
AgeDays: ageDays,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// ScoreComponents contains the breakdown of an importance score calculation.
|
||||||
|
type ScoreComponents struct {
|
||||||
|
TypeWeight float64 `json:"type_weight"`
|
||||||
|
RecencyDecay float64 `json:"recency_decay"`
|
||||||
|
CoreScore float64 `json:"core_score"`
|
||||||
|
FeedbackContrib float64 `json:"feedback_contrib"`
|
||||||
|
ConceptContrib float64 `json:"concept_contrib"`
|
||||||
|
RetrievalContrib float64 `json:"retrieval_contrib"`
|
||||||
|
FinalScore float64 `json:"final_score"`
|
||||||
|
AgeDays float64 `json:"age_days"`
|
||||||
|
}
|
||||||
|
|
||||||
|
// BatchCalculate computes scores for multiple observations.
|
||||||
|
// Returns a map of observation ID to calculated score.
|
||||||
|
func (c *Calculator) BatchCalculate(observations []*models.Observation, now time.Time) map[int64]float64 {
|
||||||
|
scores := make(map[int64]float64, len(observations))
|
||||||
|
for _, obs := range observations {
|
||||||
|
scores[obs.ID] = c.Calculate(obs, now)
|
||||||
|
}
|
||||||
|
return scores
|
||||||
|
}
|
||||||
|
|
||||||
|
// RecalculateThreshold returns the minimum duration before an observation
|
||||||
|
// should have its score recalculated. This prevents excessive recalculation
|
||||||
|
// while ensuring scores stay reasonably fresh.
|
||||||
|
func (c *Calculator) RecalculateThreshold() time.Duration {
|
||||||
|
// Recalculate at most every 6 hours
|
||||||
|
// This balances freshness with performance
|
||||||
|
return 6 * time.Hour
|
||||||
|
}
|
||||||
|
|
||||||
|
// UpdateConfig updates the calculator's scoring configuration.
|
||||||
|
// This allows runtime tuning of scoring parameters.
|
||||||
|
func (c *Calculator) UpdateConfig(config *models.ScoringConfig) {
|
||||||
|
if config != nil {
|
||||||
|
c.config = config
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// GetConfig returns the current scoring configuration.
|
||||||
|
func (c *Calculator) GetConfig() *models.ScoringConfig {
|
||||||
|
return c.config
|
||||||
|
}
|
||||||
@@ -0,0 +1,638 @@
|
|||||||
|
// Package scoring provides importance score calculation for observations.
|
||||||
|
package scoring
|
||||||
|
|
||||||
|
import (
|
||||||
|
"testing"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"github.com/lukaszraczylo/claude-mnemonic/pkg/models"
|
||||||
|
"github.com/stretchr/testify/assert"
|
||||||
|
"github.com/stretchr/testify/require"
|
||||||
|
"github.com/stretchr/testify/suite"
|
||||||
|
)
|
||||||
|
|
||||||
|
// CalculatorSuite is a test suite for the Calculator.
|
||||||
|
type CalculatorSuite struct {
|
||||||
|
suite.Suite
|
||||||
|
calc *Calculator
|
||||||
|
config *models.ScoringConfig
|
||||||
|
now time.Time
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *CalculatorSuite) SetupTest() {
|
||||||
|
s.config = models.DefaultScoringConfig()
|
||||||
|
s.calc = NewCalculator(s.config)
|
||||||
|
s.now = time.Date(2025, 1, 15, 12, 0, 0, 0, time.UTC)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestCalculatorSuite(t *testing.T) {
|
||||||
|
suite.Run(t, new(CalculatorSuite))
|
||||||
|
}
|
||||||
|
|
||||||
|
// =============================================================================
|
||||||
|
// GOOD SCENARIOS - Expected normal operations
|
||||||
|
// =============================================================================
|
||||||
|
|
||||||
|
func (s *CalculatorSuite) TestCalculate_GoodScenarios_NewObservation() {
|
||||||
|
// A brand new observation should have score close to type weight
|
||||||
|
obs := &models.Observation{
|
||||||
|
ID: 1,
|
||||||
|
Type: models.ObsTypeBugfix,
|
||||||
|
CreatedAtEpoch: s.now.UnixMilli(),
|
||||||
|
}
|
||||||
|
|
||||||
|
score := s.calc.Calculate(obs, s.now)
|
||||||
|
|
||||||
|
// Expected: 1.0 × 1.3 (bugfix weight) × 1.0 (no decay) = 1.3
|
||||||
|
s.InDelta(1.3, score, 0.01, "new bugfix should score ~1.3")
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *CalculatorSuite) TestCalculate_GoodScenarios_OneWeekOld() {
|
||||||
|
// One week old observation should have half the recency score
|
||||||
|
obs := &models.Observation{
|
||||||
|
ID: 1,
|
||||||
|
Type: models.ObsTypeDiscovery,
|
||||||
|
CreatedAtEpoch: s.now.Add(-7 * 24 * time.Hour).UnixMilli(),
|
||||||
|
}
|
||||||
|
|
||||||
|
score := s.calc.Calculate(obs, s.now)
|
||||||
|
|
||||||
|
// Expected: 1.0 × 1.1 (discovery) × 0.5 (7 days half-life) = 0.55
|
||||||
|
s.InDelta(0.55, score, 0.05, "7-day old discovery should score ~0.55")
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *CalculatorSuite) TestCalculate_GoodScenarios_TwoWeeksOld() {
|
||||||
|
// Two weeks old should have 1/4 recency score
|
||||||
|
obs := &models.Observation{
|
||||||
|
ID: 1,
|
||||||
|
Type: models.ObsTypeFeature,
|
||||||
|
CreatedAtEpoch: s.now.Add(-14 * 24 * time.Hour).UnixMilli(),
|
||||||
|
}
|
||||||
|
|
||||||
|
score := s.calc.Calculate(obs, s.now)
|
||||||
|
|
||||||
|
// Expected: 1.0 × 1.2 (feature) × 0.25 (14 days = 2 half-lives) = 0.30
|
||||||
|
s.InDelta(0.30, score, 0.05, "14-day old feature should score ~0.30")
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *CalculatorSuite) TestCalculate_GoodScenarios_PositiveFeedback() {
|
||||||
|
// Positive feedback should boost score
|
||||||
|
obs := &models.Observation{
|
||||||
|
ID: 1,
|
||||||
|
Type: models.ObsTypeChange,
|
||||||
|
CreatedAtEpoch: s.now.UnixMilli(),
|
||||||
|
UserFeedback: 1, // thumbs up
|
||||||
|
}
|
||||||
|
|
||||||
|
score := s.calc.Calculate(obs, s.now)
|
||||||
|
|
||||||
|
// Expected: (1.0 × 0.9) + 0.30 (feedback) = 1.20
|
||||||
|
s.InDelta(1.20, score, 0.01, "thumbs up should boost score by 0.30")
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *CalculatorSuite) TestCalculate_GoodScenarios_NegativeFeedback() {
|
||||||
|
// Negative feedback should reduce score
|
||||||
|
obs := &models.Observation{
|
||||||
|
ID: 1,
|
||||||
|
Type: models.ObsTypeChange,
|
||||||
|
CreatedAtEpoch: s.now.UnixMilli(),
|
||||||
|
UserFeedback: -1, // thumbs down
|
||||||
|
}
|
||||||
|
|
||||||
|
score := s.calc.Calculate(obs, s.now)
|
||||||
|
|
||||||
|
// Expected: (1.0 × 0.9) - 0.30 (feedback) = 0.60
|
||||||
|
s.InDelta(0.60, score, 0.01, "thumbs down should reduce score by 0.30")
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *CalculatorSuite) TestCalculate_GoodScenarios_WithConcepts() {
|
||||||
|
// Observation with valuable concepts should get boost
|
||||||
|
obs := &models.Observation{
|
||||||
|
ID: 1,
|
||||||
|
Type: models.ObsTypeBugfix,
|
||||||
|
CreatedAtEpoch: s.now.UnixMilli(),
|
||||||
|
Concepts: []string{"security", "gotcha"},
|
||||||
|
}
|
||||||
|
|
||||||
|
score := s.calc.Calculate(obs, s.now)
|
||||||
|
|
||||||
|
// Concept boost: (0.30 + 0.25) × 0.20 = 0.11
|
||||||
|
// Expected: 1.3 + 0.11 = 1.41
|
||||||
|
s.InDelta(1.41, score, 0.05, "security+gotcha concepts should boost score")
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *CalculatorSuite) TestCalculate_GoodScenarios_WithRetrievals() {
|
||||||
|
// Popular observations should get retrieval boost
|
||||||
|
obs := &models.Observation{
|
||||||
|
ID: 1,
|
||||||
|
Type: models.ObsTypeDiscovery,
|
||||||
|
CreatedAtEpoch: s.now.UnixMilli(),
|
||||||
|
RetrievalCount: 7, // log2(8) = 3
|
||||||
|
}
|
||||||
|
|
||||||
|
score := s.calc.Calculate(obs, s.now)
|
||||||
|
|
||||||
|
// Retrieval boost: log2(7+1) × 0.1 × 0.15 = 3 × 0.1 × 0.15 = 0.045
|
||||||
|
// Expected: 1.1 + 0.045 ≈ 1.145
|
||||||
|
s.InDelta(1.145, score, 0.05, "7 retrievals should add small boost")
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *CalculatorSuite) TestCalculate_GoodScenarios_CombinedFactors() {
|
||||||
|
// Test with all factors combined
|
||||||
|
obs := &models.Observation{
|
||||||
|
ID: 1,
|
||||||
|
Type: models.ObsTypeBugfix,
|
||||||
|
CreatedAtEpoch: s.now.Add(-7 * 24 * time.Hour).UnixMilli(), // 7 days old
|
||||||
|
UserFeedback: 1,
|
||||||
|
Concepts: []string{"security"},
|
||||||
|
RetrievalCount: 3,
|
||||||
|
}
|
||||||
|
|
||||||
|
score := s.calc.Calculate(obs, s.now)
|
||||||
|
|
||||||
|
// Core: 1.0 × 1.3 × 0.5 = 0.65
|
||||||
|
// Feedback: 0.30
|
||||||
|
// Concept: 0.30 × 0.20 = 0.06
|
||||||
|
// Retrieval: log2(4) × 0.1 × 0.15 = 2 × 0.1 × 0.15 = 0.03
|
||||||
|
// Total ≈ 1.04
|
||||||
|
s.InDelta(1.04, score, 0.1, "combined factors should result in ~1.04")
|
||||||
|
}
|
||||||
|
|
||||||
|
// =============================================================================
|
||||||
|
// WORSE SCENARIOS - Degraded but acceptable operations
|
||||||
|
// =============================================================================
|
||||||
|
|
||||||
|
func (s *CalculatorSuite) TestCalculate_WorseScenarios_VeryOldObservation() {
|
||||||
|
// Very old observation should have low but non-zero score
|
||||||
|
obs := &models.Observation{
|
||||||
|
ID: 1,
|
||||||
|
Type: models.ObsTypeChange,
|
||||||
|
CreatedAtEpoch: s.now.Add(-90 * 24 * time.Hour).UnixMilli(), // 90 days old
|
||||||
|
}
|
||||||
|
|
||||||
|
score := s.calc.Calculate(obs, s.now)
|
||||||
|
|
||||||
|
// 90 days = ~12.86 half-lives → decay ≈ 0.00014
|
||||||
|
// Core: 1.0 × 0.9 × 0.00014 = 0.000126
|
||||||
|
// But minimum score is 0.01
|
||||||
|
s.GreaterOrEqual(score, 0.01, "very old observation should still meet minimum")
|
||||||
|
s.Less(score, 0.1, "very old observation should be low scoring")
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *CalculatorSuite) TestCalculate_WorseScenarios_NegativeFeedbackOld() {
|
||||||
|
// Old observation with negative feedback should still have minimum score
|
||||||
|
obs := &models.Observation{
|
||||||
|
ID: 1,
|
||||||
|
Type: models.ObsTypeChange,
|
||||||
|
CreatedAtEpoch: s.now.Add(-60 * 24 * time.Hour).UnixMilli(),
|
||||||
|
UserFeedback: -1,
|
||||||
|
}
|
||||||
|
|
||||||
|
score := s.calc.Calculate(obs, s.now)
|
||||||
|
|
||||||
|
s.GreaterOrEqual(score, s.config.MinScore, "should never go below minimum score")
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *CalculatorSuite) TestCalculate_WorseScenarios_UnknownConcepts() {
|
||||||
|
// Unknown concepts should not affect score negatively
|
||||||
|
obs := &models.Observation{
|
||||||
|
ID: 1,
|
||||||
|
Type: models.ObsTypeDiscovery,
|
||||||
|
CreatedAtEpoch: s.now.UnixMilli(),
|
||||||
|
Concepts: []string{"unknown-concept", "another-unknown"},
|
||||||
|
}
|
||||||
|
|
||||||
|
score := s.calc.Calculate(obs, s.now)
|
||||||
|
|
||||||
|
// Should just be the base score without concept boost
|
||||||
|
s.InDelta(1.1, score, 0.01, "unknown concepts should not affect score")
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *CalculatorSuite) TestCalculate_WorseScenarios_MixedConcepts() {
|
||||||
|
// Mix of known and unknown concepts
|
||||||
|
obs := &models.Observation{
|
||||||
|
ID: 1,
|
||||||
|
Type: models.ObsTypeDiscovery,
|
||||||
|
CreatedAtEpoch: s.now.UnixMilli(),
|
||||||
|
Concepts: []string{"security", "unknown-concept"},
|
||||||
|
}
|
||||||
|
|
||||||
|
score := s.calc.Calculate(obs, s.now)
|
||||||
|
|
||||||
|
// Only security should contribute
|
||||||
|
// Expected: 1.1 + (0.30 × 0.20) = 1.16
|
||||||
|
s.InDelta(1.16, score, 0.05, "only known concepts should boost score")
|
||||||
|
}
|
||||||
|
|
||||||
|
// =============================================================================
|
||||||
|
// BAD SCENARIOS - Edge cases and error conditions
|
||||||
|
// =============================================================================
|
||||||
|
|
||||||
|
func (s *CalculatorSuite) TestCalculate_BadScenarios_FutureTimestamp() {
|
||||||
|
// Observation created in the future (clock skew)
|
||||||
|
obs := &models.Observation{
|
||||||
|
ID: 1,
|
||||||
|
Type: models.ObsTypeBugfix,
|
||||||
|
CreatedAtEpoch: s.now.Add(24 * time.Hour).UnixMilli(), // 1 day in future
|
||||||
|
}
|
||||||
|
|
||||||
|
score := s.calc.Calculate(obs, s.now)
|
||||||
|
|
||||||
|
// Should handle gracefully - age should be 0
|
||||||
|
s.InDelta(1.3, score, 0.01, "future timestamp should be treated as now")
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *CalculatorSuite) TestCalculate_BadScenarios_ZeroEpoch() {
|
||||||
|
// Missing creation timestamp
|
||||||
|
obs := &models.Observation{
|
||||||
|
ID: 1,
|
||||||
|
Type: models.ObsTypeDiscovery,
|
||||||
|
CreatedAtEpoch: 0, // Missing timestamp
|
||||||
|
}
|
||||||
|
|
||||||
|
score := s.calc.Calculate(obs, s.now)
|
||||||
|
|
||||||
|
// This will be treated as very old (1970)
|
||||||
|
s.GreaterOrEqual(score, s.config.MinScore, "should still meet minimum")
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *CalculatorSuite) TestCalculate_BadScenarios_EmptyObservation() {
|
||||||
|
// Minimal observation with defaults
|
||||||
|
obs := &models.Observation{
|
||||||
|
ID: 1,
|
||||||
|
Type: "", // Empty type
|
||||||
|
CreatedAtEpoch: s.now.UnixMilli(),
|
||||||
|
}
|
||||||
|
|
||||||
|
score := s.calc.Calculate(obs, s.now)
|
||||||
|
|
||||||
|
// Unknown type should default to 1.0 weight
|
||||||
|
s.InDelta(1.0, score, 0.01, "empty type should use default weight 1.0")
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *CalculatorSuite) TestCalculate_BadScenarios_ExtremeRetrievalCount() {
|
||||||
|
// Very high retrieval count
|
||||||
|
obs := &models.Observation{
|
||||||
|
ID: 1,
|
||||||
|
Type: models.ObsTypeDiscovery,
|
||||||
|
CreatedAtEpoch: s.now.UnixMilli(),
|
||||||
|
RetrievalCount: 1000000, // Extreme value
|
||||||
|
}
|
||||||
|
|
||||||
|
score := s.calc.Calculate(obs, s.now)
|
||||||
|
|
||||||
|
// log2(1000001) ≈ 19.93, so boost = 19.93 × 0.1 × 0.15 ≈ 0.30
|
||||||
|
// Score should be reasonable, not exploding
|
||||||
|
s.Less(score, 2.0, "extreme retrieval count should not explode score")
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *CalculatorSuite) TestCalculate_BadScenarios_NegativeRetrievalCount() {
|
||||||
|
// Negative retrieval count (should not happen but test defensively)
|
||||||
|
obs := &models.Observation{
|
||||||
|
ID: 1,
|
||||||
|
Type: models.ObsTypeDiscovery,
|
||||||
|
CreatedAtEpoch: s.now.UnixMilli(),
|
||||||
|
RetrievalCount: -5,
|
||||||
|
}
|
||||||
|
|
||||||
|
score := s.calc.Calculate(obs, s.now)
|
||||||
|
|
||||||
|
// Should not panic and should give base score
|
||||||
|
s.InDelta(1.1, score, 0.01, "negative retrieval should be ignored")
|
||||||
|
}
|
||||||
|
|
||||||
|
// =============================================================================
|
||||||
|
// EDGE CASES - Boundary conditions
|
||||||
|
// =============================================================================
|
||||||
|
|
||||||
|
func (s *CalculatorSuite) TestCalculate_EdgeCases_ExactlyOneHalfLife() {
|
||||||
|
obs := &models.Observation{
|
||||||
|
ID: 1,
|
||||||
|
Type: models.ObsTypeChange, // 0.9 weight
|
||||||
|
CreatedAtEpoch: s.now.Add(-7 * 24 * time.Hour).UnixMilli(),
|
||||||
|
}
|
||||||
|
|
||||||
|
score := s.calc.Calculate(obs, s.now)
|
||||||
|
s.InDelta(0.45, score, 0.01, "exactly 7 days should give 0.5 decay")
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *CalculatorSuite) TestCalculate_EdgeCases_ExactlyTwoHalfLives() {
|
||||||
|
obs := &models.Observation{
|
||||||
|
ID: 1,
|
||||||
|
Type: models.ObsTypeChange,
|
||||||
|
CreatedAtEpoch: s.now.Add(-14 * 24 * time.Hour).UnixMilli(),
|
||||||
|
}
|
||||||
|
|
||||||
|
score := s.calc.Calculate(obs, s.now)
|
||||||
|
s.InDelta(0.225, score, 0.01, "exactly 14 days should give 0.25 decay")
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *CalculatorSuite) TestCalculate_EdgeCases_AllTypeWeights() {
|
||||||
|
types := []struct {
|
||||||
|
t models.ObservationType
|
||||||
|
weight float64
|
||||||
|
}{
|
||||||
|
{models.ObsTypeBugfix, 1.3},
|
||||||
|
{models.ObsTypeFeature, 1.2},
|
||||||
|
{models.ObsTypeDiscovery, 1.1},
|
||||||
|
{models.ObsTypeDecision, 1.1},
|
||||||
|
{models.ObsTypeRefactor, 1.0},
|
||||||
|
{models.ObsTypeChange, 0.9},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tt := range types {
|
||||||
|
s.Run(string(tt.t), func() {
|
||||||
|
obs := &models.Observation{
|
||||||
|
ID: 1,
|
||||||
|
Type: tt.t,
|
||||||
|
CreatedAtEpoch: s.now.UnixMilli(),
|
||||||
|
}
|
||||||
|
score := s.calc.Calculate(obs, s.now)
|
||||||
|
s.InDelta(tt.weight, score, 0.01)
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *CalculatorSuite) TestCalculate_EdgeCases_MinimumScoreEnforced() {
|
||||||
|
// Create worst case scenario
|
||||||
|
obs := &models.Observation{
|
||||||
|
ID: 1,
|
||||||
|
Type: models.ObsTypeChange, // Lowest weight 0.9
|
||||||
|
CreatedAtEpoch: time.Date(2000, 1, 1, 0, 0, 0, 0, time.UTC).UnixMilli(), // Very old
|
||||||
|
UserFeedback: -1, // Negative feedback
|
||||||
|
}
|
||||||
|
|
||||||
|
score := s.calc.Calculate(obs, s.now)
|
||||||
|
|
||||||
|
s.Equal(s.config.MinScore, score, "should be exactly minimum score")
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *CalculatorSuite) TestCalculate_EdgeCases_AllConceptsMaxWeight() {
|
||||||
|
// Observation with all high-value concepts
|
||||||
|
obs := &models.Observation{
|
||||||
|
ID: 1,
|
||||||
|
Type: models.ObsTypeBugfix,
|
||||||
|
CreatedAtEpoch: s.now.UnixMilli(),
|
||||||
|
Concepts: []string{"security", "gotcha", "best-practice", "anti-pattern"},
|
||||||
|
}
|
||||||
|
|
||||||
|
score := s.calc.Calculate(obs, s.now)
|
||||||
|
|
||||||
|
// security=0.30, gotcha=0.25, best-practice=0.20, anti-pattern=0.20 = 0.95
|
||||||
|
// Concept contrib: 0.95 × 0.20 = 0.19
|
||||||
|
// Total: 1.3 + 0.19 = 1.49
|
||||||
|
s.InDelta(1.49, score, 0.05, "all high-value concepts should boost significantly")
|
||||||
|
}
|
||||||
|
|
||||||
|
// =============================================================================
|
||||||
|
// CALCULATE COMPONENTS TESTS
|
||||||
|
// =============================================================================
|
||||||
|
|
||||||
|
func (s *CalculatorSuite) TestCalculateComponents_ReturnsAllComponents() {
|
||||||
|
obs := &models.Observation{
|
||||||
|
ID: 1,
|
||||||
|
Type: models.ObsTypeBugfix,
|
||||||
|
CreatedAtEpoch: s.now.Add(-7 * 24 * time.Hour).UnixMilli(),
|
||||||
|
UserFeedback: 1,
|
||||||
|
Concepts: []string{"security"},
|
||||||
|
RetrievalCount: 7,
|
||||||
|
}
|
||||||
|
|
||||||
|
components := s.calc.CalculateComponents(obs, s.now)
|
||||||
|
|
||||||
|
s.InDelta(1.3, components.TypeWeight, 0.01)
|
||||||
|
s.InDelta(0.5, components.RecencyDecay, 0.01)
|
||||||
|
s.InDelta(0.65, components.CoreScore, 0.05)
|
||||||
|
s.InDelta(0.30, components.FeedbackContrib, 0.01)
|
||||||
|
s.InDelta(0.06, components.ConceptContrib, 0.02)
|
||||||
|
s.Greater(components.RetrievalContrib, 0.0)
|
||||||
|
s.InDelta(7.0, components.AgeDays, 0.1)
|
||||||
|
s.Greater(components.FinalScore, 0.0)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *CalculatorSuite) TestCalculateComponents_MatchesCalculate() {
|
||||||
|
obs := &models.Observation{
|
||||||
|
ID: 1,
|
||||||
|
Type: models.ObsTypeFeature,
|
||||||
|
CreatedAtEpoch: s.now.Add(-3 * 24 * time.Hour).UnixMilli(),
|
||||||
|
UserFeedback: -1,
|
||||||
|
Concepts: []string{"performance", "architecture"},
|
||||||
|
RetrievalCount: 15,
|
||||||
|
}
|
||||||
|
|
||||||
|
score := s.calc.Calculate(obs, s.now)
|
||||||
|
components := s.calc.CalculateComponents(obs, s.now)
|
||||||
|
|
||||||
|
s.InDelta(score, components.FinalScore, 0.001, "Calculate and CalculateComponents should match")
|
||||||
|
}
|
||||||
|
|
||||||
|
// =============================================================================
|
||||||
|
// BATCH CALCULATE TESTS
|
||||||
|
// =============================================================================
|
||||||
|
|
||||||
|
func (s *CalculatorSuite) TestBatchCalculate_Empty() {
|
||||||
|
scores := s.calc.BatchCalculate(nil, s.now)
|
||||||
|
s.Empty(scores)
|
||||||
|
|
||||||
|
scores = s.calc.BatchCalculate([]*models.Observation{}, s.now)
|
||||||
|
s.Empty(scores)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *CalculatorSuite) TestBatchCalculate_Multiple() {
|
||||||
|
obs := []*models.Observation{
|
||||||
|
{ID: 1, Type: models.ObsTypeBugfix, CreatedAtEpoch: s.now.UnixMilli()},
|
||||||
|
{ID: 2, Type: models.ObsTypeFeature, CreatedAtEpoch: s.now.Add(-7 * 24 * time.Hour).UnixMilli()},
|
||||||
|
{ID: 3, Type: models.ObsTypeChange, CreatedAtEpoch: s.now.Add(-14 * 24 * time.Hour).UnixMilli()},
|
||||||
|
}
|
||||||
|
|
||||||
|
scores := s.calc.BatchCalculate(obs, s.now)
|
||||||
|
|
||||||
|
s.Len(scores, 3)
|
||||||
|
s.Contains(scores, int64(1))
|
||||||
|
s.Contains(scores, int64(2))
|
||||||
|
s.Contains(scores, int64(3))
|
||||||
|
|
||||||
|
s.InDelta(1.3, scores[1], 0.01) // New bugfix
|
||||||
|
s.InDelta(0.6, scores[2], 0.1) // 7-day feature
|
||||||
|
s.InDelta(0.225, scores[3], 0.05) // 14-day change
|
||||||
|
}
|
||||||
|
|
||||||
|
// =============================================================================
|
||||||
|
// CONFIGURATION TESTS
|
||||||
|
// =============================================================================
|
||||||
|
|
||||||
|
func (s *CalculatorSuite) TestNewCalculator_NilConfig() {
|
||||||
|
calc := NewCalculator(nil)
|
||||||
|
s.NotNil(calc)
|
||||||
|
s.NotNil(calc.config)
|
||||||
|
s.Equal(7.0, calc.config.RecencyHalfLifeDays)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *CalculatorSuite) TestUpdateConfig() {
|
||||||
|
newConfig := &models.ScoringConfig{
|
||||||
|
RecencyHalfLifeDays: 14.0, // Changed from 7
|
||||||
|
FeedbackWeight: 0.50,
|
||||||
|
ConceptWeight: 0.10,
|
||||||
|
RetrievalWeight: 0.05,
|
||||||
|
MinScore: 0.001,
|
||||||
|
ConceptWeights: map[string]float64{"test": 0.5},
|
||||||
|
}
|
||||||
|
|
||||||
|
s.calc.UpdateConfig(newConfig)
|
||||||
|
|
||||||
|
obs := &models.Observation{
|
||||||
|
ID: 1,
|
||||||
|
Type: models.ObsTypeChange,
|
||||||
|
CreatedAtEpoch: s.now.Add(-14 * 24 * time.Hour).UnixMilli(),
|
||||||
|
}
|
||||||
|
|
||||||
|
score := s.calc.Calculate(obs, s.now)
|
||||||
|
|
||||||
|
// With 14-day half-life, 14 days = exactly one half-life
|
||||||
|
// Expected: 1.0 × 0.9 × 0.5 = 0.45
|
||||||
|
s.InDelta(0.45, score, 0.01)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *CalculatorSuite) TestUpdateConfig_NilIgnored() {
|
||||||
|
originalConfig := s.calc.GetConfig()
|
||||||
|
s.calc.UpdateConfig(nil)
|
||||||
|
s.Equal(originalConfig, s.calc.GetConfig())
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *CalculatorSuite) TestGetConfig() {
|
||||||
|
config := s.calc.GetConfig()
|
||||||
|
s.NotNil(config)
|
||||||
|
s.Equal(7.0, config.RecencyHalfLifeDays)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *CalculatorSuite) TestRecalculateThreshold() {
|
||||||
|
threshold := s.calc.RecalculateThreshold()
|
||||||
|
s.Equal(6*time.Hour, threshold)
|
||||||
|
}
|
||||||
|
|
||||||
|
// =============================================================================
|
||||||
|
// STANDALONE TESTS (non-suite)
|
||||||
|
// =============================================================================
|
||||||
|
|
||||||
|
func TestNewCalculator_DefaultConfig(t *testing.T) {
|
||||||
|
calc := NewCalculator(nil)
|
||||||
|
require.NotNil(t, calc)
|
||||||
|
assert.Equal(t, 7.0, calc.config.RecencyHalfLifeDays)
|
||||||
|
assert.Equal(t, 0.30, calc.config.FeedbackWeight)
|
||||||
|
assert.Equal(t, 0.01, calc.config.MinScore)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestCalculator_ConcurrentAccess(t *testing.T) {
|
||||||
|
calc := NewCalculator(nil)
|
||||||
|
now := time.Now()
|
||||||
|
|
||||||
|
// Test that calculator is safe for concurrent reads
|
||||||
|
done := make(chan bool, 10)
|
||||||
|
for i := 0; i < 10; i++ {
|
||||||
|
go func(id int64) {
|
||||||
|
obs := &models.Observation{
|
||||||
|
ID: id,
|
||||||
|
Type: models.ObsTypeBugfix,
|
||||||
|
CreatedAtEpoch: now.UnixMilli(),
|
||||||
|
}
|
||||||
|
score := calc.Calculate(obs, now)
|
||||||
|
assert.Greater(t, score, 0.0)
|
||||||
|
done <- true
|
||||||
|
}(int64(i))
|
||||||
|
}
|
||||||
|
|
||||||
|
for i := 0; i < 10; i++ {
|
||||||
|
<-done
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestCalculator_DecayPrecision(t *testing.T) {
|
||||||
|
calc := NewCalculator(nil)
|
||||||
|
now := time.Now()
|
||||||
|
|
||||||
|
// Test that decay is mathematically correct
|
||||||
|
testCases := []struct {
|
||||||
|
days int
|
||||||
|
expectedDecay float64
|
||||||
|
}{
|
||||||
|
{0, 1.0},
|
||||||
|
{7, 0.5},
|
||||||
|
{14, 0.25},
|
||||||
|
{21, 0.125},
|
||||||
|
{28, 0.0625},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tc := range testCases {
|
||||||
|
t.Run(string(rune('0'+tc.days/7))+"_half_lives", func(t *testing.T) {
|
||||||
|
obs := &models.Observation{
|
||||||
|
ID: 1,
|
||||||
|
Type: models.ObsTypeRefactor, // 1.0 weight
|
||||||
|
CreatedAtEpoch: now.Add(-time.Duration(tc.days) * 24 * time.Hour).UnixMilli(),
|
||||||
|
}
|
||||||
|
components := calc.CalculateComponents(obs, now)
|
||||||
|
assert.InDelta(t, tc.expectedDecay, components.RecencyDecay, 0.001)
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestTypeBaseScore_UnknownType(t *testing.T) {
|
||||||
|
score := models.TypeBaseScore("unknown-type")
|
||||||
|
assert.Equal(t, 1.0, score, "unknown type should default to 1.0")
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestTypeBaseScore_AllKnownTypes(t *testing.T) {
|
||||||
|
expected := map[models.ObservationType]float64{
|
||||||
|
models.ObsTypeBugfix: 1.3,
|
||||||
|
models.ObsTypeFeature: 1.2,
|
||||||
|
models.ObsTypeDiscovery: 1.1,
|
||||||
|
models.ObsTypeDecision: 1.1,
|
||||||
|
models.ObsTypeRefactor: 1.0,
|
||||||
|
models.ObsTypeChange: 0.9,
|
||||||
|
}
|
||||||
|
|
||||||
|
for obsType, expectedScore := range expected {
|
||||||
|
t.Run(string(obsType), func(t *testing.T) {
|
||||||
|
score := models.TypeBaseScore(obsType)
|
||||||
|
assert.Equal(t, expectedScore, score)
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestCalculator_RetrievalBoostDiminishingReturns(t *testing.T) {
|
||||||
|
calc := NewCalculator(nil)
|
||||||
|
now := time.Now()
|
||||||
|
|
||||||
|
// Test that retrieval boost has diminishing returns
|
||||||
|
// When retrieval count doubles, the boost should NOT double (log2 gives diminishing returns)
|
||||||
|
|
||||||
|
// Collect boosts for different counts
|
||||||
|
boosts := make([]float64, 0)
|
||||||
|
retrievalCounts := []int{1, 3, 7, 15, 31, 63, 127}
|
||||||
|
|
||||||
|
for _, count := range retrievalCounts {
|
||||||
|
obs := &models.Observation{
|
||||||
|
ID: 1,
|
||||||
|
Type: models.ObsTypeRefactor,
|
||||||
|
CreatedAtEpoch: now.UnixMilli(),
|
||||||
|
RetrievalCount: count,
|
||||||
|
}
|
||||||
|
components := calc.CalculateComponents(obs, now)
|
||||||
|
boosts = append(boosts, components.RetrievalContrib)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Verify boost increases but at a decreasing rate
|
||||||
|
for i := 1; i < len(boosts); i++ {
|
||||||
|
// Each boost should be higher than the previous
|
||||||
|
assert.Greater(t, boosts[i], boosts[i-1],
|
||||||
|
"boost should increase with more retrievals")
|
||||||
|
|
||||||
|
// But not proportionally - calculate the ratios
|
||||||
|
if i >= 2 {
|
||||||
|
ratio1 := boosts[i-1] / boosts[i-2]
|
||||||
|
ratio2 := boosts[i] / boosts[i-1]
|
||||||
|
// The growth ratio should be decreasing (diminishing returns)
|
||||||
|
assert.Less(t, ratio2, ratio1+0.01, // Allow small floating point tolerance
|
||||||
|
"growth rate should decrease (diminishing returns)")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
@@ -0,0 +1,186 @@
|
|||||||
|
// Package scoring provides importance score calculation for observations.
|
||||||
|
package scoring
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"sync"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"github.com/rs/zerolog"
|
||||||
|
|
||||||
|
"github.com/lukaszraczylo/claude-mnemonic/internal/db/sqlite"
|
||||||
|
"github.com/lukaszraczylo/claude-mnemonic/pkg/models"
|
||||||
|
)
|
||||||
|
|
||||||
|
// ObservationStore defines the interface for observation storage operations needed by the recalculator.
|
||||||
|
type ObservationStore interface {
|
||||||
|
GetObservationsNeedingScoreUpdate(ctx context.Context, threshold time.Duration, limit int) ([]*models.Observation, error)
|
||||||
|
UpdateImportanceScores(ctx context.Context, scores map[int64]float64) error
|
||||||
|
GetConceptWeights(ctx context.Context) (map[string]float64, error)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Recalculator periodically recalculates importance scores for observations.
|
||||||
|
type Recalculator struct {
|
||||||
|
store ObservationStore
|
||||||
|
calculator *Calculator
|
||||||
|
log zerolog.Logger
|
||||||
|
interval time.Duration
|
||||||
|
batchSize int
|
||||||
|
stopCh chan struct{}
|
||||||
|
doneCh chan struct{}
|
||||||
|
mu sync.Mutex
|
||||||
|
running bool
|
||||||
|
}
|
||||||
|
|
||||||
|
// NewRecalculator creates a new background recalculator.
|
||||||
|
func NewRecalculator(store ObservationStore, calc *Calculator, log zerolog.Logger) *Recalculator {
|
||||||
|
return &Recalculator{
|
||||||
|
store: store,
|
||||||
|
calculator: calc,
|
||||||
|
log: log.With().Str("component", "recalculator").Logger(),
|
||||||
|
interval: 1 * time.Hour, // Run every hour
|
||||||
|
batchSize: 500, // Process 500 observations at a time
|
||||||
|
stopCh: make(chan struct{}),
|
||||||
|
doneCh: make(chan struct{}),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Start begins the background recalculation loop.
|
||||||
|
// This should be called in a goroutine.
|
||||||
|
func (r *Recalculator) Start(ctx context.Context) {
|
||||||
|
r.mu.Lock()
|
||||||
|
if r.running {
|
||||||
|
r.mu.Unlock()
|
||||||
|
return
|
||||||
|
}
|
||||||
|
r.running = true
|
||||||
|
r.mu.Unlock()
|
||||||
|
|
||||||
|
defer func() {
|
||||||
|
r.mu.Lock()
|
||||||
|
r.running = false
|
||||||
|
r.mu.Unlock()
|
||||||
|
close(r.doneCh)
|
||||||
|
}()
|
||||||
|
|
||||||
|
// Initial run
|
||||||
|
r.recalculate(ctx)
|
||||||
|
|
||||||
|
r.mu.Lock()
|
||||||
|
interval := r.interval
|
||||||
|
r.mu.Unlock()
|
||||||
|
|
||||||
|
ticker := time.NewTicker(interval)
|
||||||
|
defer ticker.Stop()
|
||||||
|
|
||||||
|
for {
|
||||||
|
select {
|
||||||
|
case <-ctx.Done():
|
||||||
|
r.log.Info().Msg("recalculator shutting down due to context cancellation")
|
||||||
|
return
|
||||||
|
case <-r.stopCh:
|
||||||
|
r.log.Info().Msg("recalculator stopping")
|
||||||
|
return
|
||||||
|
case <-ticker.C:
|
||||||
|
r.recalculate(ctx)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Stop stops the background recalculation loop.
|
||||||
|
func (r *Recalculator) Stop() {
|
||||||
|
r.mu.Lock()
|
||||||
|
if !r.running {
|
||||||
|
r.mu.Unlock()
|
||||||
|
return
|
||||||
|
}
|
||||||
|
r.mu.Unlock()
|
||||||
|
|
||||||
|
close(r.stopCh)
|
||||||
|
<-r.doneCh
|
||||||
|
}
|
||||||
|
|
||||||
|
// recalculate performs a single recalculation batch.
|
||||||
|
func (r *Recalculator) recalculate(ctx context.Context) {
|
||||||
|
now := time.Now()
|
||||||
|
threshold := r.calculator.RecalculateThreshold()
|
||||||
|
|
||||||
|
r.mu.Lock()
|
||||||
|
batchSize := r.batchSize
|
||||||
|
r.mu.Unlock()
|
||||||
|
|
||||||
|
observations, err := r.store.GetObservationsNeedingScoreUpdate(ctx, threshold, batchSize)
|
||||||
|
if err != nil {
|
||||||
|
r.log.Error().Err(err).Msg("failed to get observations for score update")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
if len(observations) == 0 {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
scores := r.calculator.BatchCalculate(observations, now)
|
||||||
|
|
||||||
|
if err := r.store.UpdateImportanceScores(ctx, scores); err != nil {
|
||||||
|
r.log.Error().Err(err).Msg("failed to update importance scores")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
r.log.Info().
|
||||||
|
Int("count", len(scores)).
|
||||||
|
Dur("elapsed", time.Since(now)).
|
||||||
|
Msg("recalculated importance scores")
|
||||||
|
}
|
||||||
|
|
||||||
|
// RecalculateNow triggers an immediate recalculation.
|
||||||
|
// This is useful for testing or when scores need to be updated urgently.
|
||||||
|
func (r *Recalculator) RecalculateNow(ctx context.Context) error {
|
||||||
|
r.recalculate(ctx)
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// RefreshConceptWeights reloads concept weights from the database.
|
||||||
|
// Call this after updating concept weights to apply changes.
|
||||||
|
func (r *Recalculator) RefreshConceptWeights(ctx context.Context) error {
|
||||||
|
weights, err := r.store.GetConceptWeights(ctx)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
config := r.calculator.GetConfig()
|
||||||
|
config.ConceptWeights = weights
|
||||||
|
r.calculator.UpdateConfig(config)
|
||||||
|
|
||||||
|
r.log.Info().Int("count", len(weights)).Msg("refreshed concept weights")
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// Stats returns statistics about the recalculator.
|
||||||
|
type Stats struct {
|
||||||
|
Running bool `json:"running"`
|
||||||
|
Interval time.Duration `json:"interval"`
|
||||||
|
BatchSize int `json:"batch_size"`
|
||||||
|
HalfLife float64 `json:"half_life_days"`
|
||||||
|
MinScore float64 `json:"min_score"`
|
||||||
|
ConceptsLen int `json:"concepts_count"`
|
||||||
|
}
|
||||||
|
|
||||||
|
// GetStats returns current recalculator statistics.
|
||||||
|
func (r *Recalculator) GetStats() Stats {
|
||||||
|
r.mu.Lock()
|
||||||
|
defer r.mu.Unlock()
|
||||||
|
|
||||||
|
config := r.calculator.GetConfig()
|
||||||
|
|
||||||
|
return Stats{
|
||||||
|
Running: r.running,
|
||||||
|
Interval: r.interval,
|
||||||
|
BatchSize: r.batchSize,
|
||||||
|
HalfLife: config.RecencyHalfLifeDays,
|
||||||
|
MinScore: config.MinScore,
|
||||||
|
ConceptsLen: len(config.ConceptWeights),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Ensure ObservationStore satisfies the interface
|
||||||
|
var _ ObservationStore = (*sqlite.ObservationStore)(nil)
|
||||||
@@ -0,0 +1,447 @@
|
|||||||
|
// Package scoring provides importance score calculation for observations.
|
||||||
|
package scoring
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"sync"
|
||||||
|
"testing"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"github.com/rs/zerolog"
|
||||||
|
"github.com/stretchr/testify/assert"
|
||||||
|
"github.com/stretchr/testify/require"
|
||||||
|
|
||||||
|
"github.com/lukaszraczylo/claude-mnemonic/pkg/models"
|
||||||
|
)
|
||||||
|
|
||||||
|
// MockObservationStore is a mock implementation of ObservationStore for testing.
|
||||||
|
type MockObservationStore struct {
|
||||||
|
mu sync.Mutex
|
||||||
|
observations []*models.Observation
|
||||||
|
scores map[int64]float64
|
||||||
|
conceptWeights map[string]float64
|
||||||
|
updateErr error
|
||||||
|
getErr error
|
||||||
|
getConceptsErr error
|
||||||
|
updateScoresCalls int
|
||||||
|
}
|
||||||
|
|
||||||
|
func NewMockObservationStore() *MockObservationStore {
|
||||||
|
return &MockObservationStore{
|
||||||
|
observations: []*models.Observation{},
|
||||||
|
scores: make(map[int64]float64),
|
||||||
|
conceptWeights: make(map[string]float64),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *MockObservationStore) GetObservationsNeedingScoreUpdate(ctx context.Context, threshold time.Duration, limit int) ([]*models.Observation, error) {
|
||||||
|
m.mu.Lock()
|
||||||
|
defer m.mu.Unlock()
|
||||||
|
|
||||||
|
if m.getErr != nil {
|
||||||
|
return nil, m.getErr
|
||||||
|
}
|
||||||
|
|
||||||
|
// Return observations that haven't been updated within threshold
|
||||||
|
now := time.Now()
|
||||||
|
var result []*models.Observation
|
||||||
|
for _, obs := range m.observations {
|
||||||
|
if !obs.ScoreUpdatedAt.Valid || now.Sub(time.Unix(obs.ScoreUpdatedAt.Int64, 0)) > threshold {
|
||||||
|
result = append(result, obs)
|
||||||
|
if len(result) >= limit {
|
||||||
|
break
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return result, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *MockObservationStore) UpdateImportanceScores(ctx context.Context, scores map[int64]float64) error {
|
||||||
|
m.mu.Lock()
|
||||||
|
defer m.mu.Unlock()
|
||||||
|
|
||||||
|
m.updateScoresCalls++
|
||||||
|
if m.updateErr != nil {
|
||||||
|
return m.updateErr
|
||||||
|
}
|
||||||
|
|
||||||
|
for id, score := range scores {
|
||||||
|
m.scores[id] = score
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *MockObservationStore) GetConceptWeights(ctx context.Context) (map[string]float64, error) {
|
||||||
|
m.mu.Lock()
|
||||||
|
defer m.mu.Unlock()
|
||||||
|
|
||||||
|
if m.getConceptsErr != nil {
|
||||||
|
return nil, m.getConceptsErr
|
||||||
|
}
|
||||||
|
return m.conceptWeights, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *MockObservationStore) AddObservation(obs *models.Observation) {
|
||||||
|
m.mu.Lock()
|
||||||
|
defer m.mu.Unlock()
|
||||||
|
m.observations = append(m.observations, obs)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *MockObservationStore) SetConceptWeights(weights map[string]float64) {
|
||||||
|
m.mu.Lock()
|
||||||
|
defer m.mu.Unlock()
|
||||||
|
m.conceptWeights = weights
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *MockObservationStore) GetScore(id int64) (float64, bool) {
|
||||||
|
m.mu.Lock()
|
||||||
|
defer m.mu.Unlock()
|
||||||
|
score, ok := m.scores[id]
|
||||||
|
return score, ok
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *MockObservationStore) GetUpdateScoresCalls() int {
|
||||||
|
m.mu.Lock()
|
||||||
|
defer m.mu.Unlock()
|
||||||
|
return m.updateScoresCalls
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *MockObservationStore) SetUpdateError(err error) {
|
||||||
|
m.mu.Lock()
|
||||||
|
defer m.mu.Unlock()
|
||||||
|
m.updateErr = err
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *MockObservationStore) SetGetError(err error) {
|
||||||
|
m.mu.Lock()
|
||||||
|
defer m.mu.Unlock()
|
||||||
|
m.getErr = err
|
||||||
|
}
|
||||||
|
|
||||||
|
// =============================================================================
|
||||||
|
// RECALCULATOR TESTS
|
||||||
|
// =============================================================================
|
||||||
|
|
||||||
|
func TestNewRecalculator(t *testing.T) {
|
||||||
|
store := NewMockObservationStore()
|
||||||
|
calc := NewCalculator(nil)
|
||||||
|
log := zerolog.Nop()
|
||||||
|
|
||||||
|
recalc := NewRecalculator(store, calc, log)
|
||||||
|
|
||||||
|
require.NotNil(t, recalc)
|
||||||
|
assert.NotNil(t, recalc.store)
|
||||||
|
assert.NotNil(t, recalc.calculator)
|
||||||
|
assert.Equal(t, 1*time.Hour, recalc.interval)
|
||||||
|
assert.Equal(t, 500, recalc.batchSize)
|
||||||
|
assert.False(t, recalc.running)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestRecalculator_RecalculateNow(t *testing.T) {
|
||||||
|
store := NewMockObservationStore()
|
||||||
|
calc := NewCalculator(nil)
|
||||||
|
log := zerolog.Nop()
|
||||||
|
recalc := NewRecalculator(store, calc, log)
|
||||||
|
|
||||||
|
// Add observations
|
||||||
|
now := time.Now()
|
||||||
|
store.AddObservation(&models.Observation{
|
||||||
|
ID: 1,
|
||||||
|
Type: models.ObsTypeBugfix,
|
||||||
|
CreatedAtEpoch: now.UnixMilli(),
|
||||||
|
})
|
||||||
|
store.AddObservation(&models.Observation{
|
||||||
|
ID: 2,
|
||||||
|
Type: models.ObsTypeFeature,
|
||||||
|
CreatedAtEpoch: now.Add(-7 * 24 * time.Hour).UnixMilli(),
|
||||||
|
})
|
||||||
|
|
||||||
|
ctx := context.Background()
|
||||||
|
err := recalc.RecalculateNow(ctx)
|
||||||
|
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
// Verify scores were calculated
|
||||||
|
score1, ok := store.GetScore(1)
|
||||||
|
assert.True(t, ok)
|
||||||
|
assert.Greater(t, score1, 0.0)
|
||||||
|
|
||||||
|
score2, ok := store.GetScore(2)
|
||||||
|
assert.True(t, ok)
|
||||||
|
assert.Greater(t, score2, 0.0)
|
||||||
|
assert.Less(t, score2, score1, "older observation should have lower score")
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestRecalculator_RefreshConceptWeights(t *testing.T) {
|
||||||
|
store := NewMockObservationStore()
|
||||||
|
calc := NewCalculator(nil)
|
||||||
|
log := zerolog.Nop()
|
||||||
|
recalc := NewRecalculator(store, calc, log)
|
||||||
|
|
||||||
|
// Set up custom weights in store
|
||||||
|
customWeights := map[string]float64{
|
||||||
|
"security": 0.50,
|
||||||
|
"performance": 0.25,
|
||||||
|
}
|
||||||
|
store.SetConceptWeights(customWeights)
|
||||||
|
|
||||||
|
ctx := context.Background()
|
||||||
|
err := recalc.RefreshConceptWeights(ctx)
|
||||||
|
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
// Verify weights were updated in calculator config
|
||||||
|
config := calc.GetConfig()
|
||||||
|
assert.Equal(t, 0.50, config.ConceptWeights["security"])
|
||||||
|
assert.Equal(t, 0.25, config.ConceptWeights["performance"])
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestRecalculator_RefreshConceptWeights_Error(t *testing.T) {
|
||||||
|
store := NewMockObservationStore()
|
||||||
|
calc := NewCalculator(nil)
|
||||||
|
log := zerolog.Nop()
|
||||||
|
recalc := NewRecalculator(store, calc, log)
|
||||||
|
|
||||||
|
store.getConceptsErr = assert.AnError
|
||||||
|
|
||||||
|
ctx := context.Background()
|
||||||
|
err := recalc.RefreshConceptWeights(ctx)
|
||||||
|
|
||||||
|
require.Error(t, err)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestRecalculator_GetStats(t *testing.T) {
|
||||||
|
store := NewMockObservationStore()
|
||||||
|
calc := NewCalculator(nil)
|
||||||
|
log := zerolog.Nop()
|
||||||
|
recalc := NewRecalculator(store, calc, log)
|
||||||
|
|
||||||
|
// Set fields directly for testing
|
||||||
|
recalc.interval = 2 * time.Hour
|
||||||
|
recalc.batchSize = 250
|
||||||
|
|
||||||
|
stats := recalc.GetStats()
|
||||||
|
|
||||||
|
assert.False(t, stats.Running)
|
||||||
|
assert.Equal(t, 2*time.Hour, stats.Interval)
|
||||||
|
assert.Equal(t, 250, stats.BatchSize)
|
||||||
|
assert.Equal(t, 7.0, stats.HalfLife)
|
||||||
|
assert.Equal(t, 0.01, stats.MinScore)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestRecalculator_StartStop(t *testing.T) {
|
||||||
|
store := NewMockObservationStore()
|
||||||
|
calc := NewCalculator(nil)
|
||||||
|
log := zerolog.Nop()
|
||||||
|
recalc := NewRecalculator(store, calc, log)
|
||||||
|
|
||||||
|
// Use a short interval for testing
|
||||||
|
recalc.interval = 50 * time.Millisecond
|
||||||
|
|
||||||
|
// Add an observation
|
||||||
|
store.AddObservation(&models.Observation{
|
||||||
|
ID: 1,
|
||||||
|
Type: models.ObsTypeBugfix,
|
||||||
|
CreatedAtEpoch: time.Now().UnixMilli(),
|
||||||
|
})
|
||||||
|
|
||||||
|
ctx, cancel := context.WithCancel(context.Background())
|
||||||
|
defer cancel()
|
||||||
|
|
||||||
|
// Start in goroutine
|
||||||
|
go recalc.Start(ctx)
|
||||||
|
|
||||||
|
// Wait for initial run and at least one tick
|
||||||
|
time.Sleep(100 * time.Millisecond)
|
||||||
|
|
||||||
|
// Verify it ran
|
||||||
|
calls := store.GetUpdateScoresCalls()
|
||||||
|
assert.GreaterOrEqual(t, calls, 1, "should have run at least once")
|
||||||
|
|
||||||
|
// Stop via context cancellation
|
||||||
|
cancel()
|
||||||
|
time.Sleep(100 * time.Millisecond)
|
||||||
|
|
||||||
|
// Verify stopped
|
||||||
|
stats := recalc.GetStats()
|
||||||
|
assert.False(t, stats.Running)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestRecalculator_StartTwice(t *testing.T) {
|
||||||
|
store := NewMockObservationStore()
|
||||||
|
calc := NewCalculator(nil)
|
||||||
|
log := zerolog.Nop()
|
||||||
|
recalc := NewRecalculator(store, calc, log)
|
||||||
|
|
||||||
|
recalc.interval = 1 * time.Hour // Long interval so it doesn't tick during test
|
||||||
|
|
||||||
|
ctx, cancel := context.WithCancel(context.Background())
|
||||||
|
defer cancel()
|
||||||
|
|
||||||
|
// Start first time
|
||||||
|
go recalc.Start(ctx)
|
||||||
|
time.Sleep(50 * time.Millisecond)
|
||||||
|
|
||||||
|
// Try to start second time (should return immediately)
|
||||||
|
go recalc.Start(ctx)
|
||||||
|
time.Sleep(50 * time.Millisecond)
|
||||||
|
|
||||||
|
// Should still be running (only once)
|
||||||
|
stats := recalc.GetStats()
|
||||||
|
assert.True(t, stats.Running)
|
||||||
|
|
||||||
|
cancel()
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestRecalculator_StopWhenNotRunning(t *testing.T) {
|
||||||
|
store := NewMockObservationStore()
|
||||||
|
calc := NewCalculator(nil)
|
||||||
|
log := zerolog.Nop()
|
||||||
|
recalc := NewRecalculator(store, calc, log)
|
||||||
|
|
||||||
|
// Stop without starting - should not panic
|
||||||
|
recalc.Stop()
|
||||||
|
|
||||||
|
stats := recalc.GetStats()
|
||||||
|
assert.False(t, stats.Running)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestRecalculator_EmptyStore(t *testing.T) {
|
||||||
|
store := NewMockObservationStore()
|
||||||
|
calc := NewCalculator(nil)
|
||||||
|
log := zerolog.Nop()
|
||||||
|
recalc := NewRecalculator(store, calc, log)
|
||||||
|
|
||||||
|
ctx := context.Background()
|
||||||
|
err := recalc.RecalculateNow(ctx)
|
||||||
|
|
||||||
|
require.NoError(t, err)
|
||||||
|
assert.Equal(t, 0, store.GetUpdateScoresCalls(), "should not call update with no observations")
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestRecalculator_GetObservationsError(t *testing.T) {
|
||||||
|
store := NewMockObservationStore()
|
||||||
|
calc := NewCalculator(nil)
|
||||||
|
log := zerolog.Nop()
|
||||||
|
recalc := NewRecalculator(store, calc, log)
|
||||||
|
|
||||||
|
store.SetGetError(assert.AnError)
|
||||||
|
|
||||||
|
ctx := context.Background()
|
||||||
|
err := recalc.RecalculateNow(ctx)
|
||||||
|
|
||||||
|
// Should not return error (logs it instead)
|
||||||
|
require.NoError(t, err)
|
||||||
|
assert.Equal(t, 0, store.GetUpdateScoresCalls())
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestRecalculator_UpdateScoresError(t *testing.T) {
|
||||||
|
store := NewMockObservationStore()
|
||||||
|
calc := NewCalculator(nil)
|
||||||
|
log := zerolog.Nop()
|
||||||
|
recalc := NewRecalculator(store, calc, log)
|
||||||
|
|
||||||
|
store.AddObservation(&models.Observation{
|
||||||
|
ID: 1,
|
||||||
|
Type: models.ObsTypeBugfix,
|
||||||
|
CreatedAtEpoch: time.Now().UnixMilli(),
|
||||||
|
})
|
||||||
|
|
||||||
|
store.SetUpdateError(assert.AnError)
|
||||||
|
|
||||||
|
ctx := context.Background()
|
||||||
|
err := recalc.RecalculateNow(ctx)
|
||||||
|
|
||||||
|
// Should not return error (logs it instead)
|
||||||
|
require.NoError(t, err)
|
||||||
|
assert.Equal(t, 1, store.GetUpdateScoresCalls())
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestRecalculator_BatchProcessing(t *testing.T) {
|
||||||
|
store := NewMockObservationStore()
|
||||||
|
calc := NewCalculator(nil)
|
||||||
|
log := zerolog.Nop()
|
||||||
|
recalc := NewRecalculator(store, calc, log)
|
||||||
|
|
||||||
|
// Set small batch size
|
||||||
|
recalc.batchSize = 3
|
||||||
|
|
||||||
|
// Add 5 observations
|
||||||
|
now := time.Now()
|
||||||
|
for i := 1; i <= 5; i++ {
|
||||||
|
store.AddObservation(&models.Observation{
|
||||||
|
ID: int64(i),
|
||||||
|
Type: models.ObsTypeBugfix,
|
||||||
|
CreatedAtEpoch: now.UnixMilli(),
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
ctx := context.Background()
|
||||||
|
err := recalc.RecalculateNow(ctx)
|
||||||
|
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
// Should only process batch size (3)
|
||||||
|
scores := 0
|
||||||
|
for i := 1; i <= 5; i++ {
|
||||||
|
if _, ok := store.GetScore(int64(i)); ok {
|
||||||
|
scores++
|
||||||
|
}
|
||||||
|
}
|
||||||
|
assert.Equal(t, 3, scores, "should only process batch size observations")
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestRecalculator_ConcurrentAccess(t *testing.T) {
|
||||||
|
store := NewMockObservationStore()
|
||||||
|
calc := NewCalculator(nil)
|
||||||
|
log := zerolog.Nop()
|
||||||
|
recalc := NewRecalculator(store, calc, log)
|
||||||
|
|
||||||
|
// Add observations
|
||||||
|
now := time.Now()
|
||||||
|
for i := 1; i <= 10; i++ {
|
||||||
|
store.AddObservation(&models.Observation{
|
||||||
|
ID: int64(i),
|
||||||
|
Type: models.ObsTypeBugfix,
|
||||||
|
CreatedAtEpoch: now.UnixMilli(),
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
ctx := context.Background()
|
||||||
|
var wg sync.WaitGroup
|
||||||
|
|
||||||
|
// Run multiple recalculations concurrently
|
||||||
|
for i := 0; i < 5; i++ {
|
||||||
|
wg.Add(1)
|
||||||
|
go func() {
|
||||||
|
defer wg.Done()
|
||||||
|
_ = recalc.RecalculateNow(ctx)
|
||||||
|
}()
|
||||||
|
}
|
||||||
|
|
||||||
|
wg.Wait()
|
||||||
|
|
||||||
|
// Should complete without race conditions
|
||||||
|
// (use -race flag to verify)
|
||||||
|
assert.GreaterOrEqual(t, store.GetUpdateScoresCalls(), 1)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestRecalculator_StatsThreadSafe(t *testing.T) {
|
||||||
|
store := NewMockObservationStore()
|
||||||
|
calc := NewCalculator(nil)
|
||||||
|
log := zerolog.Nop()
|
||||||
|
recalc := NewRecalculator(store, calc, log)
|
||||||
|
|
||||||
|
var wg sync.WaitGroup
|
||||||
|
|
||||||
|
// Concurrent reads (use -race flag to verify)
|
||||||
|
for i := 0; i < 10; i++ {
|
||||||
|
wg.Add(1)
|
||||||
|
go func() {
|
||||||
|
defer wg.Done()
|
||||||
|
_ = recalc.GetStats()
|
||||||
|
}()
|
||||||
|
}
|
||||||
|
|
||||||
|
wg.Wait()
|
||||||
|
}
|
||||||
@@ -0,0 +1,448 @@
|
|||||||
|
// Package expansion provides context-aware query expansion for improved search recall.
|
||||||
|
package expansion
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"regexp"
|
||||||
|
"strings"
|
||||||
|
"sync"
|
||||||
|
|
||||||
|
"github.com/lukaszraczylo/claude-mnemonic/internal/embedding"
|
||||||
|
"github.com/rs/zerolog/log"
|
||||||
|
)
|
||||||
|
|
||||||
|
// QueryIntent represents the detected intent of a query.
|
||||||
|
type QueryIntent string
|
||||||
|
|
||||||
|
const (
|
||||||
|
// IntentQuestion indicates a question-type query (how, why, what, etc.)
|
||||||
|
IntentQuestion QueryIntent = "question"
|
||||||
|
// IntentError indicates an error/debugging query
|
||||||
|
IntentError QueryIntent = "error"
|
||||||
|
// IntentImplementation indicates an implementation/coding query
|
||||||
|
IntentImplementation QueryIntent = "implementation"
|
||||||
|
// IntentArchitecture indicates an architecture/design query
|
||||||
|
IntentArchitecture QueryIntent = "architecture"
|
||||||
|
// IntentGeneral indicates a general lookup query
|
||||||
|
IntentGeneral QueryIntent = "general"
|
||||||
|
)
|
||||||
|
|
||||||
|
// ExpandedQuery represents a query variant with metadata.
|
||||||
|
type ExpandedQuery struct {
|
||||||
|
Query string `json:"query"`
|
||||||
|
Weight float64 `json:"weight"` // Weight for result merging (0.0-1.0)
|
||||||
|
Source string `json:"source"` // Where this expansion came from
|
||||||
|
Intent QueryIntent `json:"intent"` // Detected intent
|
||||||
|
}
|
||||||
|
|
||||||
|
// Expander provides context-aware query expansion.
|
||||||
|
type Expander struct {
|
||||||
|
embedSvc *embedding.Service
|
||||||
|
vocabulary []VocabEntry // Known vocabulary from observations
|
||||||
|
vocabVectors [][]float32 // Embeddings for vocabulary entries
|
||||||
|
vocabMu sync.RWMutex // Protects vocabulary
|
||||||
|
intentPatterns map[QueryIntent][]*regexp.Regexp
|
||||||
|
}
|
||||||
|
|
||||||
|
// VocabEntry represents a vocabulary term from observations.
|
||||||
|
type VocabEntry struct {
|
||||||
|
Term string // The term itself
|
||||||
|
Weight float64 // How common/important this term is (0.0-1.0)
|
||||||
|
Source string // Where it came from (title, concept, narrative)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Config holds expander configuration.
|
||||||
|
type Config struct {
|
||||||
|
// MaxExpansions limits the number of expanded queries returned
|
||||||
|
MaxExpansions int
|
||||||
|
// MinSimilarity is the minimum similarity score for vocabulary expansion
|
||||||
|
MinSimilarity float64
|
||||||
|
// EnableVocabularyExpansion enables finding related terms from observations
|
||||||
|
EnableVocabularyExpansion bool
|
||||||
|
}
|
||||||
|
|
||||||
|
// DefaultConfig returns sensible default configuration.
|
||||||
|
func DefaultConfig() Config {
|
||||||
|
return Config{
|
||||||
|
MaxExpansions: 4,
|
||||||
|
MinSimilarity: 0.5,
|
||||||
|
EnableVocabularyExpansion: true,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// NewExpander creates a new query expander.
|
||||||
|
func NewExpander(embedSvc *embedding.Service) *Expander {
|
||||||
|
e := &Expander{
|
||||||
|
embedSvc: embedSvc,
|
||||||
|
intentPatterns: buildIntentPatterns(),
|
||||||
|
}
|
||||||
|
return e
|
||||||
|
}
|
||||||
|
|
||||||
|
// buildIntentPatterns creates regex patterns for intent detection.
|
||||||
|
func buildIntentPatterns() map[QueryIntent][]*regexp.Regexp {
|
||||||
|
patterns := make(map[QueryIntent][]*regexp.Regexp)
|
||||||
|
|
||||||
|
// Question patterns
|
||||||
|
patterns[IntentQuestion] = []*regexp.Regexp{
|
||||||
|
regexp.MustCompile(`(?i)^(how|why|what|when|where|which|who)\b`),
|
||||||
|
regexp.MustCompile(`(?i)\?$`),
|
||||||
|
regexp.MustCompile(`(?i)\b(explain|describe|understand)\b`),
|
||||||
|
}
|
||||||
|
|
||||||
|
// Error/debugging patterns
|
||||||
|
patterns[IntentError] = []*regexp.Regexp{
|
||||||
|
regexp.MustCompile(`(?i)\b(error|bug|issue|problem|fail|crash|exception|panic)\b`),
|
||||||
|
regexp.MustCompile(`(?i)\b(fix|debug|troubleshoot|resolve)\b`),
|
||||||
|
regexp.MustCompile(`(?i)\b(doesn't work|not working|broken)\b`),
|
||||||
|
}
|
||||||
|
|
||||||
|
// Implementation patterns
|
||||||
|
patterns[IntentImplementation] = []*regexp.Regexp{
|
||||||
|
regexp.MustCompile(`(?i)\b(implement|add|create|build|write|code)\b`),
|
||||||
|
regexp.MustCompile(`(?i)\b(function|method|handler|endpoint|api)\b`),
|
||||||
|
regexp.MustCompile(`(?i)\b(feature|functionality)\b`),
|
||||||
|
}
|
||||||
|
|
||||||
|
// Architecture patterns
|
||||||
|
patterns[IntentArchitecture] = []*regexp.Regexp{
|
||||||
|
regexp.MustCompile(`(?i)\b(architecture|design|pattern|structure)\b`),
|
||||||
|
regexp.MustCompile(`(?i)\b(component|module|layer|service)\b`),
|
||||||
|
regexp.MustCompile(`(?i)\b(flow|pipeline|workflow)\b`),
|
||||||
|
}
|
||||||
|
|
||||||
|
return patterns
|
||||||
|
}
|
||||||
|
|
||||||
|
// DetectIntent analyzes a query to determine its intent.
|
||||||
|
func (e *Expander) DetectIntent(query string) QueryIntent {
|
||||||
|
query = strings.TrimSpace(query)
|
||||||
|
if query == "" {
|
||||||
|
return IntentGeneral
|
||||||
|
}
|
||||||
|
|
||||||
|
// Check patterns in priority order
|
||||||
|
intentOrder := []QueryIntent{IntentError, IntentQuestion, IntentImplementation, IntentArchitecture}
|
||||||
|
|
||||||
|
for _, intent := range intentOrder {
|
||||||
|
patterns := e.intentPatterns[intent]
|
||||||
|
for _, pattern := range patterns {
|
||||||
|
if pattern.MatchString(query) {
|
||||||
|
return intent
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return IntentGeneral
|
||||||
|
}
|
||||||
|
|
||||||
|
// Expand generates expanded query variants based on the original query.
|
||||||
|
func (e *Expander) Expand(ctx context.Context, query string, cfg Config) []ExpandedQuery {
|
||||||
|
query = strings.TrimSpace(query)
|
||||||
|
if query == "" {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
intent := e.DetectIntent(query)
|
||||||
|
expansions := make([]ExpandedQuery, 0, cfg.MaxExpansions)
|
||||||
|
|
||||||
|
// Always include the original query with highest weight
|
||||||
|
expansions = append(expansions, ExpandedQuery{
|
||||||
|
Query: query,
|
||||||
|
Weight: 1.0,
|
||||||
|
Source: "original",
|
||||||
|
Intent: intent,
|
||||||
|
})
|
||||||
|
|
||||||
|
// Generate intent-based expansions
|
||||||
|
intentExpansions := e.expandByIntent(query, intent)
|
||||||
|
expansions = append(expansions, intentExpansions...)
|
||||||
|
|
||||||
|
// Generate vocabulary-based expansions if enabled and we have vocabulary
|
||||||
|
if cfg.EnableVocabularyExpansion && e.embedSvc != nil && len(e.vocabulary) > 0 {
|
||||||
|
vocabExpansions := e.expandByVocabulary(ctx, query, cfg.MinSimilarity)
|
||||||
|
expansions = append(expansions, vocabExpansions...)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Deduplicate and limit
|
||||||
|
expansions = deduplicateExpansions(expansions)
|
||||||
|
if len(expansions) > cfg.MaxExpansions {
|
||||||
|
expansions = expansions[:cfg.MaxExpansions]
|
||||||
|
}
|
||||||
|
|
||||||
|
log.Debug().
|
||||||
|
Str("query", truncate(query, 50)).
|
||||||
|
Str("intent", string(intent)).
|
||||||
|
Int("expansions", len(expansions)).
|
||||||
|
Msg("Query expanded")
|
||||||
|
|
||||||
|
return expansions
|
||||||
|
}
|
||||||
|
|
||||||
|
// expandByIntent generates expansions based on detected query intent.
|
||||||
|
func (e *Expander) expandByIntent(query string, intent QueryIntent) []ExpandedQuery {
|
||||||
|
var expansions []ExpandedQuery
|
||||||
|
|
||||||
|
// Extract key terms from query for context-aware expansion
|
||||||
|
keyTerms := extractKeyTerms(query)
|
||||||
|
|
||||||
|
switch intent {
|
||||||
|
case IntentQuestion:
|
||||||
|
// For questions, create a declarative variant
|
||||||
|
declarative := makeDeclarative(query)
|
||||||
|
if declarative != query {
|
||||||
|
expansions = append(expansions, ExpandedQuery{
|
||||||
|
Query: declarative,
|
||||||
|
Weight: 0.85,
|
||||||
|
Source: "intent:declarative",
|
||||||
|
Intent: intent,
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
case IntentError:
|
||||||
|
// For errors, expand with solution-oriented terms
|
||||||
|
if len(keyTerms) > 0 {
|
||||||
|
solutionQuery := strings.Join(keyTerms, " ") + " solution fix"
|
||||||
|
expansions = append(expansions, ExpandedQuery{
|
||||||
|
Query: solutionQuery,
|
||||||
|
Weight: 0.8,
|
||||||
|
Source: "intent:solution",
|
||||||
|
Intent: intent,
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
case IntentImplementation:
|
||||||
|
// For implementation queries, focus on the what/how
|
||||||
|
if len(keyTerms) > 0 {
|
||||||
|
howQuery := "how " + strings.Join(keyTerms, " ")
|
||||||
|
expansions = append(expansions, ExpandedQuery{
|
||||||
|
Query: howQuery,
|
||||||
|
Weight: 0.75,
|
||||||
|
Source: "intent:how",
|
||||||
|
Intent: intent,
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
case IntentArchitecture:
|
||||||
|
// For architecture queries, expand with design context
|
||||||
|
if len(keyTerms) > 0 {
|
||||||
|
designQuery := strings.Join(keyTerms, " ") + " design structure"
|
||||||
|
expansions = append(expansions, ExpandedQuery{
|
||||||
|
Query: designQuery,
|
||||||
|
Weight: 0.75,
|
||||||
|
Source: "intent:design",
|
||||||
|
Intent: intent,
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
case IntentGeneral:
|
||||||
|
// For general queries, try noun phrase extraction
|
||||||
|
// No additional expansion - rely on vocabulary expansion
|
||||||
|
}
|
||||||
|
|
||||||
|
return expansions
|
||||||
|
}
|
||||||
|
|
||||||
|
// expandByVocabulary finds similar terms from the observation vocabulary.
|
||||||
|
func (e *Expander) expandByVocabulary(ctx context.Context, query string, minSimilarity float64) []ExpandedQuery {
|
||||||
|
e.vocabMu.RLock()
|
||||||
|
defer e.vocabMu.RUnlock()
|
||||||
|
|
||||||
|
if len(e.vocabulary) == 0 || e.embedSvc == nil {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// Embed the query
|
||||||
|
queryEmb, err := e.embedSvc.Embed(query)
|
||||||
|
if err != nil {
|
||||||
|
log.Warn().Err(err).Msg("Failed to embed query for vocabulary expansion")
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// Find similar vocabulary terms
|
||||||
|
type scoredTerm struct {
|
||||||
|
entry VocabEntry
|
||||||
|
score float64
|
||||||
|
}
|
||||||
|
|
||||||
|
var similar []scoredTerm
|
||||||
|
for i, entry := range e.vocabulary {
|
||||||
|
if i >= len(e.vocabVectors) {
|
||||||
|
break
|
||||||
|
}
|
||||||
|
|
||||||
|
score := cosineSimilarity(queryEmb, e.vocabVectors[i])
|
||||||
|
if score >= minSimilarity {
|
||||||
|
similar = append(similar, scoredTerm{entry: entry, score: score})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if len(similar) == 0 {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// Sort by score (descending) using bubble sort
|
||||||
|
for i := 0; i < len(similar)-1; i++ {
|
||||||
|
for j := i + 1; j < len(similar); j++ {
|
||||||
|
if similar[j].score > similar[i].score {
|
||||||
|
similar[i], similar[j] = similar[j], similar[i]
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Create expansion by combining top similar terms with query
|
||||||
|
var expansions []ExpandedQuery
|
||||||
|
if len(similar) > 0 {
|
||||||
|
// Take top 2 similar terms and combine with original key terms
|
||||||
|
keyTerms := extractKeyTerms(query)
|
||||||
|
for i := 0; i < min(2, len(similar)); i++ {
|
||||||
|
term := similar[i].entry.Term
|
||||||
|
// Don't add if term is already in query
|
||||||
|
if strings.Contains(strings.ToLower(query), strings.ToLower(term)) {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
combinedQuery := strings.Join(keyTerms, " ") + " " + term
|
||||||
|
expansions = append(expansions, ExpandedQuery{
|
||||||
|
Query: combinedQuery,
|
||||||
|
Weight: 0.7 * similar[i].score * similar[i].entry.Weight,
|
||||||
|
Source: "vocabulary:" + term,
|
||||||
|
Intent: IntentGeneral,
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return expansions
|
||||||
|
}
|
||||||
|
|
||||||
|
// Helper functions
|
||||||
|
|
||||||
|
// extractKeyTerms extracts meaningful terms from a query.
|
||||||
|
func extractKeyTerms(query string) []string {
|
||||||
|
// Common stop words to filter out
|
||||||
|
stopWords := map[string]bool{
|
||||||
|
"a": true, "an": true, "the": true, "is": true, "are": true,
|
||||||
|
"was": true, "were": true, "be": true, "been": true, "being": true,
|
||||||
|
"have": true, "has": true, "had": true, "do": true, "does": true,
|
||||||
|
"did": true, "will": true, "would": true, "could": true, "should": true,
|
||||||
|
"may": true, "might": true, "must": true, "can": true,
|
||||||
|
"i": true, "me": true, "my": true, "we": true, "our": true,
|
||||||
|
"you": true, "your": true, "it": true, "its": true,
|
||||||
|
"this": true, "that": true, "these": true, "those": true,
|
||||||
|
"what": true, "which": true, "who": true, "whom": true,
|
||||||
|
"how": true, "why": true, "when": true, "where": true,
|
||||||
|
"to": true, "for": true, "with": true, "about": true, "from": true,
|
||||||
|
"in": true, "on": true, "at": true, "by": true, "of": true,
|
||||||
|
"and": true, "or": true, "but": true, "if": true, "then": true,
|
||||||
|
}
|
||||||
|
|
||||||
|
// Split and filter
|
||||||
|
words := strings.Fields(strings.ToLower(query))
|
||||||
|
var terms []string
|
||||||
|
|
||||||
|
for _, word := range words {
|
||||||
|
// Remove punctuation
|
||||||
|
word = strings.Trim(word, ".,?!;:'\"()[]{}")
|
||||||
|
if len(word) < 2 {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
if stopWords[word] {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
terms = append(terms, word)
|
||||||
|
}
|
||||||
|
|
||||||
|
return terms
|
||||||
|
}
|
||||||
|
|
||||||
|
// makeDeclarative converts a question to a declarative statement.
|
||||||
|
func makeDeclarative(query string) string {
|
||||||
|
query = strings.TrimSpace(query)
|
||||||
|
|
||||||
|
// Remove question mark
|
||||||
|
query = strings.TrimSuffix(query, "?")
|
||||||
|
|
||||||
|
// Handle common question patterns
|
||||||
|
patterns := []struct {
|
||||||
|
prefix string
|
||||||
|
replacement string
|
||||||
|
}{
|
||||||
|
{"how do i ", ""},
|
||||||
|
{"how to ", ""},
|
||||||
|
{"how does ", ""},
|
||||||
|
{"how is ", ""},
|
||||||
|
{"what is ", ""},
|
||||||
|
{"what are ", ""},
|
||||||
|
{"why does ", ""},
|
||||||
|
{"why is ", ""},
|
||||||
|
{"where is ", ""},
|
||||||
|
{"where are ", ""},
|
||||||
|
{"when does ", ""},
|
||||||
|
{"when is ", ""},
|
||||||
|
}
|
||||||
|
|
||||||
|
lower := strings.ToLower(query)
|
||||||
|
for _, p := range patterns {
|
||||||
|
if strings.HasPrefix(lower, p.prefix) {
|
||||||
|
return strings.TrimSpace(query[len(p.prefix):])
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return query
|
||||||
|
}
|
||||||
|
|
||||||
|
// deduplicateExpansions removes duplicate queries while preserving order.
|
||||||
|
func deduplicateExpansions(expansions []ExpandedQuery) []ExpandedQuery {
|
||||||
|
seen := make(map[string]bool)
|
||||||
|
result := make([]ExpandedQuery, 0, len(expansions))
|
||||||
|
|
||||||
|
for _, exp := range expansions {
|
||||||
|
normalized := strings.ToLower(strings.TrimSpace(exp.Query))
|
||||||
|
if !seen[normalized] {
|
||||||
|
seen[normalized] = true
|
||||||
|
result = append(result, exp)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return result
|
||||||
|
}
|
||||||
|
|
||||||
|
// cosineSimilarity computes cosine similarity between two vectors.
|
||||||
|
func cosineSimilarity(a, b []float32) float64 {
|
||||||
|
if len(a) != len(b) || len(a) == 0 {
|
||||||
|
return 0
|
||||||
|
}
|
||||||
|
|
||||||
|
var dot, normA, normB float64
|
||||||
|
for i := range a {
|
||||||
|
dot += float64(a[i]) * float64(b[i])
|
||||||
|
normA += float64(a[i]) * float64(a[i])
|
||||||
|
normB += float64(b[i]) * float64(b[i])
|
||||||
|
}
|
||||||
|
|
||||||
|
if normA == 0 || normB == 0 {
|
||||||
|
return 0
|
||||||
|
}
|
||||||
|
|
||||||
|
return dot / (sqrt(normA) * sqrt(normB))
|
||||||
|
}
|
||||||
|
|
||||||
|
// sqrt is a simple square root implementation.
|
||||||
|
func sqrt(x float64) float64 {
|
||||||
|
if x <= 0 {
|
||||||
|
return 0
|
||||||
|
}
|
||||||
|
z := x
|
||||||
|
for i := 0; i < 10; i++ {
|
||||||
|
z = (z + x/z) / 2
|
||||||
|
}
|
||||||
|
return z
|
||||||
|
}
|
||||||
|
|
||||||
|
// truncate truncates a string to maxLen characters.
|
||||||
|
func truncate(s string, maxLen int) string {
|
||||||
|
if len(s) <= maxLen {
|
||||||
|
return s
|
||||||
|
}
|
||||||
|
return s[:maxLen] + "..."
|
||||||
|
}
|
||||||
@@ -0,0 +1,518 @@
|
|||||||
|
// Package expansion provides context-aware query expansion for improved search recall.
|
||||||
|
package expansion
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"testing"
|
||||||
|
|
||||||
|
"github.com/stretchr/testify/assert"
|
||||||
|
"github.com/stretchr/testify/suite"
|
||||||
|
)
|
||||||
|
|
||||||
|
// ExpanderSuite tests the Expander functionality.
|
||||||
|
type ExpanderSuite struct {
|
||||||
|
suite.Suite
|
||||||
|
expander *Expander
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestExpanderSuite(t *testing.T) {
|
||||||
|
suite.Run(t, new(ExpanderSuite))
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *ExpanderSuite) SetupTest() {
|
||||||
|
// Create expander without embedding service for basic tests
|
||||||
|
s.expander = NewExpander(nil)
|
||||||
|
}
|
||||||
|
|
||||||
|
// TestNewExpander tests expander creation.
|
||||||
|
func (s *ExpanderSuite) TestNewExpander() {
|
||||||
|
e := NewExpander(nil)
|
||||||
|
s.NotNil(e)
|
||||||
|
s.NotNil(e.intentPatterns)
|
||||||
|
s.Nil(e.embedSvc)
|
||||||
|
}
|
||||||
|
|
||||||
|
// TestDetectIntent tests intent detection.
|
||||||
|
func (s *ExpanderSuite) TestDetectIntent() {
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
query string
|
||||||
|
expected QueryIntent
|
||||||
|
}{
|
||||||
|
// Question intents
|
||||||
|
{"how_question", "how do I implement auth?", IntentQuestion},
|
||||||
|
{"why_question", "why does this fail?", IntentError}, // "fail" triggers error first
|
||||||
|
{"what_question", "what is the purpose of this function?", IntentQuestion},
|
||||||
|
{"question_mark", "the handler for auth?", IntentQuestion},
|
||||||
|
{"explain", "explain the architecture", IntentQuestion},
|
||||||
|
|
||||||
|
// Error intents
|
||||||
|
{"error_word", "authentication error in login", IntentError},
|
||||||
|
{"bug_word", "bug in user registration", IntentError},
|
||||||
|
{"fix_word", "fix the memory leak", IntentError},
|
||||||
|
{"not_working", "login not working", IntentError},
|
||||||
|
{"crash", "application crash on startup", IntentError},
|
||||||
|
|
||||||
|
// Implementation intents
|
||||||
|
{"implement", "implement user authentication", IntentImplementation},
|
||||||
|
{"add_feature", "add new endpoint for users", IntentImplementation},
|
||||||
|
{"create", "create a handler for uploads", IntentImplementation},
|
||||||
|
{"function", "function to validate input", IntentImplementation},
|
||||||
|
|
||||||
|
// Architecture intents
|
||||||
|
{"architecture", "architecture of the system", IntentArchitecture},
|
||||||
|
{"design", "design pattern for observers", IntentArchitecture},
|
||||||
|
{"component", "component structure", IntentArchitecture},
|
||||||
|
{"flow", "data flow in the pipeline", IntentArchitecture},
|
||||||
|
|
||||||
|
// General intents
|
||||||
|
{"general", "user authentication", IntentGeneral},
|
||||||
|
{"empty", "", IntentGeneral},
|
||||||
|
{"simple", "database", IntentGeneral},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tt := range tests {
|
||||||
|
s.Run(tt.name, func() {
|
||||||
|
result := s.expander.DetectIntent(tt.query)
|
||||||
|
s.Equal(tt.expected, result, "Query: %s", tt.query)
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// TestExpand tests basic query expansion.
|
||||||
|
func (s *ExpanderSuite) TestExpand() {
|
||||||
|
ctx := context.Background()
|
||||||
|
cfg := DefaultConfig()
|
||||||
|
cfg.EnableVocabularyExpansion = false // Disable for unit test
|
||||||
|
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
query string
|
||||||
|
minExpansions int
|
||||||
|
hasOriginal bool
|
||||||
|
expectedIntent QueryIntent
|
||||||
|
}{
|
||||||
|
{"question", "how do I implement auth", 1, true, IntentQuestion},
|
||||||
|
{"error", "fix the bug in login", 1, true, IntentError},
|
||||||
|
{"implementation", "implement user handler", 1, true, IntentImplementation},
|
||||||
|
{"architecture", "architecture design", 1, true, IntentArchitecture},
|
||||||
|
{"general", "database connection", 1, true, IntentGeneral},
|
||||||
|
{"empty", "", 0, false, IntentGeneral},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tt := range tests {
|
||||||
|
s.Run(tt.name, func() {
|
||||||
|
expansions := s.expander.Expand(ctx, tt.query, cfg)
|
||||||
|
|
||||||
|
if tt.minExpansions == 0 {
|
||||||
|
s.Empty(expansions)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
s.GreaterOrEqual(len(expansions), tt.minExpansions)
|
||||||
|
|
||||||
|
if tt.hasOriginal {
|
||||||
|
// First expansion should be the original
|
||||||
|
s.Equal(tt.query, expansions[0].Query)
|
||||||
|
s.Equal(1.0, expansions[0].Weight)
|
||||||
|
s.Equal("original", expansions[0].Source)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// TestExpandWithConfig tests expansion with custom config.
|
||||||
|
func (s *ExpanderSuite) TestExpandWithConfig() {
|
||||||
|
ctx := context.Background()
|
||||||
|
|
||||||
|
cfg := Config{
|
||||||
|
MaxExpansions: 2,
|
||||||
|
MinSimilarity: 0.7,
|
||||||
|
EnableVocabularyExpansion: false,
|
||||||
|
}
|
||||||
|
|
||||||
|
expansions := s.expander.Expand(ctx, "how to implement authentication", cfg)
|
||||||
|
s.LessOrEqual(len(expansions), cfg.MaxExpansions)
|
||||||
|
}
|
||||||
|
|
||||||
|
// TestExpandDeduplication tests that duplicates are removed.
|
||||||
|
func (s *ExpanderSuite) TestExpandDeduplication() {
|
||||||
|
ctx := context.Background()
|
||||||
|
cfg := DefaultConfig()
|
||||||
|
cfg.EnableVocabularyExpansion = false
|
||||||
|
|
||||||
|
// Query that might generate duplicate expansions
|
||||||
|
query := "how to fix authentication"
|
||||||
|
expansions := s.expander.Expand(ctx, query, cfg)
|
||||||
|
|
||||||
|
// Check for duplicates
|
||||||
|
seen := make(map[string]bool)
|
||||||
|
for _, exp := range expansions {
|
||||||
|
normalized := exp.Query
|
||||||
|
s.False(seen[normalized], "Duplicate expansion found: %s", exp.Query)
|
||||||
|
seen[normalized] = true
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// TestExtractKeyTerms tests key term extraction.
|
||||||
|
func TestExtractKeyTerms(t *testing.T) {
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
query string
|
||||||
|
expected []string
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
name: "simple",
|
||||||
|
query: "user authentication handler",
|
||||||
|
expected: []string{"user", "authentication", "handler"},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "with_stop_words",
|
||||||
|
query: "how to implement the user login",
|
||||||
|
expected: []string{"implement", "user", "login"},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "with_punctuation",
|
||||||
|
query: "fix the bug, please!",
|
||||||
|
expected: []string{"fix", "bug", "please"},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "empty",
|
||||||
|
query: "",
|
||||||
|
expected: nil,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "only_stop_words",
|
||||||
|
query: "the a an is are",
|
||||||
|
expected: nil,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "short_words_filtered",
|
||||||
|
query: "a b c auth",
|
||||||
|
expected: []string{"auth"},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tt := range tests {
|
||||||
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
result := extractKeyTerms(tt.query)
|
||||||
|
if tt.expected == nil {
|
||||||
|
assert.Empty(t, result)
|
||||||
|
} else {
|
||||||
|
assert.Equal(t, tt.expected, result)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// TestMakeDeclarative tests question to declarative conversion.
|
||||||
|
func TestMakeDeclarative(t *testing.T) {
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
query string
|
||||||
|
expected string
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
name: "how_do_i",
|
||||||
|
query: "how do I implement auth?",
|
||||||
|
expected: "implement auth",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "how_to",
|
||||||
|
query: "how to fix the bug",
|
||||||
|
expected: "fix the bug",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "what_is",
|
||||||
|
query: "what is the purpose of this?",
|
||||||
|
expected: "the purpose of this",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "why_does",
|
||||||
|
query: "why does this fail?",
|
||||||
|
expected: "this fail",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "already_declarative",
|
||||||
|
query: "user authentication",
|
||||||
|
expected: "user authentication",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "question_mark_only",
|
||||||
|
query: "authentication?",
|
||||||
|
expected: "authentication",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "case_insensitive",
|
||||||
|
query: "How To Fix Auth?",
|
||||||
|
expected: "Fix Auth",
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tt := range tests {
|
||||||
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
result := makeDeclarative(tt.query)
|
||||||
|
assert.Equal(t, tt.expected, result)
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// TestDeduplicateExpansions tests deduplication.
|
||||||
|
func TestDeduplicateExpansions(t *testing.T) {
|
||||||
|
expansions := []ExpandedQuery{
|
||||||
|
{Query: "auth handler", Weight: 1.0},
|
||||||
|
{Query: "AUTH HANDLER", Weight: 0.8}, // Duplicate (case insensitive)
|
||||||
|
{Query: "auth handler ", Weight: 0.7}, // Duplicate (whitespace)
|
||||||
|
{Query: "user auth", Weight: 0.6},
|
||||||
|
}
|
||||||
|
|
||||||
|
result := deduplicateExpansions(expansions)
|
||||||
|
assert.Len(t, result, 2) // "auth handler" and "user auth"
|
||||||
|
assert.Equal(t, "auth handler", result[0].Query)
|
||||||
|
assert.Equal(t, 1.0, result[0].Weight) // First one preserved
|
||||||
|
assert.Equal(t, "user auth", result[1].Query)
|
||||||
|
}
|
||||||
|
|
||||||
|
// TestCosineSimilarity tests cosine similarity calculation.
|
||||||
|
func TestCosineSimilarity(t *testing.T) {
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
a []float32
|
||||||
|
b []float32
|
||||||
|
expected float64
|
||||||
|
delta float64
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
name: "identical_vectors",
|
||||||
|
a: []float32{1, 0, 0},
|
||||||
|
b: []float32{1, 0, 0},
|
||||||
|
expected: 1.0,
|
||||||
|
delta: 0.001,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "orthogonal_vectors",
|
||||||
|
a: []float32{1, 0, 0},
|
||||||
|
b: []float32{0, 1, 0},
|
||||||
|
expected: 0.0,
|
||||||
|
delta: 0.001,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "opposite_vectors",
|
||||||
|
a: []float32{1, 0, 0},
|
||||||
|
b: []float32{-1, 0, 0},
|
||||||
|
expected: -1.0,
|
||||||
|
delta: 0.001,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "similar_vectors",
|
||||||
|
a: []float32{1, 1, 0},
|
||||||
|
b: []float32{1, 0, 0},
|
||||||
|
expected: 0.707,
|
||||||
|
delta: 0.01,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "empty_vectors",
|
||||||
|
a: []float32{},
|
||||||
|
b: []float32{},
|
||||||
|
expected: 0.0,
|
||||||
|
delta: 0.001,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "different_lengths",
|
||||||
|
a: []float32{1, 0},
|
||||||
|
b: []float32{1, 0, 0},
|
||||||
|
expected: 0.0,
|
||||||
|
delta: 0.001,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "zero_vector",
|
||||||
|
a: []float32{0, 0, 0},
|
||||||
|
b: []float32{1, 1, 1},
|
||||||
|
expected: 0.0,
|
||||||
|
delta: 0.001,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tt := range tests {
|
||||||
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
result := cosineSimilarity(tt.a, tt.b)
|
||||||
|
assert.InDelta(t, tt.expected, result, tt.delta)
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// TestDefaultConfig tests default configuration.
|
||||||
|
func TestDefaultConfig(t *testing.T) {
|
||||||
|
cfg := DefaultConfig()
|
||||||
|
|
||||||
|
assert.Equal(t, 4, cfg.MaxExpansions)
|
||||||
|
assert.Equal(t, 0.5, cfg.MinSimilarity)
|
||||||
|
assert.True(t, cfg.EnableVocabularyExpansion)
|
||||||
|
}
|
||||||
|
|
||||||
|
// TestExpandedQueryStruct tests ExpandedQuery struct.
|
||||||
|
func TestExpandedQueryStruct(t *testing.T) {
|
||||||
|
eq := ExpandedQuery{
|
||||||
|
Query: "test query",
|
||||||
|
Weight: 0.85,
|
||||||
|
Source: "vocabulary:auth",
|
||||||
|
Intent: IntentQuestion,
|
||||||
|
}
|
||||||
|
|
||||||
|
assert.Equal(t, "test query", eq.Query)
|
||||||
|
assert.Equal(t, 0.85, eq.Weight)
|
||||||
|
assert.Equal(t, "vocabulary:auth", eq.Source)
|
||||||
|
assert.Equal(t, IntentQuestion, eq.Intent)
|
||||||
|
}
|
||||||
|
|
||||||
|
// TestVocabEntry tests VocabEntry struct.
|
||||||
|
func TestVocabEntry(t *testing.T) {
|
||||||
|
ve := VocabEntry{
|
||||||
|
Term: "authentication",
|
||||||
|
Weight: 0.9,
|
||||||
|
Source: "concept",
|
||||||
|
}
|
||||||
|
|
||||||
|
assert.Equal(t, "authentication", ve.Term)
|
||||||
|
assert.Equal(t, 0.9, ve.Weight)
|
||||||
|
assert.Equal(t, "concept", ve.Source)
|
||||||
|
}
|
||||||
|
|
||||||
|
// TestIntentConstants tests intent constant values.
|
||||||
|
func TestIntentConstants(t *testing.T) {
|
||||||
|
assert.Equal(t, QueryIntent("question"), IntentQuestion)
|
||||||
|
assert.Equal(t, QueryIntent("error"), IntentError)
|
||||||
|
assert.Equal(t, QueryIntent("implementation"), IntentImplementation)
|
||||||
|
assert.Equal(t, QueryIntent("architecture"), IntentArchitecture)
|
||||||
|
assert.Equal(t, QueryIntent("general"), IntentGeneral)
|
||||||
|
}
|
||||||
|
|
||||||
|
// TestTruncate tests the truncate helper.
|
||||||
|
func TestTruncate(t *testing.T) {
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
input string
|
||||||
|
maxLen int
|
||||||
|
expected string
|
||||||
|
}{
|
||||||
|
{"short", "hello", 10, "hello"},
|
||||||
|
{"exact", "hello", 5, "hello"},
|
||||||
|
{"long", "hello world", 5, "hello..."},
|
||||||
|
{"empty", "", 10, ""},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tt := range tests {
|
||||||
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
result := truncate(tt.input, tt.maxLen)
|
||||||
|
assert.Equal(t, tt.expected, result)
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// TestSqrt tests the sqrt helper.
|
||||||
|
func TestSqrt(t *testing.T) {
|
||||||
|
tests := []struct {
|
||||||
|
input float64
|
||||||
|
expected float64
|
||||||
|
delta float64
|
||||||
|
}{
|
||||||
|
{4.0, 2.0, 0.001},
|
||||||
|
{9.0, 3.0, 0.001},
|
||||||
|
{16.0, 4.0, 0.001},
|
||||||
|
{2.0, 1.414, 0.01},
|
||||||
|
{0.0, 0.0, 0.001},
|
||||||
|
{-1.0, 0.0, 0.001},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tt := range tests {
|
||||||
|
t.Run("", func(t *testing.T) {
|
||||||
|
result := sqrt(tt.input)
|
||||||
|
assert.InDelta(t, tt.expected, result, tt.delta)
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// TestExpandByIntentError tests error intent expansion.
|
||||||
|
func (s *ExpanderSuite) TestExpandByIntentError() {
|
||||||
|
expansions := s.expander.expandByIntent("fix authentication bug", IntentError)
|
||||||
|
s.NotEmpty(expansions)
|
||||||
|
|
||||||
|
// Should have solution-oriented expansion
|
||||||
|
hasSolution := false
|
||||||
|
for _, exp := range expansions {
|
||||||
|
if exp.Source == "intent:solution" {
|
||||||
|
hasSolution = true
|
||||||
|
break
|
||||||
|
}
|
||||||
|
}
|
||||||
|
s.True(hasSolution)
|
||||||
|
}
|
||||||
|
|
||||||
|
// TestExpandByIntentQuestion tests question intent expansion.
|
||||||
|
func (s *ExpanderSuite) TestExpandByIntentQuestion() {
|
||||||
|
expansions := s.expander.expandByIntent("how do I implement auth", IntentQuestion)
|
||||||
|
s.NotEmpty(expansions)
|
||||||
|
|
||||||
|
// Should have declarative expansion
|
||||||
|
hasDeclarative := false
|
||||||
|
for _, exp := range expansions {
|
||||||
|
if exp.Source == "intent:declarative" {
|
||||||
|
hasDeclarative = true
|
||||||
|
break
|
||||||
|
}
|
||||||
|
}
|
||||||
|
s.True(hasDeclarative)
|
||||||
|
}
|
||||||
|
|
||||||
|
// TestExpandByIntentImplementation tests implementation intent expansion.
|
||||||
|
func (s *ExpanderSuite) TestExpandByIntentImplementation() {
|
||||||
|
expansions := s.expander.expandByIntent("implement user handler", IntentImplementation)
|
||||||
|
s.NotEmpty(expansions)
|
||||||
|
|
||||||
|
// Should have how expansion
|
||||||
|
hasHow := false
|
||||||
|
for _, exp := range expansions {
|
||||||
|
if exp.Source == "intent:how" {
|
||||||
|
hasHow = true
|
||||||
|
break
|
||||||
|
}
|
||||||
|
}
|
||||||
|
s.True(hasHow)
|
||||||
|
}
|
||||||
|
|
||||||
|
// TestExpandByIntentArchitecture tests architecture intent expansion.
|
||||||
|
func (s *ExpanderSuite) TestExpandByIntentArchitecture() {
|
||||||
|
expansions := s.expander.expandByIntent("system architecture design", IntentArchitecture)
|
||||||
|
s.NotEmpty(expansions)
|
||||||
|
|
||||||
|
// Should have design expansion
|
||||||
|
hasDesign := false
|
||||||
|
for _, exp := range expansions {
|
||||||
|
if exp.Source == "intent:design" {
|
||||||
|
hasDesign = true
|
||||||
|
break
|
||||||
|
}
|
||||||
|
}
|
||||||
|
s.True(hasDesign)
|
||||||
|
}
|
||||||
|
|
||||||
|
// TestExpandByIntentGeneral tests general intent returns no expansions.
|
||||||
|
func (s *ExpanderSuite) TestExpandByIntentGeneral() {
|
||||||
|
expansions := s.expander.expandByIntent("database", IntentGeneral)
|
||||||
|
s.Empty(expansions) // General intent doesn't add intent-based expansions
|
||||||
|
}
|
||||||
|
|
||||||
|
// TestEmptyVocabulary tests expansion with empty vocabulary.
|
||||||
|
func (s *ExpanderSuite) TestEmptyVocabulary() {
|
||||||
|
ctx := context.Background()
|
||||||
|
expansions := s.expander.expandByVocabulary(ctx, "test query", 0.5)
|
||||||
|
s.Empty(expansions)
|
||||||
|
}
|
||||||
|
|
||||||
|
// TestIntentPatternsExist tests that all intents have patterns.
|
||||||
|
func (s *ExpanderSuite) TestIntentPatternsExist() {
|
||||||
|
s.NotEmpty(s.expander.intentPatterns[IntentQuestion])
|
||||||
|
s.NotEmpty(s.expander.intentPatterns[IntentError])
|
||||||
|
s.NotEmpty(s.expander.intentPatterns[IntentImplementation])
|
||||||
|
s.NotEmpty(s.expander.intentPatterns[IntentArchitecture])
|
||||||
|
}
|
||||||
+29
-15
@@ -35,20 +35,21 @@ func NewManager(
|
|||||||
|
|
||||||
// SearchParams contains parameters for unified search.
|
// SearchParams contains parameters for unified search.
|
||||||
type SearchParams struct {
|
type SearchParams struct {
|
||||||
Query string
|
Query string
|
||||||
Type string // "observations", "sessions", "prompts", or empty for all
|
Type string // "observations", "sessions", "prompts", or empty for all
|
||||||
Project string
|
Project string
|
||||||
ObsType string // Observation type filter
|
ObsType string // Observation type filter
|
||||||
Concepts string
|
Concepts string
|
||||||
Files string
|
Files string
|
||||||
DateStart int64
|
DateStart int64
|
||||||
DateEnd int64
|
DateEnd int64
|
||||||
OrderBy string // "relevance", "date_desc", "date_asc"
|
OrderBy string // "relevance", "date_desc", "date_asc"
|
||||||
Limit int
|
Limit int
|
||||||
Offset int
|
Offset int
|
||||||
Format string // "index" or "full"
|
Format string // "index" or "full"
|
||||||
Scope string // "project", "global", or empty for project+global
|
Scope string // "project", "global", or empty for project+global
|
||||||
IncludeGlobal bool // If true, include global observations along with project-scoped
|
IncludeGlobal bool // If true, include global observations along with project-scoped
|
||||||
|
ExcludeSuperseded bool // If true, exclude observations that have been superseded
|
||||||
}
|
}
|
||||||
|
|
||||||
// SearchResult represents a unified search result.
|
// SearchResult represents a unified search result.
|
||||||
@@ -126,6 +127,10 @@ func (m *Manager) vectorSearch(ctx context.Context, params SearchParams) (*Unifi
|
|||||||
obs, err := m.observationStore.GetObservationsByIDs(ctx, obsIDs, params.OrderBy, 0)
|
obs, err := m.observationStore.GetObservationsByIDs(ctx, obsIDs, params.OrderBy, 0)
|
||||||
if err == nil {
|
if err == nil {
|
||||||
for _, o := range obs {
|
for _, o := range obs {
|
||||||
|
// Skip superseded observations when requested
|
||||||
|
if params.ExcludeSuperseded && o.IsSuperseded {
|
||||||
|
continue
|
||||||
|
}
|
||||||
results = append(results, m.observationToResult(o, params.Format))
|
results = append(results, m.observationToResult(o, params.Format))
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@@ -167,7 +172,16 @@ func (m *Manager) filterSearch(ctx context.Context, params SearchParams) (*Unifi
|
|||||||
|
|
||||||
// Search observations
|
// Search observations
|
||||||
if params.Type == "" || params.Type == "observations" {
|
if params.Type == "" || params.Type == "observations" {
|
||||||
obs, err := m.observationStore.GetRecentObservations(ctx, params.Project, params.Limit)
|
var obs []*models.Observation
|
||||||
|
var err error
|
||||||
|
|
||||||
|
// Use active observations (excluding superseded) when requested
|
||||||
|
if params.ExcludeSuperseded {
|
||||||
|
obs, err = m.observationStore.GetActiveObservations(ctx, params.Project, params.Limit)
|
||||||
|
} else {
|
||||||
|
obs, err = m.observationStore.GetRecentObservations(ctx, params.Project, params.Limit)
|
||||||
|
}
|
||||||
|
|
||||||
if err == nil {
|
if err == nil {
|
||||||
for _, o := range obs {
|
for _, o := range obs {
|
||||||
results = append(results, m.observationToResult(o, params.Format))
|
results = append(results, m.observationToResult(o, params.Format))
|
||||||
|
|||||||
@@ -60,12 +60,15 @@ func (c *Client) AddDocuments(ctx context.Context, docs []Document) error {
|
|||||||
return fmt.Errorf("generate embeddings: %w", err)
|
return fmt.Errorf("generate embeddings: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
// Insert into vectors table
|
// Insert into vectors table with model version tracking
|
||||||
const insertQuery = `
|
const insertQuery = `
|
||||||
INSERT OR REPLACE INTO vectors (doc_id, embedding, sqlite_id, doc_type, field_type, project, scope)
|
INSERT OR REPLACE INTO vectors (doc_id, embedding, sqlite_id, doc_type, field_type, project, scope, model_version)
|
||||||
VALUES (?, ?, ?, ?, ?, ?, ?)
|
VALUES (?, ?, ?, ?, ?, ?, ?, ?)
|
||||||
`
|
`
|
||||||
|
|
||||||
|
// Get current model version for tracking
|
||||||
|
modelVersion := c.embedSvc.Version()
|
||||||
|
|
||||||
tx, err := c.db.BeginTx(ctx, nil)
|
tx, err := c.db.BeginTx(ctx, nil)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return fmt.Errorf("begin transaction: %w", err)
|
return fmt.Errorf("begin transaction: %w", err)
|
||||||
@@ -104,6 +107,7 @@ func (c *Client) AddDocuments(ctx context.Context, docs []Document) error {
|
|||||||
fieldType,
|
fieldType,
|
||||||
project,
|
project,
|
||||||
scope,
|
scope,
|
||||||
|
modelVersion,
|
||||||
)
|
)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return fmt.Errorf("insert document %s: %w", doc.ID, err)
|
return fmt.Errorf("insert document %s: %w", doc.ID, err)
|
||||||
@@ -114,7 +118,7 @@ func (c *Client) AddDocuments(ctx context.Context, docs []Document) error {
|
|||||||
return fmt.Errorf("commit transaction: %w", err)
|
return fmt.Errorf("commit transaction: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
log.Debug().Int("count", len(docs)).Msg("Added documents to sqlite-vec")
|
log.Debug().Int("count", len(docs)).Str("model", modelVersion).Msg("Added documents to sqlite-vec")
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -212,6 +216,7 @@ func (c *Client) Query(ctx context.Context, query string, limit int, where map[s
|
|||||||
return nil, fmt.Errorf("scan row: %w", err)
|
return nil, fmt.Errorf("scan row: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
r.Similarity = DistanceToSimilarity(r.Distance)
|
||||||
r.Metadata = map[string]any{
|
r.Metadata = map[string]any{
|
||||||
"sqlite_id": float64(sqliteID), // Keep as float64 for compatibility
|
"sqlite_id": float64(sqliteID), // Keep as float64 for compatibility
|
||||||
"doc_type": docType.String,
|
"doc_type": docType.String,
|
||||||
@@ -252,3 +257,148 @@ func truncateString(s string, maxLen int) string {
|
|||||||
}
|
}
|
||||||
return s[:maxLen] + "..."
|
return s[:maxLen] + "..."
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Count returns the total number of vectors in the store.
|
||||||
|
func (c *Client) Count(ctx context.Context) (int64, error) {
|
||||||
|
c.mu.Lock()
|
||||||
|
defer c.mu.Unlock()
|
||||||
|
|
||||||
|
var count int64
|
||||||
|
err := c.db.QueryRowContext(ctx, "SELECT COUNT(*) FROM vectors").Scan(&count)
|
||||||
|
if err != nil {
|
||||||
|
return 0, fmt.Errorf("count vectors: %w", err)
|
||||||
|
}
|
||||||
|
return count, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// ModelVersion returns the current embedding model version.
|
||||||
|
func (c *Client) ModelVersion() string {
|
||||||
|
return c.embedSvc.Version()
|
||||||
|
}
|
||||||
|
|
||||||
|
// NeedsRebuild checks if vectors need to be rebuilt due to model version change.
|
||||||
|
// Returns true if:
|
||||||
|
// - The vectors table is empty
|
||||||
|
// - Any vectors have a different model_version than the current model
|
||||||
|
func (c *Client) NeedsRebuild(ctx context.Context) (bool, string) {
|
||||||
|
c.mu.Lock()
|
||||||
|
defer c.mu.Unlock()
|
||||||
|
|
||||||
|
currentModel := c.embedSvc.Version()
|
||||||
|
|
||||||
|
// Check total count
|
||||||
|
var totalCount int64
|
||||||
|
err := c.db.QueryRowContext(ctx, "SELECT COUNT(*) FROM vectors").Scan(&totalCount)
|
||||||
|
if err != nil {
|
||||||
|
log.Warn().Err(err).Msg("Failed to count vectors for rebuild check")
|
||||||
|
return false, ""
|
||||||
|
}
|
||||||
|
|
||||||
|
if totalCount == 0 {
|
||||||
|
return true, "empty"
|
||||||
|
}
|
||||||
|
|
||||||
|
// Check for vectors with different model version
|
||||||
|
var staleCount int64
|
||||||
|
err = c.db.QueryRowContext(ctx,
|
||||||
|
"SELECT COUNT(*) FROM vectors WHERE model_version != ? OR model_version IS NULL",
|
||||||
|
currentModel,
|
||||||
|
).Scan(&staleCount)
|
||||||
|
if err != nil {
|
||||||
|
log.Warn().Err(err).Msg("Failed to count stale vectors")
|
||||||
|
return false, ""
|
||||||
|
}
|
||||||
|
|
||||||
|
if staleCount > 0 {
|
||||||
|
return true, fmt.Sprintf("model_mismatch:%d", staleCount)
|
||||||
|
}
|
||||||
|
|
||||||
|
return false, ""
|
||||||
|
}
|
||||||
|
|
||||||
|
// StaleVectorInfo contains information about a vector that needs rebuilding.
|
||||||
|
type StaleVectorInfo struct {
|
||||||
|
DocID string
|
||||||
|
SQLiteID int64
|
||||||
|
DocType string
|
||||||
|
FieldType string
|
||||||
|
Project string
|
||||||
|
Scope string
|
||||||
|
}
|
||||||
|
|
||||||
|
// GetStaleVectors returns doc_ids of vectors with mismatched or null model versions.
|
||||||
|
// This enables granular rebuild - only re-embedding documents that need updating.
|
||||||
|
func (c *Client) GetStaleVectors(ctx context.Context) ([]StaleVectorInfo, error) {
|
||||||
|
c.mu.Lock()
|
||||||
|
defer c.mu.Unlock()
|
||||||
|
|
||||||
|
currentModel := c.embedSvc.Version()
|
||||||
|
|
||||||
|
query := `
|
||||||
|
SELECT doc_id, sqlite_id, doc_type, field_type, project, scope
|
||||||
|
FROM vectors
|
||||||
|
WHERE model_version != ? OR model_version IS NULL
|
||||||
|
`
|
||||||
|
|
||||||
|
rows, err := c.db.QueryContext(ctx, query, currentModel)
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("query stale vectors: %w", err)
|
||||||
|
}
|
||||||
|
defer rows.Close()
|
||||||
|
|
||||||
|
var results []StaleVectorInfo
|
||||||
|
for rows.Next() {
|
||||||
|
var info StaleVectorInfo
|
||||||
|
var sqliteID sql.NullInt64
|
||||||
|
var docType, fieldType, project, scope sql.NullString
|
||||||
|
|
||||||
|
if err := rows.Scan(&info.DocID, &sqliteID, &docType, &fieldType, &project, &scope); err != nil {
|
||||||
|
return nil, fmt.Errorf("scan row: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
info.SQLiteID = sqliteID.Int64
|
||||||
|
info.DocType = docType.String
|
||||||
|
info.FieldType = fieldType.String
|
||||||
|
info.Project = project.String
|
||||||
|
info.Scope = scope.String
|
||||||
|
|
||||||
|
results = append(results, info)
|
||||||
|
}
|
||||||
|
|
||||||
|
if err = rows.Err(); err != nil {
|
||||||
|
return nil, fmt.Errorf("iterate rows: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
return results, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// DeleteVectorsByDocIDs removes vectors by their doc_ids.
|
||||||
|
// Used for granular rebuild - delete stale vectors before re-adding.
|
||||||
|
func (c *Client) DeleteVectorsByDocIDs(ctx context.Context, docIDs []string) error {
|
||||||
|
if len(docIDs) == 0 {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
c.mu.Lock()
|
||||||
|
defer c.mu.Unlock()
|
||||||
|
|
||||||
|
// Build placeholder string
|
||||||
|
placeholders := make([]string, len(docIDs))
|
||||||
|
args := make([]interface{}, len(docIDs))
|
||||||
|
for i, id := range docIDs {
|
||||||
|
placeholders[i] = "?"
|
||||||
|
args[i] = id
|
||||||
|
}
|
||||||
|
|
||||||
|
// #nosec G201 -- Placeholders are "?" strings, actual values are parameterized via args
|
||||||
|
query := fmt.Sprintf("DELETE FROM vectors WHERE doc_id IN (%s)",
|
||||||
|
strings.Join(placeholders, ","))
|
||||||
|
|
||||||
|
_, err := c.db.ExecContext(ctx, query, args...)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("delete vectors by doc_ids: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
log.Debug().Int("count", len(docIDs)).Msg("Deleted stale vectors by doc_id")
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|||||||
@@ -38,7 +38,8 @@ func testDB(t *testing.T) (*sql.DB, func()) {
|
|||||||
doc_type TEXT,
|
doc_type TEXT,
|
||||||
field_type TEXT,
|
field_type TEXT,
|
||||||
project TEXT,
|
project TEXT,
|
||||||
scope TEXT
|
scope TEXT,
|
||||||
|
model_version TEXT
|
||||||
)
|
)
|
||||||
`)
|
`)
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
|
|||||||
@@ -19,9 +19,32 @@ type Document struct {
|
|||||||
|
|
||||||
// QueryResult represents a search result from vector search.
|
// QueryResult represents a search result from vector search.
|
||||||
type QueryResult struct {
|
type QueryResult struct {
|
||||||
ID string
|
ID string
|
||||||
Distance float64
|
Distance float64
|
||||||
Metadata map[string]any
|
Similarity float64 // 1.0 = identical, 0.0 = opposite (derived from distance)
|
||||||
|
Metadata map[string]any
|
||||||
|
}
|
||||||
|
|
||||||
|
// DistanceToSimilarity converts sqlite-vec cosine distance to similarity score.
|
||||||
|
// Cosine distance: 0 = identical, 2 = opposite
|
||||||
|
// Similarity: 1.0 = identical, 0.0 = opposite
|
||||||
|
func DistanceToSimilarity(distance float64) float64 {
|
||||||
|
return 1.0 - (distance / 2.0)
|
||||||
|
}
|
||||||
|
|
||||||
|
// FilterByThreshold filters results to only include those above the similarity threshold.
|
||||||
|
// If maxResults > 0, also caps the number of results.
|
||||||
|
func FilterByThreshold(results []QueryResult, threshold float64, maxResults int) []QueryResult {
|
||||||
|
var filtered []QueryResult
|
||||||
|
for _, r := range results {
|
||||||
|
if r.Similarity >= threshold {
|
||||||
|
filtered = append(filtered, r)
|
||||||
|
if maxResults > 0 && len(filtered) >= maxResults {
|
||||||
|
break
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return filtered
|
||||||
}
|
}
|
||||||
|
|
||||||
// ExtractedIDs contains SQLite IDs extracted from query results, grouped by document type.
|
// ExtractedIDs contains SQLite IDs extracted from query results, grouped by document type.
|
||||||
|
|||||||
@@ -240,3 +240,101 @@ func (s *Sync) DeleteUserPrompts(ctx context.Context, promptIDs []int64) error {
|
|||||||
|
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// SyncPattern syncs a single pattern to the vector store.
|
||||||
|
func (s *Sync) SyncPattern(ctx context.Context, pattern *models.Pattern) error {
|
||||||
|
docs := s.formatPatternDocs(pattern)
|
||||||
|
if len(docs) == 0 {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := s.client.AddDocuments(ctx, docs); err != nil {
|
||||||
|
return fmt.Errorf("add pattern docs: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
log.Debug().
|
||||||
|
Int64("patternId", pattern.ID).
|
||||||
|
Int("docCount", len(docs)).
|
||||||
|
Msg("Synced pattern to sqlite-vec")
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// formatPatternDocs formats a pattern into vector documents.
|
||||||
|
func (s *Sync) formatPatternDocs(pattern *models.Pattern) []Document {
|
||||||
|
docs := make([]Document, 0, 3)
|
||||||
|
|
||||||
|
baseMetadata := map[string]any{
|
||||||
|
"sqlite_id": pattern.ID,
|
||||||
|
"doc_type": "pattern",
|
||||||
|
"pattern_type": string(pattern.Type),
|
||||||
|
"status": string(pattern.Status),
|
||||||
|
"scope": "global", // Patterns are always global
|
||||||
|
"frequency": pattern.Frequency,
|
||||||
|
"confidence": pattern.Confidence,
|
||||||
|
"created_at_epoch": pattern.CreatedAtEpoch,
|
||||||
|
}
|
||||||
|
|
||||||
|
if len(pattern.Signature) > 0 {
|
||||||
|
baseMetadata["signature"] = joinStrings(pattern.Signature, ",")
|
||||||
|
}
|
||||||
|
if len(pattern.Projects) > 0 {
|
||||||
|
baseMetadata["projects"] = joinStrings(pattern.Projects, ",")
|
||||||
|
}
|
||||||
|
|
||||||
|
// Pattern name as document
|
||||||
|
if pattern.Name != "" {
|
||||||
|
docs = append(docs, Document{
|
||||||
|
ID: fmt.Sprintf("pattern_%d_name", pattern.ID),
|
||||||
|
Content: pattern.Name,
|
||||||
|
Metadata: copyMetadata(baseMetadata, "field_type", "name"),
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
// Pattern description as document
|
||||||
|
if pattern.Description.Valid && pattern.Description.String != "" {
|
||||||
|
docs = append(docs, Document{
|
||||||
|
ID: fmt.Sprintf("pattern_%d_description", pattern.ID),
|
||||||
|
Content: pattern.Description.String,
|
||||||
|
Metadata: copyMetadata(baseMetadata, "field_type", "description"),
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
// Pattern recommendation as document
|
||||||
|
if pattern.Recommendation.Valid && pattern.Recommendation.String != "" {
|
||||||
|
docs = append(docs, Document{
|
||||||
|
ID: fmt.Sprintf("pattern_%d_recommendation", pattern.ID),
|
||||||
|
Content: pattern.Recommendation.String,
|
||||||
|
Metadata: copyMetadata(baseMetadata, "field_type", "recommendation"),
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
return docs
|
||||||
|
}
|
||||||
|
|
||||||
|
// DeletePatterns removes pattern documents from the vector store.
|
||||||
|
func (s *Sync) DeletePatterns(ctx context.Context, patternIDs []int64) error {
|
||||||
|
if len(patternIDs) == 0 {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// Generate all possible document IDs for these patterns
|
||||||
|
// Pattern: pattern_{id}_name, pattern_{id}_description, pattern_{id}_recommendation
|
||||||
|
ids := make([]string, 0, len(patternIDs)*3)
|
||||||
|
|
||||||
|
for _, patternID := range patternIDs {
|
||||||
|
ids = append(ids, fmt.Sprintf("pattern_%d_name", patternID))
|
||||||
|
ids = append(ids, fmt.Sprintf("pattern_%d_description", patternID))
|
||||||
|
ids = append(ids, fmt.Sprintf("pattern_%d_recommendation", patternID))
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := s.client.DeleteDocuments(ctx, ids); err != nil {
|
||||||
|
return fmt.Errorf("delete pattern docs: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
log.Debug().
|
||||||
|
Int("patternCount", len(patternIDs)).
|
||||||
|
Msg("Deleted patterns from sqlite-vec")
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|||||||
+238
-4
@@ -4,13 +4,18 @@ package worker
|
|||||||
import (
|
import (
|
||||||
"context"
|
"context"
|
||||||
"encoding/json"
|
"encoding/json"
|
||||||
|
"fmt"
|
||||||
"net/http"
|
"net/http"
|
||||||
|
"sort"
|
||||||
"strconv"
|
"strconv"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
"github.com/go-chi/chi/v5"
|
"github.com/go-chi/chi/v5"
|
||||||
"github.com/lukaszraczylo/claude-mnemonic/internal/db/sqlite"
|
"github.com/lukaszraczylo/claude-mnemonic/internal/db/sqlite"
|
||||||
|
"github.com/lukaszraczylo/claude-mnemonic/internal/embedding"
|
||||||
"github.com/lukaszraczylo/claude-mnemonic/internal/privacy"
|
"github.com/lukaszraczylo/claude-mnemonic/internal/privacy"
|
||||||
|
"github.com/lukaszraczylo/claude-mnemonic/internal/reranking"
|
||||||
|
"github.com/lukaszraczylo/claude-mnemonic/internal/search/expansion"
|
||||||
"github.com/lukaszraczylo/claude-mnemonic/internal/vector/sqlitevec"
|
"github.com/lukaszraczylo/claude-mnemonic/internal/vector/sqlitevec"
|
||||||
"github.com/lukaszraczylo/claude-mnemonic/internal/worker/sdk"
|
"github.com/lukaszraczylo/claude-mnemonic/internal/worker/sdk"
|
||||||
"github.com/lukaszraczylo/claude-mnemonic/internal/worker/session"
|
"github.com/lukaszraczylo/claude-mnemonic/internal/worker/session"
|
||||||
@@ -641,6 +646,18 @@ func (s *Service) handleGetTypes(w http.ResponseWriter, r *http.Request) {
|
|||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// handleGetModels returns available embedding models.
|
||||||
|
func (s *Service) handleGetModels(w http.ResponseWriter, _ *http.Request) {
|
||||||
|
models := embedding.ListModels()
|
||||||
|
defaultModel := embedding.GetDefaultModel()
|
||||||
|
|
||||||
|
writeJSON(w, map[string]interface{}{
|
||||||
|
"models": models,
|
||||||
|
"default": defaultModel,
|
||||||
|
"current": s.embedSvc.Version(),
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
// handleGetStats returns worker statistics.
|
// handleGetStats returns worker statistics.
|
||||||
func (s *Service) handleGetStats(w http.ResponseWriter, r *http.Request) {
|
func (s *Service) handleGetStats(w http.ResponseWriter, r *http.Request) {
|
||||||
project := r.URL.Query().Get("project")
|
project := r.URL.Query().Get("project")
|
||||||
@@ -658,6 +675,22 @@ func (s *Service) handleGetStats(w http.ResponseWriter, r *http.Request) {
|
|||||||
"ready": s.ready.Load(),
|
"ready": s.ready.Load(),
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Add embedding model info
|
||||||
|
if s.embedSvc != nil {
|
||||||
|
response["embeddingModel"] = map[string]interface{}{
|
||||||
|
"name": s.embedSvc.Name(),
|
||||||
|
"version": s.embedSvc.Version(),
|
||||||
|
"dimensions": s.embedSvc.Dimensions(),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Add vector count
|
||||||
|
if s.vectorClient != nil {
|
||||||
|
if count, err := s.vectorClient.Count(r.Context()); err == nil {
|
||||||
|
response["vectorCount"] = count
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
// Include project-specific observation count if project is specified
|
// Include project-specific observation count if project is specified
|
||||||
if project != "" {
|
if project != "" {
|
||||||
count, err := s.observationStore.GetObservationCount(r.Context(), project)
|
count, err := s.observationStore.GetObservationCount(r.Context(), project)
|
||||||
@@ -715,15 +748,69 @@ func (s *Service) handleSearchByPrompt(w http.ResponseWriter, r *http.Request) {
|
|||||||
var observations []*models.Observation
|
var observations []*models.Observation
|
||||||
var err error
|
var err error
|
||||||
var usedVector bool
|
var usedVector bool
|
||||||
|
similarityScores := make(map[int64]float64) // Track similarity per observation
|
||||||
|
|
||||||
|
// Get threshold settings from config
|
||||||
|
threshold := s.config.ContextRelevanceThreshold
|
||||||
|
maxResults := s.config.ContextMaxPromptResults
|
||||||
|
|
||||||
|
// Generate expanded queries if query expander is available
|
||||||
|
var expandedQueries []expansion.ExpandedQuery
|
||||||
|
var detectedIntent string
|
||||||
|
if s.queryExpander != nil {
|
||||||
|
cfg := expansion.DefaultConfig()
|
||||||
|
cfg.EnableVocabularyExpansion = false // Vocabulary expansion is optional
|
||||||
|
expandedQueries = s.queryExpander.Expand(r.Context(), query, cfg)
|
||||||
|
if len(expandedQueries) > 0 {
|
||||||
|
detectedIntent = string(expandedQueries[0].Intent)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if len(expandedQueries) == 0 {
|
||||||
|
// Fallback to just the original query
|
||||||
|
expandedQueries = []expansion.ExpandedQuery{
|
||||||
|
{Query: query, Weight: 1.0, Source: "original"},
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
// Try vector search first if available
|
// Try vector search first if available
|
||||||
if s.vectorClient != nil && s.vectorClient.IsConnected() {
|
if s.vectorClient != nil && s.vectorClient.IsConnected() {
|
||||||
where := sqlitevec.BuildWhereFilter(sqlitevec.DocTypeObservation, "")
|
where := sqlitevec.BuildWhereFilter(sqlitevec.DocTypeObservation, "")
|
||||||
|
|
||||||
vectorResults, vecErr := s.vectorClient.Query(r.Context(), query, limit*2, where)
|
// Search with each expanded query and merge results
|
||||||
if vecErr == nil && len(vectorResults) > 0 {
|
allVectorResults := make([]sqlitevec.QueryResult, 0)
|
||||||
|
queryWeights := make(map[string]float64) // Track weights for score merging
|
||||||
|
|
||||||
|
for _, eq := range expandedQueries {
|
||||||
|
vectorResults, vecErr := s.vectorClient.Query(r.Context(), eq.Query, limit*2, where)
|
||||||
|
if vecErr == nil && len(vectorResults) > 0 {
|
||||||
|
// Apply weight to similarity scores before merging
|
||||||
|
for i := range vectorResults {
|
||||||
|
vectorResults[i].Similarity *= eq.Weight
|
||||||
|
}
|
||||||
|
allVectorResults = append(allVectorResults, vectorResults...)
|
||||||
|
queryWeights[eq.Query] = eq.Weight
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if len(allVectorResults) > 0 {
|
||||||
|
// Filter by relevance threshold before extracting IDs
|
||||||
|
// Use a slightly lower threshold for expanded queries
|
||||||
|
effectiveThreshold := threshold * 0.9 // Allow slightly lower scores for expanded queries
|
||||||
|
filteredResults := sqlitevec.FilterByThreshold(allVectorResults, effectiveThreshold, 0)
|
||||||
|
|
||||||
|
// Build similarity map for filtered results (keeping highest weighted score per observation)
|
||||||
|
for _, vr := range filteredResults {
|
||||||
|
if sqliteID, ok := vr.Metadata["sqlite_id"].(float64); ok {
|
||||||
|
id := int64(sqliteID)
|
||||||
|
// Keep the highest score for each observation
|
||||||
|
if existing, exists := similarityScores[id]; !exists || vr.Similarity > existing {
|
||||||
|
similarityScores[id] = vr.Similarity
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
// Extract observation IDs with project/scope filtering using shared helper
|
// Extract observation IDs with project/scope filtering using shared helper
|
||||||
obsIDs := sqlitevec.ExtractObservationIDs(vectorResults, project)
|
obsIDs := sqlitevec.ExtractObservationIDs(filteredResults, project)
|
||||||
|
|
||||||
if len(obsIDs) > 0 {
|
if len(obsIDs) > 0 {
|
||||||
// Fetch full observations from SQLite
|
// Fetch full observations from SQLite
|
||||||
@@ -773,23 +860,132 @@ func (s *Service) handleSearchByPrompt(w http.ResponseWriter, r *http.Request) {
|
|||||||
freshObservations = append(freshObservations, obs)
|
freshObservations = append(freshObservations, obs)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Apply cross-encoder reranking if available
|
||||||
|
var reranked bool
|
||||||
|
if s.reranker != nil && len(freshObservations) > 0 && usedVector {
|
||||||
|
// Build candidates from observations with their bi-encoder scores
|
||||||
|
candidates := make([]reranking.Candidate, len(freshObservations))
|
||||||
|
for i, obs := range freshObservations {
|
||||||
|
content := obs.Title.String
|
||||||
|
if obs.Narrative.Valid && obs.Narrative.String != "" {
|
||||||
|
content = content + " " + obs.Narrative.String
|
||||||
|
}
|
||||||
|
candidates[i] = reranking.Candidate{
|
||||||
|
ID: fmt.Sprintf("%d", obs.ID),
|
||||||
|
Content: content,
|
||||||
|
Score: similarityScores[obs.ID],
|
||||||
|
Metadata: map[string]any{"obs_idx": i},
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Rerank using cross-encoder - use pure mode or combined scores
|
||||||
|
var rerankResults []reranking.RerankResult
|
||||||
|
var rerankErr error
|
||||||
|
if s.config.RerankingPureMode {
|
||||||
|
rerankResults, rerankErr = s.reranker.RerankByScore(query, candidates, s.config.RerankingResults)
|
||||||
|
} else {
|
||||||
|
rerankResults, rerankErr = s.reranker.Rerank(query, candidates, s.config.RerankingResults)
|
||||||
|
}
|
||||||
|
if rerankErr != nil {
|
||||||
|
log.Warn().Err(rerankErr).Msg("Cross-encoder reranking failed, using original order")
|
||||||
|
} else if len(rerankResults) > 0 {
|
||||||
|
// Update similarity scores with reranked scores
|
||||||
|
for _, rr := range rerankResults {
|
||||||
|
if id, err := strconv.ParseInt(rr.ID, 10, 64); err == nil {
|
||||||
|
similarityScores[id] = rr.CombinedScore
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Reorder observations based on rerank results
|
||||||
|
reorderedObs := make([]*models.Observation, 0, len(rerankResults))
|
||||||
|
obsMap := make(map[int64]*models.Observation)
|
||||||
|
for _, obs := range freshObservations {
|
||||||
|
obsMap[obs.ID] = obs
|
||||||
|
}
|
||||||
|
for _, rr := range rerankResults {
|
||||||
|
if id, err := strconv.ParseInt(rr.ID, 10, 64); err == nil {
|
||||||
|
if obs, ok := obsMap[id]; ok {
|
||||||
|
reorderedObs = append(reorderedObs, obs)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
freshObservations = reorderedObs
|
||||||
|
reranked = true
|
||||||
|
|
||||||
|
log.Debug().
|
||||||
|
Int("candidates", len(candidates)).
|
||||||
|
Int("returned", len(rerankResults)).
|
||||||
|
Msg("Cross-encoder reranking complete")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
// Cluster similar observations to remove duplicates
|
// Cluster similar observations to remove duplicates
|
||||||
clusteredObservations := clusterObservations(freshObservations, 0.4)
|
clusteredObservations := clusterObservations(freshObservations, 0.4)
|
||||||
|
|
||||||
|
// Sort by similarity score (highest first) if we have scores and didn't rerank
|
||||||
|
if len(similarityScores) > 0 && len(clusteredObservations) > 0 && !reranked {
|
||||||
|
sort.Slice(clusteredObservations, func(i, j int) bool {
|
||||||
|
scoreI := similarityScores[clusteredObservations[i].ID]
|
||||||
|
scoreJ := similarityScores[clusteredObservations[j].ID]
|
||||||
|
return scoreI > scoreJ
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
// Apply max results cap if configured
|
||||||
|
if maxResults > 0 && len(clusteredObservations) > maxResults {
|
||||||
|
clusteredObservations = clusteredObservations[:maxResults]
|
||||||
|
}
|
||||||
|
|
||||||
// Record retrieval stats (no verification done, so verified=0, deleted=0)
|
// Record retrieval stats (no verification done, so verified=0, deleted=0)
|
||||||
s.recordRetrievalStats(project, int64(len(clusteredObservations)), 0, 0, true)
|
s.recordRetrievalStats(project, int64(len(clusteredObservations)), 0, 0, true)
|
||||||
|
|
||||||
|
// Increment retrieval counts for scoring (async, non-blocking)
|
||||||
|
if len(clusteredObservations) > 0 {
|
||||||
|
ids := make([]int64, len(clusteredObservations))
|
||||||
|
for i, obs := range clusteredObservations {
|
||||||
|
ids[i] = obs.ID
|
||||||
|
}
|
||||||
|
s.incrementRetrievalCounts(ids)
|
||||||
|
}
|
||||||
|
|
||||||
log.Info().
|
log.Info().
|
||||||
Str("project", project).
|
Str("project", project).
|
||||||
Str("query", query).
|
Str("query", query).
|
||||||
|
Str("intent", detectedIntent).
|
||||||
|
Int("expansions", len(expandedQueries)).
|
||||||
Int("found", len(clusteredObservations)).
|
Int("found", len(clusteredObservations)).
|
||||||
Int("stale_excluded", staleCount).
|
Int("stale_excluded", staleCount).
|
||||||
|
Float64("threshold", threshold).
|
||||||
Msg("Prompt-based observation search")
|
Msg("Prompt-based observation search")
|
||||||
|
|
||||||
|
// Build response with similarity scores
|
||||||
|
obsWithScores := make([]map[string]interface{}, len(clusteredObservations))
|
||||||
|
for i, obs := range clusteredObservations {
|
||||||
|
obsMap := obs.ToMap()
|
||||||
|
if score, ok := similarityScores[obs.ID]; ok {
|
||||||
|
obsMap["similarity"] = score
|
||||||
|
}
|
||||||
|
obsWithScores[i] = obsMap
|
||||||
|
}
|
||||||
|
|
||||||
|
// Build expansion info for response
|
||||||
|
expansionInfo := make([]map[string]interface{}, len(expandedQueries))
|
||||||
|
for i, eq := range expandedQueries {
|
||||||
|
expansionInfo[i] = map[string]interface{}{
|
||||||
|
"query": eq.Query,
|
||||||
|
"weight": eq.Weight,
|
||||||
|
"source": eq.Source,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
writeJSON(w, map[string]interface{}{
|
writeJSON(w, map[string]interface{}{
|
||||||
"project": project,
|
"project": project,
|
||||||
"query": query,
|
"query": query,
|
||||||
"observations": clusteredObservations,
|
"intent": detectedIntent,
|
||||||
|
"expansions": expansionInfo,
|
||||||
|
"observations": obsWithScores,
|
||||||
|
"threshold": threshold,
|
||||||
|
"max_results": maxResults,
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -857,6 +1053,15 @@ func (s *Service) handleContextInject(w http.ResponseWriter, r *http.Request) {
|
|||||||
// Record retrieval stats (no verification done)
|
// Record retrieval stats (no verification done)
|
||||||
s.recordRetrievalStats(project, int64(len(clusteredObservations)), 0, 0, false)
|
s.recordRetrievalStats(project, int64(len(clusteredObservations)), 0, 0, false)
|
||||||
|
|
||||||
|
// Increment retrieval counts for scoring (async, non-blocking)
|
||||||
|
if len(clusteredObservations) > 0 {
|
||||||
|
ids := make([]int64, len(clusteredObservations))
|
||||||
|
for i, obs := range clusteredObservations {
|
||||||
|
ids[i] = obs.ID
|
||||||
|
}
|
||||||
|
s.incrementRetrievalCounts(ids)
|
||||||
|
}
|
||||||
|
|
||||||
log.Info().
|
log.Info().
|
||||||
Str("project", project).
|
Str("project", project).
|
||||||
Int("total", len(observations)).
|
Int("total", len(observations)).
|
||||||
@@ -1015,6 +1220,35 @@ func (s *Service) handleSelfCheck(w http.ResponseWriter, r *http.Request) {
|
|||||||
}
|
}
|
||||||
components = append(components, sseStatus)
|
components = append(components, sseStatus)
|
||||||
|
|
||||||
|
// Check Cross-Encoder Reranker
|
||||||
|
rerankerStatus := ComponentHealth{Name: "Cross-Encoder Reranker", Status: "healthy"}
|
||||||
|
if !s.config.RerankingEnabled {
|
||||||
|
rerankerStatus.Status = "degraded"
|
||||||
|
rerankerStatus.Message = "Disabled in config"
|
||||||
|
if overall == "healthy" {
|
||||||
|
overall = "degraded"
|
||||||
|
}
|
||||||
|
} else if s.reranker == nil {
|
||||||
|
rerankerStatus.Status = "degraded"
|
||||||
|
rerankerStatus.Message = "Not initialized"
|
||||||
|
if overall == "healthy" {
|
||||||
|
overall = "degraded"
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
// Verify reranker is functional using Score
|
||||||
|
_, normalizedScore, err := s.reranker.Score("test query", "test document")
|
||||||
|
if err != nil {
|
||||||
|
rerankerStatus.Status = "unhealthy"
|
||||||
|
rerankerStatus.Message = fmt.Sprintf("Score check failed: %v", err)
|
||||||
|
if overall == "healthy" {
|
||||||
|
overall = "degraded"
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
rerankerStatus.Message = fmt.Sprintf("Score check passed (%.4f)", normalizedScore)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
components = append(components, rerankerStatus)
|
||||||
|
|
||||||
// Calculate uptime
|
// Calculate uptime
|
||||||
uptime := time.Since(s.startTime).Round(time.Second).String()
|
uptime := time.Since(s.startTime).Round(time.Second).String()
|
||||||
|
|
||||||
|
|||||||
@@ -0,0 +1,292 @@
|
|||||||
|
// Package worker provides the main worker service for claude-mnemonic.
|
||||||
|
package worker
|
||||||
|
|
||||||
|
import (
|
||||||
|
"encoding/json"
|
||||||
|
"net/http"
|
||||||
|
"strconv"
|
||||||
|
|
||||||
|
"github.com/go-chi/chi/v5"
|
||||||
|
"github.com/lukaszraczylo/claude-mnemonic/pkg/models"
|
||||||
|
)
|
||||||
|
|
||||||
|
// DefaultPatternsLimit is the default number of patterns to return.
|
||||||
|
const DefaultPatternsLimit = 100
|
||||||
|
|
||||||
|
// handleGetPatterns returns all active patterns, optionally filtered by type or project.
|
||||||
|
func (s *Service) handleGetPatterns(w http.ResponseWriter, r *http.Request) {
|
||||||
|
s.initMu.RLock()
|
||||||
|
store := s.patternStore
|
||||||
|
s.initMu.RUnlock()
|
||||||
|
|
||||||
|
if store == nil {
|
||||||
|
http.Error(w, "pattern store not initialized", http.StatusServiceUnavailable)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// Parse query parameters
|
||||||
|
limit := DefaultPatternsLimit
|
||||||
|
if l := r.URL.Query().Get("limit"); l != "" {
|
||||||
|
if parsed, err := strconv.Atoi(l); err == nil && parsed > 0 {
|
||||||
|
limit = parsed
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
patternType := r.URL.Query().Get("type")
|
||||||
|
project := r.URL.Query().Get("project")
|
||||||
|
|
||||||
|
var patterns []*models.Pattern
|
||||||
|
var err error
|
||||||
|
|
||||||
|
if patternType != "" {
|
||||||
|
// Filter by type
|
||||||
|
patterns, err = store.GetPatternsByType(r.Context(), models.PatternType(patternType), limit)
|
||||||
|
} else if project != "" {
|
||||||
|
// Filter by project
|
||||||
|
patterns, err = store.GetPatternsByProject(r.Context(), project, limit)
|
||||||
|
} else {
|
||||||
|
// Get all active patterns
|
||||||
|
patterns, err = store.GetActivePatterns(r.Context(), limit)
|
||||||
|
}
|
||||||
|
|
||||||
|
if err != nil {
|
||||||
|
http.Error(w, err.Error(), http.StatusInternalServerError)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
writeJSON(w, patterns)
|
||||||
|
}
|
||||||
|
|
||||||
|
// handleGetPatternStats returns aggregate statistics about patterns.
|
||||||
|
func (s *Service) handleGetPatternStats(w http.ResponseWriter, r *http.Request) {
|
||||||
|
s.initMu.RLock()
|
||||||
|
store := s.patternStore
|
||||||
|
s.initMu.RUnlock()
|
||||||
|
|
||||||
|
if store == nil {
|
||||||
|
http.Error(w, "pattern store not initialized", http.StatusServiceUnavailable)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
stats, err := store.GetPatternStats(r.Context())
|
||||||
|
if err != nil {
|
||||||
|
http.Error(w, err.Error(), http.StatusInternalServerError)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
writeJSON(w, stats)
|
||||||
|
}
|
||||||
|
|
||||||
|
// handleGetPatternByID returns a single pattern by ID.
|
||||||
|
func (s *Service) handleGetPatternByID(w http.ResponseWriter, r *http.Request) {
|
||||||
|
s.initMu.RLock()
|
||||||
|
store := s.patternStore
|
||||||
|
s.initMu.RUnlock()
|
||||||
|
|
||||||
|
if store == nil {
|
||||||
|
http.Error(w, "pattern store not initialized", http.StatusServiceUnavailable)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
idStr := chi.URLParam(r, "id")
|
||||||
|
id, err := strconv.ParseInt(idStr, 10, 64)
|
||||||
|
if err != nil {
|
||||||
|
http.Error(w, "invalid pattern ID", http.StatusBadRequest)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
pattern, err := store.GetPatternByID(r.Context(), id)
|
||||||
|
if err != nil {
|
||||||
|
http.Error(w, err.Error(), http.StatusInternalServerError)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
if pattern == nil {
|
||||||
|
http.Error(w, "pattern not found", http.StatusNotFound)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
writeJSON(w, pattern)
|
||||||
|
}
|
||||||
|
|
||||||
|
// handleGetPatternInsight returns a formatted insight string for a pattern.
|
||||||
|
func (s *Service) handleGetPatternInsight(w http.ResponseWriter, r *http.Request) {
|
||||||
|
s.initMu.RLock()
|
||||||
|
detector := s.patternDetector
|
||||||
|
s.initMu.RUnlock()
|
||||||
|
|
||||||
|
if detector == nil {
|
||||||
|
http.Error(w, "pattern detector not initialized", http.StatusServiceUnavailable)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
idStr := chi.URLParam(r, "id")
|
||||||
|
id, err := strconv.ParseInt(idStr, 10, 64)
|
||||||
|
if err != nil {
|
||||||
|
http.Error(w, "invalid pattern ID", http.StatusBadRequest)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
insight, err := detector.GetPatternInsight(r.Context(), id)
|
||||||
|
if err != nil {
|
||||||
|
http.Error(w, err.Error(), http.StatusInternalServerError)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
writeJSON(w, map[string]string{"insight": insight})
|
||||||
|
}
|
||||||
|
|
||||||
|
// handleDeletePattern deletes a pattern by ID.
|
||||||
|
func (s *Service) handleDeletePattern(w http.ResponseWriter, r *http.Request) {
|
||||||
|
s.initMu.RLock()
|
||||||
|
store := s.patternStore
|
||||||
|
s.initMu.RUnlock()
|
||||||
|
|
||||||
|
if store == nil {
|
||||||
|
http.Error(w, "pattern store not initialized", http.StatusServiceUnavailable)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
idStr := chi.URLParam(r, "id")
|
||||||
|
id, err := strconv.ParseInt(idStr, 10, 64)
|
||||||
|
if err != nil {
|
||||||
|
http.Error(w, "invalid pattern ID", http.StatusBadRequest)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := store.DeletePattern(r.Context(), id); err != nil {
|
||||||
|
http.Error(w, err.Error(), http.StatusInternalServerError)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
writeJSON(w, map[string]string{"status": "deleted"})
|
||||||
|
}
|
||||||
|
|
||||||
|
// handleDeprecatePattern marks a pattern as deprecated.
|
||||||
|
func (s *Service) handleDeprecatePattern(w http.ResponseWriter, r *http.Request) {
|
||||||
|
s.initMu.RLock()
|
||||||
|
store := s.patternStore
|
||||||
|
s.initMu.RUnlock()
|
||||||
|
|
||||||
|
if store == nil {
|
||||||
|
http.Error(w, "pattern store not initialized", http.StatusServiceUnavailable)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
idStr := chi.URLParam(r, "id")
|
||||||
|
id, err := strconv.ParseInt(idStr, 10, 64)
|
||||||
|
if err != nil {
|
||||||
|
http.Error(w, "invalid pattern ID", http.StatusBadRequest)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := store.MarkPatternDeprecated(r.Context(), id); err != nil {
|
||||||
|
http.Error(w, err.Error(), http.StatusInternalServerError)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
writeJSON(w, map[string]string{"status": "deprecated"})
|
||||||
|
}
|
||||||
|
|
||||||
|
// MergePatternsRequest is the request body for merging patterns.
|
||||||
|
type MergePatternsRequest struct {
|
||||||
|
SourceID int64 `json:"source_id"`
|
||||||
|
TargetID int64 `json:"target_id"`
|
||||||
|
}
|
||||||
|
|
||||||
|
// handleSearchPatterns performs full-text search on patterns.
|
||||||
|
func (s *Service) handleSearchPatterns(w http.ResponseWriter, r *http.Request) {
|
||||||
|
s.initMu.RLock()
|
||||||
|
store := s.patternStore
|
||||||
|
s.initMu.RUnlock()
|
||||||
|
|
||||||
|
if store == nil {
|
||||||
|
http.Error(w, "pattern store not initialized", http.StatusServiceUnavailable)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
query := r.URL.Query().Get("q")
|
||||||
|
if query == "" {
|
||||||
|
http.Error(w, "query parameter 'q' is required", http.StatusBadRequest)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
limit := DefaultPatternsLimit
|
||||||
|
if l := r.URL.Query().Get("limit"); l != "" {
|
||||||
|
if parsed, err := strconv.Atoi(l); err == nil && parsed > 0 {
|
||||||
|
limit = parsed
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
patterns, err := store.SearchPatternsFTS(r.Context(), query, limit)
|
||||||
|
if err != nil {
|
||||||
|
http.Error(w, err.Error(), http.StatusInternalServerError)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
writeJSON(w, patterns)
|
||||||
|
}
|
||||||
|
|
||||||
|
// handleGetPatternByName returns a pattern by its name.
|
||||||
|
func (s *Service) handleGetPatternByName(w http.ResponseWriter, r *http.Request) {
|
||||||
|
s.initMu.RLock()
|
||||||
|
store := s.patternStore
|
||||||
|
s.initMu.RUnlock()
|
||||||
|
|
||||||
|
if store == nil {
|
||||||
|
http.Error(w, "pattern store not initialized", http.StatusServiceUnavailable)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
name := r.URL.Query().Get("name")
|
||||||
|
if name == "" {
|
||||||
|
http.Error(w, "query parameter 'name' is required", http.StatusBadRequest)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
pattern, err := store.GetPatternByName(r.Context(), name)
|
||||||
|
if err != nil {
|
||||||
|
http.Error(w, err.Error(), http.StatusInternalServerError)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
if pattern == nil {
|
||||||
|
http.Error(w, "pattern not found", http.StatusNotFound)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
writeJSON(w, pattern)
|
||||||
|
}
|
||||||
|
|
||||||
|
// handleMergePatterns merges a source pattern into a target pattern.
|
||||||
|
func (s *Service) handleMergePatterns(w http.ResponseWriter, r *http.Request) {
|
||||||
|
s.initMu.RLock()
|
||||||
|
store := s.patternStore
|
||||||
|
s.initMu.RUnlock()
|
||||||
|
|
||||||
|
if store == nil {
|
||||||
|
http.Error(w, "pattern store not initialized", http.StatusServiceUnavailable)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
var req MergePatternsRequest
|
||||||
|
if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
|
||||||
|
http.Error(w, err.Error(), http.StatusBadRequest)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
if req.SourceID == 0 || req.TargetID == 0 {
|
||||||
|
http.Error(w, "source_id and target_id are required", http.StatusBadRequest)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
if req.SourceID == req.TargetID {
|
||||||
|
http.Error(w, "source_id and target_id cannot be the same", http.StatusBadRequest)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := store.MergePatterns(r.Context(), req.SourceID, req.TargetID); err != nil {
|
||||||
|
http.Error(w, err.Error(), http.StatusInternalServerError)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
writeJSON(w, map[string]string{"status": "merged"})
|
||||||
|
}
|
||||||
@@ -0,0 +1,174 @@
|
|||||||
|
// Package worker provides the main worker service for claude-mnemonic.
|
||||||
|
package worker
|
||||||
|
|
||||||
|
import (
|
||||||
|
"net/http"
|
||||||
|
"strconv"
|
||||||
|
|
||||||
|
"github.com/go-chi/chi/v5"
|
||||||
|
"github.com/lukaszraczylo/claude-mnemonic/pkg/models"
|
||||||
|
)
|
||||||
|
|
||||||
|
// DefaultRelationsLimit is the default number of relations to return.
|
||||||
|
const DefaultRelationsLimit = 50
|
||||||
|
|
||||||
|
// handleGetRelations returns relations for an observation.
|
||||||
|
func (s *Service) handleGetRelations(w http.ResponseWriter, r *http.Request) {
|
||||||
|
idStr := chi.URLParam(r, "id")
|
||||||
|
id, err := strconv.ParseInt(idStr, 10, 64)
|
||||||
|
if err != nil {
|
||||||
|
http.Error(w, "invalid observation id", http.StatusBadRequest)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
relations, err := s.relationStore.GetRelationsWithDetails(r.Context(), id)
|
||||||
|
if err != nil {
|
||||||
|
http.Error(w, err.Error(), http.StatusInternalServerError)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
if relations == nil {
|
||||||
|
relations = []*models.RelationWithDetails{}
|
||||||
|
}
|
||||||
|
|
||||||
|
writeJSON(w, relations)
|
||||||
|
}
|
||||||
|
|
||||||
|
// handleGetRelationGraph returns the relation graph for an observation.
|
||||||
|
func (s *Service) handleGetRelationGraph(w http.ResponseWriter, r *http.Request) {
|
||||||
|
idStr := chi.URLParam(r, "id")
|
||||||
|
id, err := strconv.ParseInt(idStr, 10, 64)
|
||||||
|
if err != nil {
|
||||||
|
http.Error(w, "invalid observation id", http.StatusBadRequest)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// Get depth parameter (default 2)
|
||||||
|
depth := 2
|
||||||
|
if depthStr := r.URL.Query().Get("depth"); depthStr != "" {
|
||||||
|
if d, err := strconv.Atoi(depthStr); err == nil && d > 0 && d <= 5 {
|
||||||
|
depth = d
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
graph, err := s.relationStore.GetRelationGraph(r.Context(), id, depth)
|
||||||
|
if err != nil {
|
||||||
|
http.Error(w, err.Error(), http.StatusInternalServerError)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
writeJSON(w, graph)
|
||||||
|
}
|
||||||
|
|
||||||
|
// handleGetRelatedObservations returns observations related to a given one.
|
||||||
|
func (s *Service) handleGetRelatedObservations(w http.ResponseWriter, r *http.Request) {
|
||||||
|
idStr := chi.URLParam(r, "id")
|
||||||
|
id, err := strconv.ParseInt(idStr, 10, 64)
|
||||||
|
if err != nil {
|
||||||
|
http.Error(w, "invalid observation id", http.StatusBadRequest)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// Get minimum confidence parameter (default 0.4)
|
||||||
|
minConfidence := 0.4
|
||||||
|
if confStr := r.URL.Query().Get("min_confidence"); confStr != "" {
|
||||||
|
if c, err := strconv.ParseFloat(confStr, 64); err == nil && c >= 0 && c <= 1 {
|
||||||
|
minConfidence = c
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Get related observation IDs
|
||||||
|
relatedIDs, err := s.relationStore.GetRelatedObservationIDs(r.Context(), id, minConfidence)
|
||||||
|
if err != nil {
|
||||||
|
http.Error(w, err.Error(), http.StatusInternalServerError)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
if len(relatedIDs) == 0 {
|
||||||
|
writeJSON(w, []*models.Observation{})
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// Fetch full observations
|
||||||
|
observations, err := s.observationStore.GetObservationsByIDs(r.Context(), relatedIDs, "importance", 50)
|
||||||
|
if err != nil {
|
||||||
|
http.Error(w, err.Error(), http.StatusInternalServerError)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
if observations == nil {
|
||||||
|
observations = []*models.Observation{}
|
||||||
|
}
|
||||||
|
|
||||||
|
writeJSON(w, observations)
|
||||||
|
}
|
||||||
|
|
||||||
|
// handleGetRelationsByType returns all relations of a specific type.
|
||||||
|
func (s *Service) handleGetRelationsByType(w http.ResponseWriter, r *http.Request) {
|
||||||
|
relType := chi.URLParam(r, "type")
|
||||||
|
|
||||||
|
// Validate relation type
|
||||||
|
validType := false
|
||||||
|
for _, t := range models.AllRelationTypes {
|
||||||
|
if string(t) == relType {
|
||||||
|
validType = true
|
||||||
|
break
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if !validType {
|
||||||
|
http.Error(w, "invalid relation type", http.StatusBadRequest)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
limit := DefaultRelationsLimit
|
||||||
|
if limitStr := r.URL.Query().Get("limit"); limitStr != "" {
|
||||||
|
if l, err := strconv.Atoi(limitStr); err == nil && l > 0 && l <= 100 {
|
||||||
|
limit = l
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
relations, err := s.relationStore.GetRelationsByType(r.Context(), models.RelationType(relType), limit)
|
||||||
|
if err != nil {
|
||||||
|
http.Error(w, err.Error(), http.StatusInternalServerError)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
if relations == nil {
|
||||||
|
relations = []*models.ObservationRelation{}
|
||||||
|
}
|
||||||
|
|
||||||
|
writeJSON(w, relations)
|
||||||
|
}
|
||||||
|
|
||||||
|
// handleGetRelationStats returns statistics about relations.
|
||||||
|
func (s *Service) handleGetRelationStats(w http.ResponseWriter, r *http.Request) {
|
||||||
|
// Get total relation count
|
||||||
|
totalCount, err := s.relationStore.GetTotalRelationCount(r.Context())
|
||||||
|
if err != nil {
|
||||||
|
http.Error(w, err.Error(), http.StatusInternalServerError)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// Get high confidence relations count
|
||||||
|
highConfRelations, err := s.relationStore.GetHighConfidenceRelations(r.Context(), 0.7, 1000)
|
||||||
|
if err != nil {
|
||||||
|
http.Error(w, err.Error(), http.StatusInternalServerError)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// Count by relation type
|
||||||
|
typeCounts := make(map[string]int)
|
||||||
|
for _, t := range models.AllRelationTypes {
|
||||||
|
relations, err := s.relationStore.GetRelationsByType(r.Context(), t, 1000)
|
||||||
|
if err == nil {
|
||||||
|
typeCounts[string(t)] = len(relations)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
writeJSON(w, map[string]interface{}{
|
||||||
|
"total_count": totalCount,
|
||||||
|
"high_confidence": len(highConfRelations),
|
||||||
|
"by_type": typeCounts,
|
||||||
|
"min_confidence_used": 0.4,
|
||||||
|
})
|
||||||
|
}
|
||||||
@@ -0,0 +1,354 @@
|
|||||||
|
// Package worker provides the main worker service for claude-mnemonic.
|
||||||
|
package worker
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"encoding/json"
|
||||||
|
"net/http"
|
||||||
|
"strconv"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"github.com/go-chi/chi/v5"
|
||||||
|
"github.com/lukaszraczylo/claude-mnemonic/pkg/models"
|
||||||
|
)
|
||||||
|
|
||||||
|
// FeedbackRequest represents a user feedback submission.
|
||||||
|
type FeedbackRequest struct {
|
||||||
|
Feedback int `json:"feedback"` // -1 (thumbs down), 0 (neutral), 1 (thumbs up)
|
||||||
|
}
|
||||||
|
|
||||||
|
// handleObservationFeedback handles user feedback (thumbs up/down) for an observation.
|
||||||
|
// POST /api/observations/{id}/feedback
|
||||||
|
func (s *Service) handleObservationFeedback(w http.ResponseWriter, r *http.Request) {
|
||||||
|
// Parse observation ID
|
||||||
|
idStr := chi.URLParam(r, "id")
|
||||||
|
id, err := strconv.ParseInt(idStr, 10, 64)
|
||||||
|
if err != nil {
|
||||||
|
http.Error(w, "invalid observation id", http.StatusBadRequest)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// Parse request body
|
||||||
|
var req FeedbackRequest
|
||||||
|
if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
|
||||||
|
http.Error(w, "invalid request body", http.StatusBadRequest)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// Validate feedback value
|
||||||
|
if req.Feedback < -1 || req.Feedback > 1 {
|
||||||
|
http.Error(w, "feedback must be -1, 0, or 1", http.StatusBadRequest)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// Get required components
|
||||||
|
s.initMu.RLock()
|
||||||
|
observationStore := s.observationStore
|
||||||
|
scoreCalculator := s.scoreCalculator
|
||||||
|
s.initMu.RUnlock()
|
||||||
|
|
||||||
|
if observationStore == nil {
|
||||||
|
http.Error(w, "service not ready", http.StatusServiceUnavailable)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// Update feedback in database
|
||||||
|
if err := observationStore.UpdateObservationFeedback(r.Context(), id, req.Feedback); err != nil {
|
||||||
|
http.Error(w, "failed to update feedback", http.StatusInternalServerError)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// Recalculate score immediately if calculator is available
|
||||||
|
var newScore float64
|
||||||
|
if scoreCalculator != nil {
|
||||||
|
obs, err := observationStore.GetObservationByID(r.Context(), id)
|
||||||
|
if err == nil && obs != nil {
|
||||||
|
obs.UserFeedback = req.Feedback // Apply the new feedback
|
||||||
|
newScore = scoreCalculator.Calculate(obs, time.Now())
|
||||||
|
if err := observationStore.UpdateImportanceScore(r.Context(), id, newScore); err != nil {
|
||||||
|
// Log but don't fail - feedback was recorded
|
||||||
|
// Score will be updated on next recalculation cycle
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Broadcast update via SSE
|
||||||
|
s.sseBroadcaster.Broadcast(map[string]interface{}{
|
||||||
|
"type": "observation_feedback",
|
||||||
|
"id": id,
|
||||||
|
"feedback": req.Feedback,
|
||||||
|
"score": newScore,
|
||||||
|
})
|
||||||
|
|
||||||
|
writeJSON(w, map[string]interface{}{
|
||||||
|
"status": "ok",
|
||||||
|
"id": id,
|
||||||
|
"feedback": req.Feedback,
|
||||||
|
"score": newScore,
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
// handleGetScoringStats returns scoring statistics and configuration.
|
||||||
|
// GET /api/scoring/stats
|
||||||
|
func (s *Service) handleGetScoringStats(w http.ResponseWriter, r *http.Request) {
|
||||||
|
project := r.URL.Query().Get("project")
|
||||||
|
|
||||||
|
s.initMu.RLock()
|
||||||
|
observationStore := s.observationStore
|
||||||
|
recalculator := s.recalculator
|
||||||
|
s.initMu.RUnlock()
|
||||||
|
|
||||||
|
if observationStore == nil {
|
||||||
|
http.Error(w, "service not ready", http.StatusServiceUnavailable)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// Get feedback statistics
|
||||||
|
feedbackStats, err := observationStore.GetObservationFeedbackStats(r.Context(), project)
|
||||||
|
if err != nil {
|
||||||
|
http.Error(w, "failed to get feedback stats", http.StatusInternalServerError)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
response := map[string]interface{}{
|
||||||
|
"feedback": feedbackStats,
|
||||||
|
}
|
||||||
|
|
||||||
|
// Add recalculator stats if available
|
||||||
|
if recalculator != nil {
|
||||||
|
response["recalculator"] = recalculator.GetStats()
|
||||||
|
}
|
||||||
|
|
||||||
|
writeJSON(w, response)
|
||||||
|
}
|
||||||
|
|
||||||
|
// handleGetTopObservations returns the highest-scoring observations.
|
||||||
|
// GET /api/observations/top
|
||||||
|
func (s *Service) handleGetTopObservations(w http.ResponseWriter, r *http.Request) {
|
||||||
|
limit := parseIntParam(r, "limit", 10)
|
||||||
|
project := r.URL.Query().Get("project")
|
||||||
|
|
||||||
|
s.initMu.RLock()
|
||||||
|
observationStore := s.observationStore
|
||||||
|
s.initMu.RUnlock()
|
||||||
|
|
||||||
|
if observationStore == nil {
|
||||||
|
http.Error(w, "service not ready", http.StatusServiceUnavailable)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
observations, err := observationStore.GetTopScoringObservations(r.Context(), project, limit)
|
||||||
|
if err != nil {
|
||||||
|
http.Error(w, "failed to get top observations", http.StatusInternalServerError)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
if observations == nil {
|
||||||
|
observations = []*models.Observation{}
|
||||||
|
}
|
||||||
|
|
||||||
|
writeJSON(w, observations)
|
||||||
|
}
|
||||||
|
|
||||||
|
// handleGetMostRetrieved returns the most frequently retrieved observations.
|
||||||
|
// GET /api/observations/most-retrieved
|
||||||
|
func (s *Service) handleGetMostRetrieved(w http.ResponseWriter, r *http.Request) {
|
||||||
|
limit := parseIntParam(r, "limit", 10)
|
||||||
|
project := r.URL.Query().Get("project")
|
||||||
|
|
||||||
|
s.initMu.RLock()
|
||||||
|
observationStore := s.observationStore
|
||||||
|
s.initMu.RUnlock()
|
||||||
|
|
||||||
|
if observationStore == nil {
|
||||||
|
http.Error(w, "service not ready", http.StatusServiceUnavailable)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
observations, err := observationStore.GetMostRetrievedObservations(r.Context(), project, limit)
|
||||||
|
if err != nil {
|
||||||
|
http.Error(w, "failed to get most retrieved observations", http.StatusInternalServerError)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
if observations == nil {
|
||||||
|
observations = []*models.Observation{}
|
||||||
|
}
|
||||||
|
|
||||||
|
writeJSON(w, observations)
|
||||||
|
}
|
||||||
|
|
||||||
|
// handleExplainScore returns a breakdown of how an observation's score was calculated.
|
||||||
|
// GET /api/observations/{id}/score
|
||||||
|
func (s *Service) handleExplainScore(w http.ResponseWriter, r *http.Request) {
|
||||||
|
// Parse observation ID
|
||||||
|
idStr := chi.URLParam(r, "id")
|
||||||
|
id, err := strconv.ParseInt(idStr, 10, 64)
|
||||||
|
if err != nil {
|
||||||
|
http.Error(w, "invalid observation id", http.StatusBadRequest)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
s.initMu.RLock()
|
||||||
|
observationStore := s.observationStore
|
||||||
|
scoreCalculator := s.scoreCalculator
|
||||||
|
s.initMu.RUnlock()
|
||||||
|
|
||||||
|
if observationStore == nil || scoreCalculator == nil {
|
||||||
|
http.Error(w, "service not ready", http.StatusServiceUnavailable)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// Get observation
|
||||||
|
obs, err := observationStore.GetObservationByID(r.Context(), id)
|
||||||
|
if err != nil {
|
||||||
|
http.Error(w, "failed to get observation", http.StatusInternalServerError)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
if obs == nil {
|
||||||
|
http.Error(w, "observation not found", http.StatusNotFound)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// Calculate score components
|
||||||
|
components := scoreCalculator.CalculateComponents(obs, time.Now())
|
||||||
|
|
||||||
|
writeJSON(w, map[string]interface{}{
|
||||||
|
"id": id,
|
||||||
|
"components": components,
|
||||||
|
"config": scoreCalculator.GetConfig(),
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
// handleUpdateConceptWeight updates a concept weight.
|
||||||
|
// PUT /api/scoring/concepts/{concept}
|
||||||
|
func (s *Service) handleUpdateConceptWeight(w http.ResponseWriter, r *http.Request) {
|
||||||
|
concept := chi.URLParam(r, "concept")
|
||||||
|
if concept == "" {
|
||||||
|
http.Error(w, "concept is required", http.StatusBadRequest)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
var req struct {
|
||||||
|
Weight float64 `json:"weight"`
|
||||||
|
}
|
||||||
|
if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
|
||||||
|
http.Error(w, "invalid request body", http.StatusBadRequest)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// Validate weight
|
||||||
|
if req.Weight < 0 || req.Weight > 1 {
|
||||||
|
http.Error(w, "weight must be between 0 and 1", http.StatusBadRequest)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
s.initMu.RLock()
|
||||||
|
observationStore := s.observationStore
|
||||||
|
recalculator := s.recalculator
|
||||||
|
s.initMu.RUnlock()
|
||||||
|
|
||||||
|
if observationStore == nil {
|
||||||
|
http.Error(w, "service not ready", http.StatusServiceUnavailable)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// Update in database
|
||||||
|
if err := observationStore.UpdateConceptWeight(r.Context(), concept, req.Weight); err != nil {
|
||||||
|
http.Error(w, "failed to update concept weight", http.StatusInternalServerError)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// Refresh concept weights in recalculator
|
||||||
|
if recalculator != nil {
|
||||||
|
if err := recalculator.RefreshConceptWeights(r.Context()); err != nil {
|
||||||
|
// Log but don't fail - weight was saved
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
writeJSON(w, map[string]interface{}{
|
||||||
|
"status": "ok",
|
||||||
|
"concept": concept,
|
||||||
|
"weight": req.Weight,
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
// handleGetConceptWeights returns all concept weights.
|
||||||
|
// GET /api/scoring/concepts
|
||||||
|
func (s *Service) handleGetConceptWeights(w http.ResponseWriter, r *http.Request) {
|
||||||
|
s.initMu.RLock()
|
||||||
|
observationStore := s.observationStore
|
||||||
|
s.initMu.RUnlock()
|
||||||
|
|
||||||
|
if observationStore == nil {
|
||||||
|
http.Error(w, "service not ready", http.StatusServiceUnavailable)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
weights, err := observationStore.GetConceptWeights(r.Context())
|
||||||
|
if err != nil {
|
||||||
|
http.Error(w, "failed to get concept weights", http.StatusInternalServerError)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
writeJSON(w, weights)
|
||||||
|
}
|
||||||
|
|
||||||
|
// handleTriggerRecalculation triggers an immediate score recalculation.
|
||||||
|
// POST /api/scoring/recalculate
|
||||||
|
func (s *Service) handleTriggerRecalculation(w http.ResponseWriter, r *http.Request) {
|
||||||
|
s.initMu.RLock()
|
||||||
|
recalculator := s.recalculator
|
||||||
|
s.initMu.RUnlock()
|
||||||
|
|
||||||
|
if recalculator == nil {
|
||||||
|
http.Error(w, "recalculator not available", http.StatusServiceUnavailable)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// Run recalculation in background
|
||||||
|
go func() {
|
||||||
|
if err := recalculator.RecalculateNow(r.Context()); err != nil {
|
||||||
|
// Log error but don't block response
|
||||||
|
}
|
||||||
|
}()
|
||||||
|
|
||||||
|
writeJSON(w, map[string]string{"status": "recalculation triggered"})
|
||||||
|
}
|
||||||
|
|
||||||
|
// parseIntParam parses an integer query parameter with a default value.
|
||||||
|
func parseIntParam(r *http.Request, name string, defaultVal int) int {
|
||||||
|
if val := r.URL.Query().Get(name); val != "" {
|
||||||
|
if parsed, err := strconv.Atoi(val); err == nil && parsed > 0 {
|
||||||
|
return parsed
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return defaultVal
|
||||||
|
}
|
||||||
|
|
||||||
|
// incrementRetrievalCounts increments retrieval counts for observations.
|
||||||
|
// Called after search results are returned to track popularity.
|
||||||
|
func (s *Service) incrementRetrievalCounts(ids []int64) {
|
||||||
|
if len(ids) == 0 {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
s.initMu.RLock()
|
||||||
|
store := s.observationStore
|
||||||
|
s.initMu.RUnlock()
|
||||||
|
|
||||||
|
if store == nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// Increment in background to not block response
|
||||||
|
go func() {
|
||||||
|
// Create a new context with timeout for the background operation
|
||||||
|
ctx, cancel := context.WithTimeout(context.Background(), 3*time.Second)
|
||||||
|
defer cancel()
|
||||||
|
|
||||||
|
if err := store.IncrementRetrievalCount(ctx, ids); err != nil {
|
||||||
|
// Log but don't fail - this is a background operation
|
||||||
|
}
|
||||||
|
}()
|
||||||
|
}
|
||||||
+490
-1
@@ -15,6 +15,10 @@ import (
|
|||||||
"github.com/lukaszraczylo/claude-mnemonic/internal/config"
|
"github.com/lukaszraczylo/claude-mnemonic/internal/config"
|
||||||
"github.com/lukaszraczylo/claude-mnemonic/internal/db/sqlite"
|
"github.com/lukaszraczylo/claude-mnemonic/internal/db/sqlite"
|
||||||
"github.com/lukaszraczylo/claude-mnemonic/internal/embedding"
|
"github.com/lukaszraczylo/claude-mnemonic/internal/embedding"
|
||||||
|
"github.com/lukaszraczylo/claude-mnemonic/internal/pattern"
|
||||||
|
"github.com/lukaszraczylo/claude-mnemonic/internal/reranking"
|
||||||
|
"github.com/lukaszraczylo/claude-mnemonic/internal/scoring"
|
||||||
|
"github.com/lukaszraczylo/claude-mnemonic/internal/search/expansion"
|
||||||
"github.com/lukaszraczylo/claude-mnemonic/internal/update"
|
"github.com/lukaszraczylo/claude-mnemonic/internal/update"
|
||||||
"github.com/lukaszraczylo/claude-mnemonic/internal/vector/sqlitevec"
|
"github.com/lukaszraczylo/claude-mnemonic/internal/vector/sqlitevec"
|
||||||
"github.com/lukaszraczylo/claude-mnemonic/internal/watcher"
|
"github.com/lukaszraczylo/claude-mnemonic/internal/watcher"
|
||||||
@@ -64,6 +68,12 @@ type Service struct {
|
|||||||
observationStore *sqlite.ObservationStore
|
observationStore *sqlite.ObservationStore
|
||||||
summaryStore *sqlite.SummaryStore
|
summaryStore *sqlite.SummaryStore
|
||||||
promptStore *sqlite.PromptStore
|
promptStore *sqlite.PromptStore
|
||||||
|
conflictStore *sqlite.ConflictStore
|
||||||
|
patternStore *sqlite.PatternStore
|
||||||
|
relationStore *sqlite.RelationStore
|
||||||
|
|
||||||
|
// Pattern detection
|
||||||
|
patternDetector *pattern.Detector
|
||||||
|
|
||||||
// Domain services
|
// Domain services
|
||||||
sessionManager *session.Manager
|
sessionManager *session.Manager
|
||||||
@@ -75,6 +85,16 @@ type Service struct {
|
|||||||
vectorClient *sqlitevec.Client
|
vectorClient *sqlitevec.Client
|
||||||
vectorSync *sqlitevec.Sync
|
vectorSync *sqlitevec.Sync
|
||||||
|
|
||||||
|
// Cross-encoder reranking (for improved search relevance)
|
||||||
|
reranker *reranking.Service
|
||||||
|
|
||||||
|
// Query expansion (for improved search recall)
|
||||||
|
queryExpander *expansion.Expander
|
||||||
|
|
||||||
|
// Importance scoring
|
||||||
|
scoreCalculator *scoring.Calculator
|
||||||
|
recalculator *scoring.Recalculator
|
||||||
|
|
||||||
// HTTP server
|
// HTTP server
|
||||||
router *chi.Mux
|
router *chi.Mux
|
||||||
server *http.Server
|
server *http.Server
|
||||||
@@ -177,6 +197,15 @@ func (s *Service) initializeAsync() {
|
|||||||
observationStore := sqlite.NewObservationStore(store)
|
observationStore := sqlite.NewObservationStore(store)
|
||||||
summaryStore := sqlite.NewSummaryStore(store)
|
summaryStore := sqlite.NewSummaryStore(store)
|
||||||
promptStore := sqlite.NewPromptStore(store)
|
promptStore := sqlite.NewPromptStore(store)
|
||||||
|
conflictStore := sqlite.NewConflictStore(store)
|
||||||
|
patternStore := sqlite.NewPatternStore(store)
|
||||||
|
relationStore := sqlite.NewRelationStore(store)
|
||||||
|
|
||||||
|
// Enable conflict detection by linking stores
|
||||||
|
observationStore.SetConflictStore(conflictStore)
|
||||||
|
|
||||||
|
// Enable relation detection by linking stores
|
||||||
|
observationStore.SetRelationStore(relationStore)
|
||||||
|
|
||||||
// Create session manager
|
// Create session manager
|
||||||
sessionManager := session.NewManager(sessionStore)
|
sessionManager := session.NewManager(sessionStore)
|
||||||
@@ -186,6 +215,8 @@ func (s *Service) initializeAsync() {
|
|||||||
var vectorClient *sqlitevec.Client
|
var vectorClient *sqlitevec.Client
|
||||||
var vectorSync *sqlitevec.Sync
|
var vectorSync *sqlitevec.Sync
|
||||||
|
|
||||||
|
var reranker *reranking.Service
|
||||||
|
|
||||||
emb, err := embedding.NewService()
|
emb, err := embedding.NewService()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Warn().Err(err).Msg("Embedding service creation failed - vector search disabled")
|
log.Warn().Err(err).Msg("Embedding service creation failed - vector search disabled")
|
||||||
@@ -200,8 +231,32 @@ func (s *Service) initializeAsync() {
|
|||||||
} else {
|
} else {
|
||||||
vectorClient = client
|
vectorClient = client
|
||||||
vectorSync = sqlitevec.NewSync(client)
|
vectorSync = sqlitevec.NewSync(client)
|
||||||
log.Info().Msg("sqlite-vec vector search enabled")
|
log.Info().
|
||||||
|
Str("model", embedSvc.Version()).
|
||||||
|
Msg("sqlite-vec vector search enabled")
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Create cross-encoder reranking service if enabled
|
||||||
|
if s.config.RerankingEnabled {
|
||||||
|
rerankCfg := reranking.DefaultConfig()
|
||||||
|
if s.config.RerankingAlpha > 0 && s.config.RerankingAlpha <= 1 {
|
||||||
|
rerankCfg.Alpha = s.config.RerankingAlpha
|
||||||
|
}
|
||||||
|
|
||||||
|
ranker, err := reranking.NewService(rerankCfg)
|
||||||
|
if err != nil {
|
||||||
|
log.Warn().Err(err).Msg("Cross-encoder reranking service creation failed - reranking disabled")
|
||||||
|
} else {
|
||||||
|
reranker = ranker
|
||||||
|
log.Info().
|
||||||
|
Float64("alpha", rerankCfg.Alpha).
|
||||||
|
Msg("Cross-encoder reranking enabled")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Create query expander for improved search recall
|
||||||
|
s.queryExpander = expansion.NewExpander(embedSvc)
|
||||||
|
log.Info().Msg("Query expansion enabled")
|
||||||
}
|
}
|
||||||
|
|
||||||
// Create SDK processor (optional - will be nil if Claude CLI not available)
|
// Create SDK processor (optional - will be nil if Claude CLI not available)
|
||||||
@@ -225,11 +280,38 @@ func (s *Service) initializeAsync() {
|
|||||||
s.observationStore = observationStore
|
s.observationStore = observationStore
|
||||||
s.summaryStore = summaryStore
|
s.summaryStore = summaryStore
|
||||||
s.promptStore = promptStore
|
s.promptStore = promptStore
|
||||||
|
s.conflictStore = conflictStore
|
||||||
|
s.patternStore = patternStore
|
||||||
|
s.relationStore = relationStore
|
||||||
s.sessionManager = sessionManager
|
s.sessionManager = sessionManager
|
||||||
s.processor = processor
|
s.processor = processor
|
||||||
s.embedSvc = embedSvc
|
s.embedSvc = embedSvc
|
||||||
s.vectorClient = vectorClient
|
s.vectorClient = vectorClient
|
||||||
s.vectorSync = vectorSync
|
s.vectorSync = vectorSync
|
||||||
|
s.reranker = reranker
|
||||||
|
s.initMu.Unlock()
|
||||||
|
|
||||||
|
// Initialize pattern detector
|
||||||
|
patternDetector := pattern.NewDetector(patternStore, observationStore, pattern.DefaultConfig())
|
||||||
|
|
||||||
|
// Set pattern sync callback if vector sync is available
|
||||||
|
if vectorSync != nil {
|
||||||
|
patternDetector.SetSyncFunc(func(p *models.Pattern) {
|
||||||
|
if err := vectorSync.SyncPattern(s.ctx, p); err != nil {
|
||||||
|
log.Warn().Err(err).Int64("id", p.ID).Msg("Failed to sync pattern to sqlite-vec")
|
||||||
|
}
|
||||||
|
})
|
||||||
|
|
||||||
|
// Set cleanup callback for pattern deletions
|
||||||
|
patternStore.SetCleanupFunc(func(ctx context.Context, deletedIDs []int64) {
|
||||||
|
if err := vectorSync.DeletePatterns(ctx, deletedIDs); err != nil {
|
||||||
|
log.Warn().Err(err).Ints64("ids", deletedIDs).Msg("Failed to delete patterns from sqlite-vec")
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
s.initMu.Lock()
|
||||||
|
s.patternDetector = patternDetector
|
||||||
s.initMu.Unlock()
|
s.initMu.Unlock()
|
||||||
|
|
||||||
// Set vector sync callbacks on processor if both are available
|
// Set vector sync callbacks on processor if both are available
|
||||||
@@ -238,6 +320,22 @@ func (s *Service) initializeAsync() {
|
|||||||
if err := vectorSync.SyncObservation(s.ctx, obs); err != nil {
|
if err := vectorSync.SyncObservation(s.ctx, obs); err != nil {
|
||||||
log.Warn().Err(err).Int64("id", obs.ID).Msg("Failed to sync observation to sqlite-vec")
|
log.Warn().Err(err).Int64("id", obs.ID).Msg("Failed to sync observation to sqlite-vec")
|
||||||
}
|
}
|
||||||
|
// Trigger pattern detection for the new observation
|
||||||
|
if patternDetector != nil {
|
||||||
|
go func(observation *models.Observation) {
|
||||||
|
detectCtx, cancel := context.WithTimeout(context.Background(), 10*time.Second)
|
||||||
|
defer cancel()
|
||||||
|
if result, err := patternDetector.AnalyzeObservation(detectCtx, observation); err != nil {
|
||||||
|
log.Warn().Err(err).Int64("obs_id", observation.ID).Msg("Pattern detection failed")
|
||||||
|
} else if result.MatchedPattern != nil {
|
||||||
|
log.Debug().
|
||||||
|
Int64("pattern_id", result.MatchedPattern.ID).
|
||||||
|
Str("pattern_name", result.MatchedPattern.Name).
|
||||||
|
Bool("is_new", result.IsNewPattern).
|
||||||
|
Msg("Pattern matched for observation")
|
||||||
|
}
|
||||||
|
}(obs)
|
||||||
|
}
|
||||||
})
|
})
|
||||||
processor.SetSyncSummaryFunc(func(summary *models.SessionSummary) {
|
processor.SetSyncSummaryFunc(func(summary *models.SessionSummary) {
|
||||||
if err := vectorSync.SyncSummary(s.ctx, summary); err != nil {
|
if err := vectorSync.SyncSummary(s.ctx, summary); err != nil {
|
||||||
@@ -282,6 +380,37 @@ func (s *Service) initializeAsync() {
|
|||||||
})
|
})
|
||||||
})
|
})
|
||||||
|
|
||||||
|
// Initialize importance scoring system
|
||||||
|
scoringConfig := models.DefaultScoringConfig()
|
||||||
|
|
||||||
|
// Load concept weights from database if available
|
||||||
|
if weights, err := observationStore.GetConceptWeights(s.ctx); err == nil && len(weights) > 0 {
|
||||||
|
scoringConfig.ConceptWeights = weights
|
||||||
|
log.Info().Int("count", len(weights)).Msg("Loaded concept weights from database")
|
||||||
|
}
|
||||||
|
|
||||||
|
scoreCalculator := scoring.NewCalculator(scoringConfig)
|
||||||
|
recalculator := scoring.NewRecalculator(observationStore, scoreCalculator, log.Logger)
|
||||||
|
|
||||||
|
s.initMu.Lock()
|
||||||
|
s.scoreCalculator = scoreCalculator
|
||||||
|
s.recalculator = recalculator
|
||||||
|
s.initMu.Unlock()
|
||||||
|
|
||||||
|
// Start background recalculator
|
||||||
|
s.wg.Add(1)
|
||||||
|
go func() {
|
||||||
|
defer s.wg.Done()
|
||||||
|
recalculator.Start(s.ctx)
|
||||||
|
}()
|
||||||
|
log.Info().Msg("Importance scoring system initialized")
|
||||||
|
|
||||||
|
// Start pattern detector background analysis
|
||||||
|
if patternDetector != nil {
|
||||||
|
patternDetector.Start()
|
||||||
|
log.Info().Msg("Pattern recognition engine started")
|
||||||
|
}
|
||||||
|
|
||||||
// Mark as ready
|
// Mark as ready
|
||||||
s.ready.Store(true)
|
s.ready.Store(true)
|
||||||
log.Info().Msg("Async initialization complete - service ready")
|
log.Info().Msg("Async initialization complete - service ready")
|
||||||
@@ -294,6 +423,27 @@ func (s *Service) initializeAsync() {
|
|||||||
|
|
||||||
// Start file watchers for auto-recreation on deletion
|
// Start file watchers for auto-recreation on deletion
|
||||||
s.startWatchers()
|
s.startWatchers()
|
||||||
|
|
||||||
|
// Check if vectors need rebuilding (empty or model version mismatch) and trigger background rebuild
|
||||||
|
if vectorClient != nil && vectorSync != nil {
|
||||||
|
needsRebuild, reason := vectorClient.NeedsRebuild(s.ctx)
|
||||||
|
if needsRebuild {
|
||||||
|
log.Info().
|
||||||
|
Str("reason", reason).
|
||||||
|
Str("model", vectorClient.ModelVersion()).
|
||||||
|
Msg("Vector rebuild required")
|
||||||
|
|
||||||
|
if reason == "empty" {
|
||||||
|
// Full rebuild - vectors table is empty
|
||||||
|
s.wg.Add(1)
|
||||||
|
go s.rebuildAllVectors(observationStore, summaryStore, promptStore, vectorSync)
|
||||||
|
} else {
|
||||||
|
// Granular rebuild - only rebuild vectors with mismatched model versions
|
||||||
|
s.wg.Add(1)
|
||||||
|
go s.rebuildStaleVectors(observationStore, summaryStore, promptStore, vectorClient, vectorSync)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// startWatchers initializes and starts file watchers for database and config.
|
// startWatchers initializes and starts file watchers for database and config.
|
||||||
@@ -384,6 +534,15 @@ func (s *Service) reinitializeDatabase() {
|
|||||||
observationStore := sqlite.NewObservationStore(store)
|
observationStore := sqlite.NewObservationStore(store)
|
||||||
summaryStore := sqlite.NewSummaryStore(store)
|
summaryStore := sqlite.NewSummaryStore(store)
|
||||||
promptStore := sqlite.NewPromptStore(store)
|
promptStore := sqlite.NewPromptStore(store)
|
||||||
|
conflictStore := sqlite.NewConflictStore(store)
|
||||||
|
patternStore := sqlite.NewPatternStore(store)
|
||||||
|
relationStore := sqlite.NewRelationStore(store)
|
||||||
|
|
||||||
|
// Enable conflict detection by linking stores
|
||||||
|
observationStore.SetConflictStore(conflictStore)
|
||||||
|
|
||||||
|
// Enable relation detection by linking stores
|
||||||
|
observationStore.SetRelationStore(relationStore)
|
||||||
|
|
||||||
// Create new session manager
|
// Create new session manager
|
||||||
sessionManager := session.NewManager(sessionStore)
|
sessionManager := session.NewManager(sessionStore)
|
||||||
@@ -393,6 +552,8 @@ func (s *Service) reinitializeDatabase() {
|
|||||||
var vectorClient *sqlitevec.Client
|
var vectorClient *sqlitevec.Client
|
||||||
var vectorSync *sqlitevec.Sync
|
var vectorSync *sqlitevec.Sync
|
||||||
|
|
||||||
|
var reranker *reranking.Service
|
||||||
|
|
||||||
emb, err := embedding.NewService()
|
emb, err := embedding.NewService()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Warn().Err(err).Msg("Embedding service creation failed after reinit")
|
log.Warn().Err(err).Msg("Embedding service creation failed after reinit")
|
||||||
@@ -408,6 +569,34 @@ func (s *Service) reinitializeDatabase() {
|
|||||||
vectorSync = sqlitevec.NewSync(client)
|
vectorSync = sqlitevec.NewSync(client)
|
||||||
log.Info().Msg("sqlite-vec reconnected after reinit")
|
log.Info().Msg("sqlite-vec reconnected after reinit")
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Recreate cross-encoder reranking service if enabled
|
||||||
|
if s.config.RerankingEnabled {
|
||||||
|
rerankCfg := reranking.DefaultConfig()
|
||||||
|
if s.config.RerankingAlpha > 0 && s.config.RerankingAlpha <= 1 {
|
||||||
|
rerankCfg.Alpha = s.config.RerankingAlpha
|
||||||
|
}
|
||||||
|
|
||||||
|
ranker, err := reranking.NewService(rerankCfg)
|
||||||
|
if err != nil {
|
||||||
|
log.Warn().Err(err).Msg("Cross-encoder reranking service creation failed after reinit")
|
||||||
|
} else {
|
||||||
|
reranker = ranker
|
||||||
|
log.Info().Msg("Cross-encoder reranking reconnected after reinit")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Recreate query expander
|
||||||
|
s.queryExpander = expansion.NewExpander(embedSvc)
|
||||||
|
log.Info().Msg("Query expansion reconnected after reinit")
|
||||||
|
}
|
||||||
|
|
||||||
|
// Close old reranker if exists
|
||||||
|
s.initMu.RLock()
|
||||||
|
oldReranker := s.reranker
|
||||||
|
s.initMu.RUnlock()
|
||||||
|
if oldReranker != nil {
|
||||||
|
_ = oldReranker.Close()
|
||||||
}
|
}
|
||||||
|
|
||||||
// Recreate SDK processor with new stores
|
// Recreate SDK processor with new stores
|
||||||
@@ -422,6 +611,30 @@ func (s *Service) reinitializeDatabase() {
|
|||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Stop old pattern detector if it exists
|
||||||
|
if s.patternDetector != nil {
|
||||||
|
s.patternDetector.Stop()
|
||||||
|
}
|
||||||
|
|
||||||
|
// Create new pattern detector
|
||||||
|
patternDetector := pattern.NewDetector(patternStore, observationStore, pattern.DefaultConfig())
|
||||||
|
|
||||||
|
// Set pattern sync callback if vector sync is available
|
||||||
|
if vectorSync != nil {
|
||||||
|
patternDetector.SetSyncFunc(func(p *models.Pattern) {
|
||||||
|
if err := vectorSync.SyncPattern(s.ctx, p); err != nil {
|
||||||
|
log.Warn().Err(err).Int64("id", p.ID).Msg("Failed to sync pattern to sqlite-vec")
|
||||||
|
}
|
||||||
|
})
|
||||||
|
|
||||||
|
// Set cleanup callback for pattern deletions
|
||||||
|
patternStore.SetCleanupFunc(func(ctx context.Context, deletedIDs []int64) {
|
||||||
|
if err := vectorSync.DeletePatterns(ctx, deletedIDs); err != nil {
|
||||||
|
log.Warn().Err(err).Ints64("ids", deletedIDs).Msg("Failed to delete patterns from sqlite-vec")
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
// Atomically swap all components
|
// Atomically swap all components
|
||||||
s.initMu.Lock()
|
s.initMu.Lock()
|
||||||
s.store = store
|
s.store = store
|
||||||
@@ -429,20 +642,44 @@ func (s *Service) reinitializeDatabase() {
|
|||||||
s.observationStore = observationStore
|
s.observationStore = observationStore
|
||||||
s.summaryStore = summaryStore
|
s.summaryStore = summaryStore
|
||||||
s.promptStore = promptStore
|
s.promptStore = promptStore
|
||||||
|
s.conflictStore = conflictStore
|
||||||
|
s.patternStore = patternStore
|
||||||
|
s.relationStore = relationStore
|
||||||
|
s.patternDetector = patternDetector
|
||||||
s.sessionManager = sessionManager
|
s.sessionManager = sessionManager
|
||||||
s.processor = processor
|
s.processor = processor
|
||||||
s.embedSvc = embedSvc
|
s.embedSvc = embedSvc
|
||||||
s.vectorClient = vectorClient
|
s.vectorClient = vectorClient
|
||||||
s.vectorSync = vectorSync
|
s.vectorSync = vectorSync
|
||||||
|
s.reranker = reranker
|
||||||
s.initError = nil
|
s.initError = nil
|
||||||
s.initMu.Unlock()
|
s.initMu.Unlock()
|
||||||
|
|
||||||
|
// Start pattern detector
|
||||||
|
patternDetector.Start()
|
||||||
|
|
||||||
// Set vector sync callbacks on processor if both are available
|
// Set vector sync callbacks on processor if both are available
|
||||||
if processor != nil && vectorSync != nil {
|
if processor != nil && vectorSync != nil {
|
||||||
processor.SetSyncObservationFunc(func(obs *models.Observation) {
|
processor.SetSyncObservationFunc(func(obs *models.Observation) {
|
||||||
if err := vectorSync.SyncObservation(s.ctx, obs); err != nil {
|
if err := vectorSync.SyncObservation(s.ctx, obs); err != nil {
|
||||||
log.Warn().Err(err).Int64("id", obs.ID).Msg("Failed to sync observation to sqlite-vec")
|
log.Warn().Err(err).Int64("id", obs.ID).Msg("Failed to sync observation to sqlite-vec")
|
||||||
}
|
}
|
||||||
|
// Trigger pattern detection for the new observation
|
||||||
|
if patternDetector != nil {
|
||||||
|
go func(observation *models.Observation) {
|
||||||
|
detectCtx, cancel := context.WithTimeout(context.Background(), 10*time.Second)
|
||||||
|
defer cancel()
|
||||||
|
if result, err := patternDetector.AnalyzeObservation(detectCtx, observation); err != nil {
|
||||||
|
log.Warn().Err(err).Int64("obs_id", observation.ID).Msg("Pattern detection failed")
|
||||||
|
} else if result.MatchedPattern != nil {
|
||||||
|
log.Debug().
|
||||||
|
Int64("pattern_id", result.MatchedPattern.ID).
|
||||||
|
Str("pattern_name", result.MatchedPattern.Name).
|
||||||
|
Bool("is_new", result.IsNewPattern).
|
||||||
|
Msg("Pattern matched for observation")
|
||||||
|
}
|
||||||
|
}(obs)
|
||||||
|
}
|
||||||
})
|
})
|
||||||
processor.SetSyncSummaryFunc(func(summary *models.SessionSummary) {
|
processor.SetSyncSummaryFunc(func(summary *models.SessionSummary) {
|
||||||
if err := vectorSync.SyncSummary(s.ctx, summary); err != nil {
|
if err := vectorSync.SyncSummary(s.ctx, summary); err != nil {
|
||||||
@@ -565,6 +802,210 @@ func (s *Service) processStaleQueue() {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// rebuildAllVectors rebuilds all vectors from observations, summaries, and prompts.
|
||||||
|
// Called when the vectors table is empty (e.g., after migration 20 drops all vectors).
|
||||||
|
func (s *Service) rebuildAllVectors(
|
||||||
|
observationStore *sqlite.ObservationStore,
|
||||||
|
summaryStore *sqlite.SummaryStore,
|
||||||
|
promptStore *sqlite.PromptStore,
|
||||||
|
vectorSync *sqlitevec.Sync,
|
||||||
|
) {
|
||||||
|
defer s.wg.Done()
|
||||||
|
|
||||||
|
log.Info().Msg("Starting full vector rebuild...")
|
||||||
|
start := time.Now()
|
||||||
|
|
||||||
|
var totalSynced int
|
||||||
|
var syncErrors int
|
||||||
|
|
||||||
|
// Rebuild observations
|
||||||
|
observations, err := observationStore.GetAllObservations(s.ctx)
|
||||||
|
if err != nil {
|
||||||
|
log.Error().Err(err).Msg("Failed to fetch observations for vector rebuild")
|
||||||
|
} else {
|
||||||
|
for _, obs := range observations {
|
||||||
|
if err := vectorSync.SyncObservation(s.ctx, obs); err != nil {
|
||||||
|
log.Warn().Err(err).Int64("id", obs.ID).Msg("Failed to sync observation during rebuild")
|
||||||
|
syncErrors++
|
||||||
|
} else {
|
||||||
|
totalSynced++
|
||||||
|
}
|
||||||
|
}
|
||||||
|
log.Info().Int("count", len(observations)).Msg("Rebuilt observation vectors")
|
||||||
|
}
|
||||||
|
|
||||||
|
// Rebuild summaries
|
||||||
|
summaries, err := summaryStore.GetAllSummaries(s.ctx)
|
||||||
|
if err != nil {
|
||||||
|
log.Error().Err(err).Msg("Failed to fetch summaries for vector rebuild")
|
||||||
|
} else {
|
||||||
|
for _, summary := range summaries {
|
||||||
|
if err := vectorSync.SyncSummary(s.ctx, summary); err != nil {
|
||||||
|
log.Warn().Err(err).Int64("id", summary.ID).Msg("Failed to sync summary during rebuild")
|
||||||
|
syncErrors++
|
||||||
|
} else {
|
||||||
|
totalSynced++
|
||||||
|
}
|
||||||
|
}
|
||||||
|
log.Info().Int("count", len(summaries)).Msg("Rebuilt summary vectors")
|
||||||
|
}
|
||||||
|
|
||||||
|
// Rebuild user prompts
|
||||||
|
prompts, err := promptStore.GetAllPrompts(s.ctx)
|
||||||
|
if err != nil {
|
||||||
|
log.Error().Err(err).Msg("Failed to fetch prompts for vector rebuild")
|
||||||
|
} else {
|
||||||
|
for _, prompt := range prompts {
|
||||||
|
if err := vectorSync.SyncUserPrompt(s.ctx, prompt); err != nil {
|
||||||
|
log.Warn().Err(err).Int64("id", prompt.ID).Msg("Failed to sync prompt during rebuild")
|
||||||
|
syncErrors++
|
||||||
|
} else {
|
||||||
|
totalSynced++
|
||||||
|
}
|
||||||
|
}
|
||||||
|
log.Info().Int("count", len(prompts)).Msg("Rebuilt prompt vectors")
|
||||||
|
}
|
||||||
|
|
||||||
|
elapsed := time.Since(start)
|
||||||
|
log.Info().
|
||||||
|
Int("total_synced", totalSynced).
|
||||||
|
Int("errors", syncErrors).
|
||||||
|
Dur("elapsed", elapsed).
|
||||||
|
Msg("Full vector rebuild complete")
|
||||||
|
}
|
||||||
|
|
||||||
|
// rebuildStaleVectors rebuilds only vectors with mismatched or unknown model versions.
|
||||||
|
// This is more efficient than rebuilding all vectors when only some need updating.
|
||||||
|
func (s *Service) rebuildStaleVectors(
|
||||||
|
observationStore *sqlite.ObservationStore,
|
||||||
|
summaryStore *sqlite.SummaryStore,
|
||||||
|
promptStore *sqlite.PromptStore,
|
||||||
|
vectorClient *sqlitevec.Client,
|
||||||
|
vectorSync *sqlitevec.Sync,
|
||||||
|
) {
|
||||||
|
defer s.wg.Done()
|
||||||
|
|
||||||
|
log.Info().Msg("Starting granular vector rebuild for stale vectors...")
|
||||||
|
start := time.Now()
|
||||||
|
|
||||||
|
// Get all stale vectors
|
||||||
|
staleVectors, err := vectorClient.GetStaleVectors(s.ctx)
|
||||||
|
if err != nil {
|
||||||
|
log.Error().Err(err).Msg("Failed to get stale vectors")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
if len(staleVectors) == 0 {
|
||||||
|
log.Info().Msg("No stale vectors found")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
log.Info().Int("stale_count", len(staleVectors)).Msg("Found stale vectors to rebuild")
|
||||||
|
|
||||||
|
// Group stale vectors by doc_type and sqlite_id for efficient lookup
|
||||||
|
staleObsIDs := make(map[int64]bool)
|
||||||
|
staleSummaryIDs := make(map[int64]bool)
|
||||||
|
stalePromptIDs := make(map[int64]bool)
|
||||||
|
staleDocIDs := make([]string, 0, len(staleVectors))
|
||||||
|
|
||||||
|
for _, sv := range staleVectors {
|
||||||
|
staleDocIDs = append(staleDocIDs, sv.DocID)
|
||||||
|
switch sv.DocType {
|
||||||
|
case "observation":
|
||||||
|
staleObsIDs[sv.SQLiteID] = true
|
||||||
|
case "summary":
|
||||||
|
staleSummaryIDs[sv.SQLiteID] = true
|
||||||
|
case "prompt":
|
||||||
|
stalePromptIDs[sv.SQLiteID] = true
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Delete stale vectors before re-syncing
|
||||||
|
if err := vectorClient.DeleteVectorsByDocIDs(s.ctx, staleDocIDs); err != nil {
|
||||||
|
log.Error().Err(err).Msg("Failed to delete stale vectors")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
var totalSynced int
|
||||||
|
var syncErrors int
|
||||||
|
|
||||||
|
// Rebuild stale observations
|
||||||
|
if len(staleObsIDs) > 0 {
|
||||||
|
ids := make([]int64, 0, len(staleObsIDs))
|
||||||
|
for id := range staleObsIDs {
|
||||||
|
ids = append(ids, id)
|
||||||
|
}
|
||||||
|
|
||||||
|
observations, err := observationStore.GetObservationsByIDs(s.ctx, ids, "date_desc", 0)
|
||||||
|
if err != nil {
|
||||||
|
log.Error().Err(err).Msg("Failed to fetch observations for rebuild")
|
||||||
|
} else {
|
||||||
|
for _, obs := range observations {
|
||||||
|
if err := vectorSync.SyncObservation(s.ctx, obs); err != nil {
|
||||||
|
log.Warn().Err(err).Int64("id", obs.ID).Msg("Failed to sync observation during rebuild")
|
||||||
|
syncErrors++
|
||||||
|
} else {
|
||||||
|
totalSynced++
|
||||||
|
}
|
||||||
|
}
|
||||||
|
log.Info().Int("count", len(observations)).Msg("Rebuilt stale observation vectors")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Rebuild stale summaries
|
||||||
|
if len(staleSummaryIDs) > 0 {
|
||||||
|
ids := make([]int64, 0, len(staleSummaryIDs))
|
||||||
|
for id := range staleSummaryIDs {
|
||||||
|
ids = append(ids, id)
|
||||||
|
}
|
||||||
|
|
||||||
|
summaries, err := summaryStore.GetSummariesByIDs(s.ctx, ids, "date_desc", 0)
|
||||||
|
if err != nil {
|
||||||
|
log.Error().Err(err).Msg("Failed to fetch summaries for rebuild")
|
||||||
|
} else {
|
||||||
|
for _, summary := range summaries {
|
||||||
|
if err := vectorSync.SyncSummary(s.ctx, summary); err != nil {
|
||||||
|
log.Warn().Err(err).Int64("id", summary.ID).Msg("Failed to sync summary during rebuild")
|
||||||
|
syncErrors++
|
||||||
|
} else {
|
||||||
|
totalSynced++
|
||||||
|
}
|
||||||
|
}
|
||||||
|
log.Info().Int("count", len(summaries)).Msg("Rebuilt stale summary vectors")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Rebuild stale prompts
|
||||||
|
if len(stalePromptIDs) > 0 {
|
||||||
|
ids := make([]int64, 0, len(stalePromptIDs))
|
||||||
|
for id := range stalePromptIDs {
|
||||||
|
ids = append(ids, id)
|
||||||
|
}
|
||||||
|
|
||||||
|
prompts, err := promptStore.GetPromptsByIDs(s.ctx, ids, "date_desc", 0)
|
||||||
|
if err != nil {
|
||||||
|
log.Error().Err(err).Msg("Failed to fetch prompts for rebuild")
|
||||||
|
} else {
|
||||||
|
for _, prompt := range prompts {
|
||||||
|
if err := vectorSync.SyncUserPrompt(s.ctx, prompt); err != nil {
|
||||||
|
log.Warn().Err(err).Int64("id", prompt.ID).Msg("Failed to sync prompt during rebuild")
|
||||||
|
syncErrors++
|
||||||
|
} else {
|
||||||
|
totalSynced++
|
||||||
|
}
|
||||||
|
}
|
||||||
|
log.Info().Int("count", len(prompts)).Msg("Rebuilt stale prompt vectors")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
elapsed := time.Since(start)
|
||||||
|
log.Info().
|
||||||
|
Int("total_synced", totalSynced).
|
||||||
|
Int("errors", syncErrors).
|
||||||
|
Dur("elapsed", elapsed).
|
||||||
|
Msg("Granular vector rebuild complete")
|
||||||
|
}
|
||||||
|
|
||||||
// verifyStaleObservation verifies a single stale observation in the background.
|
// verifyStaleObservation verifies a single stale observation in the background.
|
||||||
func (s *Service) verifyStaleObservation(req staleVerifyRequest) {
|
func (s *Service) verifyStaleObservation(req staleVerifyRequest) {
|
||||||
// Wait for service to be ready
|
// Wait for service to be ready
|
||||||
@@ -667,11 +1108,42 @@ func (s *Service) setupRoutes() {
|
|||||||
r.Get("/api/stats", s.handleGetStats)
|
r.Get("/api/stats", s.handleGetStats)
|
||||||
r.Get("/api/stats/retrieval", s.handleGetRetrievalStats)
|
r.Get("/api/stats/retrieval", s.handleGetRetrievalStats)
|
||||||
r.Get("/api/types", s.handleGetTypes)
|
r.Get("/api/types", s.handleGetTypes)
|
||||||
|
r.Get("/api/models", s.handleGetModels)
|
||||||
|
|
||||||
|
// Observation scoring and feedback routes
|
||||||
|
r.Post("/api/observations/{id}/feedback", s.handleObservationFeedback)
|
||||||
|
r.Get("/api/observations/{id}/score", s.handleExplainScore)
|
||||||
|
r.Get("/api/observations/top", s.handleGetTopObservations)
|
||||||
|
r.Get("/api/observations/most-retrieved", s.handleGetMostRetrieved)
|
||||||
|
|
||||||
|
// Scoring configuration routes
|
||||||
|
r.Get("/api/scoring/stats", s.handleGetScoringStats)
|
||||||
|
r.Get("/api/scoring/concepts", s.handleGetConceptWeights)
|
||||||
|
r.Put("/api/scoring/concepts/{concept}", s.handleUpdateConceptWeight)
|
||||||
|
r.Post("/api/scoring/recalculate", s.handleTriggerRecalculation)
|
||||||
|
|
||||||
// Context injection
|
// Context injection
|
||||||
r.Get("/api/context/count", s.handleContextCount)
|
r.Get("/api/context/count", s.handleContextCount)
|
||||||
r.Get("/api/context/inject", s.handleContextInject)
|
r.Get("/api/context/inject", s.handleContextInject)
|
||||||
r.Get("/api/context/search", s.handleSearchByPrompt)
|
r.Get("/api/context/search", s.handleSearchByPrompt)
|
||||||
|
|
||||||
|
// Pattern routes
|
||||||
|
r.Get("/api/patterns", s.handleGetPatterns)
|
||||||
|
r.Get("/api/patterns/stats", s.handleGetPatternStats)
|
||||||
|
r.Get("/api/patterns/search", s.handleSearchPatterns)
|
||||||
|
r.Get("/api/patterns/by-name", s.handleGetPatternByName)
|
||||||
|
r.Get("/api/patterns/{id}", s.handleGetPatternByID)
|
||||||
|
r.Get("/api/patterns/{id}/insight", s.handleGetPatternInsight)
|
||||||
|
r.Delete("/api/patterns/{id}", s.handleDeletePattern)
|
||||||
|
r.Post("/api/patterns/{id}/deprecate", s.handleDeprecatePattern)
|
||||||
|
r.Post("/api/patterns/merge", s.handleMergePatterns)
|
||||||
|
|
||||||
|
// Relation routes (knowledge graph)
|
||||||
|
r.Get("/api/relations/stats", s.handleGetRelationStats)
|
||||||
|
r.Get("/api/relations/type/{type}", s.handleGetRelationsByType)
|
||||||
|
r.Get("/api/observations/{id}/relations", s.handleGetRelations)
|
||||||
|
r.Get("/api/observations/{id}/graph", s.handleGetRelationGraph)
|
||||||
|
r.Get("/api/observations/{id}/related", s.handleGetRelatedObservations)
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -894,6 +1366,16 @@ func (s *Service) Shutdown(ctx context.Context) error {
|
|||||||
_ = s.configWatcher.Stop()
|
_ = s.configWatcher.Stop()
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Stop background recalculator
|
||||||
|
if s.recalculator != nil {
|
||||||
|
s.recalculator.Stop()
|
||||||
|
}
|
||||||
|
|
||||||
|
// Stop pattern detector
|
||||||
|
if s.patternDetector != nil {
|
||||||
|
s.patternDetector.Stop()
|
||||||
|
}
|
||||||
|
|
||||||
// Shutdown all sessions
|
// Shutdown all sessions
|
||||||
s.sessionManager.ShutdownAll(ctx)
|
s.sessionManager.ShutdownAll(ctx)
|
||||||
|
|
||||||
@@ -904,6 +1386,13 @@ func (s *Service) Shutdown(ctx context.Context) error {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Close reranking service
|
||||||
|
if s.reranker != nil {
|
||||||
|
if err := s.reranker.Close(); err != nil {
|
||||||
|
log.Error().Err(err).Msg("Reranking service close error")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
// Close embedding service
|
// Close embedding service
|
||||||
if s.embedSvc != nil {
|
if s.embedSvc != nil {
|
||||||
if err := s.embedSvc.Close(); err != nil {
|
if err := s.embedSvc.Close(); err != nil {
|
||||||
|
|||||||
@@ -0,0 +1,258 @@
|
|||||||
|
// Package models contains domain models for claude-mnemonic.
|
||||||
|
package models
|
||||||
|
|
||||||
|
import (
|
||||||
|
"regexp"
|
||||||
|
"strings"
|
||||||
|
"time"
|
||||||
|
)
|
||||||
|
|
||||||
|
// ConflictType represents the type of conflict between observations.
|
||||||
|
type ConflictType string
|
||||||
|
|
||||||
|
const (
|
||||||
|
// ConflictSuperseded means newer observation supersedes older one (same topic, updated info).
|
||||||
|
ConflictSuperseded ConflictType = "superseded"
|
||||||
|
// ConflictContradicts means observations contain contradictory information.
|
||||||
|
ConflictContradicts ConflictType = "contradicts"
|
||||||
|
// ConflictOutdatedPattern means an outdated pattern/practice was identified.
|
||||||
|
ConflictOutdatedPattern ConflictType = "outdated_pattern"
|
||||||
|
)
|
||||||
|
|
||||||
|
// ConflictResolution indicates which observation to prefer.
|
||||||
|
type ConflictResolution string
|
||||||
|
|
||||||
|
const (
|
||||||
|
// ResolutionPreferNewer means prefer the newer observation.
|
||||||
|
ResolutionPreferNewer ConflictResolution = "prefer_newer"
|
||||||
|
// ResolutionPreferOlder means prefer the older observation (rare).
|
||||||
|
ResolutionPreferOlder ConflictResolution = "prefer_older"
|
||||||
|
// ResolutionManual means manual review is needed.
|
||||||
|
ResolutionManual ConflictResolution = "manual"
|
||||||
|
)
|
||||||
|
|
||||||
|
// ObservationConflict tracks conflicting observations.
|
||||||
|
type ObservationConflict struct {
|
||||||
|
ID int64 `db:"id" json:"id"`
|
||||||
|
NewerObsID int64 `db:"newer_obs_id" json:"newer_obs_id"`
|
||||||
|
OlderObsID int64 `db:"older_obs_id" json:"older_obs_id"`
|
||||||
|
ConflictType ConflictType `db:"conflict_type" json:"conflict_type"`
|
||||||
|
Resolution ConflictResolution `db:"resolution" json:"resolution"`
|
||||||
|
Reason string `db:"reason" json:"reason"`
|
||||||
|
DetectedAt string `db:"detected_at" json:"detected_at"`
|
||||||
|
DetectedAtEpoch int64 `db:"detected_at_epoch" json:"detected_at_epoch"`
|
||||||
|
Resolved bool `db:"resolved" json:"resolved"`
|
||||||
|
ResolvedAt *string `db:"resolved_at" json:"resolved_at,omitempty"`
|
||||||
|
}
|
||||||
|
|
||||||
|
// ConflictDetectionResult contains the result of conflict detection.
|
||||||
|
type ConflictDetectionResult struct {
|
||||||
|
HasConflict bool
|
||||||
|
Type ConflictType
|
||||||
|
Resolution ConflictResolution
|
||||||
|
Reason string
|
||||||
|
OlderObsIDs []int64 // IDs of observations that conflict with the new one
|
||||||
|
}
|
||||||
|
|
||||||
|
// NewObservationConflict creates a new conflict record.
|
||||||
|
func NewObservationConflict(newerID, olderID int64, conflictType ConflictType, resolution ConflictResolution, reason string) *ObservationConflict {
|
||||||
|
now := time.Now()
|
||||||
|
return &ObservationConflict{
|
||||||
|
NewerObsID: newerID,
|
||||||
|
OlderObsID: olderID,
|
||||||
|
ConflictType: conflictType,
|
||||||
|
Resolution: resolution,
|
||||||
|
Reason: reason,
|
||||||
|
DetectedAt: now.Format(time.RFC3339),
|
||||||
|
DetectedAtEpoch: now.UnixMilli(),
|
||||||
|
Resolved: false,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// CorrectionPatterns contains regex patterns that indicate explicit corrections.
|
||||||
|
var CorrectionPatterns = []*regexp.Regexp{
|
||||||
|
regexp.MustCompile(`(?i)\bactually[,\s]+that\s+was\s+wrong\b`),
|
||||||
|
regexp.MustCompile(`(?i)\bactually[,\s]+that's\s+(wrong|incorrect|not\s+right)\b`),
|
||||||
|
regexp.MustCompile(`(?i)\bpreviously\s+(said|mentioned|noted)\s+.*\s+but\b`),
|
||||||
|
regexp.MustCompile(`(?i)\bcorrection:\s*`),
|
||||||
|
regexp.MustCompile(`(?i)\bignore\s+(the\s+)?(previous|earlier)\b`),
|
||||||
|
regexp.MustCompile(`(?i)\bdisregard\s+(the\s+)?(previous|earlier)\b`),
|
||||||
|
regexp.MustCompile(`(?i)\bwas\s+(wrong|incorrect|mistaken)\b`),
|
||||||
|
regexp.MustCompile(`(?i)\bturns\s+out\s+.*(wrong|incorrect|not\s+the\s+case)\b`),
|
||||||
|
regexp.MustCompile(`(?i)\b(supersedes|replaces|overrides)\s+(the\s+)?(previous|earlier|old)\b`),
|
||||||
|
regexp.MustCompile(`(?i)\b(don't|do\s+not)\s+use\s+.*\s+anymore\b`),
|
||||||
|
regexp.MustCompile(`(?i)\bno\s+longer\s+(valid|applicable|correct|recommended)\b`),
|
||||||
|
regexp.MustCompile(`(?i)\bdeprecated\s+(approach|method|pattern|way)\b`),
|
||||||
|
regexp.MustCompile(`(?i)\bshould\s+have\s+(been|used)\b.*instead\b`),
|
||||||
|
regexp.MustCompile(`(?i)\bbetter\s+(approach|way|method|solution)\s+is\b`),
|
||||||
|
}
|
||||||
|
|
||||||
|
// OpposingChangePatterns detects add/remove conflicts.
|
||||||
|
var OpposingChangePatterns = map[string]string{
|
||||||
|
"add": "remove",
|
||||||
|
"added": "removed",
|
||||||
|
"create": "delete",
|
||||||
|
"created": "deleted",
|
||||||
|
"enable": "disable",
|
||||||
|
"enabled": "disabled",
|
||||||
|
"include": "exclude",
|
||||||
|
"allow": "deny",
|
||||||
|
"permit": "block",
|
||||||
|
}
|
||||||
|
|
||||||
|
// DetectExplicitCorrection checks if text contains explicit correction language.
|
||||||
|
func DetectExplicitCorrection(text string) (bool, string) {
|
||||||
|
for _, pattern := range CorrectionPatterns {
|
||||||
|
if match := pattern.FindString(text); match != "" {
|
||||||
|
return true, "Explicit correction detected: " + match
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return false, ""
|
||||||
|
}
|
||||||
|
|
||||||
|
// DetectOpposingFileChanges checks if two observations have opposing changes on the same file.
|
||||||
|
func DetectOpposingFileChanges(newer, older *Observation) (bool, string) {
|
||||||
|
// Check for overlapping modified files
|
||||||
|
newerFiles := make(map[string]bool)
|
||||||
|
for _, f := range newer.FilesModified {
|
||||||
|
newerFiles[f] = true
|
||||||
|
}
|
||||||
|
|
||||||
|
var overlappingFiles []string
|
||||||
|
for _, f := range older.FilesModified {
|
||||||
|
if newerFiles[f] {
|
||||||
|
overlappingFiles = append(overlappingFiles, f)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if len(overlappingFiles) == 0 {
|
||||||
|
return false, ""
|
||||||
|
}
|
||||||
|
|
||||||
|
// Check for opposing action words in titles/narratives
|
||||||
|
newerText := strings.ToLower(newer.Title.String + " " + newer.Narrative.String)
|
||||||
|
olderText := strings.ToLower(older.Title.String + " " + older.Narrative.String)
|
||||||
|
|
||||||
|
for action, opposite := range OpposingChangePatterns {
|
||||||
|
if (strings.Contains(newerText, action) && strings.Contains(olderText, opposite)) ||
|
||||||
|
(strings.Contains(newerText, opposite) && strings.Contains(olderText, action)) {
|
||||||
|
return true, "Opposing changes on files: " + strings.Join(overlappingFiles, ", ")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return false, ""
|
||||||
|
}
|
||||||
|
|
||||||
|
// DetectConceptTagMismatch checks if observations have same concepts but different recommendations.
|
||||||
|
func DetectConceptTagMismatch(newer, older *Observation) (bool, string) {
|
||||||
|
// Find overlapping concepts
|
||||||
|
newerConcepts := make(map[string]bool)
|
||||||
|
for _, c := range newer.Concepts {
|
||||||
|
newerConcepts[c] = true
|
||||||
|
}
|
||||||
|
|
||||||
|
var overlapping []string
|
||||||
|
for _, c := range older.Concepts {
|
||||||
|
if newerConcepts[c] {
|
||||||
|
overlapping = append(overlapping, c)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if len(overlapping) == 0 {
|
||||||
|
return false, ""
|
||||||
|
}
|
||||||
|
|
||||||
|
// Check if same file was modified and concepts overlap
|
||||||
|
// This suggests the newer observation may update the approach
|
||||||
|
newerFiles := make(map[string]bool)
|
||||||
|
for _, f := range newer.FilesModified {
|
||||||
|
newerFiles[f] = true
|
||||||
|
}
|
||||||
|
for _, f := range older.FilesModified {
|
||||||
|
if newerFiles[f] {
|
||||||
|
// Same file modified with same concepts - likely an update
|
||||||
|
return true, "Same concepts (" + strings.Join(overlapping, ", ") + ") with overlapping file changes"
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return false, ""
|
||||||
|
}
|
||||||
|
|
||||||
|
// DetectConflict performs comprehensive conflict detection between a new observation
|
||||||
|
// and an existing one. Returns detection result.
|
||||||
|
func DetectConflict(newer, older *Observation) *ConflictDetectionResult {
|
||||||
|
result := &ConflictDetectionResult{
|
||||||
|
HasConflict: false,
|
||||||
|
}
|
||||||
|
|
||||||
|
// 1. Check for explicit correction language in newer observation
|
||||||
|
if newer.Narrative.Valid {
|
||||||
|
if isCorrection, reason := DetectExplicitCorrection(newer.Narrative.String); isCorrection {
|
||||||
|
result.HasConflict = true
|
||||||
|
result.Type = ConflictContradicts
|
||||||
|
result.Resolution = ResolutionPreferNewer
|
||||||
|
result.Reason = reason
|
||||||
|
result.OlderObsIDs = append(result.OlderObsIDs, older.ID)
|
||||||
|
return result
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Check title as well
|
||||||
|
if newer.Title.Valid {
|
||||||
|
if isCorrection, reason := DetectExplicitCorrection(newer.Title.String); isCorrection {
|
||||||
|
result.HasConflict = true
|
||||||
|
result.Type = ConflictContradicts
|
||||||
|
result.Resolution = ResolutionPreferNewer
|
||||||
|
result.Reason = reason
|
||||||
|
result.OlderObsIDs = append(result.OlderObsIDs, older.ID)
|
||||||
|
return result
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// 2. Check for opposing file changes
|
||||||
|
if isOpposing, reason := DetectOpposingFileChanges(newer, older); isOpposing {
|
||||||
|
result.HasConflict = true
|
||||||
|
result.Type = ConflictSuperseded
|
||||||
|
result.Resolution = ResolutionPreferNewer
|
||||||
|
result.Reason = reason
|
||||||
|
result.OlderObsIDs = append(result.OlderObsIDs, older.ID)
|
||||||
|
return result
|
||||||
|
}
|
||||||
|
|
||||||
|
// 3. Check for concept tag mismatches with same files
|
||||||
|
if isMismatch, reason := DetectConceptTagMismatch(newer, older); isMismatch {
|
||||||
|
result.HasConflict = true
|
||||||
|
result.Type = ConflictSuperseded
|
||||||
|
result.Resolution = ResolutionPreferNewer
|
||||||
|
result.Reason = reason
|
||||||
|
result.OlderObsIDs = append(result.OlderObsIDs, older.ID)
|
||||||
|
return result
|
||||||
|
}
|
||||||
|
|
||||||
|
return result
|
||||||
|
}
|
||||||
|
|
||||||
|
// DetectConflictsWithExisting checks a new observation against a list of existing observations.
|
||||||
|
// Returns all detected conflicts.
|
||||||
|
func DetectConflictsWithExisting(newer *Observation, existing []*Observation) []*ConflictDetectionResult {
|
||||||
|
var results []*ConflictDetectionResult
|
||||||
|
|
||||||
|
for _, older := range existing {
|
||||||
|
// Skip self-comparison
|
||||||
|
if older.ID == newer.ID {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
// Only compare within same project (or both global)
|
||||||
|
if newer.Project != older.Project && newer.Scope != ScopeGlobal && older.Scope != ScopeGlobal {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
result := DetectConflict(newer, older)
|
||||||
|
if result.HasConflict {
|
||||||
|
results = append(results, result)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return results
|
||||||
|
}
|
||||||
@@ -0,0 +1,421 @@
|
|||||||
|
// Package models contains domain models for claude-mnemonic.
|
||||||
|
package models
|
||||||
|
|
||||||
|
import (
|
||||||
|
"database/sql"
|
||||||
|
"testing"
|
||||||
|
|
||||||
|
"github.com/stretchr/testify/assert"
|
||||||
|
"github.com/stretchr/testify/suite"
|
||||||
|
)
|
||||||
|
|
||||||
|
// ConflictSuite is a test suite for conflict detection operations.
|
||||||
|
type ConflictSuite struct {
|
||||||
|
suite.Suite
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestConflictSuite(t *testing.T) {
|
||||||
|
suite.Run(t, new(ConflictSuite))
|
||||||
|
}
|
||||||
|
|
||||||
|
// TestConflictTypeConstants tests conflict type constants.
|
||||||
|
func (s *ConflictSuite) TestConflictTypeConstants() {
|
||||||
|
s.Equal(ConflictType("superseded"), ConflictSuperseded)
|
||||||
|
s.Equal(ConflictType("contradicts"), ConflictContradicts)
|
||||||
|
s.Equal(ConflictType("outdated_pattern"), ConflictOutdatedPattern)
|
||||||
|
}
|
||||||
|
|
||||||
|
// TestResolutionConstants tests resolution constants.
|
||||||
|
func (s *ConflictSuite) TestResolutionConstants() {
|
||||||
|
s.Equal(ConflictResolution("prefer_newer"), ResolutionPreferNewer)
|
||||||
|
s.Equal(ConflictResolution("prefer_older"), ResolutionPreferOlder)
|
||||||
|
s.Equal(ConflictResolution("manual"), ResolutionManual)
|
||||||
|
}
|
||||||
|
|
||||||
|
// TestNewObservationConflict tests conflict creation.
|
||||||
|
func (s *ConflictSuite) TestNewObservationConflict() {
|
||||||
|
conflict := NewObservationConflict(2, 1, ConflictSuperseded, ResolutionPreferNewer, "Test reason")
|
||||||
|
|
||||||
|
s.Equal(int64(2), conflict.NewerObsID)
|
||||||
|
s.Equal(int64(1), conflict.OlderObsID)
|
||||||
|
s.Equal(ConflictSuperseded, conflict.ConflictType)
|
||||||
|
s.Equal(ResolutionPreferNewer, conflict.Resolution)
|
||||||
|
s.Equal("Test reason", conflict.Reason)
|
||||||
|
s.False(conflict.Resolved)
|
||||||
|
s.NotEmpty(conflict.DetectedAt)
|
||||||
|
s.Greater(conflict.DetectedAtEpoch, int64(0))
|
||||||
|
}
|
||||||
|
|
||||||
|
// TestDetectExplicitCorrection_TableDriven tests explicit correction detection.
|
||||||
|
func (s *ConflictSuite) TestDetectExplicitCorrection_TableDriven() {
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
text string
|
||||||
|
expectMatch bool
|
||||||
|
expectPattern string
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
name: "actually that was wrong",
|
||||||
|
text: "Actually, that was wrong - we should use a different approach",
|
||||||
|
expectMatch: true,
|
||||||
|
expectPattern: "actually, that was wrong",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "correction prefix",
|
||||||
|
text: "Correction: the previous implementation had a bug",
|
||||||
|
expectMatch: true,
|
||||||
|
expectPattern: "correction:",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "ignore previous",
|
||||||
|
text: "Please ignore the previous recommendation",
|
||||||
|
expectMatch: true,
|
||||||
|
expectPattern: "ignore",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "disregard earlier",
|
||||||
|
text: "Disregard the earlier suggestion, it was flawed",
|
||||||
|
expectMatch: true,
|
||||||
|
expectPattern: "disregard",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "was wrong",
|
||||||
|
text: "The original approach was wrong",
|
||||||
|
expectMatch: true,
|
||||||
|
expectPattern: "was wrong",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "no longer valid",
|
||||||
|
text: "This method is no longer valid after the refactor",
|
||||||
|
expectMatch: true,
|
||||||
|
expectPattern: "no longer valid",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "deprecated approach",
|
||||||
|
text: "This is a deprecated approach that should not be used",
|
||||||
|
expectMatch: true,
|
||||||
|
expectPattern: "deprecated approach",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "better approach is",
|
||||||
|
text: "A better approach is to use the new API",
|
||||||
|
expectMatch: true,
|
||||||
|
expectPattern: "better approach is",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "normal text - no correction",
|
||||||
|
text: "This is a normal observation about the code",
|
||||||
|
expectMatch: false,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "empty text",
|
||||||
|
text: "",
|
||||||
|
expectMatch: false,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tt := range tests {
|
||||||
|
s.Run(tt.name, func() {
|
||||||
|
found, reason := DetectExplicitCorrection(tt.text)
|
||||||
|
s.Equal(tt.expectMatch, found)
|
||||||
|
if tt.expectMatch {
|
||||||
|
s.Contains(reason, "Explicit correction detected")
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// TestDetectOpposingFileChanges_TableDriven tests opposing file change detection.
|
||||||
|
func (s *ConflictSuite) TestDetectOpposingFileChanges_TableDriven() {
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
newerObs *Observation
|
||||||
|
olderObs *Observation
|
||||||
|
expectConflict bool
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
name: "add then remove - conflict",
|
||||||
|
newerObs: &Observation{
|
||||||
|
Title: sql.NullString{String: "Remove authentication middleware", Valid: true},
|
||||||
|
Narrative: sql.NullString{String: "Removed the auth middleware from handlers", Valid: true},
|
||||||
|
FilesModified: []string{"middleware.go", "handler.go"},
|
||||||
|
},
|
||||||
|
olderObs: &Observation{
|
||||||
|
Title: sql.NullString{String: "Add authentication middleware", Valid: true},
|
||||||
|
Narrative: sql.NullString{String: "Added auth middleware to secure endpoints", Valid: true},
|
||||||
|
FilesModified: []string{"middleware.go", "handler.go"},
|
||||||
|
},
|
||||||
|
expectConflict: true,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "enable then disable - conflict",
|
||||||
|
newerObs: &Observation{
|
||||||
|
Title: sql.NullString{String: "Disable caching feature", Valid: true},
|
||||||
|
Narrative: sql.NullString{String: "Disabled caching due to issues", Valid: true},
|
||||||
|
FilesModified: []string{"cache.go"},
|
||||||
|
},
|
||||||
|
olderObs: &Observation{
|
||||||
|
Title: sql.NullString{String: "Enable caching feature", Valid: true},
|
||||||
|
Narrative: sql.NullString{String: "Enabled caching for performance", Valid: true},
|
||||||
|
FilesModified: []string{"cache.go"},
|
||||||
|
},
|
||||||
|
expectConflict: true,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "different files - no conflict",
|
||||||
|
newerObs: &Observation{
|
||||||
|
Title: sql.NullString{String: "Remove old code", Valid: true},
|
||||||
|
Narrative: sql.NullString{String: "Removed deprecated functions", Valid: true},
|
||||||
|
FilesModified: []string{"old.go"},
|
||||||
|
},
|
||||||
|
olderObs: &Observation{
|
||||||
|
Title: sql.NullString{String: "Add new feature", Valid: true},
|
||||||
|
Narrative: sql.NullString{String: "Added new functions", Valid: true},
|
||||||
|
FilesModified: []string{"new.go"},
|
||||||
|
},
|
||||||
|
expectConflict: false,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "same files but no opposing keywords - no conflict",
|
||||||
|
newerObs: &Observation{
|
||||||
|
Title: sql.NullString{String: "Update handler logic", Valid: true},
|
||||||
|
Narrative: sql.NullString{String: "Updated the handler implementation", Valid: true},
|
||||||
|
FilesModified: []string{"handler.go"},
|
||||||
|
},
|
||||||
|
olderObs: &Observation{
|
||||||
|
Title: sql.NullString{String: "Fix handler bug", Valid: true},
|
||||||
|
Narrative: sql.NullString{String: "Fixed a bug in handler", Valid: true},
|
||||||
|
FilesModified: []string{"handler.go"},
|
||||||
|
},
|
||||||
|
expectConflict: false,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tt := range tests {
|
||||||
|
s.Run(tt.name, func() {
|
||||||
|
found, _ := DetectOpposingFileChanges(tt.newerObs, tt.olderObs)
|
||||||
|
s.Equal(tt.expectConflict, found)
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// TestDetectConceptTagMismatch_TableDriven tests concept tag mismatch detection.
|
||||||
|
func (s *ConflictSuite) TestDetectConceptTagMismatch_TableDriven() {
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
newerObs *Observation
|
||||||
|
olderObs *Observation
|
||||||
|
expectConflict bool
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
name: "same concepts and files - conflict",
|
||||||
|
newerObs: &Observation{
|
||||||
|
Concepts: []string{"security", "authentication"},
|
||||||
|
FilesModified: []string{"auth.go"},
|
||||||
|
},
|
||||||
|
olderObs: &Observation{
|
||||||
|
Concepts: []string{"security", "authentication"},
|
||||||
|
FilesModified: []string{"auth.go"},
|
||||||
|
},
|
||||||
|
expectConflict: true,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "overlapping concepts different files - no conflict",
|
||||||
|
newerObs: &Observation{
|
||||||
|
Concepts: []string{"security"},
|
||||||
|
FilesModified: []string{"new_auth.go"},
|
||||||
|
},
|
||||||
|
olderObs: &Observation{
|
||||||
|
Concepts: []string{"security"},
|
||||||
|
FilesModified: []string{"old_auth.go"},
|
||||||
|
},
|
||||||
|
expectConflict: false,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "different concepts same files - no conflict",
|
||||||
|
newerObs: &Observation{
|
||||||
|
Concepts: []string{"performance"},
|
||||||
|
FilesModified: []string{"handler.go"},
|
||||||
|
},
|
||||||
|
olderObs: &Observation{
|
||||||
|
Concepts: []string{"testing"},
|
||||||
|
FilesModified: []string{"handler.go"},
|
||||||
|
},
|
||||||
|
expectConflict: false,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "no overlapping concepts or files - no conflict",
|
||||||
|
newerObs: &Observation{
|
||||||
|
Concepts: []string{"security"},
|
||||||
|
FilesModified: []string{"auth.go"},
|
||||||
|
},
|
||||||
|
olderObs: &Observation{
|
||||||
|
Concepts: []string{"testing"},
|
||||||
|
FilesModified: []string{"test.go"},
|
||||||
|
},
|
||||||
|
expectConflict: false,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tt := range tests {
|
||||||
|
s.Run(tt.name, func() {
|
||||||
|
found, _ := DetectConceptTagMismatch(tt.newerObs, tt.olderObs)
|
||||||
|
s.Equal(tt.expectConflict, found)
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// TestDetectConflict tests comprehensive conflict detection.
|
||||||
|
func (s *ConflictSuite) TestDetectConflict() {
|
||||||
|
// Test explicit correction takes precedence
|
||||||
|
newer := &Observation{
|
||||||
|
ID: 2,
|
||||||
|
Project: "test",
|
||||||
|
Narrative: sql.NullString{String: "Actually, that was wrong. We should use a different approach.", Valid: true},
|
||||||
|
}
|
||||||
|
older := &Observation{
|
||||||
|
ID: 1,
|
||||||
|
Project: "test",
|
||||||
|
}
|
||||||
|
|
||||||
|
result := DetectConflict(newer, older)
|
||||||
|
s.True(result.HasConflict)
|
||||||
|
s.Equal(ConflictContradicts, result.Type)
|
||||||
|
s.Equal(ResolutionPreferNewer, result.Resolution)
|
||||||
|
s.Contains(result.Reason, "Explicit correction")
|
||||||
|
}
|
||||||
|
|
||||||
|
// TestDetectConflictsWithExisting tests conflict detection against multiple observations.
|
||||||
|
func (s *ConflictSuite) TestDetectConflictsWithExisting() {
|
||||||
|
newer := &Observation{
|
||||||
|
ID: 3,
|
||||||
|
Project: "test",
|
||||||
|
Narrative: sql.NullString{String: "Actually, that was wrong", Valid: true},
|
||||||
|
Concepts: []string{"security"},
|
||||||
|
FilesModified: []string{"auth.go"},
|
||||||
|
}
|
||||||
|
|
||||||
|
existing := []*Observation{
|
||||||
|
{
|
||||||
|
ID: 1,
|
||||||
|
Project: "test",
|
||||||
|
Concepts: []string{"security"},
|
||||||
|
FilesModified: []string{"auth.go"},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
ID: 2,
|
||||||
|
Project: "test",
|
||||||
|
Concepts: []string{"testing"},
|
||||||
|
FilesModified: []string{"test.go"},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
ID: 3, // Same as newer - should be skipped
|
||||||
|
Project: "test",
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
results := DetectConflictsWithExisting(newer, existing)
|
||||||
|
|
||||||
|
// Should find conflicts with obs 1 (concepts + files overlap + correction language)
|
||||||
|
// but not with obs 2 (different concepts/files) or obs 3 (same ID)
|
||||||
|
s.GreaterOrEqual(len(results), 1)
|
||||||
|
|
||||||
|
// At least one result should reference obs 1
|
||||||
|
foundObs1 := false
|
||||||
|
for _, r := range results {
|
||||||
|
for _, id := range r.OlderObsIDs {
|
||||||
|
if id == 1 {
|
||||||
|
foundObs1 = true
|
||||||
|
break
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
s.True(foundObs1)
|
||||||
|
}
|
||||||
|
|
||||||
|
// TestDetectConflictsWithExisting_DifferentProjects tests that different projects don't conflict.
|
||||||
|
func (s *ConflictSuite) TestDetectConflictsWithExisting_DifferentProjects() {
|
||||||
|
newer := &Observation{
|
||||||
|
ID: 2,
|
||||||
|
Project: "project-a",
|
||||||
|
Scope: ScopeProject,
|
||||||
|
Narrative: sql.NullString{String: "Actually, that was wrong", Valid: true},
|
||||||
|
}
|
||||||
|
|
||||||
|
existing := []*Observation{
|
||||||
|
{
|
||||||
|
ID: 1,
|
||||||
|
Project: "project-b",
|
||||||
|
Scope: ScopeProject,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
results := DetectConflictsWithExisting(newer, existing)
|
||||||
|
s.Empty(results) // Different projects should not conflict
|
||||||
|
}
|
||||||
|
|
||||||
|
// TestDetectConflictsWithExisting_GlobalScope tests that global observations can conflict across projects.
|
||||||
|
func (s *ConflictSuite) TestDetectConflictsWithExisting_GlobalScope() {
|
||||||
|
newer := &Observation{
|
||||||
|
ID: 2,
|
||||||
|
Project: "project-a",
|
||||||
|
Scope: ScopeGlobal,
|
||||||
|
Narrative: sql.NullString{String: "Actually, that was wrong", Valid: true},
|
||||||
|
Concepts: []string{"security"},
|
||||||
|
FilesModified: []string{"auth.go"},
|
||||||
|
}
|
||||||
|
|
||||||
|
existing := []*Observation{
|
||||||
|
{
|
||||||
|
ID: 1,
|
||||||
|
Project: "project-b",
|
||||||
|
Scope: ScopeGlobal, // Global scope allows cross-project conflict detection
|
||||||
|
Concepts: []string{"security"},
|
||||||
|
FilesModified: []string{"auth.go"},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
results := DetectConflictsWithExisting(newer, existing)
|
||||||
|
s.GreaterOrEqual(len(results), 1) // Global scope allows conflict detection
|
||||||
|
}
|
||||||
|
|
||||||
|
// TestCorrectionPatterns_Compiled ensures all patterns compile correctly.
|
||||||
|
func TestCorrectionPatterns_Compiled(t *testing.T) {
|
||||||
|
// This test verifies that all correction patterns are valid regexps
|
||||||
|
// If any pattern fails to compile, the package won't load
|
||||||
|
assert.NotEmpty(t, CorrectionPatterns)
|
||||||
|
for i, pattern := range CorrectionPatterns {
|
||||||
|
assert.NotNil(t, pattern, "Pattern %d should not be nil", i)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// TestOpposingChangePatterns tests the opposing change pattern map.
|
||||||
|
func TestOpposingChangePatterns(t *testing.T) {
|
||||||
|
assert.NotEmpty(t, OpposingChangePatterns)
|
||||||
|
assert.Equal(t, "remove", OpposingChangePatterns["add"])
|
||||||
|
assert.Equal(t, "removed", OpposingChangePatterns["added"])
|
||||||
|
assert.Equal(t, "delete", OpposingChangePatterns["create"])
|
||||||
|
assert.Equal(t, "disable", OpposingChangePatterns["enable"])
|
||||||
|
}
|
||||||
|
|
||||||
|
// TestObservationConflict_Fields tests field access.
|
||||||
|
func TestObservationConflict_Fields(t *testing.T) {
|
||||||
|
conflict := &ObservationConflict{
|
||||||
|
ID: 1,
|
||||||
|
NewerObsID: 10,
|
||||||
|
OlderObsID: 5,
|
||||||
|
ConflictType: ConflictSuperseded,
|
||||||
|
Resolution: ResolutionPreferNewer,
|
||||||
|
Reason: "Test reason",
|
||||||
|
DetectedAt: "2024-01-01T00:00:00Z",
|
||||||
|
DetectedAtEpoch: 1704067200000,
|
||||||
|
Resolved: false,
|
||||||
|
}
|
||||||
|
|
||||||
|
assert.Equal(t, int64(1), conflict.ID)
|
||||||
|
assert.Equal(t, int64(10), conflict.NewerObsID)
|
||||||
|
assert.Equal(t, int64(5), conflict.OlderObsID)
|
||||||
|
assert.Equal(t, ConflictSuperseded, conflict.ConflictType)
|
||||||
|
assert.Equal(t, ResolutionPreferNewer, conflict.Resolution)
|
||||||
|
assert.False(t, conflict.Resolved)
|
||||||
|
}
|
||||||
@@ -139,6 +139,16 @@ type Observation struct {
|
|||||||
CreatedAt string `db:"created_at" json:"created_at"`
|
CreatedAt string `db:"created_at" json:"created_at"`
|
||||||
CreatedAtEpoch int64 `db:"created_at_epoch" json:"created_at_epoch"`
|
CreatedAtEpoch int64 `db:"created_at_epoch" json:"created_at_epoch"`
|
||||||
IsStale bool `db:"-" json:"is_stale,omitempty"`
|
IsStale bool `db:"-" json:"is_stale,omitempty"`
|
||||||
|
|
||||||
|
// Importance scoring fields
|
||||||
|
ImportanceScore float64 `db:"importance_score" json:"importance_score"`
|
||||||
|
UserFeedback int `db:"user_feedback" json:"user_feedback"`
|
||||||
|
RetrievalCount int `db:"retrieval_count" json:"retrieval_count"`
|
||||||
|
LastRetrievedAt sql.NullInt64 `db:"last_retrieved_at_epoch" json:"last_retrieved_at_epoch,omitempty"`
|
||||||
|
ScoreUpdatedAt sql.NullInt64 `db:"score_updated_at_epoch" json:"score_updated_at_epoch,omitempty"`
|
||||||
|
|
||||||
|
// Conflict detection fields
|
||||||
|
IsSuperseded bool `db:"is_superseded" json:"is_superseded,omitempty"`
|
||||||
}
|
}
|
||||||
|
|
||||||
// ParsedObservation represents an observation parsed from SDK response XML.
|
// ParsedObservation represents an observation parsed from SDK response XML.
|
||||||
@@ -205,6 +215,16 @@ type ObservationJSON struct {
|
|||||||
CreatedAt string `json:"created_at"`
|
CreatedAt string `json:"created_at"`
|
||||||
CreatedAtEpoch int64 `json:"created_at_epoch"`
|
CreatedAtEpoch int64 `json:"created_at_epoch"`
|
||||||
IsStale bool `json:"is_stale,omitempty"`
|
IsStale bool `json:"is_stale,omitempty"`
|
||||||
|
|
||||||
|
// Importance scoring fields
|
||||||
|
ImportanceScore float64 `json:"importance_score"`
|
||||||
|
UserFeedback int `json:"user_feedback"`
|
||||||
|
RetrievalCount int `json:"retrieval_count"`
|
||||||
|
LastRetrievedAt int64 `json:"last_retrieved_at_epoch,omitempty"`
|
||||||
|
ScoreUpdatedAt int64 `json:"score_updated_at_epoch,omitempty"`
|
||||||
|
|
||||||
|
// Conflict detection fields
|
||||||
|
IsSuperseded bool `json:"is_superseded,omitempty"`
|
||||||
}
|
}
|
||||||
|
|
||||||
// MarshalJSON implements json.Marshaler for Observation.
|
// MarshalJSON implements json.Marshaler for Observation.
|
||||||
@@ -225,6 +245,12 @@ func (o *Observation) MarshalJSON() ([]byte, error) {
|
|||||||
CreatedAt: o.CreatedAt,
|
CreatedAt: o.CreatedAt,
|
||||||
CreatedAtEpoch: o.CreatedAtEpoch,
|
CreatedAtEpoch: o.CreatedAtEpoch,
|
||||||
IsStale: o.IsStale,
|
IsStale: o.IsStale,
|
||||||
|
// Importance scoring fields
|
||||||
|
ImportanceScore: o.ImportanceScore,
|
||||||
|
UserFeedback: o.UserFeedback,
|
||||||
|
RetrievalCount: o.RetrievalCount,
|
||||||
|
// Conflict detection fields
|
||||||
|
IsSuperseded: o.IsSuperseded,
|
||||||
}
|
}
|
||||||
if o.Title.Valid {
|
if o.Title.Valid {
|
||||||
j.Title = o.Title.String
|
j.Title = o.Title.String
|
||||||
@@ -238,6 +264,12 @@ func (o *Observation) MarshalJSON() ([]byte, error) {
|
|||||||
if o.PromptNumber.Valid {
|
if o.PromptNumber.Valid {
|
||||||
j.PromptNumber = o.PromptNumber.Int64
|
j.PromptNumber = o.PromptNumber.Int64
|
||||||
}
|
}
|
||||||
|
if o.LastRetrievedAt.Valid {
|
||||||
|
j.LastRetrievedAt = o.LastRetrievedAt.Int64
|
||||||
|
}
|
||||||
|
if o.ScoreUpdatedAt.Valid {
|
||||||
|
j.ScoreUpdatedAt = o.ScoreUpdatedAt.Int64
|
||||||
|
}
|
||||||
return json.Marshal(j)
|
return json.Marshal(j)
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -268,9 +300,28 @@ func NewObservation(sdkSessionID, project string, parsed *ParsedObservation, pro
|
|||||||
DiscoveryTokens: discoveryTokens,
|
DiscoveryTokens: discoveryTokens,
|
||||||
CreatedAt: now.Format(time.RFC3339),
|
CreatedAt: now.Format(time.RFC3339),
|
||||||
CreatedAtEpoch: now.UnixMilli(),
|
CreatedAtEpoch: now.UnixMilli(),
|
||||||
|
// Importance scoring: new observations start with score 1.0
|
||||||
|
ImportanceScore: 1.0,
|
||||||
|
UserFeedback: 0,
|
||||||
|
RetrievalCount: 0,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// ToMap converts the observation to a map for JSON response building.
|
||||||
|
// This allows adding extra fields like similarity scores.
|
||||||
|
func (o *Observation) ToMap() map[string]interface{} {
|
||||||
|
// Marshal to JSON then unmarshal to map (uses MarshalJSON for proper conversion)
|
||||||
|
data, err := json.Marshal(o)
|
||||||
|
if err != nil {
|
||||||
|
return map[string]interface{}{"id": o.ID, "error": err.Error()}
|
||||||
|
}
|
||||||
|
var result map[string]interface{}
|
||||||
|
if err := json.Unmarshal(data, &result); err != nil {
|
||||||
|
return map[string]interface{}{"id": o.ID, "error": err.Error()}
|
||||||
|
}
|
||||||
|
return result
|
||||||
|
}
|
||||||
|
|
||||||
// CheckStaleness checks if an observation is stale based on current file mtimes.
|
// CheckStaleness checks if an observation is stale based on current file mtimes.
|
||||||
// Returns true if any tracked file has been modified since the observation was created.
|
// Returns true if any tracked file has been modified since the observation was created.
|
||||||
func (o *Observation) CheckStaleness(currentMtimes map[string]int64) bool {
|
func (o *Observation) CheckStaleness(currentMtimes map[string]int64) bool {
|
||||||
|
|||||||
@@ -0,0 +1,398 @@
|
|||||||
|
// Package models contains domain models for claude-mnemonic.
|
||||||
|
package models
|
||||||
|
|
||||||
|
import (
|
||||||
|
"database/sql"
|
||||||
|
"database/sql/driver"
|
||||||
|
"encoding/json"
|
||||||
|
"time"
|
||||||
|
)
|
||||||
|
|
||||||
|
// PatternType represents the category of detected pattern.
|
||||||
|
type PatternType string
|
||||||
|
|
||||||
|
const (
|
||||||
|
// PatternTypeBug represents recurring bug patterns (e.g., "nil handling oversight").
|
||||||
|
PatternTypeBug PatternType = "bug"
|
||||||
|
// PatternTypeRefactor represents recurring refactoring approaches (e.g., "interface extraction").
|
||||||
|
PatternTypeRefactor PatternType = "refactor"
|
||||||
|
// PatternTypeArchitecture represents consistent architectural patterns.
|
||||||
|
PatternTypeArchitecture PatternType = "architecture"
|
||||||
|
// PatternTypeAntiPattern represents identified anti-patterns to avoid.
|
||||||
|
PatternTypeAntiPattern PatternType = "anti-pattern"
|
||||||
|
// PatternTypeBestPractice represents best practices that work consistently.
|
||||||
|
PatternTypeBestPractice PatternType = "best-practice"
|
||||||
|
)
|
||||||
|
|
||||||
|
// PatternStatus represents the lifecycle status of a pattern.
|
||||||
|
type PatternStatus string
|
||||||
|
|
||||||
|
const (
|
||||||
|
// PatternStatusActive means the pattern is actively being tracked and can be referenced.
|
||||||
|
PatternStatusActive PatternStatus = "active"
|
||||||
|
// PatternStatusDeprecated means the pattern has been superseded or is no longer relevant.
|
||||||
|
PatternStatusDeprecated PatternStatus = "deprecated"
|
||||||
|
// PatternStatusMerged means this pattern was merged into another pattern.
|
||||||
|
PatternStatusMerged PatternStatus = "merged"
|
||||||
|
)
|
||||||
|
|
||||||
|
// Pattern represents a recurring pattern detected across observations.
|
||||||
|
// This enables Claude to reference historical insights: "I've encountered this pattern 12 times."
|
||||||
|
type Pattern struct {
|
||||||
|
ID int64 `db:"id" json:"id"`
|
||||||
|
Name string `db:"name" json:"name"` // e.g., "State Management Anti-Pattern"
|
||||||
|
Type PatternType `db:"type" json:"type"` // bug, refactor, architecture, etc.
|
||||||
|
Description sql.NullString `db:"description" json:"description"` // Detailed description
|
||||||
|
Signature JSONStringArray `db:"signature" json:"signature"` // Keyword clusters for detection
|
||||||
|
Recommendation sql.NullString `db:"recommendation" json:"recommendation"` // What works for this pattern
|
||||||
|
Frequency int `db:"frequency" json:"frequency"` // How many times encountered
|
||||||
|
Projects JSONStringArray `db:"projects" json:"projects"` // Projects where this pattern was seen
|
||||||
|
ObservationIDs JSONInt64Array `db:"observation_ids" json:"observation_ids"` // Source observation IDs
|
||||||
|
Status PatternStatus `db:"status" json:"status"` // active, deprecated, merged
|
||||||
|
MergedIntoID sql.NullInt64 `db:"merged_into_id" json:"merged_into_id,omitempty"`
|
||||||
|
Confidence float64 `db:"confidence" json:"confidence"` // Detection confidence (0.0-1.0)
|
||||||
|
LastSeenAt string `db:"last_seen_at" json:"last_seen_at"` // Last time pattern was detected
|
||||||
|
LastSeenEpoch int64 `db:"last_seen_at_epoch" json:"last_seen_at_epoch"`
|
||||||
|
CreatedAt string `db:"created_at" json:"created_at"`
|
||||||
|
CreatedAtEpoch int64 `db:"created_at_epoch" json:"created_at_epoch"`
|
||||||
|
}
|
||||||
|
|
||||||
|
// JSONInt64Array is a custom type for handling JSON int64 arrays in SQLite.
|
||||||
|
type JSONInt64Array []int64
|
||||||
|
|
||||||
|
// Scan implements sql.Scanner for JSONInt64Array.
|
||||||
|
func (j *JSONInt64Array) Scan(src interface{}) error {
|
||||||
|
if src == nil {
|
||||||
|
*j = nil
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
var data []byte
|
||||||
|
switch v := src.(type) {
|
||||||
|
case string:
|
||||||
|
data = []byte(v)
|
||||||
|
case []byte:
|
||||||
|
data = v
|
||||||
|
default:
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
if len(data) == 0 {
|
||||||
|
*j = nil
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
return json.Unmarshal(data, j)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Value implements driver.Valuer for JSONInt64Array.
|
||||||
|
func (j JSONInt64Array) Value() (driver.Value, error) {
|
||||||
|
if j == nil {
|
||||||
|
return nil, nil
|
||||||
|
}
|
||||||
|
return json.Marshal(j)
|
||||||
|
}
|
||||||
|
|
||||||
|
// PatternJSON is a JSON-friendly representation of Pattern.
|
||||||
|
type PatternJSON struct {
|
||||||
|
ID int64 `json:"id"`
|
||||||
|
Name string `json:"name"`
|
||||||
|
Type PatternType `json:"type"`
|
||||||
|
Description string `json:"description,omitempty"`
|
||||||
|
Signature []string `json:"signature,omitempty"`
|
||||||
|
Recommendation string `json:"recommendation,omitempty"`
|
||||||
|
Frequency int `json:"frequency"`
|
||||||
|
Projects []string `json:"projects,omitempty"`
|
||||||
|
ObservationIDs []int64 `json:"observation_ids,omitempty"`
|
||||||
|
Status PatternStatus `json:"status"`
|
||||||
|
MergedIntoID int64 `json:"merged_into_id,omitempty"`
|
||||||
|
Confidence float64 `json:"confidence"`
|
||||||
|
LastSeenAt string `json:"last_seen_at"`
|
||||||
|
LastSeenEpoch int64 `json:"last_seen_at_epoch"`
|
||||||
|
CreatedAt string `json:"created_at"`
|
||||||
|
CreatedAtEpoch int64 `json:"created_at_epoch"`
|
||||||
|
}
|
||||||
|
|
||||||
|
// MarshalJSON implements json.Marshaler for Pattern.
|
||||||
|
func (p *Pattern) MarshalJSON() ([]byte, error) {
|
||||||
|
j := PatternJSON{
|
||||||
|
ID: p.ID,
|
||||||
|
Name: p.Name,
|
||||||
|
Type: p.Type,
|
||||||
|
Signature: p.Signature,
|
||||||
|
Frequency: p.Frequency,
|
||||||
|
Projects: p.Projects,
|
||||||
|
ObservationIDs: p.ObservationIDs,
|
||||||
|
Status: p.Status,
|
||||||
|
Confidence: p.Confidence,
|
||||||
|
LastSeenAt: p.LastSeenAt,
|
||||||
|
LastSeenEpoch: p.LastSeenEpoch,
|
||||||
|
CreatedAt: p.CreatedAt,
|
||||||
|
CreatedAtEpoch: p.CreatedAtEpoch,
|
||||||
|
}
|
||||||
|
if p.Description.Valid {
|
||||||
|
j.Description = p.Description.String
|
||||||
|
}
|
||||||
|
if p.Recommendation.Valid {
|
||||||
|
j.Recommendation = p.Recommendation.String
|
||||||
|
}
|
||||||
|
if p.MergedIntoID.Valid {
|
||||||
|
j.MergedIntoID = p.MergedIntoID.Int64
|
||||||
|
}
|
||||||
|
return json.Marshal(j)
|
||||||
|
}
|
||||||
|
|
||||||
|
// NewPattern creates a new pattern from detected data.
|
||||||
|
func NewPattern(name string, patternType PatternType, description string, signature []string, project string, observationID int64) *Pattern {
|
||||||
|
now := time.Now()
|
||||||
|
return &Pattern{
|
||||||
|
Name: name,
|
||||||
|
Type: patternType,
|
||||||
|
Description: sql.NullString{String: description, Valid: description != ""},
|
||||||
|
Signature: signature,
|
||||||
|
Frequency: 1,
|
||||||
|
Projects: []string{project},
|
||||||
|
ObservationIDs: []int64{observationID},
|
||||||
|
Status: PatternStatusActive,
|
||||||
|
Confidence: 0.5, // Initial confidence
|
||||||
|
LastSeenAt: now.Format(time.RFC3339),
|
||||||
|
LastSeenEpoch: now.UnixMilli(),
|
||||||
|
CreatedAt: now.Format(time.RFC3339),
|
||||||
|
CreatedAtEpoch: now.UnixMilli(),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// AddOccurrence records a new occurrence of this pattern.
|
||||||
|
func (p *Pattern) AddOccurrence(project string, observationID int64) {
|
||||||
|
p.Frequency++
|
||||||
|
|
||||||
|
// Add project if not already tracked
|
||||||
|
found := false
|
||||||
|
for _, proj := range p.Projects {
|
||||||
|
if proj == project {
|
||||||
|
found = true
|
||||||
|
break
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if !found {
|
||||||
|
p.Projects = append(p.Projects, project)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Add observation ID if not already tracked
|
||||||
|
for _, id := range p.ObservationIDs {
|
||||||
|
if id == observationID {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
}
|
||||||
|
p.ObservationIDs = append(p.ObservationIDs, observationID)
|
||||||
|
|
||||||
|
// Update confidence based on frequency and cross-project occurrence
|
||||||
|
p.updateConfidence()
|
||||||
|
|
||||||
|
// Update last seen timestamp
|
||||||
|
now := time.Now()
|
||||||
|
p.LastSeenAt = now.Format(time.RFC3339)
|
||||||
|
p.LastSeenEpoch = now.UnixMilli()
|
||||||
|
}
|
||||||
|
|
||||||
|
// updateConfidence adjusts confidence based on frequency and cross-project validation.
|
||||||
|
func (p *Pattern) updateConfidence() {
|
||||||
|
// Base confidence from frequency (logarithmic scaling)
|
||||||
|
freqConfidence := 0.3 + (0.4 * (float64(min(p.Frequency, 10)) / 10.0))
|
||||||
|
|
||||||
|
// Cross-project bonus: patterns seen across multiple projects are more reliable
|
||||||
|
projectBonus := 0.0
|
||||||
|
if len(p.Projects) >= 2 {
|
||||||
|
projectBonus = 0.1
|
||||||
|
}
|
||||||
|
if len(p.Projects) >= 5 {
|
||||||
|
projectBonus = 0.2
|
||||||
|
}
|
||||||
|
|
||||||
|
p.Confidence = min(1.0, freqConfidence+projectBonus)
|
||||||
|
}
|
||||||
|
|
||||||
|
// PatternMatch represents a match between an observation and a potential pattern.
|
||||||
|
type PatternMatch struct {
|
||||||
|
PatternID int64 `json:"pattern_id"`
|
||||||
|
Score float64 `json:"score"` // Match score (0.0-1.0)
|
||||||
|
MatchedOn string `json:"matched_on"` // What triggered the match (concept, keyword, type, etc.)
|
||||||
|
IsNew bool `json:"is_new"` // Whether this would create a new pattern
|
||||||
|
SuggestedName string `json:"suggested_name,omitempty"`
|
||||||
|
}
|
||||||
|
|
||||||
|
// PatternSignatureKeywords are common keywords used in pattern detection.
|
||||||
|
var PatternSignatureKeywords = map[PatternType][]string{
|
||||||
|
PatternTypeBug: {
|
||||||
|
"nil", "null", "undefined", "panic", "crash", "error handling",
|
||||||
|
"race condition", "deadlock", "memory leak", "overflow",
|
||||||
|
"off-by-one", "boundary", "timeout", "concurrency",
|
||||||
|
},
|
||||||
|
PatternTypeRefactor: {
|
||||||
|
"extract", "inline", "rename", "move", "split", "merge",
|
||||||
|
"interface", "abstraction", "decouple", "simplify",
|
||||||
|
"consolidate", "modularize", "encapsulate",
|
||||||
|
},
|
||||||
|
PatternTypeArchitecture: {
|
||||||
|
"layer", "service", "repository", "controller", "handler",
|
||||||
|
"middleware", "dependency injection", "factory", "singleton",
|
||||||
|
"observer", "strategy", "adapter", "facade", "builder",
|
||||||
|
},
|
||||||
|
PatternTypeAntiPattern: {
|
||||||
|
"god class", "spaghetti", "copy paste", "magic number",
|
||||||
|
"hardcoded", "circular dependency", "premature optimization",
|
||||||
|
"over-engineering", "feature envy", "data clump",
|
||||||
|
},
|
||||||
|
PatternTypeBestPractice: {
|
||||||
|
"test", "validation", "logging", "monitoring", "documentation",
|
||||||
|
"error handling", "retry", "timeout", "circuit breaker",
|
||||||
|
"graceful shutdown", "health check", "metrics",
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
// DetectPatternType analyzes concepts and content to determine pattern type.
|
||||||
|
func DetectPatternType(concepts []string, title, narrative string) PatternType {
|
||||||
|
// Check concepts first
|
||||||
|
for _, concept := range concepts {
|
||||||
|
switch concept {
|
||||||
|
case "anti-pattern":
|
||||||
|
return PatternTypeAntiPattern
|
||||||
|
case "best-practice":
|
||||||
|
return PatternTypeBestPractice
|
||||||
|
case "architecture":
|
||||||
|
return PatternTypeArchitecture
|
||||||
|
case "refactor":
|
||||||
|
return PatternTypeRefactor
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Check for bug-related patterns in content
|
||||||
|
content := title + " " + narrative
|
||||||
|
for _, keyword := range PatternSignatureKeywords[PatternTypeBug] {
|
||||||
|
if containsIgnoreCase(content, keyword) {
|
||||||
|
return PatternTypeBug
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Default to refactor for other patterns
|
||||||
|
return PatternTypeRefactor
|
||||||
|
}
|
||||||
|
|
||||||
|
// containsIgnoreCase checks if text contains substr (case-insensitive).
|
||||||
|
func containsIgnoreCase(text, substr string) bool {
|
||||||
|
textLower := toLower(text)
|
||||||
|
substrLower := toLower(substr)
|
||||||
|
return contains(textLower, substrLower)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Simple implementations to avoid strings package dependency in this file
|
||||||
|
func toLower(s string) string {
|
||||||
|
b := make([]byte, len(s))
|
||||||
|
for i := 0; i < len(s); i++ {
|
||||||
|
c := s[i]
|
||||||
|
if 'A' <= c && c <= 'Z' {
|
||||||
|
c += 'a' - 'A'
|
||||||
|
}
|
||||||
|
b[i] = c
|
||||||
|
}
|
||||||
|
return string(b)
|
||||||
|
}
|
||||||
|
|
||||||
|
func contains(s, substr string) bool {
|
||||||
|
if len(substr) > len(s) {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
for i := 0; i <= len(s)-len(substr); i++ {
|
||||||
|
if s[i:i+len(substr)] == substr {
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|
||||||
|
// ExtractSignature creates a signature from observation content.
|
||||||
|
func ExtractSignature(concepts []string, title, narrative string) []string {
|
||||||
|
var signature []string
|
||||||
|
|
||||||
|
// Add all concepts
|
||||||
|
signature = append(signature, concepts...)
|
||||||
|
|
||||||
|
// Extract key terms from title (simple word extraction)
|
||||||
|
for _, word := range splitWords(title) {
|
||||||
|
if len(word) > 3 && isSignificantWord(word) {
|
||||||
|
signature = append(signature, toLower(word))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return uniqueStrings(signature)
|
||||||
|
}
|
||||||
|
|
||||||
|
// splitWords is a simple word splitter.
|
||||||
|
func splitWords(s string) []string {
|
||||||
|
var words []string
|
||||||
|
word := ""
|
||||||
|
for _, r := range s {
|
||||||
|
if r == ' ' || r == '-' || r == '_' || r == '.' || r == ',' {
|
||||||
|
if word != "" {
|
||||||
|
words = append(words, word)
|
||||||
|
word = ""
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
word += string(r)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if word != "" {
|
||||||
|
words = append(words, word)
|
||||||
|
}
|
||||||
|
return words
|
||||||
|
}
|
||||||
|
|
||||||
|
// isSignificantWord filters out common stop words.
|
||||||
|
func isSignificantWord(word string) bool {
|
||||||
|
stopWords := map[string]bool{
|
||||||
|
"the": true, "and": true, "for": true, "with": true, "that": true,
|
||||||
|
"this": true, "from": true, "have": true, "not": true, "are": true,
|
||||||
|
"was": true, "but": true, "all": true, "can": true, "had": true,
|
||||||
|
"were": true, "been": true, "will": true, "when": true, "what": true,
|
||||||
|
}
|
||||||
|
return !stopWords[toLower(word)]
|
||||||
|
}
|
||||||
|
|
||||||
|
// uniqueStrings returns a slice with duplicate strings removed.
|
||||||
|
func uniqueStrings(s []string) []string {
|
||||||
|
seen := make(map[string]bool)
|
||||||
|
var result []string
|
||||||
|
for _, v := range s {
|
||||||
|
if !seen[v] {
|
||||||
|
seen[v] = true
|
||||||
|
result = append(result, v)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return result
|
||||||
|
}
|
||||||
|
|
||||||
|
// CalculateMatchScore computes similarity between two signatures.
|
||||||
|
func CalculateMatchScore(sig1, sig2 []string) float64 {
|
||||||
|
if len(sig1) == 0 || len(sig2) == 0 {
|
||||||
|
return 0.0
|
||||||
|
}
|
||||||
|
|
||||||
|
set1 := make(map[string]bool)
|
||||||
|
for _, s := range sig1 {
|
||||||
|
set1[toLower(s)] = true
|
||||||
|
}
|
||||||
|
|
||||||
|
matches := 0
|
||||||
|
for _, s := range sig2 {
|
||||||
|
if set1[toLower(s)] {
|
||||||
|
matches++
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Jaccard similarity
|
||||||
|
unionSize := len(sig1) + len(sig2) - matches
|
||||||
|
if unionSize == 0 {
|
||||||
|
return 0.0
|
||||||
|
}
|
||||||
|
return float64(matches) / float64(unionSize)
|
||||||
|
}
|
||||||
@@ -0,0 +1,357 @@
|
|||||||
|
package models
|
||||||
|
|
||||||
|
import (
|
||||||
|
"database/sql"
|
||||||
|
"encoding/json"
|
||||||
|
"testing"
|
||||||
|
"time"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestNewPattern(t *testing.T) {
|
||||||
|
pattern := NewPattern(
|
||||||
|
"Test Pattern",
|
||||||
|
PatternTypeBug,
|
||||||
|
"A test pattern description",
|
||||||
|
[]string{"nil", "error", "handling"},
|
||||||
|
"test-project",
|
||||||
|
123,
|
||||||
|
)
|
||||||
|
|
||||||
|
if pattern.Name != "Test Pattern" {
|
||||||
|
t.Errorf("Expected name 'Test Pattern', got '%s'", pattern.Name)
|
||||||
|
}
|
||||||
|
if pattern.Type != PatternTypeBug {
|
||||||
|
t.Errorf("Expected type PatternTypeBug, got '%s'", pattern.Type)
|
||||||
|
}
|
||||||
|
if !pattern.Description.Valid || pattern.Description.String != "A test pattern description" {
|
||||||
|
t.Errorf("Description not set correctly")
|
||||||
|
}
|
||||||
|
if len(pattern.Signature) != 3 {
|
||||||
|
t.Errorf("Expected 3 signature elements, got %d", len(pattern.Signature))
|
||||||
|
}
|
||||||
|
if pattern.Frequency != 1 {
|
||||||
|
t.Errorf("Expected frequency 1, got %d", pattern.Frequency)
|
||||||
|
}
|
||||||
|
if len(pattern.Projects) != 1 || pattern.Projects[0] != "test-project" {
|
||||||
|
t.Errorf("Projects not set correctly")
|
||||||
|
}
|
||||||
|
if len(pattern.ObservationIDs) != 1 || pattern.ObservationIDs[0] != 123 {
|
||||||
|
t.Errorf("ObservationIDs not set correctly")
|
||||||
|
}
|
||||||
|
if pattern.Status != PatternStatusActive {
|
||||||
|
t.Errorf("Expected status Active, got '%s'", pattern.Status)
|
||||||
|
}
|
||||||
|
if pattern.Confidence != 0.5 {
|
||||||
|
t.Errorf("Expected initial confidence 0.5, got %f", pattern.Confidence)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestPattern_AddOccurrence(t *testing.T) {
|
||||||
|
pattern := NewPattern("Test", PatternTypeBug, "desc", []string{"test"}, "project1", 1)
|
||||||
|
|
||||||
|
// Add same project occurrence
|
||||||
|
pattern.AddOccurrence("project1", 2)
|
||||||
|
if pattern.Frequency != 2 {
|
||||||
|
t.Errorf("Expected frequency 2, got %d", pattern.Frequency)
|
||||||
|
}
|
||||||
|
if len(pattern.Projects) != 1 {
|
||||||
|
t.Errorf("Expected 1 project (no duplicates), got %d", len(pattern.Projects))
|
||||||
|
}
|
||||||
|
|
||||||
|
// Add different project occurrence
|
||||||
|
pattern.AddOccurrence("project2", 3)
|
||||||
|
if pattern.Frequency != 3 {
|
||||||
|
t.Errorf("Expected frequency 3, got %d", pattern.Frequency)
|
||||||
|
}
|
||||||
|
if len(pattern.Projects) != 2 {
|
||||||
|
t.Errorf("Expected 2 projects, got %d", len(pattern.Projects))
|
||||||
|
}
|
||||||
|
|
||||||
|
// Add duplicate observation ID - should not duplicate
|
||||||
|
pattern.AddOccurrence("project2", 3)
|
||||||
|
if len(pattern.ObservationIDs) != 3 {
|
||||||
|
t.Errorf("Expected 3 observation IDs (no duplicate), got %d", len(pattern.ObservationIDs))
|
||||||
|
}
|
||||||
|
|
||||||
|
// Check confidence increased
|
||||||
|
if pattern.Confidence <= 0.5 {
|
||||||
|
t.Errorf("Expected confidence to increase above 0.5, got %f", pattern.Confidence)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestPattern_ConfidenceCalculation(t *testing.T) {
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
frequency int
|
||||||
|
projectCount int
|
||||||
|
minConfidence float64
|
||||||
|
maxConfidence float64
|
||||||
|
}{
|
||||||
|
{"low_frequency", 2, 1, 0.3, 0.5},
|
||||||
|
{"high_frequency", 10, 1, 0.6, 0.8},
|
||||||
|
{"multi_project", 3, 3, 0.4, 0.7},
|
||||||
|
{"high_freq_multi_proj", 10, 5, 0.7, 1.0},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tt := range tests {
|
||||||
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
pattern := NewPattern("Test", PatternTypeBug, "", []string{}, "proj1", 1)
|
||||||
|
|
||||||
|
// Simulate occurrences
|
||||||
|
for i := 1; i < tt.frequency; i++ {
|
||||||
|
projIdx := i % tt.projectCount
|
||||||
|
if projIdx == 0 {
|
||||||
|
projIdx = 1
|
||||||
|
}
|
||||||
|
pattern.AddOccurrence("proj"+string(rune('0'+projIdx)), int64(i+1))
|
||||||
|
}
|
||||||
|
|
||||||
|
if pattern.Confidence < tt.minConfidence || pattern.Confidence > tt.maxConfidence {
|
||||||
|
t.Errorf("Expected confidence between %f and %f, got %f",
|
||||||
|
tt.minConfidence, tt.maxConfidence, pattern.Confidence)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestPatternType_Detection(t *testing.T) {
|
||||||
|
tests := []struct {
|
||||||
|
concepts []string
|
||||||
|
title string
|
||||||
|
narrative string
|
||||||
|
expected PatternType
|
||||||
|
}{
|
||||||
|
{[]string{"anti-pattern"}, "", "", PatternTypeAntiPattern},
|
||||||
|
{[]string{"best-practice"}, "", "", PatternTypeBestPractice},
|
||||||
|
{[]string{"architecture"}, "", "", PatternTypeArchitecture},
|
||||||
|
{[]string{"refactor"}, "", "", PatternTypeRefactor},
|
||||||
|
{[]string{}, "nil pointer bug", "", PatternTypeBug},
|
||||||
|
{[]string{}, "Deadlock in concurrent code", "", PatternTypeBug},
|
||||||
|
{[]string{}, "Extract interface", "", PatternTypeRefactor},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tt := range tests {
|
||||||
|
t.Run(tt.title+"_"+tt.expected.String(), func(t *testing.T) {
|
||||||
|
result := DetectPatternType(tt.concepts, tt.title, tt.narrative)
|
||||||
|
if result != tt.expected {
|
||||||
|
t.Errorf("Expected %s, got %s", tt.expected, result)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (pt PatternType) String() string {
|
||||||
|
return string(pt)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestExtractSignature(t *testing.T) {
|
||||||
|
concepts := []string{"error-handling", "security"}
|
||||||
|
title := "Nil Pointer Validation Pattern"
|
||||||
|
narrative := "Always validate before dereferencing"
|
||||||
|
|
||||||
|
signature := ExtractSignature(concepts, title, narrative)
|
||||||
|
|
||||||
|
// Should contain concepts
|
||||||
|
found := false
|
||||||
|
for _, s := range signature {
|
||||||
|
if s == "error-handling" {
|
||||||
|
found = true
|
||||||
|
break
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if !found {
|
||||||
|
t.Errorf("Expected signature to contain concepts, got %v", signature)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Should contain significant words from title
|
||||||
|
found = false
|
||||||
|
for _, s := range signature {
|
||||||
|
if s == "validation" || s == "pattern" || s == "pointer" {
|
||||||
|
found = true
|
||||||
|
break
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if !found {
|
||||||
|
t.Errorf("Expected signature to contain title keywords, got %v", signature)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestCalculateMatchScore(t *testing.T) {
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
sig1 []string
|
||||||
|
sig2 []string
|
||||||
|
minScore float64
|
||||||
|
maxScore float64
|
||||||
|
}{
|
||||||
|
{"identical", []string{"a", "b", "c"}, []string{"a", "b", "c"}, 1.0, 1.0},
|
||||||
|
{"partial", []string{"a", "b", "c"}, []string{"a", "b", "d"}, 0.4, 0.6},
|
||||||
|
{"no_match", []string{"a", "b", "c"}, []string{"x", "y", "z"}, 0.0, 0.0},
|
||||||
|
{"empty", []string{}, []string{"a", "b"}, 0.0, 0.0},
|
||||||
|
{"subset", []string{"a", "b"}, []string{"a", "b", "c", "d"}, 0.4, 0.6},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tt := range tests {
|
||||||
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
score := CalculateMatchScore(tt.sig1, tt.sig2)
|
||||||
|
if score < tt.minScore || score > tt.maxScore {
|
||||||
|
t.Errorf("Expected score between %f and %f, got %f",
|
||||||
|
tt.minScore, tt.maxScore, score)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestPattern_MarshalJSON(t *testing.T) {
|
||||||
|
pattern := &Pattern{
|
||||||
|
ID: 1,
|
||||||
|
Name: "Test Pattern",
|
||||||
|
Type: PatternTypeBug,
|
||||||
|
Description: sql.NullString{String: "A description", Valid: true},
|
||||||
|
Signature: []string{"a", "b"},
|
||||||
|
Recommendation: sql.NullString{String: "Do this", Valid: true},
|
||||||
|
Frequency: 5,
|
||||||
|
Projects: []string{"proj1", "proj2"},
|
||||||
|
ObservationIDs: []int64{1, 2, 3},
|
||||||
|
Status: PatternStatusActive,
|
||||||
|
MergedIntoID: sql.NullInt64{Int64: 0, Valid: false},
|
||||||
|
Confidence: 0.8,
|
||||||
|
LastSeenAt: time.Now().Format(time.RFC3339),
|
||||||
|
LastSeenEpoch: time.Now().UnixMilli(),
|
||||||
|
CreatedAt: time.Now().Format(time.RFC3339),
|
||||||
|
CreatedAtEpoch: time.Now().UnixMilli(),
|
||||||
|
}
|
||||||
|
|
||||||
|
data, err := json.Marshal(pattern)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("Failed to marshal pattern: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
var result PatternJSON
|
||||||
|
if err := json.Unmarshal(data, &result); err != nil {
|
||||||
|
t.Fatalf("Failed to unmarshal pattern: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if result.Name != pattern.Name {
|
||||||
|
t.Errorf("Expected name %s, got %s", pattern.Name, result.Name)
|
||||||
|
}
|
||||||
|
if result.Description != pattern.Description.String {
|
||||||
|
t.Errorf("Expected description %s, got %s", pattern.Description.String, result.Description)
|
||||||
|
}
|
||||||
|
if result.Frequency != pattern.Frequency {
|
||||||
|
t.Errorf("Expected frequency %d, got %d", pattern.Frequency, result.Frequency)
|
||||||
|
}
|
||||||
|
if result.MergedIntoID != 0 {
|
||||||
|
t.Errorf("Expected merged_into_id 0 for invalid NullInt64, got %d", result.MergedIntoID)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestJSONInt64Array_Scan(t *testing.T) {
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
input interface{}
|
||||||
|
expected JSONInt64Array
|
||||||
|
wantErr bool
|
||||||
|
}{
|
||||||
|
{"string_array", "[1, 2, 3]", JSONInt64Array{1, 2, 3}, false},
|
||||||
|
{"bytes_array", []byte("[4, 5, 6]"), JSONInt64Array{4, 5, 6}, false},
|
||||||
|
{"nil", nil, nil, false},
|
||||||
|
{"empty_string", "", nil, false},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tt := range tests {
|
||||||
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
var arr JSONInt64Array
|
||||||
|
err := arr.Scan(tt.input)
|
||||||
|
if (err != nil) != tt.wantErr {
|
||||||
|
t.Errorf("Scan() error = %v, wantErr %v", err, tt.wantErr)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
if len(arr) != len(tt.expected) {
|
||||||
|
t.Errorf("Expected length %d, got %d", len(tt.expected), len(arr))
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestJSONInt64Array_Value(t *testing.T) {
|
||||||
|
arr := JSONInt64Array{1, 2, 3}
|
||||||
|
val, err := arr.Value()
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("Value() error = %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
bytes, ok := val.([]byte)
|
||||||
|
if !ok {
|
||||||
|
t.Fatalf("Expected []byte, got %T", val)
|
||||||
|
}
|
||||||
|
|
||||||
|
var result []int64
|
||||||
|
if err := json.Unmarshal(bytes, &result); err != nil {
|
||||||
|
t.Fatalf("Failed to unmarshal: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if len(result) != 3 || result[0] != 1 || result[1] != 2 || result[2] != 3 {
|
||||||
|
t.Errorf("Expected [1, 2, 3], got %v", result)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestPatternSignatureKeywords(t *testing.T) {
|
||||||
|
// Verify keywords exist for each type
|
||||||
|
types := []PatternType{
|
||||||
|
PatternTypeBug,
|
||||||
|
PatternTypeRefactor,
|
||||||
|
PatternTypeArchitecture,
|
||||||
|
PatternTypeAntiPattern,
|
||||||
|
PatternTypeBestPractice,
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, pt := range types {
|
||||||
|
keywords := PatternSignatureKeywords[pt]
|
||||||
|
if len(keywords) == 0 {
|
||||||
|
t.Errorf("No keywords defined for pattern type %s", pt)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestUniqueStrings(t *testing.T) {
|
||||||
|
tests := []struct {
|
||||||
|
input []string
|
||||||
|
expected int
|
||||||
|
}{
|
||||||
|
{[]string{"a", "b", "c"}, 3},
|
||||||
|
{[]string{"a", "a", "b"}, 2},
|
||||||
|
{[]string{"a", "a", "a"}, 1},
|
||||||
|
{[]string{}, 0},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tt := range tests {
|
||||||
|
result := uniqueStrings(tt.input)
|
||||||
|
if len(result) != tt.expected {
|
||||||
|
t.Errorf("uniqueStrings(%v) = %v (len=%d), expected len=%d",
|
||||||
|
tt.input, result, len(result), tt.expected)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestContainsIgnoreCase(t *testing.T) {
|
||||||
|
tests := []struct {
|
||||||
|
text string
|
||||||
|
substr string
|
||||||
|
expected bool
|
||||||
|
}{
|
||||||
|
{"Hello World", "hello", true},
|
||||||
|
{"Hello World", "WORLD", true},
|
||||||
|
{"Hello World", "xyz", false},
|
||||||
|
{"", "a", false},
|
||||||
|
{"a", "", true},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tt := range tests {
|
||||||
|
result := containsIgnoreCase(tt.text, tt.substr)
|
||||||
|
if result != tt.expected {
|
||||||
|
t.Errorf("containsIgnoreCase(%q, %q) = %v, expected %v",
|
||||||
|
tt.text, tt.substr, result, tt.expected)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
@@ -0,0 +1,489 @@
|
|||||||
|
// Package models contains domain models for claude-mnemonic.
|
||||||
|
package models
|
||||||
|
|
||||||
|
import (
|
||||||
|
"strings"
|
||||||
|
"time"
|
||||||
|
)
|
||||||
|
|
||||||
|
// RelationType represents the type of relationship between observations.
|
||||||
|
type RelationType string
|
||||||
|
|
||||||
|
const (
|
||||||
|
// RelationCauses means source observation caused target observation.
|
||||||
|
// Example: "This architectural decision caused this bug"
|
||||||
|
RelationCauses RelationType = "causes"
|
||||||
|
// RelationFixes means source observation fixes target observation.
|
||||||
|
// Example: "This bugfix addresses that discovered issue"
|
||||||
|
RelationFixes RelationType = "fixes"
|
||||||
|
// RelationSupersedes means source observation supersedes target observation.
|
||||||
|
// Example: "This new approach replaces the old workaround"
|
||||||
|
RelationSupersedes RelationType = "supersedes"
|
||||||
|
// RelationDependsOn means source observation depends on target observation.
|
||||||
|
// Example: "This feature relies on that architectural decision"
|
||||||
|
RelationDependsOn RelationType = "depends_on"
|
||||||
|
// RelationRelatesTo means observations are related but no causal relationship.
|
||||||
|
// Example: "Both deal with authentication"
|
||||||
|
RelationRelatesTo RelationType = "relates_to"
|
||||||
|
// RelationEvolvesFrom means source observation evolved from target observation.
|
||||||
|
// Example: "This refined pattern evolved from that initial discovery"
|
||||||
|
RelationEvolvesFrom RelationType = "evolves_from"
|
||||||
|
)
|
||||||
|
|
||||||
|
// AllRelationTypes is the list of all valid relation types.
|
||||||
|
var AllRelationTypes = []RelationType{
|
||||||
|
RelationCauses,
|
||||||
|
RelationFixes,
|
||||||
|
RelationSupersedes,
|
||||||
|
RelationDependsOn,
|
||||||
|
RelationRelatesTo,
|
||||||
|
RelationEvolvesFrom,
|
||||||
|
}
|
||||||
|
|
||||||
|
// RelationDetectionSource indicates how a relationship was detected.
|
||||||
|
type RelationDetectionSource string
|
||||||
|
|
||||||
|
const (
|
||||||
|
// DetectionSourceFileOverlap means relationship was detected via shared file references.
|
||||||
|
DetectionSourceFileOverlap RelationDetectionSource = "file_overlap"
|
||||||
|
// DetectionSourceEmbeddingSimilarity means relationship was detected via vector similarity.
|
||||||
|
DetectionSourceEmbeddingSimilarity RelationDetectionSource = "embedding_similarity"
|
||||||
|
// DetectionSourceTemporalProximity means relationship was detected via close timestamps.
|
||||||
|
DetectionSourceTemporalProximity RelationDetectionSource = "temporal_proximity"
|
||||||
|
// DetectionSourceNarrativeMention means relationship was detected via explicit mentions.
|
||||||
|
DetectionSourceNarrativeMention RelationDetectionSource = "narrative_mention"
|
||||||
|
// DetectionSourceConceptOverlap means relationship was detected via shared concepts.
|
||||||
|
DetectionSourceConceptOverlap RelationDetectionSource = "concept_overlap"
|
||||||
|
// DetectionSourceTypeProgression means relationship was detected via type progression pattern.
|
||||||
|
DetectionSourceTypeProgression RelationDetectionSource = "type_progression"
|
||||||
|
)
|
||||||
|
|
||||||
|
// ObservationRelation represents a directed relationship between two observations.
|
||||||
|
type ObservationRelation struct {
|
||||||
|
ID int64 `db:"id" json:"id"`
|
||||||
|
SourceID int64 `db:"source_id" json:"source_id"`
|
||||||
|
TargetID int64 `db:"target_id" json:"target_id"`
|
||||||
|
RelationType RelationType `db:"relation_type" json:"relation_type"`
|
||||||
|
Confidence float64 `db:"confidence" json:"confidence"`
|
||||||
|
DetectionSource RelationDetectionSource `db:"detection_source" json:"detection_source"`
|
||||||
|
Reason string `db:"reason" json:"reason,omitempty"`
|
||||||
|
CreatedAt string `db:"created_at" json:"created_at"`
|
||||||
|
CreatedAtEpoch int64 `db:"created_at_epoch" json:"created_at_epoch"`
|
||||||
|
}
|
||||||
|
|
||||||
|
// NewObservationRelation creates a new observation relation.
|
||||||
|
func NewObservationRelation(sourceID, targetID int64, relType RelationType, confidence float64, source RelationDetectionSource, reason string) *ObservationRelation {
|
||||||
|
now := time.Now()
|
||||||
|
return &ObservationRelation{
|
||||||
|
SourceID: sourceID,
|
||||||
|
TargetID: targetID,
|
||||||
|
RelationType: relType,
|
||||||
|
Confidence: confidence,
|
||||||
|
DetectionSource: source,
|
||||||
|
Reason: reason,
|
||||||
|
CreatedAt: now.Format(time.RFC3339),
|
||||||
|
CreatedAtEpoch: now.UnixMilli(),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// RelationDetectionResult contains the result of relation detection.
|
||||||
|
type RelationDetectionResult struct {
|
||||||
|
SourceID int64
|
||||||
|
TargetID int64
|
||||||
|
RelationType RelationType
|
||||||
|
Confidence float64
|
||||||
|
DetectionSource RelationDetectionSource
|
||||||
|
Reason string
|
||||||
|
}
|
||||||
|
|
||||||
|
// DetectFileOverlapRelation checks if observations share file references and determines relationship type.
|
||||||
|
func DetectFileOverlapRelation(newer, older *Observation) *RelationDetectionResult {
|
||||||
|
// Check for overlapping modified files
|
||||||
|
newerModified := make(map[string]bool)
|
||||||
|
for _, f := range newer.FilesModified {
|
||||||
|
newerModified[f] = true
|
||||||
|
}
|
||||||
|
|
||||||
|
olderModified := make(map[string]bool)
|
||||||
|
for _, f := range older.FilesModified {
|
||||||
|
olderModified[f] = true
|
||||||
|
}
|
||||||
|
|
||||||
|
// Files modified by both
|
||||||
|
var sharedModified []string
|
||||||
|
for f := range newerModified {
|
||||||
|
if olderModified[f] {
|
||||||
|
sharedModified = append(sharedModified, f)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Files that newer reads which older modified
|
||||||
|
var newerReadsOlderModified []string
|
||||||
|
for _, f := range newer.FilesRead {
|
||||||
|
if olderModified[f] {
|
||||||
|
newerReadsOlderModified = append(newerReadsOlderModified, f)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Calculate overlap score
|
||||||
|
overlap := len(sharedModified) + len(newerReadsOlderModified)
|
||||||
|
if overlap == 0 {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// Determine relationship type based on observation types and file overlap
|
||||||
|
relType := RelationRelatesTo
|
||||||
|
confidence := 0.5 + float64(overlap)*0.1 // Base 0.5, +0.1 per overlapping file
|
||||||
|
|
||||||
|
// Type-based relationship inference
|
||||||
|
switch {
|
||||||
|
case newer.Type == ObsTypeBugfix && (older.Type == ObsTypeDecision || older.Type == ObsTypeFeature):
|
||||||
|
relType = RelationFixes
|
||||||
|
confidence += 0.2
|
||||||
|
case newer.Type == ObsTypeRefactor && older.Type == ObsTypeDiscovery:
|
||||||
|
relType = RelationEvolvesFrom
|
||||||
|
confidence += 0.15
|
||||||
|
case newer.Type == older.Type && len(sharedModified) > 0:
|
||||||
|
relType = RelationSupersedes
|
||||||
|
confidence += 0.1
|
||||||
|
case newer.Type == ObsTypeFeature && older.Type == ObsTypeDecision:
|
||||||
|
relType = RelationDependsOn
|
||||||
|
confidence += 0.15
|
||||||
|
}
|
||||||
|
|
||||||
|
if confidence > 1.0 {
|
||||||
|
confidence = 1.0
|
||||||
|
}
|
||||||
|
|
||||||
|
reason := buildFileOverlapReason(sharedModified, newerReadsOlderModified)
|
||||||
|
|
||||||
|
return &RelationDetectionResult{
|
||||||
|
SourceID: newer.ID,
|
||||||
|
TargetID: older.ID,
|
||||||
|
RelationType: relType,
|
||||||
|
Confidence: confidence,
|
||||||
|
DetectionSource: DetectionSourceFileOverlap,
|
||||||
|
Reason: reason,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// buildFileOverlapReason creates a human-readable reason for file overlap relation.
|
||||||
|
func buildFileOverlapReason(shared, readsModified []string) string {
|
||||||
|
parts := []string{}
|
||||||
|
if len(shared) > 0 {
|
||||||
|
parts = append(parts, "both modified: "+strings.Join(truncateList(shared, 3), ", "))
|
||||||
|
}
|
||||||
|
if len(readsModified) > 0 {
|
||||||
|
parts = append(parts, "reads files modified by older: "+strings.Join(truncateList(readsModified, 3), ", "))
|
||||||
|
}
|
||||||
|
return strings.Join(parts, "; ")
|
||||||
|
}
|
||||||
|
|
||||||
|
// DetectConceptOverlapRelation checks if observations share concepts.
|
||||||
|
func DetectConceptOverlapRelation(newer, older *Observation) *RelationDetectionResult {
|
||||||
|
newerConcepts := make(map[string]bool)
|
||||||
|
for _, c := range newer.Concepts {
|
||||||
|
newerConcepts[c] = true
|
||||||
|
}
|
||||||
|
|
||||||
|
var shared []string
|
||||||
|
for _, c := range older.Concepts {
|
||||||
|
if newerConcepts[c] {
|
||||||
|
shared = append(shared, c)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if len(shared) == 0 {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// Calculate confidence based on overlap ratio
|
||||||
|
totalUniqueConcepts := len(newerConcepts)
|
||||||
|
for _, c := range older.Concepts {
|
||||||
|
if !newerConcepts[c] {
|
||||||
|
totalUniqueConcepts++
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
overlapRatio := float64(len(shared)) / float64(totalUniqueConcepts)
|
||||||
|
confidence := 0.3 + overlapRatio*0.5 // Base 0.3, scale with overlap
|
||||||
|
|
||||||
|
// Boost for important concepts
|
||||||
|
for _, c := range shared {
|
||||||
|
if isHighValueConcept(c) {
|
||||||
|
confidence += 0.1
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if confidence > 1.0 {
|
||||||
|
confidence = 1.0
|
||||||
|
}
|
||||||
|
|
||||||
|
return &RelationDetectionResult{
|
||||||
|
SourceID: newer.ID,
|
||||||
|
TargetID: older.ID,
|
||||||
|
RelationType: RelationRelatesTo,
|
||||||
|
Confidence: confidence,
|
||||||
|
DetectionSource: DetectionSourceConceptOverlap,
|
||||||
|
Reason: "shared concepts: " + strings.Join(truncateList(shared, 5), ", "),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// isHighValueConcept returns true for concepts that strongly indicate relationships.
|
||||||
|
func isHighValueConcept(concept string) bool {
|
||||||
|
highValue := map[string]bool{
|
||||||
|
"security": true,
|
||||||
|
"architecture": true,
|
||||||
|
"gotcha": true,
|
||||||
|
"anti-pattern": true,
|
||||||
|
"best-practice": true,
|
||||||
|
"error-handling": true,
|
||||||
|
}
|
||||||
|
return highValue[concept]
|
||||||
|
}
|
||||||
|
|
||||||
|
// DetectTypeProgressionRelation checks for natural type progressions.
|
||||||
|
// Example: discovery -> decision -> feature -> bugfix
|
||||||
|
func DetectTypeProgressionRelation(newer, older *Observation) *RelationDetectionResult {
|
||||||
|
// Define natural type progressions
|
||||||
|
progressions := map[ObservationType][]ObservationType{
|
||||||
|
ObsTypeBugfix: {ObsTypeDiscovery, ObsTypeFeature, ObsTypeDecision},
|
||||||
|
ObsTypeFeature: {ObsTypeDiscovery, ObsTypeDecision},
|
||||||
|
ObsTypeRefactor: {ObsTypeDiscovery, ObsTypeFeature, ObsTypeBugfix},
|
||||||
|
ObsTypeDecision: {ObsTypeDiscovery},
|
||||||
|
ObsTypeChange: {ObsTypeDiscovery, ObsTypeDecision},
|
||||||
|
}
|
||||||
|
|
||||||
|
validPredecessors, ok := progressions[newer.Type]
|
||||||
|
if !ok {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
isValidProgression := false
|
||||||
|
for _, pred := range validPredecessors {
|
||||||
|
if older.Type == pred {
|
||||||
|
isValidProgression = true
|
||||||
|
break
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if !isValidProgression {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// Determine relationship type based on progression
|
||||||
|
var relType RelationType
|
||||||
|
var confidence float64 = 0.4
|
||||||
|
|
||||||
|
switch {
|
||||||
|
case newer.Type == ObsTypeBugfix && older.Type == ObsTypeDiscovery:
|
||||||
|
relType = RelationFixes
|
||||||
|
confidence = 0.6
|
||||||
|
case newer.Type == ObsTypeBugfix && older.Type == ObsTypeFeature:
|
||||||
|
relType = RelationFixes
|
||||||
|
confidence = 0.5
|
||||||
|
case newer.Type == ObsTypeFeature && older.Type == ObsTypeDecision:
|
||||||
|
relType = RelationDependsOn
|
||||||
|
confidence = 0.6
|
||||||
|
case newer.Type == ObsTypeRefactor:
|
||||||
|
relType = RelationEvolvesFrom
|
||||||
|
confidence = 0.5
|
||||||
|
default:
|
||||||
|
relType = RelationRelatesTo
|
||||||
|
}
|
||||||
|
|
||||||
|
return &RelationDetectionResult{
|
||||||
|
SourceID: newer.ID,
|
||||||
|
TargetID: older.ID,
|
||||||
|
RelationType: relType,
|
||||||
|
Confidence: confidence,
|
||||||
|
DetectionSource: DetectionSourceTypeProgression,
|
||||||
|
Reason: string(older.Type) + " -> " + string(newer.Type) + " progression",
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// DetectTemporalProximityRelation checks if observations are temporally close (same session).
|
||||||
|
func DetectTemporalProximityRelation(newer, older *Observation) *RelationDetectionResult {
|
||||||
|
// Only relate observations from the same session
|
||||||
|
if newer.SDKSessionID != older.SDKSessionID {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// Check temporal proximity (within 5 minutes)
|
||||||
|
timeDiffMs := newer.CreatedAtEpoch - older.CreatedAtEpoch
|
||||||
|
if timeDiffMs < 0 {
|
||||||
|
timeDiffMs = -timeDiffMs
|
||||||
|
}
|
||||||
|
|
||||||
|
fiveMinutesMs := int64(5 * 60 * 1000)
|
||||||
|
if timeDiffMs > fiveMinutesMs {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// Calculate confidence based on temporal proximity
|
||||||
|
// Closer = higher confidence
|
||||||
|
proximityRatio := 1.0 - (float64(timeDiffMs) / float64(fiveMinutesMs))
|
||||||
|
confidence := 0.3 + proximityRatio*0.4
|
||||||
|
|
||||||
|
return &RelationDetectionResult{
|
||||||
|
SourceID: newer.ID,
|
||||||
|
TargetID: older.ID,
|
||||||
|
RelationType: RelationRelatesTo,
|
||||||
|
Confidence: confidence,
|
||||||
|
DetectionSource: DetectionSourceTemporalProximity,
|
||||||
|
Reason: "same session, close timestamps",
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// NarrativeMentionPatterns are patterns that indicate explicit relationships in narratives.
|
||||||
|
var NarrativeMentionPatterns = []struct {
|
||||||
|
Pattern string
|
||||||
|
RelationType RelationType
|
||||||
|
ConfBoost float64
|
||||||
|
}{
|
||||||
|
{" caused ", RelationCauses, 0.3},
|
||||||
|
{" causes ", RelationCauses, 0.3},
|
||||||
|
{" because of ", RelationCauses, 0.25},
|
||||||
|
{" due to ", RelationCauses, 0.2},
|
||||||
|
{" fixes ", RelationFixes, 0.3},
|
||||||
|
{" fixed ", RelationFixes, 0.3},
|
||||||
|
{" resolves ", RelationFixes, 0.3},
|
||||||
|
{" addresses ", RelationFixes, 0.25},
|
||||||
|
{" replaces ", RelationSupersedes, 0.3},
|
||||||
|
{" supersedes ", RelationSupersedes, 0.35},
|
||||||
|
{" instead of ", RelationSupersedes, 0.25},
|
||||||
|
{" depends on ", RelationDependsOn, 0.3},
|
||||||
|
{" requires ", RelationDependsOn, 0.25},
|
||||||
|
{" builds on ", RelationDependsOn, 0.25},
|
||||||
|
{" based on ", RelationDependsOn, 0.2},
|
||||||
|
{" related to ", RelationRelatesTo, 0.2},
|
||||||
|
{" similar to ", RelationRelatesTo, 0.2},
|
||||||
|
{" evolved from ", RelationEvolvesFrom, 0.3},
|
||||||
|
{" improved from ", RelationEvolvesFrom, 0.25},
|
||||||
|
{" refined from ", RelationEvolvesFrom, 0.25},
|
||||||
|
}
|
||||||
|
|
||||||
|
// DetectNarrativeMentionRelation checks if newer observation's narrative mentions relationship.
|
||||||
|
func DetectNarrativeMentionRelation(newer, older *Observation) *RelationDetectionResult {
|
||||||
|
if !newer.Narrative.Valid || newer.Narrative.String == "" {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
narrative := strings.ToLower(newer.Narrative.String)
|
||||||
|
|
||||||
|
// Check for patterns
|
||||||
|
for _, p := range NarrativeMentionPatterns {
|
||||||
|
if strings.Contains(narrative, p.Pattern) {
|
||||||
|
// Found a pattern - this is a potential relationship
|
||||||
|
confidence := 0.4 + p.ConfBoost
|
||||||
|
if confidence > 1.0 {
|
||||||
|
confidence = 1.0
|
||||||
|
}
|
||||||
|
|
||||||
|
return &RelationDetectionResult{
|
||||||
|
SourceID: newer.ID,
|
||||||
|
TargetID: older.ID,
|
||||||
|
RelationType: p.RelationType,
|
||||||
|
Confidence: confidence,
|
||||||
|
DetectionSource: DetectionSourceNarrativeMention,
|
||||||
|
Reason: "narrative contains '" + strings.TrimSpace(p.Pattern) + "' language",
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// DetectRelationsWithExisting checks a new observation against existing ones and returns detected relations.
|
||||||
|
// This is the main entry point for relation detection.
|
||||||
|
func DetectRelationsWithExisting(newer *Observation, existing []*Observation, minConfidence float64) []*RelationDetectionResult {
|
||||||
|
var results []*RelationDetectionResult
|
||||||
|
seen := make(map[int64]bool)
|
||||||
|
|
||||||
|
for _, older := range existing {
|
||||||
|
// Skip self
|
||||||
|
if older.ID == newer.ID {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
// Skip if already superseded
|
||||||
|
if older.IsSuperseded {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
// Only compare within same project (or both global)
|
||||||
|
if newer.Project != older.Project && newer.Scope != ScopeGlobal && older.Scope != ScopeGlobal {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
// Run all detection methods and keep highest confidence result per target
|
||||||
|
var bestResult *RelationDetectionResult
|
||||||
|
|
||||||
|
// 1. File overlap detection
|
||||||
|
if result := DetectFileOverlapRelation(newer, older); result != nil && result.Confidence >= minConfidence {
|
||||||
|
if bestResult == nil || result.Confidence > bestResult.Confidence {
|
||||||
|
bestResult = result
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// 2. Concept overlap detection
|
||||||
|
if result := DetectConceptOverlapRelation(newer, older); result != nil && result.Confidence >= minConfidence {
|
||||||
|
if bestResult == nil || result.Confidence > bestResult.Confidence {
|
||||||
|
bestResult = result
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// 3. Type progression detection
|
||||||
|
if result := DetectTypeProgressionRelation(newer, older); result != nil && result.Confidence >= minConfidence {
|
||||||
|
if bestResult == nil || result.Confidence > bestResult.Confidence {
|
||||||
|
bestResult = result
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// 4. Temporal proximity detection
|
||||||
|
if result := DetectTemporalProximityRelation(newer, older); result != nil && result.Confidence >= minConfidence {
|
||||||
|
// Only use temporal proximity if no better detection found
|
||||||
|
if bestResult == nil {
|
||||||
|
bestResult = result
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// 5. Narrative mention detection (can upgrade relation type)
|
||||||
|
if result := DetectNarrativeMentionRelation(newer, older); result != nil && result.Confidence >= minConfidence {
|
||||||
|
if bestResult == nil || result.Confidence > bestResult.Confidence {
|
||||||
|
bestResult = result
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Add best result if found and not already seen
|
||||||
|
if bestResult != nil && !seen[older.ID] {
|
||||||
|
results = append(results, bestResult)
|
||||||
|
seen[older.ID] = true
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return results
|
||||||
|
}
|
||||||
|
|
||||||
|
// truncateList truncates a list to maxLen items.
|
||||||
|
func truncateList(items []string, maxLen int) []string {
|
||||||
|
if len(items) <= maxLen {
|
||||||
|
return items
|
||||||
|
}
|
||||||
|
result := items[:maxLen]
|
||||||
|
return append(result, "...")
|
||||||
|
}
|
||||||
|
|
||||||
|
// RelationWithDetails contains a relation with its observation details.
|
||||||
|
type RelationWithDetails struct {
|
||||||
|
Relation *ObservationRelation `json:"relation"`
|
||||||
|
SourceTitle string `json:"source_title"`
|
||||||
|
TargetTitle string `json:"target_title"`
|
||||||
|
SourceType ObservationType `json:"source_type"`
|
||||||
|
TargetType ObservationType `json:"target_type"`
|
||||||
|
}
|
||||||
|
|
||||||
|
// RelationGraph represents a graph of related observations.
|
||||||
|
type RelationGraph struct {
|
||||||
|
CenterID int64 `json:"center_id"`
|
||||||
|
Relations []*RelationWithDetails `json:"relations"`
|
||||||
|
}
|
||||||
@@ -0,0 +1,473 @@
|
|||||||
|
// Package models contains domain models for claude-mnemonic.
|
||||||
|
package models
|
||||||
|
|
||||||
|
import (
|
||||||
|
"database/sql"
|
||||||
|
"testing"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestDetectFileOverlapRelation(t *testing.T) {
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
newer *Observation
|
||||||
|
older *Observation
|
||||||
|
wantRelation bool
|
||||||
|
wantRelType RelationType
|
||||||
|
wantMinConfid float64
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
name: "no file overlap",
|
||||||
|
newer: &Observation{
|
||||||
|
ID: 1,
|
||||||
|
FilesModified: []string{"file1.go", "file2.go"},
|
||||||
|
},
|
||||||
|
older: &Observation{
|
||||||
|
ID: 2,
|
||||||
|
FilesModified: []string{"file3.go", "file4.go"},
|
||||||
|
},
|
||||||
|
wantRelation: false,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "shared modified files",
|
||||||
|
newer: &Observation{
|
||||||
|
ID: 1,
|
||||||
|
Type: ObsTypeRefactor,
|
||||||
|
FilesModified: []string{"shared.go", "file2.go"},
|
||||||
|
},
|
||||||
|
older: &Observation{
|
||||||
|
ID: 2,
|
||||||
|
Type: ObsTypeRefactor,
|
||||||
|
FilesModified: []string{"shared.go", "file4.go"},
|
||||||
|
},
|
||||||
|
wantRelation: true,
|
||||||
|
wantRelType: RelationSupersedes,
|
||||||
|
wantMinConfid: 0.5,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "bugfix on feature file",
|
||||||
|
newer: &Observation{
|
||||||
|
ID: 1,
|
||||||
|
Type: ObsTypeBugfix,
|
||||||
|
FilesModified: []string{"feature.go"},
|
||||||
|
},
|
||||||
|
older: &Observation{
|
||||||
|
ID: 2,
|
||||||
|
Type: ObsTypeFeature,
|
||||||
|
FilesModified: []string{"feature.go"},
|
||||||
|
},
|
||||||
|
wantRelation: true,
|
||||||
|
wantRelType: RelationFixes,
|
||||||
|
wantMinConfid: 0.6,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "newer reads older modified",
|
||||||
|
newer: &Observation{
|
||||||
|
ID: 1,
|
||||||
|
Type: ObsTypeChange,
|
||||||
|
FilesRead: []string{"dep.go"},
|
||||||
|
FilesModified: []string{"caller.go"},
|
||||||
|
},
|
||||||
|
older: &Observation{
|
||||||
|
ID: 2,
|
||||||
|
Type: ObsTypeDecision,
|
||||||
|
FilesModified: []string{"dep.go"},
|
||||||
|
},
|
||||||
|
wantRelation: true,
|
||||||
|
wantMinConfid: 0.5,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tt := range tests {
|
||||||
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
result := DetectFileOverlapRelation(tt.newer, tt.older)
|
||||||
|
|
||||||
|
if tt.wantRelation {
|
||||||
|
if result == nil {
|
||||||
|
t.Fatal("expected relation, got nil")
|
||||||
|
}
|
||||||
|
if tt.wantRelType != "" && result.RelationType != tt.wantRelType {
|
||||||
|
t.Errorf("relation type = %v, want %v", result.RelationType, tt.wantRelType)
|
||||||
|
}
|
||||||
|
if result.Confidence < tt.wantMinConfid {
|
||||||
|
t.Errorf("confidence = %v, want at least %v", result.Confidence, tt.wantMinConfid)
|
||||||
|
}
|
||||||
|
if result.DetectionSource != DetectionSourceFileOverlap {
|
||||||
|
t.Errorf("source = %v, want %v", result.DetectionSource, DetectionSourceFileOverlap)
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
if result != nil {
|
||||||
|
t.Errorf("expected no relation, got %+v", result)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestDetectConceptOverlapRelation(t *testing.T) {
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
newer *Observation
|
||||||
|
older *Observation
|
||||||
|
wantRelation bool
|
||||||
|
wantMinConfid float64
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
name: "no concept overlap",
|
||||||
|
newer: &Observation{
|
||||||
|
ID: 1,
|
||||||
|
Concepts: []string{"auth", "api"},
|
||||||
|
},
|
||||||
|
older: &Observation{
|
||||||
|
ID: 2,
|
||||||
|
Concepts: []string{"database", "caching"},
|
||||||
|
},
|
||||||
|
wantRelation: false,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "shared concepts",
|
||||||
|
newer: &Observation{
|
||||||
|
ID: 1,
|
||||||
|
Concepts: []string{"security", "auth"},
|
||||||
|
},
|
||||||
|
older: &Observation{
|
||||||
|
ID: 2,
|
||||||
|
Concepts: []string{"security", "validation"},
|
||||||
|
},
|
||||||
|
wantRelation: true,
|
||||||
|
wantMinConfid: 0.4, // security is a high-value concept
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "multiple shared concepts",
|
||||||
|
newer: &Observation{
|
||||||
|
ID: 1,
|
||||||
|
Concepts: []string{"auth", "api", "validation"},
|
||||||
|
},
|
||||||
|
older: &Observation{
|
||||||
|
ID: 2,
|
||||||
|
Concepts: []string{"auth", "api", "database"},
|
||||||
|
},
|
||||||
|
wantRelation: true,
|
||||||
|
wantMinConfid: 0.5,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tt := range tests {
|
||||||
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
result := DetectConceptOverlapRelation(tt.newer, tt.older)
|
||||||
|
|
||||||
|
if tt.wantRelation {
|
||||||
|
if result == nil {
|
||||||
|
t.Fatal("expected relation, got nil")
|
||||||
|
}
|
||||||
|
if result.Confidence < tt.wantMinConfid {
|
||||||
|
t.Errorf("confidence = %v, want at least %v", result.Confidence, tt.wantMinConfid)
|
||||||
|
}
|
||||||
|
if result.DetectionSource != DetectionSourceConceptOverlap {
|
||||||
|
t.Errorf("source = %v, want %v", result.DetectionSource, DetectionSourceConceptOverlap)
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
if result != nil {
|
||||||
|
t.Errorf("expected no relation, got %+v", result)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestDetectTypeProgressionRelation(t *testing.T) {
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
newerType ObservationType
|
||||||
|
olderType ObservationType
|
||||||
|
wantRelation bool
|
||||||
|
wantRelType RelationType
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
name: "bugfix fixes discovery",
|
||||||
|
newerType: ObsTypeBugfix,
|
||||||
|
olderType: ObsTypeDiscovery,
|
||||||
|
wantRelation: true,
|
||||||
|
wantRelType: RelationFixes,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "bugfix fixes feature",
|
||||||
|
newerType: ObsTypeBugfix,
|
||||||
|
olderType: ObsTypeFeature,
|
||||||
|
wantRelation: true,
|
||||||
|
wantRelType: RelationFixes,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "feature depends on decision",
|
||||||
|
newerType: ObsTypeFeature,
|
||||||
|
olderType: ObsTypeDecision,
|
||||||
|
wantRelation: true,
|
||||||
|
wantRelType: RelationDependsOn,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "refactor evolves from discovery",
|
||||||
|
newerType: ObsTypeRefactor,
|
||||||
|
olderType: ObsTypeDiscovery,
|
||||||
|
wantRelation: true,
|
||||||
|
wantRelType: RelationEvolvesFrom,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "no progression discovery to bugfix",
|
||||||
|
newerType: ObsTypeDiscovery,
|
||||||
|
olderType: ObsTypeBugfix,
|
||||||
|
wantRelation: false,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tt := range tests {
|
||||||
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
newer := &Observation{ID: 1, Type: tt.newerType}
|
||||||
|
older := &Observation{ID: 2, Type: tt.olderType}
|
||||||
|
result := DetectTypeProgressionRelation(newer, older)
|
||||||
|
|
||||||
|
if tt.wantRelation {
|
||||||
|
if result == nil {
|
||||||
|
t.Fatal("expected relation, got nil")
|
||||||
|
}
|
||||||
|
if result.RelationType != tt.wantRelType {
|
||||||
|
t.Errorf("relation type = %v, want %v", result.RelationType, tt.wantRelType)
|
||||||
|
}
|
||||||
|
if result.DetectionSource != DetectionSourceTypeProgression {
|
||||||
|
t.Errorf("source = %v, want %v", result.DetectionSource, DetectionSourceTypeProgression)
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
if result != nil {
|
||||||
|
t.Errorf("expected no relation, got %+v", result)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestDetectTemporalProximityRelation(t *testing.T) {
|
||||||
|
baseTime := int64(1700000000000) // some base epoch ms
|
||||||
|
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
newerSession string
|
||||||
|
olderSession string
|
||||||
|
newerTime int64
|
||||||
|
olderTime int64
|
||||||
|
wantRelation bool
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
name: "same session close time",
|
||||||
|
newerSession: "session-1",
|
||||||
|
olderSession: "session-1",
|
||||||
|
newerTime: baseTime + 60000, // 1 minute later
|
||||||
|
olderTime: baseTime,
|
||||||
|
wantRelation: true,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "same session far apart",
|
||||||
|
newerSession: "session-1",
|
||||||
|
olderSession: "session-1",
|
||||||
|
newerTime: baseTime + 600000, // 10 minutes later
|
||||||
|
olderTime: baseTime,
|
||||||
|
wantRelation: false, // > 5 minutes
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "different sessions close time",
|
||||||
|
newerSession: "session-1",
|
||||||
|
olderSession: "session-2",
|
||||||
|
newerTime: baseTime + 30000,
|
||||||
|
olderTime: baseTime,
|
||||||
|
wantRelation: false, // different sessions
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tt := range tests {
|
||||||
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
newer := &Observation{
|
||||||
|
ID: 1,
|
||||||
|
SDKSessionID: tt.newerSession,
|
||||||
|
CreatedAtEpoch: tt.newerTime,
|
||||||
|
}
|
||||||
|
older := &Observation{
|
||||||
|
ID: 2,
|
||||||
|
SDKSessionID: tt.olderSession,
|
||||||
|
CreatedAtEpoch: tt.olderTime,
|
||||||
|
}
|
||||||
|
result := DetectTemporalProximityRelation(newer, older)
|
||||||
|
|
||||||
|
if tt.wantRelation {
|
||||||
|
if result == nil {
|
||||||
|
t.Fatal("expected relation, got nil")
|
||||||
|
}
|
||||||
|
if result.DetectionSource != DetectionSourceTemporalProximity {
|
||||||
|
t.Errorf("source = %v, want %v", result.DetectionSource, DetectionSourceTemporalProximity)
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
if result != nil {
|
||||||
|
t.Errorf("expected no relation, got %+v", result)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestDetectNarrativeMentionRelation(t *testing.T) {
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
narrative string
|
||||||
|
wantRelation bool
|
||||||
|
wantRelType RelationType
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
name: "fixes language",
|
||||||
|
narrative: "This change fixes the issue with authentication",
|
||||||
|
wantRelation: true,
|
||||||
|
wantRelType: RelationFixes,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "causes language",
|
||||||
|
narrative: "This decision caused unexpected side effects",
|
||||||
|
wantRelation: true,
|
||||||
|
wantRelType: RelationCauses,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "supersedes language",
|
||||||
|
narrative: "This approach supersedes the previous workaround",
|
||||||
|
wantRelation: true,
|
||||||
|
wantRelType: RelationSupersedes,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "depends on language",
|
||||||
|
narrative: "This feature depends on the authentication module",
|
||||||
|
wantRelation: true,
|
||||||
|
wantRelType: RelationDependsOn,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "no relationship language",
|
||||||
|
narrative: "Added new feature for user management",
|
||||||
|
wantRelation: false,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tt := range tests {
|
||||||
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
newer := &Observation{
|
||||||
|
ID: 1,
|
||||||
|
Narrative: sql.NullString{String: tt.narrative, Valid: true},
|
||||||
|
}
|
||||||
|
older := &Observation{ID: 2}
|
||||||
|
result := DetectNarrativeMentionRelation(newer, older)
|
||||||
|
|
||||||
|
if tt.wantRelation {
|
||||||
|
if result == nil {
|
||||||
|
t.Fatal("expected relation, got nil")
|
||||||
|
}
|
||||||
|
if result.RelationType != tt.wantRelType {
|
||||||
|
t.Errorf("relation type = %v, want %v", result.RelationType, tt.wantRelType)
|
||||||
|
}
|
||||||
|
if result.DetectionSource != DetectionSourceNarrativeMention {
|
||||||
|
t.Errorf("source = %v, want %v", result.DetectionSource, DetectionSourceNarrativeMention)
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
if result != nil {
|
||||||
|
t.Errorf("expected no relation, got %+v", result)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestDetectRelationsWithExisting(t *testing.T) {
|
||||||
|
newer := &Observation{
|
||||||
|
ID: 1,
|
||||||
|
SDKSessionID: "session-1",
|
||||||
|
Project: "test-project",
|
||||||
|
Type: ObsTypeBugfix,
|
||||||
|
FilesModified: []string{"auth.go"},
|
||||||
|
Concepts: []string{"security", "auth"},
|
||||||
|
Narrative: sql.NullString{String: "Fixed security issue in auth module", Valid: true},
|
||||||
|
}
|
||||||
|
|
||||||
|
existing := []*Observation{
|
||||||
|
{
|
||||||
|
ID: 2,
|
||||||
|
SDKSessionID: "session-1",
|
||||||
|
Project: "test-project",
|
||||||
|
Type: ObsTypeDiscovery,
|
||||||
|
FilesModified: []string{"auth.go"},
|
||||||
|
Concepts: []string{"security"},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
ID: 3,
|
||||||
|
SDKSessionID: "session-2",
|
||||||
|
Project: "test-project",
|
||||||
|
Type: ObsTypeFeature,
|
||||||
|
FilesModified: []string{"other.go"},
|
||||||
|
Concepts: []string{"api"},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
ID: 4,
|
||||||
|
SDKSessionID: "session-1",
|
||||||
|
Project: "other-project", // different project
|
||||||
|
Type: ObsTypeDiscovery,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
results := DetectRelationsWithExisting(newer, existing, 0.4)
|
||||||
|
|
||||||
|
// Should find relation with observation 2 (file overlap + concept overlap + type progression)
|
||||||
|
// Should not find relation with observation 3 (no overlap)
|
||||||
|
// Should not find relation with observation 4 (different project)
|
||||||
|
|
||||||
|
if len(results) == 0 {
|
||||||
|
t.Fatal("expected at least one relation")
|
||||||
|
}
|
||||||
|
|
||||||
|
// Check that we found relation with observation 2
|
||||||
|
foundObs2 := false
|
||||||
|
for _, r := range results {
|
||||||
|
if r.TargetID == 2 {
|
||||||
|
foundObs2 = true
|
||||||
|
// Should be high confidence due to multiple signals
|
||||||
|
if r.Confidence < 0.5 {
|
||||||
|
t.Errorf("expected higher confidence for obs 2, got %v", r.Confidence)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
// Should not find relation with obs 4
|
||||||
|
if r.TargetID == 4 {
|
||||||
|
t.Error("should not find relation with different project")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if !foundObs2 {
|
||||||
|
t.Error("expected to find relation with observation 2")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestNewObservationRelation(t *testing.T) {
|
||||||
|
rel := NewObservationRelation(1, 2, RelationFixes, 0.8, DetectionSourceFileOverlap, "test reason")
|
||||||
|
|
||||||
|
if rel.SourceID != 1 {
|
||||||
|
t.Errorf("SourceID = %v, want 1", rel.SourceID)
|
||||||
|
}
|
||||||
|
if rel.TargetID != 2 {
|
||||||
|
t.Errorf("TargetID = %v, want 2", rel.TargetID)
|
||||||
|
}
|
||||||
|
if rel.RelationType != RelationFixes {
|
||||||
|
t.Errorf("RelationType = %v, want %v", rel.RelationType, RelationFixes)
|
||||||
|
}
|
||||||
|
if rel.Confidence != 0.8 {
|
||||||
|
t.Errorf("Confidence = %v, want 0.8", rel.Confidence)
|
||||||
|
}
|
||||||
|
if rel.DetectionSource != DetectionSourceFileOverlap {
|
||||||
|
t.Errorf("DetectionSource = %v, want %v", rel.DetectionSource, DetectionSourceFileOverlap)
|
||||||
|
}
|
||||||
|
if rel.Reason != "test reason" {
|
||||||
|
t.Errorf("Reason = %v, want 'test reason'", rel.Reason)
|
||||||
|
}
|
||||||
|
if rel.CreatedAt == "" {
|
||||||
|
t.Error("CreatedAt should be set")
|
||||||
|
}
|
||||||
|
if rel.CreatedAtEpoch == 0 {
|
||||||
|
t.Error("CreatedAtEpoch should be set")
|
||||||
|
}
|
||||||
|
}
|
||||||
@@ -0,0 +1,112 @@
|
|||||||
|
// Package models contains domain models for claude-mnemonic.
|
||||||
|
package models
|
||||||
|
|
||||||
|
// ConceptWeight represents a configurable weight for a concept.
|
||||||
|
type ConceptWeight struct {
|
||||||
|
Concept string `db:"concept" json:"concept"`
|
||||||
|
Weight float64 `db:"weight" json:"weight"`
|
||||||
|
UpdatedAt string `db:"updated_at" json:"updated_at"`
|
||||||
|
}
|
||||||
|
|
||||||
|
// UserFeedbackType represents the type of user feedback.
|
||||||
|
type UserFeedbackType int
|
||||||
|
|
||||||
|
const (
|
||||||
|
// FeedbackNegative represents a thumbs down.
|
||||||
|
FeedbackNegative UserFeedbackType = -1
|
||||||
|
// FeedbackNeutral represents no feedback.
|
||||||
|
FeedbackNeutral UserFeedbackType = 0
|
||||||
|
// FeedbackPositive represents a thumbs up.
|
||||||
|
FeedbackPositive UserFeedbackType = 1
|
||||||
|
)
|
||||||
|
|
||||||
|
// DefaultConceptWeights contains the default weights for concepts.
|
||||||
|
// Higher weights indicate more important concepts.
|
||||||
|
var DefaultConceptWeights = map[string]float64{
|
||||||
|
// Critical concepts (0.25-0.30)
|
||||||
|
"security": 0.30, // Security issues are most critical
|
||||||
|
|
||||||
|
// High importance (0.20-0.25)
|
||||||
|
"gotcha": 0.25, // Gotchas prevent future mistakes
|
||||||
|
"best-practice": 0.20, // Best practices guide development
|
||||||
|
"anti-pattern": 0.20, // Anti-patterns prevent bad code
|
||||||
|
|
||||||
|
// Medium importance (0.10-0.15)
|
||||||
|
"architecture": 0.15, // Architectural decisions have lasting impact
|
||||||
|
"performance": 0.15, // Performance patterns matter for scale
|
||||||
|
"error-handling": 0.15, // Error handling prevents failures
|
||||||
|
"pattern": 0.10, // General patterns are useful
|
||||||
|
"testing": 0.10, // Testing knowledge helps quality
|
||||||
|
"debugging": 0.10, // Debugging tips save time
|
||||||
|
"problem-solution": 0.10, // Problem-solution pairs are actionable
|
||||||
|
"trade-off": 0.10, // Trade-offs inform decisions
|
||||||
|
|
||||||
|
// Lower importance (0.05)
|
||||||
|
"workflow": 0.05, // Workflow optimizations are nice-to-have
|
||||||
|
"tooling": 0.05, // Tooling preferences are subjective
|
||||||
|
"how-it-works": 0.05, // Understanding is foundational but less urgent
|
||||||
|
"why-it-exists": 0.05, // Context is helpful but less actionable
|
||||||
|
"what-changed": 0.05, // Changes are informational
|
||||||
|
}
|
||||||
|
|
||||||
|
// TypeBaseScores contains the base importance multipliers for each observation type.
|
||||||
|
// These are multiplied with the core score to weight different observation types.
|
||||||
|
var TypeBaseScores = map[ObservationType]float64{
|
||||||
|
ObsTypeBugfix: 1.3, // Bugfixes are valuable - prevent regressions
|
||||||
|
ObsTypeFeature: 1.2, // New features expand capabilities
|
||||||
|
ObsTypeDiscovery: 1.1, // Discoveries inform future work
|
||||||
|
ObsTypeDecision: 1.1, // Architectural decisions guide development
|
||||||
|
ObsTypeRefactor: 1.0, // Refactoring is neutral
|
||||||
|
ObsTypeChange: 0.9, // Minor changes are slightly less important
|
||||||
|
}
|
||||||
|
|
||||||
|
// ScoringConfig contains all scoring weights and parameters.
|
||||||
|
type ScoringConfig struct {
|
||||||
|
// RecencyHalfLifeDays is the number of days for the importance score to halve.
|
||||||
|
// With 7 days, a 7-day old observation has 50% of a new observation's recency score.
|
||||||
|
RecencyHalfLifeDays float64 `json:"recency_half_life_days"`
|
||||||
|
|
||||||
|
// FeedbackWeight scales the user feedback contribution to final score.
|
||||||
|
// With 0.30, a thumbs up adds 0.30 to the score, thumbs down subtracts 0.30.
|
||||||
|
FeedbackWeight float64 `json:"feedback_weight"`
|
||||||
|
|
||||||
|
// ConceptWeight scales the concept boost contribution.
|
||||||
|
// The sum of matching concept weights is multiplied by this.
|
||||||
|
ConceptWeight float64 `json:"concept_weight"`
|
||||||
|
|
||||||
|
// RetrievalWeight scales the retrieval boost contribution.
|
||||||
|
// Popular observations get a logarithmic bonus.
|
||||||
|
RetrievalWeight float64 `json:"retrieval_weight"`
|
||||||
|
|
||||||
|
// ConceptWeights maps concept names to their importance weights.
|
||||||
|
ConceptWeights map[string]float64 `json:"concept_weights"`
|
||||||
|
|
||||||
|
// MinScore is the minimum allowed importance score.
|
||||||
|
// Prevents observations from completely disappearing.
|
||||||
|
MinScore float64 `json:"min_score"`
|
||||||
|
}
|
||||||
|
|
||||||
|
// DefaultScoringConfig returns the default scoring configuration.
|
||||||
|
func DefaultScoringConfig() *ScoringConfig {
|
||||||
|
conceptWeights := make(map[string]float64, len(DefaultConceptWeights))
|
||||||
|
for k, v := range DefaultConceptWeights {
|
||||||
|
conceptWeights[k] = v
|
||||||
|
}
|
||||||
|
|
||||||
|
return &ScoringConfig{
|
||||||
|
RecencyHalfLifeDays: 7.0, // Score halves every 7 days
|
||||||
|
FeedbackWeight: 0.30, // Feedback has moderate impact
|
||||||
|
ConceptWeight: 0.20, // Concept weights have smaller impact
|
||||||
|
RetrievalWeight: 0.15, // Retrieval has smallest impact
|
||||||
|
ConceptWeights: conceptWeights,
|
||||||
|
MinScore: 0.01, // Never completely disappear
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// TypeBaseScore returns the base weight for an observation type.
|
||||||
|
func TypeBaseScore(t ObservationType) float64 {
|
||||||
|
if score, ok := TypeBaseScores[t]; ok {
|
||||||
|
return score
|
||||||
|
}
|
||||||
|
return 1.0 // Default for unknown types
|
||||||
|
}
|
||||||
Executable
+121
@@ -0,0 +1,121 @@
|
|||||||
|
#!/bin/bash
|
||||||
|
# Download BGE-small-en-v1.5 model for embedding
|
||||||
|
# Usage: ./download-bge-model.sh [--force]
|
||||||
|
# Use --force to re-download even if files exist
|
||||||
|
|
||||||
|
set -e
|
||||||
|
|
||||||
|
MODEL_NAME="bge-small-en-v1.5"
|
||||||
|
MODEL_REPO="BAAI/bge-small-en-v1.5"
|
||||||
|
ASSETS_DIR="internal/embedding/assets"
|
||||||
|
VERSION_FILE="${ASSETS_DIR}/.model_version"
|
||||||
|
FORCE_DOWNLOAD=false
|
||||||
|
|
||||||
|
# Check for --force flag
|
||||||
|
for arg in "$@"; do
|
||||||
|
if [ "$arg" = "--force" ]; then
|
||||||
|
FORCE_DOWNLOAD=true
|
||||||
|
fi
|
||||||
|
done
|
||||||
|
|
||||||
|
# Temporary directory for downloads
|
||||||
|
TEMP_DIR=$(mktemp -d)
|
||||||
|
trap "rm -rf ${TEMP_DIR}" EXIT
|
||||||
|
|
||||||
|
# Check if model already exists
|
||||||
|
model_exists() {
|
||||||
|
[ -f "${ASSETS_DIR}/model.onnx" ] && [ -f "${ASSETS_DIR}/tokenizer.json" ]
|
||||||
|
}
|
||||||
|
|
||||||
|
# Get installed version
|
||||||
|
get_installed_version() {
|
||||||
|
if [ -f "$VERSION_FILE" ]; then
|
||||||
|
cat "$VERSION_FILE"
|
||||||
|
else
|
||||||
|
echo ""
|
||||||
|
fi
|
||||||
|
}
|
||||||
|
|
||||||
|
# Write version file
|
||||||
|
write_version_file() {
|
||||||
|
echo "${MODEL_NAME}" > "$VERSION_FILE"
|
||||||
|
}
|
||||||
|
|
||||||
|
download_model() {
|
||||||
|
echo "Downloading ${MODEL_NAME} from Hugging Face..."
|
||||||
|
|
||||||
|
# Create assets directory
|
||||||
|
mkdir -p "${ASSETS_DIR}"
|
||||||
|
|
||||||
|
# Download ONNX model
|
||||||
|
# BGE models have ONNX exports available in the repo
|
||||||
|
echo "Downloading ONNX model..."
|
||||||
|
curl -fsSL \
|
||||||
|
"https://huggingface.co/${MODEL_REPO}/resolve/main/onnx/model.onnx" \
|
||||||
|
-o "${TEMP_DIR}/model.onnx"
|
||||||
|
|
||||||
|
# Download tokenizer.json
|
||||||
|
echo "Downloading tokenizer..."
|
||||||
|
curl -fsSL \
|
||||||
|
"https://huggingface.co/${MODEL_REPO}/resolve/main/tokenizer.json" \
|
||||||
|
-o "${TEMP_DIR}/tokenizer.json"
|
||||||
|
|
||||||
|
# Verify files exist and have content
|
||||||
|
if [ ! -s "${TEMP_DIR}/model.onnx" ]; then
|
||||||
|
echo "Error: Failed to download model.onnx or file is empty"
|
||||||
|
exit 1
|
||||||
|
fi
|
||||||
|
|
||||||
|
if [ ! -s "${TEMP_DIR}/tokenizer.json" ]; then
|
||||||
|
echo "Error: Failed to download tokenizer.json or file is empty"
|
||||||
|
exit 1
|
||||||
|
fi
|
||||||
|
|
||||||
|
# Move to assets directory (backup old files first)
|
||||||
|
if [ -f "${ASSETS_DIR}/model.onnx" ]; then
|
||||||
|
mv "${ASSETS_DIR}/model.onnx" "${ASSETS_DIR}/model.onnx.bak"
|
||||||
|
fi
|
||||||
|
if [ -f "${ASSETS_DIR}/tokenizer.json" ]; then
|
||||||
|
mv "${ASSETS_DIR}/tokenizer.json" "${ASSETS_DIR}/tokenizer.json.bak"
|
||||||
|
fi
|
||||||
|
|
||||||
|
mv "${TEMP_DIR}/model.onnx" "${ASSETS_DIR}/model.onnx"
|
||||||
|
mv "${TEMP_DIR}/tokenizer.json" "${ASSETS_DIR}/tokenizer.json"
|
||||||
|
|
||||||
|
# Remove backups on success
|
||||||
|
rm -f "${ASSETS_DIR}/model.onnx.bak" "${ASSETS_DIR}/tokenizer.json.bak"
|
||||||
|
|
||||||
|
# Write version file
|
||||||
|
write_version_file
|
||||||
|
|
||||||
|
echo "Model size: $(du -h "${ASSETS_DIR}/model.onnx" | cut -f1)"
|
||||||
|
echo "Tokenizer size: $(du -h "${ASSETS_DIR}/tokenizer.json" | cut -f1)"
|
||||||
|
}
|
||||||
|
|
||||||
|
echo "BGE Model Downloader - ${MODEL_NAME}"
|
||||||
|
echo "=================================="
|
||||||
|
|
||||||
|
need_download=false
|
||||||
|
reason=""
|
||||||
|
|
||||||
|
if [ "$FORCE_DOWNLOAD" = true ]; then
|
||||||
|
need_download=true
|
||||||
|
reason="forced"
|
||||||
|
elif ! model_exists; then
|
||||||
|
need_download=true
|
||||||
|
reason="not found"
|
||||||
|
elif [ "$(get_installed_version)" != "${MODEL_NAME}" ]; then
|
||||||
|
need_download=true
|
||||||
|
reason="version mismatch (installed: $(get_installed_version), required: ${MODEL_NAME})"
|
||||||
|
fi
|
||||||
|
|
||||||
|
if [ "$need_download" = true ]; then
|
||||||
|
if [ -n "$reason" ] && [ "$reason" != "not found" ]; then
|
||||||
|
echo "Re-downloading: ${reason}"
|
||||||
|
fi
|
||||||
|
download_model
|
||||||
|
echo "Done! ${MODEL_NAME} installed successfully."
|
||||||
|
else
|
||||||
|
echo "Model ${MODEL_NAME} already exists, skipping download."
|
||||||
|
echo "Use --force to re-download."
|
||||||
|
fi
|
||||||
Generated
+108
-2
@@ -1,13 +1,15 @@
|
|||||||
{
|
{
|
||||||
"name": "claude-mnemonic-dashboard",
|
"name": "claude-mnemonic-dashboard",
|
||||||
"version": "v0.6.33-3-gf38ce5c-dirty",
|
"version": "0ddacaa-dirty",
|
||||||
"lockfileVersion": 3,
|
"lockfileVersion": 3,
|
||||||
"requires": true,
|
"requires": true,
|
||||||
"packages": {
|
"packages": {
|
||||||
"": {
|
"": {
|
||||||
"name": "claude-mnemonic-dashboard",
|
"name": "claude-mnemonic-dashboard",
|
||||||
"version": "v0.6.33-3-gf38ce5c-dirty",
|
"version": "0ddacaa-dirty",
|
||||||
"dependencies": {
|
"dependencies": {
|
||||||
|
"vis-data": "^7.1.9",
|
||||||
|
"vis-network": "^9.1.9",
|
||||||
"vue": "^3.5.13"
|
"vue": "^3.5.13"
|
||||||
},
|
},
|
||||||
"devDependencies": {
|
"devDependencies": {
|
||||||
@@ -81,6 +83,19 @@
|
|||||||
"node": ">=6.9.0"
|
"node": ">=6.9.0"
|
||||||
}
|
}
|
||||||
},
|
},
|
||||||
|
"node_modules/@egjs/hammerjs": {
|
||||||
|
"version": "2.0.17",
|
||||||
|
"resolved": "https://registry.npmjs.org/@egjs/hammerjs/-/hammerjs-2.0.17.tgz",
|
||||||
|
"integrity": "sha512-XQsZgjm2EcVUiZQf11UBJQfmZeEmOW8DpI1gsFeln6w0ae0ii4dMQEQ0kjl6DspdWX1aGY1/loyXnP0JS06e/A==",
|
||||||
|
"license": "MIT",
|
||||||
|
"peer": true,
|
||||||
|
"dependencies": {
|
||||||
|
"@types/hammerjs": "^2.0.36"
|
||||||
|
},
|
||||||
|
"engines": {
|
||||||
|
"node": ">=0.8.0"
|
||||||
|
}
|
||||||
|
},
|
||||||
"node_modules/@esbuild/aix-ppc64": {
|
"node_modules/@esbuild/aix-ppc64": {
|
||||||
"version": "0.25.12",
|
"version": "0.25.12",
|
||||||
"resolved": "https://registry.npmjs.org/@esbuild/aix-ppc64/-/aix-ppc64-0.25.12.tgz",
|
"resolved": "https://registry.npmjs.org/@esbuild/aix-ppc64/-/aix-ppc64-0.25.12.tgz",
|
||||||
@@ -924,6 +939,13 @@
|
|||||||
"dev": true,
|
"dev": true,
|
||||||
"license": "MIT"
|
"license": "MIT"
|
||||||
},
|
},
|
||||||
|
"node_modules/@types/hammerjs": {
|
||||||
|
"version": "2.0.46",
|
||||||
|
"resolved": "https://registry.npmjs.org/@types/hammerjs/-/hammerjs-2.0.46.tgz",
|
||||||
|
"integrity": "sha512-ynRvcq6wvqexJ9brDMS4BnBLzmr0e14d6ZJTEShTBWKymQiHwlAyGu0ZPEFI2Fh1U53F7tN9ufClWM5KvqkKOw==",
|
||||||
|
"license": "MIT",
|
||||||
|
"peer": true
|
||||||
|
},
|
||||||
"node_modules/@types/node": {
|
"node_modules/@types/node": {
|
||||||
"version": "22.19.2",
|
"version": "22.19.2",
|
||||||
"resolved": "https://registry.npmjs.org/@types/node/-/node-22.19.2.tgz",
|
"resolved": "https://registry.npmjs.org/@types/node/-/node-22.19.2.tgz",
|
||||||
@@ -1352,6 +1374,19 @@
|
|||||||
"node": ">= 6"
|
"node": ">= 6"
|
||||||
}
|
}
|
||||||
},
|
},
|
||||||
|
"node_modules/component-emitter": {
|
||||||
|
"version": "2.0.0",
|
||||||
|
"resolved": "https://registry.npmjs.org/component-emitter/-/component-emitter-2.0.0.tgz",
|
||||||
|
"integrity": "sha512-4m5s3Me2xxlVKG9PkZpQqHQR7bgpnN7joDMJ4yvVkVXngjoITG76IaZmzmywSeRTeTpc6N6r3H3+KyUurV8OYw==",
|
||||||
|
"license": "MIT",
|
||||||
|
"peer": true,
|
||||||
|
"engines": {
|
||||||
|
"node": ">=18"
|
||||||
|
},
|
||||||
|
"funding": {
|
||||||
|
"url": "https://github.com/sponsors/sindresorhus"
|
||||||
|
}
|
||||||
|
},
|
||||||
"node_modules/cssesc": {
|
"node_modules/cssesc": {
|
||||||
"version": "3.0.0",
|
"version": "3.0.0",
|
||||||
"resolved": "https://registry.npmjs.org/cssesc/-/cssesc-3.0.0.tgz",
|
"resolved": "https://registry.npmjs.org/cssesc/-/cssesc-3.0.0.tgz",
|
||||||
@@ -1669,6 +1704,13 @@
|
|||||||
"jiti": "bin/jiti.js"
|
"jiti": "bin/jiti.js"
|
||||||
}
|
}
|
||||||
},
|
},
|
||||||
|
"node_modules/keycharm": {
|
||||||
|
"version": "0.4.0",
|
||||||
|
"resolved": "https://registry.npmjs.org/keycharm/-/keycharm-0.4.0.tgz",
|
||||||
|
"integrity": "sha512-TyQTtsabOVv3MeOpR92sIKk/br9wxS+zGj4BG7CR8YbK4jM3tyIBaF0zhzeBUMx36/Q/iQLOKKOT+3jOQtemRQ==",
|
||||||
|
"license": "(Apache-2.0 OR MIT)",
|
||||||
|
"peer": true
|
||||||
|
},
|
||||||
"node_modules/lilconfig": {
|
"node_modules/lilconfig": {
|
||||||
"version": "3.1.3",
|
"version": "3.1.3",
|
||||||
"resolved": "https://registry.npmjs.org/lilconfig/-/lilconfig-3.1.3.tgz",
|
"resolved": "https://registry.npmjs.org/lilconfig/-/lilconfig-3.1.3.tgz",
|
||||||
@@ -2412,6 +2454,70 @@
|
|||||||
"dev": true,
|
"dev": true,
|
||||||
"license": "MIT"
|
"license": "MIT"
|
||||||
},
|
},
|
||||||
|
"node_modules/uuid": {
|
||||||
|
"version": "11.1.0",
|
||||||
|
"resolved": "https://registry.npmjs.org/uuid/-/uuid-11.1.0.tgz",
|
||||||
|
"integrity": "sha512-0/A9rDy9P7cJ+8w1c9WD9V//9Wj15Ce2MPz8Ri6032usz+NfePxx5AcN3bN+r6ZL6jEo066/yNYB3tn4pQEx+A==",
|
||||||
|
"funding": [
|
||||||
|
"https://github.com/sponsors/broofa",
|
||||||
|
"https://github.com/sponsors/ctavan"
|
||||||
|
],
|
||||||
|
"license": "MIT",
|
||||||
|
"peer": true,
|
||||||
|
"bin": {
|
||||||
|
"uuid": "dist/esm/bin/uuid"
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"node_modules/vis-data": {
|
||||||
|
"version": "7.1.10",
|
||||||
|
"resolved": "https://registry.npmjs.org/vis-data/-/vis-data-7.1.10.tgz",
|
||||||
|
"integrity": "sha512-23juM9tdCaHTX5vyIQ7XBzsfZU0Hny+gSTwniLrfFcmw9DOm7pi3+h9iEBsoZMp5rX6KNqWwc1MF0fkAmWVuoQ==",
|
||||||
|
"license": "(Apache-2.0 OR MIT)",
|
||||||
|
"funding": {
|
||||||
|
"type": "opencollective",
|
||||||
|
"url": "https://opencollective.com/visjs"
|
||||||
|
},
|
||||||
|
"peerDependencies": {
|
||||||
|
"uuid": "^3.4.0 || ^7.0.0 || ^8.0.0 || ^9.0.0 || ^10.0.0 || ^11.0.0",
|
||||||
|
"vis-util": "^5.0.1"
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"node_modules/vis-network": {
|
||||||
|
"version": "9.1.13",
|
||||||
|
"resolved": "https://registry.npmjs.org/vis-network/-/vis-network-9.1.13.tgz",
|
||||||
|
"integrity": "sha512-HLeHd5KZS92qzO1kC59qMh1/FWAZxMUEwUWBwDMoj6RKj/Ajkrgy/heEYo0Zc8SZNQ2J+u6omvK2+a28GX1QuQ==",
|
||||||
|
"license": "(Apache-2.0 OR MIT)",
|
||||||
|
"funding": {
|
||||||
|
"type": "opencollective",
|
||||||
|
"url": "https://opencollective.com/visjs"
|
||||||
|
},
|
||||||
|
"peerDependencies": {
|
||||||
|
"@egjs/hammerjs": "^2.0.0",
|
||||||
|
"component-emitter": "^1.3.0 || ^2.0.0",
|
||||||
|
"keycharm": "^0.2.0 || ^0.3.0 || ^0.4.0",
|
||||||
|
"uuid": "^3.4.0 || ^7.0.0 || ^8.0.0 || ^9.0.0 || ^10.0.0 || ^11.0.0",
|
||||||
|
"vis-data": "^6.3.0 || ^7.0.0",
|
||||||
|
"vis-util": "^5.0.1"
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"node_modules/vis-util": {
|
||||||
|
"version": "5.0.7",
|
||||||
|
"resolved": "https://registry.npmjs.org/vis-util/-/vis-util-5.0.7.tgz",
|
||||||
|
"integrity": "sha512-E3L03G3+trvc/X4LXvBfih3YIHcKS2WrP0XTdZefr6W6Qi/2nNCqZfe4JFfJU6DcQLm6Gxqj2Pfl+02859oL5A==",
|
||||||
|
"license": "(Apache-2.0 OR MIT)",
|
||||||
|
"peer": true,
|
||||||
|
"engines": {
|
||||||
|
"node": ">=8"
|
||||||
|
},
|
||||||
|
"funding": {
|
||||||
|
"type": "opencollective",
|
||||||
|
"url": "https://opencollective.com/visjs"
|
||||||
|
},
|
||||||
|
"peerDependencies": {
|
||||||
|
"@egjs/hammerjs": "^2.0.0",
|
||||||
|
"component-emitter": "^1.3.0 || ^2.0.0"
|
||||||
|
}
|
||||||
|
},
|
||||||
"node_modules/vite": {
|
"node_modules/vite": {
|
||||||
"version": "6.4.1",
|
"version": "6.4.1",
|
||||||
"resolved": "https://registry.npmjs.org/vite/-/vite-6.4.1.tgz",
|
"resolved": "https://registry.npmjs.org/vite/-/vite-6.4.1.tgz",
|
||||||
|
|||||||
+3
-1
@@ -1,6 +1,6 @@
|
|||||||
{
|
{
|
||||||
"name": "claude-mnemonic-dashboard",
|
"name": "claude-mnemonic-dashboard",
|
||||||
"version": "v0.6.33-3-gf38ce5c-dirty",
|
"version": "0ddacaa-dirty",
|
||||||
"private": true,
|
"private": true,
|
||||||
"type": "module",
|
"type": "module",
|
||||||
"scripts": {
|
"scripts": {
|
||||||
@@ -10,6 +10,8 @@
|
|||||||
"type-check": "vue-tsc --noEmit"
|
"type-check": "vue-tsc --noEmit"
|
||||||
},
|
},
|
||||||
"dependencies": {
|
"dependencies": {
|
||||||
|
"vis-data": "^7.1.9",
|
||||||
|
"vis-network": "^9.1.9",
|
||||||
"vue": "^3.5.13"
|
"vue": "^3.5.13"
|
||||||
},
|
},
|
||||||
"devDependencies": {
|
"devDependencies": {
|
||||||
|
|||||||
@@ -10,6 +10,8 @@
|
|||||||
"type-check": "vue-tsc --noEmit"
|
"type-check": "vue-tsc --noEmit"
|
||||||
},
|
},
|
||||||
"dependencies": {
|
"dependencies": {
|
||||||
|
"vis-data": "^7.1.9",
|
||||||
|
"vis-network": "^9.1.9",
|
||||||
"vue": "^3.5.13"
|
"vue": "^3.5.13"
|
||||||
},
|
},
|
||||||
"devDependencies": {
|
"devDependencies": {
|
||||||
|
|||||||
+1
-1
@@ -30,7 +30,7 @@ const {
|
|||||||
// Pass currentProject ref to useStats for project-specific retrieval stats
|
// Pass currentProject ref to useStats for project-specific retrieval stats
|
||||||
const { stats } = useStats(currentProject)
|
const { stats } = useStats(currentProject)
|
||||||
|
|
||||||
// Note: Timeline refresh is handled by useTimeline's SSE watcher
|
// Note: Feedback is handled directly in ObservationCard component
|
||||||
</script>
|
</script>
|
||||||
|
|
||||||
<template>
|
<template>
|
||||||
|
|||||||
@@ -1,17 +1,66 @@
|
|||||||
<script setup lang="ts">
|
<script setup lang="ts">
|
||||||
import type { ObservationFeedItem } from '@/types'
|
import type { ObservationFeedItem, RelationWithDetails } from '@/types'
|
||||||
import { TYPE_CONFIG, CONCEPT_CONFIG } from '@/types/observation'
|
import { TYPE_CONFIG, CONCEPT_CONFIG } from '@/types/observation'
|
||||||
|
import { RELATION_TYPE_CONFIG, DETECTION_SOURCE_CONFIG } from '@/types/relation'
|
||||||
import { formatRelativeTime } from '@/utils/formatters'
|
import { formatRelativeTime } from '@/utils/formatters'
|
||||||
|
import { fetchObservationRelations } from '@/utils/api'
|
||||||
import Card from './Card.vue'
|
import Card from './Card.vue'
|
||||||
import IconBox from './IconBox.vue'
|
import IconBox from './IconBox.vue'
|
||||||
import Badge from './Badge.vue'
|
import Badge from './Badge.vue'
|
||||||
import { computed } from 'vue'
|
import RelationGraph from './RelationGraph.vue'
|
||||||
|
import { computed, ref, onMounted } from 'vue'
|
||||||
|
|
||||||
const props = defineProps<{
|
const props = defineProps<{
|
||||||
observation: ObservationFeedItem
|
observation: ObservationFeedItem
|
||||||
highlight?: boolean
|
highlight?: boolean
|
||||||
|
showFeedback?: boolean
|
||||||
}>()
|
}>()
|
||||||
|
|
||||||
|
const emit = defineEmits<{
|
||||||
|
navigateToObservation: [id: number]
|
||||||
|
}>()
|
||||||
|
|
||||||
|
// Local feedback and score state (optimistic updates)
|
||||||
|
const localFeedback = ref<number | null>(null)
|
||||||
|
const localScore = ref<number | null>(null)
|
||||||
|
const isSubmitting = ref(false)
|
||||||
|
|
||||||
|
const currentFeedback = computed(() =>
|
||||||
|
localFeedback.value !== null ? localFeedback.value : (props.observation.user_feedback || 0)
|
||||||
|
)
|
||||||
|
|
||||||
|
const currentScore = computed(() =>
|
||||||
|
localScore.value !== null ? localScore.value : (props.observation.importance_score || 1)
|
||||||
|
)
|
||||||
|
|
||||||
|
const submitFeedback = async (value: number) => {
|
||||||
|
if (isSubmitting.value) return
|
||||||
|
|
||||||
|
// Toggle off if clicking same button
|
||||||
|
const newValue = currentFeedback.value === value ? 0 : value
|
||||||
|
|
||||||
|
localFeedback.value = newValue
|
||||||
|
isSubmitting.value = true
|
||||||
|
|
||||||
|
try {
|
||||||
|
const response = await fetch(`/api/observations/${props.observation.id}/feedback`, {
|
||||||
|
method: 'POST',
|
||||||
|
headers: { 'Content-Type': 'application/json' },
|
||||||
|
body: JSON.stringify({ feedback: newValue })
|
||||||
|
})
|
||||||
|
if (response.ok) {
|
||||||
|
const data = await response.json()
|
||||||
|
if (data.score !== undefined) {
|
||||||
|
localScore.value = data.score
|
||||||
|
}
|
||||||
|
}
|
||||||
|
} catch (error) {
|
||||||
|
console.error('Error submitting feedback:', error)
|
||||||
|
} finally {
|
||||||
|
isSubmitting.value = false
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
const config = computed(() => TYPE_CONFIG[props.observation.type] || TYPE_CONFIG.change)
|
const config = computed(() => TYPE_CONFIG[props.observation.type] || TYPE_CONFIG.change)
|
||||||
|
|
||||||
const concepts = computed(() => {
|
const concepts = computed(() => {
|
||||||
@@ -40,6 +89,60 @@ const filesModified = computed(() => {
|
|||||||
|
|
||||||
const hasFiles = computed(() => filesRead.value.length > 0 || filesModified.value.length > 0)
|
const hasFiles = computed(() => filesRead.value.length > 0 || filesModified.value.length > 0)
|
||||||
|
|
||||||
|
// Relations state
|
||||||
|
const relations = ref<RelationWithDetails[]>([])
|
||||||
|
const relationsLoading = ref(false)
|
||||||
|
const relationsExpanded = ref(false)
|
||||||
|
const showGraph = ref(false)
|
||||||
|
|
||||||
|
const hasRelations = computed(() => relations.value.length > 0)
|
||||||
|
const relationCount = computed(() => relations.value.length)
|
||||||
|
|
||||||
|
// Load relations on mount
|
||||||
|
const loadRelations = async () => {
|
||||||
|
relationsLoading.value = true
|
||||||
|
try {
|
||||||
|
relations.value = await fetchObservationRelations(props.observation.id)
|
||||||
|
} catch (err) {
|
||||||
|
console.error('Failed to load relations:', err)
|
||||||
|
} finally {
|
||||||
|
relationsLoading.value = false
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
onMounted(() => {
|
||||||
|
loadRelations()
|
||||||
|
})
|
||||||
|
|
||||||
|
// Toggle relations expansion
|
||||||
|
const toggleRelations = () => {
|
||||||
|
relationsExpanded.value = !relationsExpanded.value
|
||||||
|
}
|
||||||
|
|
||||||
|
// Open graph modal
|
||||||
|
const openGraph = (e: Event) => {
|
||||||
|
e.stopPropagation()
|
||||||
|
showGraph.value = true
|
||||||
|
}
|
||||||
|
|
||||||
|
// Handle navigation from graph
|
||||||
|
const handleNavigateTo = (id: number) => {
|
||||||
|
showGraph.value = false
|
||||||
|
emit('navigateToObservation', id)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Get relation display info (whether we're source or target)
|
||||||
|
const getRelationDisplay = (rel: RelationWithDetails) => {
|
||||||
|
const isSource = rel.relation.source_id === props.observation.id
|
||||||
|
return {
|
||||||
|
type: rel.relation.relation_type,
|
||||||
|
otherTitle: isSource ? rel.target_title : rel.source_title,
|
||||||
|
otherId: isSource ? rel.relation.target_id : rel.relation.source_id,
|
||||||
|
direction: isSource ? 'outgoing' : 'incoming',
|
||||||
|
confidence: rel.relation.confidence
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
// Split path into project root and relative path for styling
|
// Split path into project root and relative path for styling
|
||||||
// e.g., /Users/foo/project/src/file.go → { root: 'project', path: 'src/file.go' }
|
// e.g., /Users/foo/project/src/file.go → { root: 'project', path: 'src/file.go' }
|
||||||
const splitPath = (path: string, components = 3) => {
|
const splitPath = (path: string, components = 3) => {
|
||||||
@@ -140,7 +243,145 @@ const splitPath = (path: string, components = 3) => {
|
|||||||
</div>
|
</div>
|
||||||
</div>
|
</div>
|
||||||
</div>
|
</div>
|
||||||
|
|
||||||
|
<!-- Relations -->
|
||||||
|
<div v-if="hasRelations || relationsLoading" class="mt-3 pt-3 border-t border-slate-700/50">
|
||||||
|
<!-- Header with count and graph button -->
|
||||||
|
<div class="flex items-center justify-between">
|
||||||
|
<button
|
||||||
|
@click="toggleRelations"
|
||||||
|
class="flex items-center gap-2 text-xs text-slate-400 hover:text-slate-300 transition-colors"
|
||||||
|
:disabled="relationsLoading"
|
||||||
|
>
|
||||||
|
<i class="fas fa-diagram-project text-cyan-500/70" />
|
||||||
|
<span v-if="relationsLoading" class="text-slate-500">
|
||||||
|
<i class="fas fa-circle-notch fa-spin mr-1" />
|
||||||
|
Loading relations...
|
||||||
|
</span>
|
||||||
|
<span v-else>
|
||||||
|
{{ relationCount }} related observation{{ relationCount !== 1 ? 's' : '' }}
|
||||||
|
</span>
|
||||||
|
<i
|
||||||
|
v-if="!relationsLoading && hasRelations"
|
||||||
|
class="fas text-[10px] transition-transform"
|
||||||
|
:class="relationsExpanded ? 'fa-chevron-up' : 'fa-chevron-down'"
|
||||||
|
/>
|
||||||
|
</button>
|
||||||
|
|
||||||
|
<!-- View Graph button -->
|
||||||
|
<button
|
||||||
|
v-if="hasRelations"
|
||||||
|
@click="openGraph"
|
||||||
|
class="flex items-center gap-1.5 px-2 py-1 text-xs text-cyan-400 hover:text-cyan-300 bg-cyan-500/10 hover:bg-cyan-500/20 rounded transition-colors"
|
||||||
|
title="View knowledge graph"
|
||||||
|
>
|
||||||
|
<i class="fas fa-project-diagram" />
|
||||||
|
<span>View Graph</span>
|
||||||
|
</button>
|
||||||
|
</div>
|
||||||
|
|
||||||
|
<!-- Expanded relations list -->
|
||||||
|
<div
|
||||||
|
v-if="relationsExpanded && hasRelations"
|
||||||
|
class="mt-2 space-y-1.5"
|
||||||
|
>
|
||||||
|
<div
|
||||||
|
v-for="rel in relations"
|
||||||
|
:key="rel.relation.id"
|
||||||
|
class="flex items-center gap-2 text-xs p-1.5 rounded bg-slate-800/30 hover:bg-slate-800/50 transition-colors group"
|
||||||
|
>
|
||||||
|
<!-- Relation type icon -->
|
||||||
|
<i
|
||||||
|
class="fas w-4 text-center"
|
||||||
|
:class="[
|
||||||
|
RELATION_TYPE_CONFIG[getRelationDisplay(rel).type]?.icon || 'fa-link',
|
||||||
|
RELATION_TYPE_CONFIG[getRelationDisplay(rel).type]?.colorClass || 'text-slate-400'
|
||||||
|
]"
|
||||||
|
:title="RELATION_TYPE_CONFIG[getRelationDisplay(rel).type]?.label"
|
||||||
|
/>
|
||||||
|
|
||||||
|
<!-- Direction arrow -->
|
||||||
|
<i
|
||||||
|
class="fas text-[10px] text-slate-600"
|
||||||
|
:class="getRelationDisplay(rel).direction === 'outgoing' ? 'fa-arrow-right' : 'fa-arrow-left'"
|
||||||
|
/>
|
||||||
|
|
||||||
|
<!-- Related observation title -->
|
||||||
|
<span
|
||||||
|
class="flex-1 truncate text-slate-300 cursor-pointer hover:text-amber-300 transition-colors"
|
||||||
|
:title="getRelationDisplay(rel).otherTitle"
|
||||||
|
@click="emit('navigateToObservation', getRelationDisplay(rel).otherId)"
|
||||||
|
>
|
||||||
|
{{ getRelationDisplay(rel).otherTitle || 'Untitled' }}
|
||||||
|
</span>
|
||||||
|
|
||||||
|
<!-- Confidence -->
|
||||||
|
<span
|
||||||
|
class="text-[10px] text-slate-500 font-mono"
|
||||||
|
:title="`${Math.round(getRelationDisplay(rel).confidence * 100)}% confidence`"
|
||||||
|
>
|
||||||
|
{{ Math.round(getRelationDisplay(rel).confidence * 100) }}%
|
||||||
|
</span>
|
||||||
|
|
||||||
|
<!-- Detection source icon -->
|
||||||
|
<i
|
||||||
|
class="fas text-[10px] text-slate-600"
|
||||||
|
:class="DETECTION_SOURCE_CONFIG[rel.relation.detection_source]?.icon || 'fa-question'"
|
||||||
|
:title="DETECTION_SOURCE_CONFIG[rel.relation.detection_source]?.label"
|
||||||
|
/>
|
||||||
|
</div>
|
||||||
|
</div>
|
||||||
|
</div>
|
||||||
|
</div>
|
||||||
|
|
||||||
|
<!-- Feedback buttons (right side) -->
|
||||||
|
<div v-if="showFeedback" class="flex flex-col items-center gap-1 ml-2 flex-shrink-0">
|
||||||
|
<button
|
||||||
|
@click="submitFeedback(1)"
|
||||||
|
:disabled="isSubmitting"
|
||||||
|
:class="[
|
||||||
|
'p-1.5 rounded-lg transition-all duration-200',
|
||||||
|
currentFeedback === 1
|
||||||
|
? 'bg-green-500/30 text-green-300 shadow-green-500/20 shadow-sm'
|
||||||
|
: 'text-slate-500 hover:text-green-400 hover:bg-green-500/10'
|
||||||
|
]"
|
||||||
|
title="Helpful"
|
||||||
|
>
|
||||||
|
<i class="fas fa-thumbs-up text-sm" />
|
||||||
|
</button>
|
||||||
|
|
||||||
|
<span
|
||||||
|
class="text-[10px] font-mono px-1.5 py-0.5 rounded bg-slate-800/50 text-slate-400 flex items-center gap-1 transition-all duration-300"
|
||||||
|
:class="{ 'text-green-400': localScore !== null && localScore > (observation.importance_score || 1), 'text-red-400': localScore !== null && localScore < (observation.importance_score || 1) }"
|
||||||
|
:title="`Importance Score: ${currentScore.toFixed(3)}\nRetrieval Count: ${observation.retrieval_count || 0}`"
|
||||||
|
>
|
||||||
|
<i class="fas fa-scale-balanced text-amber-500/60" />
|
||||||
|
{{ currentScore.toFixed(2) }}
|
||||||
|
</span>
|
||||||
|
|
||||||
|
<button
|
||||||
|
@click="submitFeedback(-1)"
|
||||||
|
:disabled="isSubmitting"
|
||||||
|
:class="[
|
||||||
|
'p-1.5 rounded-lg transition-all duration-200',
|
||||||
|
currentFeedback === -1
|
||||||
|
? 'bg-red-500/30 text-red-300 shadow-red-500/20 shadow-sm'
|
||||||
|
: 'text-slate-500 hover:text-red-400 hover:bg-red-500/10'
|
||||||
|
]"
|
||||||
|
title="Not helpful"
|
||||||
|
>
|
||||||
|
<i class="fas fa-thumbs-down text-sm" />
|
||||||
|
</button>
|
||||||
</div>
|
</div>
|
||||||
</div>
|
</div>
|
||||||
|
|
||||||
|
<!-- Relation Graph Modal -->
|
||||||
|
<RelationGraph
|
||||||
|
:observation-id="observation.id"
|
||||||
|
:observation-title="observation.title || 'Untitled'"
|
||||||
|
:show="showGraph"
|
||||||
|
@close="showGraph = false"
|
||||||
|
@navigate-to="handleNavigateTo"
|
||||||
|
/>
|
||||||
</Card>
|
</Card>
|
||||||
</template>
|
</template>
|
||||||
|
|||||||
@@ -0,0 +1,409 @@
|
|||||||
|
<script setup lang="ts">
|
||||||
|
import { ref, onMounted, onUnmounted, watch, computed } from 'vue'
|
||||||
|
import { Network, type Data, type Options } from 'vis-network'
|
||||||
|
import type { RelationGraph, RelationWithDetails } from '@/types'
|
||||||
|
import { RELATION_TYPE_CONFIG, DETECTION_SOURCE_CONFIG } from '@/types/relation'
|
||||||
|
import { fetchObservationGraph } from '@/utils/api'
|
||||||
|
|
||||||
|
const props = defineProps<{
|
||||||
|
observationId: number
|
||||||
|
observationTitle: string
|
||||||
|
show: boolean
|
||||||
|
}>()
|
||||||
|
|
||||||
|
const emit = defineEmits<{
|
||||||
|
close: []
|
||||||
|
navigateTo: [id: number]
|
||||||
|
}>()
|
||||||
|
|
||||||
|
const graphContainer = ref<HTMLElement | null>(null)
|
||||||
|
const loading = ref(true)
|
||||||
|
const error = ref<string | null>(null)
|
||||||
|
const graphData = ref<RelationGraph | null>(null)
|
||||||
|
const selectedRelation = ref<RelationWithDetails | null>(null)
|
||||||
|
const depth = ref(2)
|
||||||
|
|
||||||
|
let network: Network | null = null
|
||||||
|
|
||||||
|
// Node colors based on observation type
|
||||||
|
const getNodeColor = (type: string) => {
|
||||||
|
const colors: Record<string, { background: string; border: string; highlight: { background: string; border: string } }> = {
|
||||||
|
bugfix: { background: '#ef4444', border: '#dc2626', highlight: { background: '#f87171', border: '#ef4444' } },
|
||||||
|
feature: { background: '#a855f7', border: '#9333ea', highlight: { background: '#c084fc', border: '#a855f7' } },
|
||||||
|
refactor: { background: '#3b82f6', border: '#2563eb', highlight: { background: '#60a5fa', border: '#3b82f6' } },
|
||||||
|
discovery: { background: '#06b6d4', border: '#0891b2', highlight: { background: '#22d3ee', border: '#06b6d4' } },
|
||||||
|
decision: { background: '#eab308', border: '#ca8a04', highlight: { background: '#facc15', border: '#eab308' } },
|
||||||
|
change: { background: '#64748b', border: '#475569', highlight: { background: '#94a3b8', border: '#64748b' } },
|
||||||
|
}
|
||||||
|
return colors[type] || colors.change
|
||||||
|
}
|
||||||
|
|
||||||
|
// Edge colors based on relation type
|
||||||
|
const getEdgeColor = (type: string) => {
|
||||||
|
const colors: Record<string, string> = {
|
||||||
|
causes: '#f97316',
|
||||||
|
fixes: '#22c55e',
|
||||||
|
supersedes: '#a855f7',
|
||||||
|
depends_on: '#3b82f6',
|
||||||
|
relates_to: '#64748b',
|
||||||
|
evolves_from: '#06b6d4',
|
||||||
|
}
|
||||||
|
return colors[type] || '#64748b'
|
||||||
|
}
|
||||||
|
|
||||||
|
const loadGraph = async () => {
|
||||||
|
loading.value = true
|
||||||
|
error.value = null
|
||||||
|
|
||||||
|
try {
|
||||||
|
graphData.value = await fetchObservationGraph(props.observationId, depth.value)
|
||||||
|
renderGraph()
|
||||||
|
} catch (err) {
|
||||||
|
error.value = err instanceof Error ? err.message : 'Failed to load graph'
|
||||||
|
} finally {
|
||||||
|
loading.value = false
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
const renderGraph = () => {
|
||||||
|
if (!graphContainer.value || !graphData.value) return
|
||||||
|
|
||||||
|
// Build nodes and edges from relations
|
||||||
|
const nodeMap = new Map<number, { id: number; label: string; type: string }>()
|
||||||
|
const edgesList: { from: number; to: number; label: string; color: string; arrows: string; relation: RelationWithDetails }[] = []
|
||||||
|
|
||||||
|
// Add center node
|
||||||
|
nodeMap.set(props.observationId, {
|
||||||
|
id: props.observationId,
|
||||||
|
label: truncateLabel(props.observationTitle),
|
||||||
|
type: 'center'
|
||||||
|
})
|
||||||
|
|
||||||
|
// Process relations
|
||||||
|
for (const rel of graphData.value.relations) {
|
||||||
|
// Add source node
|
||||||
|
if (!nodeMap.has(rel.relation.source_id)) {
|
||||||
|
nodeMap.set(rel.relation.source_id, {
|
||||||
|
id: rel.relation.source_id,
|
||||||
|
label: truncateLabel(rel.source_title),
|
||||||
|
type: rel.source_type
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
// Add target node
|
||||||
|
if (!nodeMap.has(rel.relation.target_id)) {
|
||||||
|
nodeMap.set(rel.relation.target_id, {
|
||||||
|
id: rel.relation.target_id,
|
||||||
|
label: truncateLabel(rel.target_title),
|
||||||
|
type: rel.target_type
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
// Add edge
|
||||||
|
edgesList.push({
|
||||||
|
from: rel.relation.source_id,
|
||||||
|
to: rel.relation.target_id,
|
||||||
|
label: rel.relation.relation_type.replace('_', ' '),
|
||||||
|
color: getEdgeColor(rel.relation.relation_type),
|
||||||
|
arrows: 'to',
|
||||||
|
relation: rel
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
// Create vis-network data using plain arrays (simpler type compatibility)
|
||||||
|
const nodes = Array.from(nodeMap.values()).map(node => ({
|
||||||
|
id: node.id,
|
||||||
|
label: node.label,
|
||||||
|
color: node.id === props.observationId
|
||||||
|
? { background: '#f59e0b', border: '#d97706', highlight: { background: '#fbbf24', border: '#f59e0b' } }
|
||||||
|
: getNodeColor(node.type),
|
||||||
|
font: { color: '#fff', size: 12 },
|
||||||
|
shape: 'box' as const,
|
||||||
|
borderWidth: node.id === props.observationId ? 3 : 2,
|
||||||
|
margin: { top: 10, right: 10, bottom: 10, left: 10 },
|
||||||
|
shadow: true
|
||||||
|
}))
|
||||||
|
|
||||||
|
const edges = edgesList.map((edge, index) => ({
|
||||||
|
id: index,
|
||||||
|
from: edge.from,
|
||||||
|
to: edge.to,
|
||||||
|
label: edge.label,
|
||||||
|
color: { color: edge.color, highlight: edge.color },
|
||||||
|
font: { color: '#94a3b8', size: 10, strokeWidth: 0 },
|
||||||
|
arrows: edge.arrows,
|
||||||
|
width: 2,
|
||||||
|
smooth: { enabled: true, type: 'curvedCW' as const, roundness: 0.2 }
|
||||||
|
}))
|
||||||
|
|
||||||
|
// Cleanup existing network
|
||||||
|
if (network) {
|
||||||
|
network.destroy()
|
||||||
|
}
|
||||||
|
|
||||||
|
// Create network data
|
||||||
|
const data: Data = { nodes, edges }
|
||||||
|
|
||||||
|
const options: Options = {
|
||||||
|
physics: {
|
||||||
|
enabled: true,
|
||||||
|
solver: 'forceAtlas2Based',
|
||||||
|
forceAtlas2Based: {
|
||||||
|
gravitationalConstant: -50,
|
||||||
|
centralGravity: 0.01,
|
||||||
|
springLength: 150,
|
||||||
|
springConstant: 0.08
|
||||||
|
},
|
||||||
|
stabilization: { iterations: 100 }
|
||||||
|
},
|
||||||
|
interaction: {
|
||||||
|
hover: true,
|
||||||
|
tooltipDelay: 200,
|
||||||
|
zoomView: true,
|
||||||
|
dragView: true
|
||||||
|
},
|
||||||
|
layout: {
|
||||||
|
improvedLayout: true
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Create network
|
||||||
|
network = new Network(graphContainer.value, data, options)
|
||||||
|
|
||||||
|
// Handle edge click to show details
|
||||||
|
network.on('selectEdge', (params: { edges: (string | number)[] }) => {
|
||||||
|
if (params.edges.length > 0) {
|
||||||
|
const edgeId = params.edges[0] as number
|
||||||
|
selectedRelation.value = edgesList[edgeId]?.relation || null
|
||||||
|
}
|
||||||
|
})
|
||||||
|
|
||||||
|
// Handle node double-click to navigate
|
||||||
|
network.on('doubleClick', (params: { nodes: (string | number)[] }) => {
|
||||||
|
if (params.nodes.length > 0) {
|
||||||
|
const nodeId = params.nodes[0] as number
|
||||||
|
if (nodeId !== props.observationId) {
|
||||||
|
emit('navigateTo', nodeId)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
})
|
||||||
|
|
||||||
|
// Clear selection when clicking background
|
||||||
|
network.on('click', (params: { nodes: (string | number)[]; edges: (string | number)[] }) => {
|
||||||
|
if (params.nodes.length === 0 && params.edges.length === 0) {
|
||||||
|
selectedRelation.value = null
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
const truncateLabel = (label: string, maxLen = 30) => {
|
||||||
|
if (label.length <= maxLen) return label
|
||||||
|
return label.substring(0, maxLen - 3) + '...'
|
||||||
|
}
|
||||||
|
|
||||||
|
const relationCount = computed(() => graphData.value?.relations.length || 0)
|
||||||
|
|
||||||
|
const closeModal = () => {
|
||||||
|
emit('close')
|
||||||
|
}
|
||||||
|
|
||||||
|
// Watch for show prop changes
|
||||||
|
watch(() => props.show, (newVal) => {
|
||||||
|
if (newVal) {
|
||||||
|
loadGraph()
|
||||||
|
} else {
|
||||||
|
selectedRelation.value = null
|
||||||
|
if (network) {
|
||||||
|
network.destroy()
|
||||||
|
network = null
|
||||||
|
}
|
||||||
|
}
|
||||||
|
})
|
||||||
|
|
||||||
|
// Watch for depth changes
|
||||||
|
watch(depth, () => {
|
||||||
|
if (props.show) {
|
||||||
|
loadGraph()
|
||||||
|
}
|
||||||
|
})
|
||||||
|
|
||||||
|
onMounted(() => {
|
||||||
|
if (props.show) {
|
||||||
|
loadGraph()
|
||||||
|
}
|
||||||
|
})
|
||||||
|
|
||||||
|
onUnmounted(() => {
|
||||||
|
if (network) {
|
||||||
|
network.destroy()
|
||||||
|
network = null
|
||||||
|
}
|
||||||
|
})
|
||||||
|
</script>
|
||||||
|
|
||||||
|
<template>
|
||||||
|
<!-- Modal backdrop -->
|
||||||
|
<Teleport to="body">
|
||||||
|
<div
|
||||||
|
v-if="show"
|
||||||
|
class="fixed inset-0 z-50 flex items-center justify-center p-4"
|
||||||
|
@click.self="closeModal"
|
||||||
|
>
|
||||||
|
<!-- Backdrop -->
|
||||||
|
<div class="absolute inset-0 bg-black/70 backdrop-blur-sm" @click="closeModal" />
|
||||||
|
|
||||||
|
<!-- Modal content -->
|
||||||
|
<div class="relative bg-slate-900 border border-slate-700 rounded-xl shadow-2xl w-full max-w-5xl max-h-[90vh] flex flex-col overflow-hidden">
|
||||||
|
<!-- Header -->
|
||||||
|
<div class="flex items-center justify-between p-4 border-b border-slate-700">
|
||||||
|
<div class="flex items-center gap-3">
|
||||||
|
<div class="p-2 rounded-lg bg-amber-500/20">
|
||||||
|
<i class="fas fa-diagram-project text-amber-400" />
|
||||||
|
</div>
|
||||||
|
<div>
|
||||||
|
<h2 class="text-lg font-semibold text-amber-100">Knowledge Graph</h2>
|
||||||
|
<p class="text-sm text-slate-400">
|
||||||
|
{{ relationCount }} relation{{ relationCount !== 1 ? 's' : '' }} for "{{ truncateLabel(observationTitle, 50) }}"
|
||||||
|
</p>
|
||||||
|
</div>
|
||||||
|
</div>
|
||||||
|
|
||||||
|
<!-- Controls -->
|
||||||
|
<div class="flex items-center gap-4">
|
||||||
|
<!-- Depth selector -->
|
||||||
|
<div class="flex items-center gap-2">
|
||||||
|
<label class="text-xs text-slate-400">Depth:</label>
|
||||||
|
<select
|
||||||
|
v-model="depth"
|
||||||
|
class="bg-slate-800 border border-slate-600 rounded px-2 py-1 text-sm text-slate-200 focus:outline-none focus:border-amber-500"
|
||||||
|
>
|
||||||
|
<option :value="1">1</option>
|
||||||
|
<option :value="2">2</option>
|
||||||
|
<option :value="3">3</option>
|
||||||
|
</select>
|
||||||
|
</div>
|
||||||
|
|
||||||
|
<!-- Close button -->
|
||||||
|
<button
|
||||||
|
@click="closeModal"
|
||||||
|
class="p-2 rounded-lg text-slate-400 hover:text-slate-200 hover:bg-slate-800 transition-colors"
|
||||||
|
>
|
||||||
|
<i class="fas fa-times" />
|
||||||
|
</button>
|
||||||
|
</div>
|
||||||
|
</div>
|
||||||
|
|
||||||
|
<!-- Graph container -->
|
||||||
|
<div class="relative" style="height: 60vh; min-height: 500px;">
|
||||||
|
<!-- Loading state -->
|
||||||
|
<div v-if="loading" class="absolute inset-0 flex items-center justify-center bg-slate-900/50">
|
||||||
|
<div class="flex items-center gap-3 text-amber-400">
|
||||||
|
<i class="fas fa-circle-notch fa-spin text-xl" />
|
||||||
|
<span>Loading graph...</span>
|
||||||
|
</div>
|
||||||
|
</div>
|
||||||
|
|
||||||
|
<!-- Error state -->
|
||||||
|
<div v-else-if="error" class="absolute inset-0 flex items-center justify-center">
|
||||||
|
<div class="text-center">
|
||||||
|
<i class="fas fa-exclamation-triangle text-3xl text-red-400 mb-2" />
|
||||||
|
<p class="text-red-300">{{ error }}</p>
|
||||||
|
<button
|
||||||
|
@click="loadGraph"
|
||||||
|
class="mt-3 px-4 py-2 bg-slate-800 hover:bg-slate-700 rounded-lg text-sm text-slate-200 transition-colors"
|
||||||
|
>
|
||||||
|
Retry
|
||||||
|
</button>
|
||||||
|
</div>
|
||||||
|
</div>
|
||||||
|
|
||||||
|
<!-- Empty state -->
|
||||||
|
<div v-else-if="graphData && graphData.relations.length === 0" class="absolute inset-0 flex items-center justify-center">
|
||||||
|
<div class="text-center">
|
||||||
|
<i class="fas fa-diagram-project text-4xl text-slate-600 mb-3" />
|
||||||
|
<p class="text-slate-400">No relations found for this observation</p>
|
||||||
|
<p class="text-sm text-slate-500 mt-1">Relations are detected automatically when observations share files, concepts, or patterns</p>
|
||||||
|
</div>
|
||||||
|
</div>
|
||||||
|
|
||||||
|
<!-- Graph -->
|
||||||
|
<div ref="graphContainer" class="absolute inset-0" />
|
||||||
|
</div>
|
||||||
|
|
||||||
|
<!-- Relation details panel -->
|
||||||
|
<div v-if="selectedRelation" class="border-t border-slate-700 p-4 bg-slate-800/50">
|
||||||
|
<div class="flex items-start gap-4">
|
||||||
|
<!-- Relation type icon -->
|
||||||
|
<div
|
||||||
|
class="p-3 rounded-lg"
|
||||||
|
:class="RELATION_TYPE_CONFIG[selectedRelation.relation.relation_type]?.bgClass"
|
||||||
|
>
|
||||||
|
<i
|
||||||
|
class="fas"
|
||||||
|
:class="[
|
||||||
|
RELATION_TYPE_CONFIG[selectedRelation.relation.relation_type]?.icon,
|
||||||
|
RELATION_TYPE_CONFIG[selectedRelation.relation.relation_type]?.colorClass
|
||||||
|
]"
|
||||||
|
/>
|
||||||
|
</div>
|
||||||
|
|
||||||
|
<!-- Details -->
|
||||||
|
<div class="flex-1 min-w-0">
|
||||||
|
<div class="flex items-center gap-2 mb-1">
|
||||||
|
<span class="font-medium" :class="RELATION_TYPE_CONFIG[selectedRelation.relation.relation_type]?.colorClass">
|
||||||
|
{{ RELATION_TYPE_CONFIG[selectedRelation.relation.relation_type]?.label }}
|
||||||
|
</span>
|
||||||
|
<span class="text-xs text-slate-500">
|
||||||
|
({{ Math.round(selectedRelation.relation.confidence * 100) }}% confidence)
|
||||||
|
</span>
|
||||||
|
</div>
|
||||||
|
|
||||||
|
<div class="flex items-center gap-2 text-sm text-slate-300 mb-2">
|
||||||
|
<span class="font-mono text-amber-400">{{ truncateLabel(selectedRelation.source_title, 40) }}</span>
|
||||||
|
<i class="fas fa-arrow-right text-slate-500 text-xs" />
|
||||||
|
<span class="font-mono text-amber-400">{{ truncateLabel(selectedRelation.target_title, 40) }}</span>
|
||||||
|
</div>
|
||||||
|
|
||||||
|
<div class="flex items-center gap-4 text-xs text-slate-500">
|
||||||
|
<span class="flex items-center gap-1">
|
||||||
|
<i :class="['fas', DETECTION_SOURCE_CONFIG[selectedRelation.relation.detection_source]?.icon]" />
|
||||||
|
{{ DETECTION_SOURCE_CONFIG[selectedRelation.relation.detection_source]?.label }}
|
||||||
|
</span>
|
||||||
|
<span v-if="selectedRelation.relation.reason">
|
||||||
|
{{ selectedRelation.relation.reason }}
|
||||||
|
</span>
|
||||||
|
</div>
|
||||||
|
</div>
|
||||||
|
</div>
|
||||||
|
</div>
|
||||||
|
|
||||||
|
<!-- Legend -->
|
||||||
|
<div class="border-t border-slate-700 p-3 bg-slate-800/30">
|
||||||
|
<div class="flex flex-wrap items-center justify-center gap-4 text-xs text-slate-400">
|
||||||
|
<span class="font-medium text-slate-300">Legend:</span>
|
||||||
|
<span class="flex items-center gap-1">
|
||||||
|
<span class="w-3 h-3 rounded bg-amber-500" /> Center
|
||||||
|
</span>
|
||||||
|
<span class="flex items-center gap-1">
|
||||||
|
<span class="w-3 h-3 rounded bg-red-500" /> Bugfix
|
||||||
|
</span>
|
||||||
|
<span class="flex items-center gap-1">
|
||||||
|
<span class="w-3 h-3 rounded bg-purple-500" /> Feature
|
||||||
|
</span>
|
||||||
|
<span class="flex items-center gap-1">
|
||||||
|
<span class="w-3 h-3 rounded bg-blue-500" /> Refactor
|
||||||
|
</span>
|
||||||
|
<span class="flex items-center gap-1">
|
||||||
|
<span class="w-3 h-3 rounded bg-cyan-500" /> Discovery
|
||||||
|
</span>
|
||||||
|
<span class="flex items-center gap-1">
|
||||||
|
<span class="w-3 h-3 rounded bg-yellow-500" /> Decision
|
||||||
|
</span>
|
||||||
|
<span class="text-slate-500">|</span>
|
||||||
|
<span>Double-click node to navigate</span>
|
||||||
|
</div>
|
||||||
|
</div>
|
||||||
|
</div>
|
||||||
|
</div>
|
||||||
|
</Teleport>
|
||||||
|
</template>
|
||||||
@@ -31,6 +31,7 @@ defineProps<{
|
|||||||
v-if="item.itemType === 'observation'"
|
v-if="item.itemType === 'observation'"
|
||||||
:observation="item"
|
:observation="item"
|
||||||
:highlight="index === 0"
|
:highlight="index === 0"
|
||||||
|
:show-feedback="true"
|
||||||
/>
|
/>
|
||||||
<PromptCard
|
<PromptCard
|
||||||
v-else-if="item.itemType === 'prompt'"
|
v-else-if="item.itemType === 'prompt'"
|
||||||
|
|||||||
@@ -2,3 +2,4 @@ export * from './observation'
|
|||||||
export * from './prompt'
|
export * from './prompt'
|
||||||
export * from './summary'
|
export * from './summary'
|
||||||
export * from './api'
|
export * from './api'
|
||||||
|
export * from './relation'
|
||||||
|
|||||||
@@ -30,6 +30,12 @@ export interface Observation {
|
|||||||
created_at: string
|
created_at: string
|
||||||
created_at_epoch: number
|
created_at_epoch: number
|
||||||
is_stale?: boolean
|
is_stale?: boolean
|
||||||
|
// Importance scoring fields
|
||||||
|
importance_score: number
|
||||||
|
user_feedback: number // -1 (thumbs down), 0 (neutral), 1 (thumbs up)
|
||||||
|
retrieval_count: number
|
||||||
|
last_retrieved_at_epoch?: number
|
||||||
|
score_updated_at_epoch?: number
|
||||||
}
|
}
|
||||||
|
|
||||||
export const OBSERVATION_TYPES: ObservationType[] = ['bugfix', 'feature', 'refactor', 'discovery', 'decision', 'change']
|
export const OBSERVATION_TYPES: ObservationType[] = ['bugfix', 'feature', 'refactor', 'discovery', 'decision', 'change']
|
||||||
|
|||||||
@@ -0,0 +1,90 @@
|
|||||||
|
export type RelationType = 'causes' | 'fixes' | 'supersedes' | 'depends_on' | 'relates_to' | 'evolves_from'
|
||||||
|
export type DetectionSource = 'file_overlap' | 'embedding_similarity' | 'temporal_proximity' | 'narrative_mention' | 'concept_overlap' | 'type_progression'
|
||||||
|
|
||||||
|
export interface ObservationRelation {
|
||||||
|
id: number
|
||||||
|
source_id: number
|
||||||
|
target_id: number
|
||||||
|
relation_type: RelationType
|
||||||
|
confidence: number
|
||||||
|
detection_source: DetectionSource
|
||||||
|
reason: string
|
||||||
|
created_at: string
|
||||||
|
created_at_epoch: number
|
||||||
|
}
|
||||||
|
|
||||||
|
export interface RelationWithDetails {
|
||||||
|
relation: ObservationRelation
|
||||||
|
source_title: string
|
||||||
|
target_title: string
|
||||||
|
source_type: string
|
||||||
|
target_type: string
|
||||||
|
}
|
||||||
|
|
||||||
|
export interface RelationGraph {
|
||||||
|
center_id: number
|
||||||
|
relations: RelationWithDetails[]
|
||||||
|
}
|
||||||
|
|
||||||
|
export interface RelationStats {
|
||||||
|
total_count: number
|
||||||
|
high_confidence: number
|
||||||
|
by_type: Record<RelationType, number>
|
||||||
|
min_confidence_used: number
|
||||||
|
}
|
||||||
|
|
||||||
|
// Configuration for relation type display
|
||||||
|
export const RELATION_TYPE_CONFIG: Record<RelationType, { icon: string; label: string; colorClass: string; bgClass: string; description: string }> = {
|
||||||
|
causes: {
|
||||||
|
icon: 'fa-arrow-right',
|
||||||
|
label: 'Causes',
|
||||||
|
colorClass: 'text-orange-300',
|
||||||
|
bgClass: 'bg-orange-500/20',
|
||||||
|
description: 'This observation caused the related issue'
|
||||||
|
},
|
||||||
|
fixes: {
|
||||||
|
icon: 'fa-wrench',
|
||||||
|
label: 'Fixes',
|
||||||
|
colorClass: 'text-green-300',
|
||||||
|
bgClass: 'bg-green-500/20',
|
||||||
|
description: 'This observation fixes the related issue'
|
||||||
|
},
|
||||||
|
supersedes: {
|
||||||
|
icon: 'fa-layer-group',
|
||||||
|
label: 'Supersedes',
|
||||||
|
colorClass: 'text-purple-300',
|
||||||
|
bgClass: 'bg-purple-500/20',
|
||||||
|
description: 'This observation replaces the older one'
|
||||||
|
},
|
||||||
|
depends_on: {
|
||||||
|
icon: 'fa-link',
|
||||||
|
label: 'Depends On',
|
||||||
|
colorClass: 'text-blue-300',
|
||||||
|
bgClass: 'bg-blue-500/20',
|
||||||
|
description: 'This observation depends on the related one'
|
||||||
|
},
|
||||||
|
relates_to: {
|
||||||
|
icon: 'fa-arrows-left-right',
|
||||||
|
label: 'Related',
|
||||||
|
colorClass: 'text-slate-300',
|
||||||
|
bgClass: 'bg-slate-500/20',
|
||||||
|
description: 'These observations are related'
|
||||||
|
},
|
||||||
|
evolves_from: {
|
||||||
|
icon: 'fa-code-branch',
|
||||||
|
label: 'Evolves From',
|
||||||
|
colorClass: 'text-cyan-300',
|
||||||
|
bgClass: 'bg-cyan-500/20',
|
||||||
|
description: 'This observation evolved from the related one'
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Configuration for detection source display
|
||||||
|
export const DETECTION_SOURCE_CONFIG: Record<DetectionSource, { icon: string; label: string }> = {
|
||||||
|
file_overlap: { icon: 'fa-file-code', label: 'Shared files' },
|
||||||
|
embedding_similarity: { icon: 'fa-brain', label: 'Semantic similarity' },
|
||||||
|
temporal_proximity: { icon: 'fa-clock', label: 'Close in time' },
|
||||||
|
narrative_mention: { icon: 'fa-quote-left', label: 'Mentioned in text' },
|
||||||
|
concept_overlap: { icon: 'fa-tags', label: 'Shared concepts' },
|
||||||
|
type_progression: { icon: 'fa-diagram-next', label: 'Natural progression' }
|
||||||
|
}
|
||||||
+18
-1
@@ -1,4 +1,4 @@
|
|||||||
import type { Observation, UserPrompt, SessionSummary, Stats, FeedItem, ObservationFeedItem, PromptFeedItem, SummaryFeedItem } from '@/types'
|
import type { Observation, UserPrompt, SessionSummary, Stats, FeedItem, ObservationFeedItem, PromptFeedItem, SummaryFeedItem, RelationWithDetails, RelationGraph, RelationStats } from '@/types'
|
||||||
|
|
||||||
const API_BASE = '/api'
|
const API_BASE = '/api'
|
||||||
const DEFAULT_TIMEOUT = 10000 // 10 seconds
|
const DEFAULT_TIMEOUT = 10000 // 10 seconds
|
||||||
@@ -147,3 +147,20 @@ export function combineTimeline(
|
|||||||
return [...obsItems, ...promptItems, ...summaryItems]
|
return [...obsItems, ...promptItems, ...summaryItems]
|
||||||
.sort((a, b) => b.timestamp.getTime() - a.timestamp.getTime())
|
.sort((a, b) => b.timestamp.getTime() - a.timestamp.getTime())
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Relation API functions
|
||||||
|
export async function fetchObservationRelations(observationId: number, signal?: AbortSignal): Promise<RelationWithDetails[]> {
|
||||||
|
return fetchWithRetry<RelationWithDetails[]>(`${API_BASE}/observations/${observationId}/relations`, { signal })
|
||||||
|
}
|
||||||
|
|
||||||
|
export async function fetchObservationGraph(observationId: number, depth: number = 2, signal?: AbortSignal): Promise<RelationGraph> {
|
||||||
|
return fetchWithRetry<RelationGraph>(`${API_BASE}/observations/${observationId}/graph?depth=${depth}`, { signal })
|
||||||
|
}
|
||||||
|
|
||||||
|
export async function fetchRelatedObservations(observationId: number, minConfidence: number = 0.4, signal?: AbortSignal): Promise<Observation[]> {
|
||||||
|
return fetchWithRetry<Observation[]>(`${API_BASE}/observations/${observationId}/related?min_confidence=${minConfidence}`, { signal })
|
||||||
|
}
|
||||||
|
|
||||||
|
export async function fetchRelationStats(signal?: AbortSignal): Promise<RelationStats> {
|
||||||
|
return fetchWithRetry<RelationStats>(`${API_BASE}/relations/stats`, { signal })
|
||||||
|
}
|
||||||
|
|||||||
@@ -1 +1 @@
|
|||||||
{"root":["./src/main.ts","./src/vite-env.d.ts","./src/components/index.ts","./src/composables/index.ts","./src/composables/usehealth.ts","./src/composables/usesse.ts","./src/composables/usestats.ts","./src/composables/usetimeline.ts","./src/composables/usetypes.ts","./src/composables/useupdate.ts","./src/types/api.ts","./src/types/index.ts","./src/types/observation.ts","./src/types/prompt.ts","./src/types/summary.ts","./src/utils/api.ts","./src/utils/formatters.ts","./src/app.vue","./src/components/badge.vue","./src/components/card.vue","./src/components/filtertabs.vue","./src/components/header.vue","./src/components/iconbox.vue","./src/components/observationcard.vue","./src/components/projectfilter.vue","./src/components/promptcard.vue","./src/components/sidebar.vue","./src/components/statscards.vue","./src/components/summarycard.vue","./src/components/timeline.vue"],"version":"5.7.3"}
|
{"root":["./src/main.ts","./src/vite-env.d.ts","./src/components/index.ts","./src/composables/index.ts","./src/composables/usehealth.ts","./src/composables/usesse.ts","./src/composables/usestats.ts","./src/composables/usetimeline.ts","./src/composables/usetypes.ts","./src/composables/useupdate.ts","./src/types/api.ts","./src/types/index.ts","./src/types/observation.ts","./src/types/prompt.ts","./src/types/relation.ts","./src/types/summary.ts","./src/utils/api.ts","./src/utils/formatters.ts","./src/app.vue","./src/components/badge.vue","./src/components/card.vue","./src/components/filtertabs.vue","./src/components/header.vue","./src/components/iconbox.vue","./src/components/observationcard.vue","./src/components/projectfilter.vue","./src/components/promptcard.vue","./src/components/relationgraph.vue","./src/components/sidebar.vue","./src/components/statscards.vue","./src/components/summarycard.vue","./src/components/timeline.vue"],"version":"5.7.3"}
|
||||||
Reference in New Issue
Block a user