mirror of
https://github.com/lukaszraczylo/claude-mnemonic.git
synced 2026-06-09 23:59:40 +00:00
feat(leann-phase2): implement hybrid vector storage and graph-based search
- [x] Add AST-aware code chunking for Go, Python, and TypeScript using tree-sitter - [x] Implement LEANN-inspired hybrid vector storage with hub detection and selective embedding storage (60-80% savings) - [x] Add observation relationship graph with CSR format and edge detection (file overlap, semantic similarity, temporal, concept) - [x] Implement graph-aware search with two-level traversal and relationship-based ranking - [x] Add auto-tuning system for dynamic hub threshold adjustment based on query performance - [x] Add comprehensive metrics tracking for vector storage, queries, latency, and graph traversals - [x] Update configuration system with graph and hybrid storage settings - [x] Add graph stats and vector metrics endpoints to worker service - [x] Enhance UI sidebar with advanced metrics display and graph visualization - [x] Optimize struct field alignment throughout codebase for memory efficiency - [x] Update documentation with LEANN Phase 2 features and performance benefits - [x] Add tree-sitter dependency for AST parsing
This commit is contained in:
+27
-16
@@ -1,23 +1,34 @@
|
||||
# Project-specific golangci-lint configuration for claude-mnemonic
|
||||
# Inherits from global ~/.golangci.yml and adds project-specific exclusions
|
||||
linters-settings:
|
||||
govet:
|
||||
enable:
|
||||
- fieldalignment
|
||||
errcheck:
|
||||
# Ignore error checks in test files for common test helpers
|
||||
exclude-functions:
|
||||
- (io.Closer).Close
|
||||
- (*encoding/json.Encoder).Encode
|
||||
- (io.Writer).Write
|
||||
|
||||
linters:
|
||||
enable:
|
||||
- errcheck
|
||||
- gosec
|
||||
- govet
|
||||
- gofmt
|
||||
- staticcheck
|
||||
- unused
|
||||
- ineffassign
|
||||
- typecheck
|
||||
|
||||
issues:
|
||||
exclude-rules:
|
||||
# Project-specific: Exclude unused warnings for public API functions in pkg/models
|
||||
# These detection functions are part of the public API
|
||||
- path: pkg/models/(conflict|relation)\.go
|
||||
linters:
|
||||
- unused
|
||||
text: "(Detect|New)"
|
||||
|
||||
# Project-specific: Test helper method used only in tests
|
||||
- path: internal/db/gorm/store\.go
|
||||
linters:
|
||||
- unused
|
||||
text: "GetDB"
|
||||
|
||||
exclude-dirs:
|
||||
- vendor
|
||||
# Exclude some linters from running on test files
|
||||
exclude-rules:
|
||||
- path: _test\.go
|
||||
linters:
|
||||
- errcheck
|
||||
- gosec
|
||||
|
||||
run:
|
||||
timeout: 5m
|
||||
|
||||
@@ -18,12 +18,12 @@ type Input struct {
|
||||
|
||||
// Observation represents an observation from the API.
|
||||
type Observation struct {
|
||||
ID int64 `json:"id"`
|
||||
Type string `json:"type"`
|
||||
Title string `json:"title"`
|
||||
Subtitle string `json:"subtitle"`
|
||||
Narrative string `json:"narrative"`
|
||||
Facts []string `json:"facts"`
|
||||
ID int64 `json:"id"`
|
||||
}
|
||||
|
||||
func main() {
|
||||
|
||||
@@ -43,21 +43,21 @@ type StatusInput struct {
|
||||
|
||||
// WorkerStats is the response from the worker's /api/stats endpoint.
|
||||
type WorkerStats struct {
|
||||
Uptime string `json:"uptime"`
|
||||
ActiveSessions int `json:"activeSessions"`
|
||||
QueueDepth int `json:"queueDepth"`
|
||||
IsProcessing bool `json:"isProcessing"`
|
||||
ConnectedClients int `json:"connectedClients"`
|
||||
SessionsToday int `json:"sessionsToday"`
|
||||
Ready bool `json:"ready"`
|
||||
Project string `json:"project,omitempty"`
|
||||
ProjectObservations int `json:"projectObservations,omitempty"`
|
||||
Retrieval struct {
|
||||
Uptime string `json:"uptime"`
|
||||
Project string `json:"project,omitempty"`
|
||||
Retrieval struct {
|
||||
TotalRequests int64 `json:"TotalRequests"`
|
||||
ObservationsServed int64 `json:"ObservationsServed"`
|
||||
SearchRequests int64 `json:"SearchRequests"`
|
||||
ContextInjections int64 `json:"ContextInjections"`
|
||||
} `json:"retrieval"`
|
||||
ActiveSessions int `json:"activeSessions"`
|
||||
QueueDepth int `json:"queueDepth"`
|
||||
ConnectedClients int `json:"connectedClients"`
|
||||
SessionsToday int `json:"sessionsToday"`
|
||||
ProjectObservations int `json:"projectObservations,omitempty"`
|
||||
IsProcessing bool `json:"isProcessing"`
|
||||
Ready bool `json:"ready"`
|
||||
}
|
||||
|
||||
// ANSI color codes
|
||||
|
||||
@@ -14,17 +14,17 @@ import (
|
||||
// Input is the hook input from Claude Code.
|
||||
type Input struct {
|
||||
hooks.BaseInput
|
||||
StopHookActive bool `json:"stop_hook_active"`
|
||||
TranscriptPath string `json:"transcript_path"`
|
||||
StopHookActive bool `json:"stop_hook_active"`
|
||||
}
|
||||
|
||||
// TranscriptMessage represents a message in the transcript JSONL file.
|
||||
type TranscriptMessage struct {
|
||||
Type string `json:"type"`
|
||||
Message struct {
|
||||
Content any `json:"content"`
|
||||
Role string `json:"role"`
|
||||
Content any `json:"content"` // Can be string or array
|
||||
} `json:"message"`
|
||||
Type string `json:"type"` // Can be string or array
|
||||
}
|
||||
|
||||
// extractTextContent extracts text content from message content (handles both string and array formats).
|
||||
|
||||
+23
-6
@@ -40,7 +40,7 @@
|
||||
class="w-full h-auto"
|
||||
/>
|
||||
</div>
|
||||
<p class="text-center text-slate-500 text-sm mt-4">The dashboard at localhost:37777 - browse, search, and manage your memories</p>
|
||||
<p class="text-center text-slate-500 text-sm mt-4">The dashboard at localhost:37777 - browse, search, and manage your memories. View graph stats, vector metrics, storage savings, and performance analytics.</p>
|
||||
</div>
|
||||
</section>
|
||||
|
||||
@@ -304,7 +304,7 @@
|
||||
<section class="py-20 lg:py-28 px-4 sm:px-6">
|
||||
<div class="max-w-6xl mx-auto">
|
||||
<SectionHeader title="Under the hood" subtitle="Built with simplicity and performance in mind" />
|
||||
<div class="grid sm:grid-cols-2 lg:grid-cols-4 gap-4 sm:gap-6 text-center">
|
||||
<div class="grid sm:grid-cols-2 lg:grid-cols-3 gap-4 sm:gap-6 text-center">
|
||||
<div class="glass rounded-2xl p-6 sm:p-8 hover:border-amber-500/30 transition-colors">
|
||||
<div class="text-3xl sm:text-4xl font-bold text-amber-500 mb-2">Go</div>
|
||||
<p class="text-slate-400 text-xs sm:text-sm">Single binary. Fast startup, low memory. Zero runtime dependencies.</p>
|
||||
@@ -315,12 +315,20 @@
|
||||
</div>
|
||||
<div class="glass rounded-2xl p-6 sm:p-8 hover:border-amber-500/30 transition-colors">
|
||||
<div class="text-3xl sm:text-4xl font-bold text-amber-500 mb-2">sqlite-vec</div>
|
||||
<p class="text-slate-400 text-xs sm:text-sm">Embedded vector database. No external services required.</p>
|
||||
<p class="text-slate-400 text-xs sm:text-sm">Hybrid vector storage with LEANN-inspired selective embeddings. 60-80% storage reduction.</p>
|
||||
</div>
|
||||
<div class="glass rounded-2xl p-6 sm:p-8 hover:border-amber-500/30 transition-colors">
|
||||
<div class="text-3xl sm:text-4xl font-bold text-amber-500 mb-2">BGE</div>
|
||||
<p class="text-slate-400 text-xs sm:text-sm">Two-stage retrieval: bi-encoder embeddings + cross-encoder reranking for high accuracy.</p>
|
||||
</div>
|
||||
<div class="glass rounded-2xl p-6 sm:p-8 hover:border-amber-500/30 transition-colors">
|
||||
<div class="text-3xl sm:text-4xl font-bold text-amber-500 mb-2">Tree-sitter</div>
|
||||
<p class="text-slate-400 text-xs sm:text-sm">AST-aware code chunking respects function boundaries for Go, Python, and TypeScript.</p>
|
||||
</div>
|
||||
<div class="glass rounded-2xl p-6 sm:p-8 hover:border-amber-500/30 transition-colors">
|
||||
<div class="text-3xl sm:text-4xl font-bold text-amber-500 mb-2">CSR Graph</div>
|
||||
<p class="text-slate-400 text-xs sm:text-sm">Memory-efficient observation relationship graph with edge detection and hub identification.</p>
|
||||
</div>
|
||||
</div>
|
||||
</div>
|
||||
</section>
|
||||
@@ -417,9 +425,12 @@ const activeTab = ref('macos')
|
||||
const features = [
|
||||
{ icon: 'fas fa-brain', title: 'Learns as you work', description: 'Every bug fix, every architecture decision, every "aha moment" - captured automatically without breaking your flow.' },
|
||||
{ icon: 'fas fa-search', title: 'Two-stage retrieval', description: 'Cross-encoder reranking delivers highly relevant results. Finds what you need even with vague queries like "that auth thing".' },
|
||||
{ icon: 'fas fa-project-diagram', title: 'Knowledge graph', description: 'Automatically discovers relationships between memories. See how concepts connect in the visual graph dashboard.' },
|
||||
{ icon: 'fas fa-project-diagram', title: 'Graph-based search', description: 'LEANN Phase 2: Graph relationships between observations (file overlap, semantic similarity, temporal proximity) for smarter context retrieval.' },
|
||||
{ icon: 'fas fa-microchip', title: 'AST-aware chunking', description: 'Intelligent code splitting respects function boundaries. Go, Python, and TypeScript code is chunked at semantic boundaries, not arbitrary line counts.' },
|
||||
{ icon: 'fas fa-database', title: 'Hybrid vector storage', description: 'LEANN-inspired selective storage: frequently-accessed "hub" observations store embeddings, others recompute on-demand. 60-80% storage savings with <50ms latency.' },
|
||||
{ icon: 'fas fa-folder-tree', title: 'Project-aware context', description: 'Your React knowledge stays in React projects. Your Go patterns stay in Go projects. No context pollution.' },
|
||||
{ icon: 'fas fa-chart-line', title: 'Smart scoring', description: 'Importance decay, pattern detection, and conflict resolution ensure the most valuable memories surface first.' },
|
||||
{ icon: 'fas fa-gauge-high', title: 'Auto-tuning', description: 'Dynamic hub threshold adjustment based on query performance. Automatically balances storage efficiency with search latency for your workload.' },
|
||||
{ icon: 'fas fa-lock', title: '100% private', description: 'Your code context never leaves your machine. No telemetry. No cloud sync. Your memories are yours.' },
|
||||
]
|
||||
|
||||
@@ -447,6 +458,10 @@ const configOptions = [
|
||||
{ name: 'CLAUDE_MNEMONIC_CONTEXT_OBSERVATIONS', description: 'Maximum observations injected per session (default: 100)', icon: 'fas fa-layer-group' },
|
||||
{ name: 'CLAUDE_MNEMONIC_RERANKING_ENABLED', description: 'Enable cross-encoder reranking for improved search relevance (default: true)', icon: 'fas fa-sort-amount-down' },
|
||||
{ name: 'CLAUDE_MNEMONIC_CONTEXT_RELEVANCE_THRESHOLD', description: 'Minimum similarity score for inclusion, 0.0-1.0 (default: 0.3)', icon: 'fas fa-filter' },
|
||||
{ name: 'CLAUDE_MNEMONIC_VECTOR_STORAGE_STRATEGY', description: 'Storage strategy: "hub" (default), "always", or "on_demand"', icon: 'fas fa-database' },
|
||||
{ name: 'CLAUDE_MNEMONIC_GRAPH_ENABLED', description: 'Enable graph-based search with observation relationships (default: true)', icon: 'fas fa-project-diagram' },
|
||||
{ name: 'CLAUDE_MNEMONIC_GRAPH_MAX_HOPS', description: 'Maximum graph traversal depth for search expansion (default: 2)', icon: 'fas fa-route' },
|
||||
{ name: 'CLAUDE_MNEMONIC_GRAPH_REBUILD_INTERVAL_MIN', description: 'How often to rebuild the observation graph in minutes (default: 60)', icon: 'fas fa-clock' },
|
||||
]
|
||||
|
||||
const requiredDeps = [
|
||||
@@ -457,9 +472,11 @@ const requiredDeps = [
|
||||
const faqs = [
|
||||
{ question: 'Will it confuse Claude with wrong context?', answer: 'No. Mnemonic uses project isolation and semantic relevance scoring. Only memories from the current project (or global best practices) are injected, and only when they\'re actually relevant to your prompt.' },
|
||||
{ question: 'What exactly gets saved?', answer: 'Bug fixes with context ("Fixed race condition by adding mutex"), architecture decisions ("Using repository pattern for data access"), conventions ("All API routes prefixed with /api/v1"), and learnings you want to preserve.' },
|
||||
{ question: 'Can I delete or edit memories?', answer: 'Yes. The web dashboard at localhost:37777 lets you browse, search, edit, and delete any memory. You\'re always in control.' },
|
||||
{ question: 'How does hybrid vector storage work?', answer: 'LEANN-inspired selective storage: frequently-accessed "hub" observations (identified by access patterns and graph centrality) store embeddings. Infrequently-accessed observations recompute embeddings on-demand during search. This reduces storage by 60-80% with minimal latency impact (<50ms).' },
|
||||
{ question: 'Can I delete or edit memories?', answer: 'Yes. The web dashboard at localhost:37777 lets you browse, search, edit, and delete any memory. You can also view graph relationships, storage metrics, and performance analytics. You\'re always in control.' },
|
||||
{ question: 'Does it work with my existing Claude Code setup?', answer: 'Yes. Mnemonic installs as a Claude Code plugin with hooks. Your existing workflows, settings, and shortcuts remain unchanged.' },
|
||||
{ question: 'What if I switch between projects frequently?', answer: 'That\'s the point. Each project has isolated memories. Switch from your Python ML project to your TypeScript app - context switches automatically.' },
|
||||
{ question: 'Is there a performance impact?', answer: 'Minimal. The Go worker is lightweight (typically under 30MB RAM). Context injection at session start takes milliseconds for most projects.' },
|
||||
{ question: 'Is there a performance impact?', answer: 'Minimal. The Go worker is lightweight (typically under 30MB RAM). Hybrid storage and auto-tuning optimize for your workload. Context injection at session start takes milliseconds for most projects.' },
|
||||
{ question: 'What is AST-aware chunking?', answer: 'When processing code observations, Mnemonic uses Tree-sitter parsers to respect function and class boundaries instead of arbitrary line limits. Go, Python, and TypeScript code is chunked at semantic boundaries for better search accuracy.' },
|
||||
]
|
||||
</script>
|
||||
|
||||
@@ -12,6 +12,7 @@ require (
|
||||
github.com/goccy/go-json v0.10.5
|
||||
github.com/mattn/go-sqlite3 v1.14.33
|
||||
github.com/rs/zerolog v1.34.0
|
||||
github.com/smacker/go-tree-sitter v0.0.0-20240827094217-dd81d9e9be82
|
||||
github.com/stretchr/testify v1.11.1
|
||||
github.com/sugarme/tokenizer v0.3.0
|
||||
github.com/yalue/onnxruntime_go v1.25.0
|
||||
|
||||
@@ -44,6 +44,8 @@ github.com/rs/zerolog v1.34.0 h1:k43nTLIwcTVQAncfCw4KZ2VY6ukYoZaBPNOE8txlOeY=
|
||||
github.com/rs/zerolog v1.34.0/go.mod h1:bJsvje4Z08ROH4Nhs5iH600c3IkWhwp44iRc54W6wYQ=
|
||||
github.com/schollz/progressbar/v2 v2.15.0 h1:dVzHQ8fHRmtPjD3K10jT3Qgn/+H+92jhPrhmxIJfDz8=
|
||||
github.com/schollz/progressbar/v2 v2.15.0/go.mod h1:UdPq3prGkfQ7MOzZKlDRpYKcFqEMczbD7YmbPgpzKMI=
|
||||
github.com/smacker/go-tree-sitter v0.0.0-20240827094217-dd81d9e9be82 h1:6C8qej6f1bStuePVkLSFxoU22XBS165D3klxlzRg8F4=
|
||||
github.com/smacker/go-tree-sitter v0.0.0-20240827094217-dd81d9e9be82/go.mod h1:xe4pgH49k4SsmkQq5OT8abwhWmnzkhpgnXeekbx2efw=
|
||||
github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME=
|
||||
github.com/stretchr/testify v1.3.0/go.mod h1:M5WIy9Dh21IEIfnGCwXGc5bZfKNJtfHm1UVUgZn+9EI=
|
||||
github.com/stretchr/testify v1.11.1 h1:7s2iGBzp5EwR7/aIZr8ao5+dra3wiQyKjjFuvgVKu7U=
|
||||
|
||||
@@ -0,0 +1,285 @@
|
||||
// Package golang provides AST-aware chunking for Go source files.
|
||||
package golang
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"go/ast"
|
||||
"go/parser"
|
||||
"go/token"
|
||||
"os"
|
||||
"strings"
|
||||
|
||||
"github.com/lukaszraczylo/claude-mnemonic/internal/chunking"
|
||||
)
|
||||
|
||||
// Chunker implements AST-aware chunking for Go files.
|
||||
type Chunker struct {
|
||||
options chunking.ChunkOptions
|
||||
}
|
||||
|
||||
// NewChunker creates a new Go chunker.
|
||||
func NewChunker(options chunking.ChunkOptions) *Chunker {
|
||||
return &Chunker{options: options}
|
||||
}
|
||||
|
||||
// Language returns the language this chunker supports.
|
||||
func (c *Chunker) Language() chunking.Language {
|
||||
return chunking.LanguageGo
|
||||
}
|
||||
|
||||
// SupportedExtensions returns the file extensions this chunker handles.
|
||||
func (c *Chunker) SupportedExtensions() []string {
|
||||
return []string{".go"}
|
||||
}
|
||||
|
||||
// Chunk parses a Go source file and returns semantic code chunks.
|
||||
func (c *Chunker) Chunk(ctx context.Context, filePath string) ([]chunking.Chunk, error) {
|
||||
// Read file content
|
||||
content, err := os.ReadFile(filePath)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("read file: %w", err)
|
||||
}
|
||||
|
||||
// Parse the Go file
|
||||
fset := token.NewFileSet()
|
||||
file, err := parser.ParseFile(fset, filePath, content, parser.ParseComments)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("parse Go file: %w", err)
|
||||
}
|
||||
|
||||
chunks := make([]chunking.Chunk, 0)
|
||||
sourceLines := strings.Split(string(content), "\n")
|
||||
|
||||
// Extract chunks from declarations
|
||||
for _, decl := range file.Decls {
|
||||
switch d := decl.(type) {
|
||||
case *ast.FuncDecl:
|
||||
chunk := c.extractFunction(fset, d, sourceLines, filePath)
|
||||
if chunk != nil {
|
||||
chunks = append(chunks, *chunk)
|
||||
}
|
||||
case *ast.GenDecl:
|
||||
extracted := c.extractGenDecl(fset, d, sourceLines, filePath)
|
||||
chunks = append(chunks, extracted...)
|
||||
}
|
||||
}
|
||||
|
||||
return chunks, nil
|
||||
}
|
||||
|
||||
// extractFunction extracts a function or method declaration as a chunk.
|
||||
func (c *Chunker) extractFunction(fset *token.FileSet, fn *ast.FuncDecl, sourceLines []string, filePath string) *chunking.Chunk {
|
||||
// Skip unexported if configured
|
||||
if !c.options.IncludePrivate && !fn.Name.IsExported() {
|
||||
return nil
|
||||
}
|
||||
|
||||
startPos := fset.Position(fn.Pos())
|
||||
endPos := fset.Position(fn.End())
|
||||
|
||||
chunk := &chunking.Chunk{
|
||||
FilePath: filePath,
|
||||
Language: chunking.LanguageGo,
|
||||
Name: fn.Name.Name,
|
||||
StartLine: startPos.Line,
|
||||
EndLine: endPos.Line,
|
||||
}
|
||||
|
||||
// Determine if this is a method or a function
|
||||
if fn.Recv != nil && len(fn.Recv.List) > 0 {
|
||||
chunk.Type = chunking.ChunkTypeMethod
|
||||
chunk.ParentName = c.extractReceiverType(fn.Recv)
|
||||
} else {
|
||||
chunk.Type = chunking.ChunkTypeFunction
|
||||
}
|
||||
|
||||
// Extract content
|
||||
chunk.Content = c.extractLines(sourceLines, startPos.Line, endPos.Line)
|
||||
|
||||
// Extract signature (function declaration without body)
|
||||
chunk.Signature = c.extractFunctionSignature(fn, fset, sourceLines)
|
||||
|
||||
// Extract doc comment
|
||||
if c.options.IncludeDocComments && fn.Doc != nil {
|
||||
chunk.DocComment = strings.TrimSpace(fn.Doc.Text())
|
||||
}
|
||||
|
||||
return chunk
|
||||
}
|
||||
|
||||
// extractGenDecl extracts general declarations (type, const, var).
|
||||
func (c *Chunker) extractGenDecl(fset *token.FileSet, gd *ast.GenDecl, sourceLines []string, filePath string) []chunking.Chunk {
|
||||
var chunks []chunking.Chunk
|
||||
|
||||
for _, spec := range gd.Specs {
|
||||
switch s := spec.(type) {
|
||||
case *ast.TypeSpec:
|
||||
chunk := c.extractTypeSpec(fset, gd, s, sourceLines, filePath)
|
||||
if chunk != nil {
|
||||
chunks = append(chunks, *chunk)
|
||||
}
|
||||
case *ast.ValueSpec:
|
||||
// Handle const and var declarations
|
||||
chunk := c.extractValueSpec(fset, gd, s, sourceLines, filePath)
|
||||
if chunk != nil {
|
||||
chunks = append(chunks, *chunk)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return chunks
|
||||
}
|
||||
|
||||
// extractTypeSpec extracts a type declaration (struct, interface, type alias).
|
||||
func (c *Chunker) extractTypeSpec(fset *token.FileSet, gd *ast.GenDecl, ts *ast.TypeSpec, sourceLines []string, filePath string) *chunking.Chunk {
|
||||
// Skip unexported if configured
|
||||
if !c.options.IncludePrivate && !ts.Name.IsExported() {
|
||||
return nil
|
||||
}
|
||||
|
||||
startPos := fset.Position(gd.Pos())
|
||||
endPos := fset.Position(gd.End())
|
||||
|
||||
chunk := &chunking.Chunk{
|
||||
FilePath: filePath,
|
||||
Language: chunking.LanguageGo,
|
||||
Name: ts.Name.Name,
|
||||
StartLine: startPos.Line,
|
||||
EndLine: endPos.Line,
|
||||
Content: c.extractLines(sourceLines, startPos.Line, endPos.Line),
|
||||
}
|
||||
|
||||
// Determine chunk type based on type expression
|
||||
switch ts.Type.(type) {
|
||||
case *ast.StructType:
|
||||
chunk.Type = chunking.ChunkTypeClass // Treat struct as class
|
||||
case *ast.InterfaceType:
|
||||
chunk.Type = chunking.ChunkTypeInterface
|
||||
default:
|
||||
chunk.Type = chunking.ChunkTypeType
|
||||
}
|
||||
|
||||
// Extract doc comment
|
||||
if c.options.IncludeDocComments && gd.Doc != nil {
|
||||
chunk.DocComment = strings.TrimSpace(gd.Doc.Text())
|
||||
}
|
||||
|
||||
return chunk
|
||||
}
|
||||
|
||||
// extractValueSpec extracts const or var declarations.
|
||||
func (c *Chunker) extractValueSpec(fset *token.FileSet, gd *ast.GenDecl, vs *ast.ValueSpec, sourceLines []string, filePath string) *chunking.Chunk {
|
||||
// Skip if all names are unexported and we're excluding private
|
||||
if !c.options.IncludePrivate {
|
||||
allUnexported := true
|
||||
for _, name := range vs.Names {
|
||||
if name.IsExported() {
|
||||
allUnexported = false
|
||||
break
|
||||
}
|
||||
}
|
||||
if allUnexported {
|
||||
return nil
|
||||
}
|
||||
}
|
||||
|
||||
startPos := fset.Position(gd.Pos())
|
||||
endPos := fset.Position(gd.End())
|
||||
|
||||
// Use first name as the chunk name, join multiple if present
|
||||
names := make([]string, len(vs.Names))
|
||||
for i, name := range vs.Names {
|
||||
names[i] = name.Name
|
||||
}
|
||||
|
||||
chunk := &chunking.Chunk{
|
||||
FilePath: filePath,
|
||||
Language: chunking.LanguageGo,
|
||||
Name: strings.Join(names, ", "),
|
||||
StartLine: startPos.Line,
|
||||
EndLine: endPos.Line,
|
||||
Content: c.extractLines(sourceLines, startPos.Line, endPos.Line),
|
||||
}
|
||||
|
||||
// Set type based on token
|
||||
if gd.Tok == token.CONST {
|
||||
chunk.Type = chunking.ChunkTypeConst
|
||||
} else {
|
||||
chunk.Type = chunking.ChunkTypeVar
|
||||
}
|
||||
|
||||
// Extract doc comment
|
||||
if c.options.IncludeDocComments && gd.Doc != nil {
|
||||
chunk.DocComment = strings.TrimSpace(gd.Doc.Text())
|
||||
}
|
||||
|
||||
return chunk
|
||||
}
|
||||
|
||||
// extractReceiverType extracts the receiver type name from a method.
|
||||
func (c *Chunker) extractReceiverType(recv *ast.FieldList) string {
|
||||
if len(recv.List) == 0 {
|
||||
return ""
|
||||
}
|
||||
|
||||
field := recv.List[0]
|
||||
switch t := field.Type.(type) {
|
||||
case *ast.Ident:
|
||||
return t.Name
|
||||
case *ast.StarExpr:
|
||||
if ident, ok := t.X.(*ast.Ident); ok {
|
||||
return ident.Name
|
||||
}
|
||||
}
|
||||
|
||||
return ""
|
||||
}
|
||||
|
||||
// extractFunctionSignature extracts the function signature without the body.
|
||||
func (c *Chunker) extractFunctionSignature(fn *ast.FuncDecl, fset *token.FileSet, sourceLines []string) string {
|
||||
if fn.Body == nil {
|
||||
// No body, return entire declaration
|
||||
startPos := fset.Position(fn.Pos())
|
||||
endPos := fset.Position(fn.End())
|
||||
return c.extractLines(sourceLines, startPos.Line, endPos.Line)
|
||||
}
|
||||
|
||||
// Extract from start of function to just before body
|
||||
startPos := fset.Position(fn.Pos())
|
||||
bodyPos := fset.Position(fn.Body.Pos())
|
||||
|
||||
// If body is on the same line, extract just that line up to the opening brace
|
||||
if startPos.Line == bodyPos.Line {
|
||||
line := sourceLines[startPos.Line-1]
|
||||
// Find the opening brace position
|
||||
if idx := strings.Index(line[startPos.Column-1:], "{"); idx >= 0 {
|
||||
return strings.TrimSpace(line[startPos.Column-1 : startPos.Column-1+idx])
|
||||
}
|
||||
return strings.TrimSpace(line[startPos.Column-1:])
|
||||
}
|
||||
|
||||
// Get lines from start to the line containing the opening brace
|
||||
sig := c.extractLines(sourceLines, startPos.Line, bodyPos.Line)
|
||||
// Remove the opening brace and anything after it
|
||||
if idx := strings.Index(sig, "{"); idx >= 0 {
|
||||
sig = sig[:idx]
|
||||
}
|
||||
return strings.TrimSpace(sig)
|
||||
}
|
||||
|
||||
// extractLines extracts a range of lines from source (1-indexed, inclusive).
|
||||
func (c *Chunker) extractLines(lines []string, start, end int) string {
|
||||
if start < 1 || end < start || start > len(lines) {
|
||||
return ""
|
||||
}
|
||||
|
||||
// Adjust for 0-indexed array (start and end are 1-indexed)
|
||||
startIdx := start - 1
|
||||
endIdx := end
|
||||
if endIdx > len(lines) {
|
||||
endIdx = len(lines)
|
||||
}
|
||||
|
||||
return strings.Join(lines[startIdx:endIdx], "\n")
|
||||
}
|
||||
@@ -0,0 +1,214 @@
|
||||
package golang
|
||||
|
||||
import (
|
||||
"context"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"testing"
|
||||
|
||||
"github.com/lukaszraczylo/claude-mnemonic/internal/chunking"
|
||||
)
|
||||
|
||||
func TestGoChunker_BasicFunctions(t *testing.T) {
|
||||
// Create temp test file
|
||||
tmpDir := t.TempDir()
|
||||
testFile := filepath.Join(tmpDir, "test.go")
|
||||
|
||||
testCode := `package main
|
||||
|
||||
import "fmt"
|
||||
|
||||
// Greet prints a greeting message
|
||||
func Greet(name string) {
|
||||
fmt.Printf("Hello, %s!\n", name)
|
||||
}
|
||||
|
||||
// Add adds two numbers
|
||||
func Add(a, b int) int {
|
||||
return a + b
|
||||
}
|
||||
|
||||
// unexported function should be included by default
|
||||
func helper() string {
|
||||
return "helper"
|
||||
}
|
||||
`
|
||||
|
||||
if err := os.WriteFile(testFile, []byte(testCode), 0600); err != nil {
|
||||
t.Fatalf("Failed to create test file: %v", err)
|
||||
}
|
||||
|
||||
// Create chunker with default options
|
||||
chunker := NewChunker(chunking.DefaultChunkOptions())
|
||||
|
||||
// Chunk the file
|
||||
chunks, err := chunker.Chunk(context.Background(), testFile)
|
||||
if err != nil {
|
||||
t.Fatalf("Chunk() failed: %v", err)
|
||||
}
|
||||
|
||||
// Verify we got all functions
|
||||
if len(chunks) != 3 {
|
||||
t.Errorf("Expected 3 chunks (Greet, Add, helper), got %d", len(chunks))
|
||||
}
|
||||
|
||||
// Verify chunk details
|
||||
expectedNames := map[string]bool{
|
||||
"Greet": false,
|
||||
"Add": false,
|
||||
"helper": false,
|
||||
}
|
||||
|
||||
for _, chunk := range chunks {
|
||||
if chunk.Type != chunking.ChunkTypeFunction {
|
||||
t.Errorf("Expected chunk type 'function', got '%s'", chunk.Type)
|
||||
}
|
||||
|
||||
if chunk.Language != chunking.LanguageGo {
|
||||
t.Errorf("Expected language 'go', got '%s'", chunk.Language)
|
||||
}
|
||||
|
||||
if _, ok := expectedNames[chunk.Name]; !ok {
|
||||
t.Errorf("Unexpected function name: %s", chunk.Name)
|
||||
} else {
|
||||
expectedNames[chunk.Name] = true
|
||||
}
|
||||
|
||||
// Verify content is non-empty
|
||||
if chunk.Content == "" {
|
||||
t.Errorf("Chunk %s has empty content", chunk.Name)
|
||||
}
|
||||
|
||||
// Verify signature is present for functions
|
||||
if chunk.Signature == "" {
|
||||
t.Errorf("Chunk %s has empty signature", chunk.Name)
|
||||
}
|
||||
}
|
||||
|
||||
// Verify all expected functions were found
|
||||
for name, found := range expectedNames {
|
||||
if !found {
|
||||
t.Errorf("Expected function %s not found", name)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestGoChunker_StructsAndMethods(t *testing.T) {
|
||||
tmpDir := t.TempDir()
|
||||
testFile := filepath.Join(tmpDir, "test.go")
|
||||
|
||||
testCode := `package main
|
||||
|
||||
// User represents a user
|
||||
type User struct {
|
||||
ID int
|
||||
Name string
|
||||
}
|
||||
|
||||
// GetName returns the user's name
|
||||
func (u *User) GetName() string {
|
||||
return u.Name
|
||||
}
|
||||
|
||||
// SetName sets the user's name
|
||||
func (u *User) SetName(name string) {
|
||||
u.Name = name
|
||||
}
|
||||
`
|
||||
|
||||
if err := os.WriteFile(testFile, []byte(testCode), 0600); err != nil {
|
||||
t.Fatalf("Failed to create test file: %v", err)
|
||||
}
|
||||
|
||||
chunker := NewChunker(chunking.DefaultChunkOptions())
|
||||
chunks, err := chunker.Chunk(context.Background(), testFile)
|
||||
if err != nil {
|
||||
t.Fatalf("Chunk() failed: %v", err)
|
||||
}
|
||||
|
||||
// Should have 1 struct + 2 methods = 3 chunks
|
||||
if len(chunks) != 3 {
|
||||
t.Errorf("Expected 3 chunks (User struct, GetName, SetName), got %d", len(chunks))
|
||||
}
|
||||
|
||||
// Find the struct and methods
|
||||
var structChunk, getNameChunk, setNameChunk *chunking.Chunk
|
||||
for i := range chunks {
|
||||
switch chunks[i].Name {
|
||||
case "User":
|
||||
structChunk = &chunks[i]
|
||||
case "GetName":
|
||||
getNameChunk = &chunks[i]
|
||||
case "SetName":
|
||||
setNameChunk = &chunks[i]
|
||||
}
|
||||
}
|
||||
|
||||
// Verify struct
|
||||
if structChunk == nil {
|
||||
t.Fatal("User struct not found")
|
||||
}
|
||||
if structChunk.Type != chunking.ChunkTypeClass {
|
||||
t.Errorf("Expected User to be ChunkTypeClass, got %s", structChunk.Type)
|
||||
}
|
||||
|
||||
// Verify methods
|
||||
if getNameChunk == nil {
|
||||
t.Fatal("GetName method not found")
|
||||
}
|
||||
if getNameChunk.Type != chunking.ChunkTypeMethod {
|
||||
t.Errorf("Expected GetName to be ChunkTypeMethod, got %s", getNameChunk.Type)
|
||||
}
|
||||
if getNameChunk.ParentName != "User" {
|
||||
t.Errorf("Expected GetName parent to be 'User', got '%s'", getNameChunk.ParentName)
|
||||
}
|
||||
|
||||
if setNameChunk == nil {
|
||||
t.Fatal("SetName method not found")
|
||||
}
|
||||
if setNameChunk.Type != chunking.ChunkTypeMethod {
|
||||
t.Errorf("Expected SetName to be ChunkTypeMethod, got %s", setNameChunk.Type)
|
||||
}
|
||||
if setNameChunk.ParentName != "User" {
|
||||
t.Errorf("Expected SetName parent to be 'User', got '%s'", setNameChunk.ParentName)
|
||||
}
|
||||
}
|
||||
|
||||
func TestGoChunker_DocComments(t *testing.T) {
|
||||
tmpDir := t.TempDir()
|
||||
testFile := filepath.Join(tmpDir, "test.go")
|
||||
|
||||
testCode := `package main
|
||||
|
||||
// Calculate performs a calculation.
|
||||
// It takes two integers and returns their sum.
|
||||
func Calculate(a, b int) int {
|
||||
return a + b
|
||||
}
|
||||
`
|
||||
|
||||
if err := os.WriteFile(testFile, []byte(testCode), 0600); err != nil {
|
||||
t.Fatalf("Failed to create test file: %v", err)
|
||||
}
|
||||
|
||||
chunker := NewChunker(chunking.DefaultChunkOptions())
|
||||
chunks, err := chunker.Chunk(context.Background(), testFile)
|
||||
if err != nil {
|
||||
t.Fatalf("Chunk() failed: %v", err)
|
||||
}
|
||||
|
||||
if len(chunks) != 1 {
|
||||
t.Fatalf("Expected 1 chunk, got %d", len(chunks))
|
||||
}
|
||||
|
||||
chunk := chunks[0]
|
||||
if chunk.DocComment == "" {
|
||||
t.Error("Expected doc comment to be present")
|
||||
}
|
||||
|
||||
// Doc comment should contain the comment text
|
||||
expectedComment := "Calculate performs a calculation.\nIt takes two integers and returns their sum."
|
||||
if chunk.DocComment != expectedComment {
|
||||
t.Errorf("Expected doc comment '%s', got '%s'", expectedComment, chunk.DocComment)
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,106 @@
|
||||
package chunking
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"path/filepath"
|
||||
"strings"
|
||||
)
|
||||
|
||||
// Manager dispatches files to appropriate language-specific chunkers.
|
||||
type Manager struct {
|
||||
chunkers map[string]Chunker // extension -> chunker
|
||||
options ChunkOptions
|
||||
}
|
||||
|
||||
// NewManager creates a new chunking manager with the given chunkers.
|
||||
func NewManager(chunkers []Chunker, options ChunkOptions) *Manager {
|
||||
m := &Manager{
|
||||
chunkers: make(map[string]Chunker),
|
||||
options: options,
|
||||
}
|
||||
|
||||
// Register chunkers by their supported extensions
|
||||
for _, chunker := range chunkers {
|
||||
for _, ext := range chunker.SupportedExtensions() {
|
||||
m.chunkers[ext] = chunker
|
||||
}
|
||||
}
|
||||
|
||||
return m
|
||||
}
|
||||
|
||||
// ChunkFile chunks a single file using the appropriate language chunker.
|
||||
// Returns an error if no chunker is found for the file extension.
|
||||
func (m *Manager) ChunkFile(ctx context.Context, filePath string) ([]Chunk, error) {
|
||||
ext := strings.ToLower(filepath.Ext(filePath))
|
||||
chunker, ok := m.chunkers[ext]
|
||||
if !ok {
|
||||
return nil, fmt.Errorf("no chunker for extension %s", ext)
|
||||
}
|
||||
|
||||
chunks, err := chunker.Chunk(ctx, filePath)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("chunk %s: %w", filePath, err)
|
||||
}
|
||||
|
||||
// Apply options-based filtering
|
||||
filtered := make([]Chunk, 0, len(chunks))
|
||||
for _, chunk := range chunks {
|
||||
// Filter by minimum lines
|
||||
if m.options.MinLines > 0 {
|
||||
lineCount := chunk.EndLine - chunk.StartLine + 1
|
||||
if lineCount < m.options.MinLines {
|
||||
continue
|
||||
}
|
||||
}
|
||||
|
||||
// Filter by maximum chunk size
|
||||
if m.options.MaxChunkSize > 0 && len(chunk.Content) > m.options.MaxChunkSize {
|
||||
// TODO: Consider splitting large chunks intelligently
|
||||
// For now, skip chunks that are too large
|
||||
continue
|
||||
}
|
||||
|
||||
filtered = append(filtered, chunk)
|
||||
}
|
||||
|
||||
return filtered, nil
|
||||
}
|
||||
|
||||
// ChunkFiles chunks multiple files in parallel.
|
||||
// Returns a map of file path to chunks, and any errors encountered.
|
||||
// Errors for individual files do not stop processing of other files.
|
||||
func (m *Manager) ChunkFiles(ctx context.Context, filePaths []string) (map[string][]Chunk, []error) {
|
||||
results := make(map[string][]Chunk)
|
||||
var errors []error
|
||||
|
||||
for _, filePath := range filePaths {
|
||||
chunks, err := m.ChunkFile(ctx, filePath)
|
||||
if err != nil {
|
||||
errors = append(errors, fmt.Errorf("%s: %w", filePath, err))
|
||||
continue
|
||||
}
|
||||
if len(chunks) > 0 {
|
||||
results[filePath] = chunks
|
||||
}
|
||||
}
|
||||
|
||||
return results, errors
|
||||
}
|
||||
|
||||
// SupportsFile checks if the manager can chunk the given file based on extension.
|
||||
func (m *Manager) SupportsFile(filePath string) bool {
|
||||
ext := strings.ToLower(filepath.Ext(filePath))
|
||||
_, ok := m.chunkers[ext]
|
||||
return ok
|
||||
}
|
||||
|
||||
// SupportedExtensions returns all file extensions supported by registered chunkers.
|
||||
func (m *Manager) SupportedExtensions() []string {
|
||||
exts := make([]string, 0, len(m.chunkers))
|
||||
for ext := range m.chunkers {
|
||||
exts = append(exts, ext)
|
||||
}
|
||||
return exts
|
||||
}
|
||||
@@ -0,0 +1,162 @@
|
||||
package chunking
|
||||
|
||||
import (
|
||||
"context"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"testing"
|
||||
)
|
||||
|
||||
// mockChunker is a test chunker that returns dummy chunks
|
||||
type mockChunker struct{}
|
||||
|
||||
func (m *mockChunker) Chunk(ctx context.Context, filePath string) ([]Chunk, error) {
|
||||
// Just return an empty chunk for testing
|
||||
return []Chunk{
|
||||
{
|
||||
FilePath: filePath,
|
||||
Language: LanguageGo,
|
||||
Type: ChunkTypeFunction,
|
||||
Name: "TestFunc",
|
||||
StartLine: 1,
|
||||
EndLine: 1,
|
||||
Content: "test",
|
||||
},
|
||||
}, nil
|
||||
}
|
||||
|
||||
func (m *mockChunker) Language() Language {
|
||||
return LanguageGo
|
||||
}
|
||||
|
||||
func (m *mockChunker) SupportedExtensions() []string {
|
||||
return []string{".go", ".py", ".ts"}
|
||||
}
|
||||
|
||||
func TestManager_ChunkMultipleFiles(t *testing.T) {
|
||||
tmpDir := t.TempDir()
|
||||
|
||||
// Create a Go file
|
||||
goFile := filepath.Join(tmpDir, "test.go")
|
||||
goCode := `package main
|
||||
|
||||
func Hello() string {
|
||||
return "hello"
|
||||
}
|
||||
`
|
||||
if err := os.WriteFile(goFile, []byte(goCode), 0600); err != nil {
|
||||
t.Fatalf("Failed to create Go file: %v", err)
|
||||
}
|
||||
|
||||
// Create a Python file
|
||||
pyFile := filepath.Join(tmpDir, "test.py")
|
||||
pyCode := `def greet(name):
|
||||
return f"Hello, {name}!"
|
||||
|
||||
class User:
|
||||
def __init__(self, name):
|
||||
self.name = name
|
||||
`
|
||||
if err := os.WriteFile(pyFile, []byte(pyCode), 0600); err != nil {
|
||||
t.Fatalf("Failed to create Python file: %v", err)
|
||||
}
|
||||
|
||||
// Create a TypeScript file
|
||||
tsFile := filepath.Join(tmpDir, "test.ts")
|
||||
tsCode := `function add(a: number, b: number): number {
|
||||
return a + b;
|
||||
}
|
||||
|
||||
class Calculator {
|
||||
multiply(a: number, b: number): number {
|
||||
return a * b;
|
||||
}
|
||||
}
|
||||
`
|
||||
if err := os.WriteFile(tsFile, []byte(tsCode), 0600); err != nil {
|
||||
t.Fatalf("Failed to create TypeScript file: %v", err)
|
||||
}
|
||||
|
||||
// Create manager
|
||||
manager := NewManager([]Chunker{&mockChunker{}}, DefaultChunkOptions())
|
||||
|
||||
// Test SupportsFile
|
||||
if !manager.SupportsFile(goFile) {
|
||||
t.Error("Manager should support .go files")
|
||||
}
|
||||
if !manager.SupportsFile(pyFile) {
|
||||
t.Error("Manager should support .py files")
|
||||
}
|
||||
if !manager.SupportsFile(tsFile) {
|
||||
t.Error("Manager should support .ts files")
|
||||
}
|
||||
|
||||
unsupportedFile := filepath.Join(tmpDir, "test.txt")
|
||||
if manager.SupportsFile(unsupportedFile) {
|
||||
t.Error("Manager should not support .txt files")
|
||||
}
|
||||
|
||||
// Test ChunkFiles
|
||||
results, errs := manager.ChunkFiles(context.Background(), []string{goFile, pyFile, tsFile})
|
||||
if len(errs) > 0 {
|
||||
t.Errorf("ChunkFiles returned errors: %v", errs)
|
||||
}
|
||||
|
||||
if len(results) != 3 {
|
||||
t.Errorf("Expected results for 3 files, got %d", len(results))
|
||||
}
|
||||
|
||||
// Verify each file has chunks
|
||||
for _, file := range []string{goFile, pyFile, tsFile} {
|
||||
if chunks, ok := results[file]; !ok || len(chunks) == 0 {
|
||||
t.Errorf("No chunks found for file %s", file)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// mockChunkerWithExts is a test chunker with configurable extensions
|
||||
type mockChunkerWithExts struct {
|
||||
exts []string
|
||||
}
|
||||
|
||||
func (m *mockChunkerWithExts) Chunk(ctx context.Context, filePath string) ([]Chunk, error) {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
func (m *mockChunkerWithExts) Language() Language {
|
||||
return LanguageGo
|
||||
}
|
||||
|
||||
func (m *mockChunkerWithExts) SupportedExtensions() []string {
|
||||
return m.exts
|
||||
}
|
||||
|
||||
func TestManager_SupportedExtensions(t *testing.T) {
|
||||
|
||||
// Create manager with mock chunkers
|
||||
manager := NewManager([]Chunker{
|
||||
&mockChunkerWithExts{exts: []string{".go"}},
|
||||
&mockChunkerWithExts{exts: []string{".py", ".pyw"}},
|
||||
}, DefaultChunkOptions())
|
||||
|
||||
exts := manager.SupportedExtensions()
|
||||
expectedExts := map[string]bool{
|
||||
".go": false,
|
||||
".py": false,
|
||||
".pyw": false,
|
||||
}
|
||||
|
||||
for _, ext := range exts {
|
||||
if _, ok := expectedExts[ext]; ok {
|
||||
expectedExts[ext] = true
|
||||
} else {
|
||||
t.Errorf("Unexpected extension: %s", ext)
|
||||
}
|
||||
}
|
||||
|
||||
for ext, found := range expectedExts {
|
||||
if !found {
|
||||
t.Errorf("Expected extension %s not found", ext)
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,291 @@
|
||||
// Package python provides AST-aware chunking for Python source files using tree-sitter.
|
||||
package python
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"os"
|
||||
"strings"
|
||||
|
||||
sitter "github.com/smacker/go-tree-sitter"
|
||||
"github.com/smacker/go-tree-sitter/python"
|
||||
|
||||
"github.com/lukaszraczylo/claude-mnemonic/internal/chunking"
|
||||
)
|
||||
|
||||
// Chunker implements AST-aware chunking for Python files.
|
||||
type Chunker struct {
|
||||
parser *sitter.Parser
|
||||
options chunking.ChunkOptions
|
||||
}
|
||||
|
||||
// NewChunker creates a new Python chunker.
|
||||
func NewChunker(options chunking.ChunkOptions) *Chunker {
|
||||
parser := sitter.NewParser()
|
||||
parser.SetLanguage(python.GetLanguage())
|
||||
|
||||
return &Chunker{
|
||||
options: options,
|
||||
parser: parser,
|
||||
}
|
||||
}
|
||||
|
||||
// Language returns the language this chunker supports.
|
||||
func (c *Chunker) Language() chunking.Language {
|
||||
return chunking.LanguagePython
|
||||
}
|
||||
|
||||
// SupportedExtensions returns the file extensions this chunker handles.
|
||||
func (c *Chunker) SupportedExtensions() []string {
|
||||
return []string{".py"}
|
||||
}
|
||||
|
||||
// Chunk parses a Python source file and returns semantic code chunks.
|
||||
func (c *Chunker) Chunk(ctx context.Context, filePath string) ([]chunking.Chunk, error) {
|
||||
// Read file content
|
||||
content, err := os.ReadFile(filePath)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("read file: %w", err)
|
||||
}
|
||||
|
||||
// Parse the Python file
|
||||
tree, err := c.parser.ParseCtx(ctx, nil, content)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("parse Python file: %w", err)
|
||||
}
|
||||
defer tree.Close()
|
||||
|
||||
sourceLines := strings.Split(string(content), "\n")
|
||||
chunks := make([]chunking.Chunk, 0)
|
||||
|
||||
// Walk the AST and extract chunks
|
||||
c.walkNode(tree.RootNode(), content, sourceLines, filePath, "", &chunks)
|
||||
|
||||
return chunks, nil
|
||||
}
|
||||
|
||||
// walkNode recursively walks the tree-sitter AST and extracts chunks.
|
||||
func (c *Chunker) walkNode(node *sitter.Node, source []byte, sourceLines []string, filePath string, parentName string, chunks *[]chunking.Chunk) {
|
||||
nodeType := node.Type()
|
||||
|
||||
switch nodeType {
|
||||
case "function_definition":
|
||||
chunk := c.extractFunction(node, source, sourceLines, filePath, parentName)
|
||||
if chunk != nil {
|
||||
*chunks = append(*chunks, *chunk)
|
||||
}
|
||||
|
||||
case "class_definition":
|
||||
chunk := c.extractClass(node, source, sourceLines, filePath)
|
||||
if chunk != nil {
|
||||
*chunks = append(*chunks, *chunk)
|
||||
|
||||
// Walk class body to find methods
|
||||
for i := 0; i < int(node.ChildCount()); i++ {
|
||||
child := node.Child(i)
|
||||
if child.Type() == "block" {
|
||||
c.walkNode(child, source, sourceLines, filePath, chunk.Name, chunks)
|
||||
}
|
||||
}
|
||||
}
|
||||
return // Don't walk children again
|
||||
|
||||
case "block":
|
||||
// Walk statements in block
|
||||
for i := 0; i < int(node.ChildCount()); i++ {
|
||||
c.walkNode(node.Child(i), source, sourceLines, filePath, parentName, chunks)
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
// Walk all children
|
||||
for i := 0; i < int(node.ChildCount()); i++ {
|
||||
c.walkNode(node.Child(i), source, sourceLines, filePath, parentName, chunks)
|
||||
}
|
||||
}
|
||||
|
||||
// extractFunction extracts a function definition chunk.
|
||||
func (c *Chunker) extractFunction(node *sitter.Node, source []byte, sourceLines []string, filePath string, parentName string) *chunking.Chunk {
|
||||
// Find function name
|
||||
var nameNode *sitter.Node
|
||||
for i := 0; i < int(node.ChildCount()); i++ {
|
||||
child := node.Child(i)
|
||||
if child.Type() == "identifier" {
|
||||
nameNode = child
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
if nameNode == nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
name := nameNode.Content(source)
|
||||
|
||||
// Skip private functions if configured
|
||||
if !c.options.IncludePrivate && strings.HasPrefix(name, "_") && !strings.HasPrefix(name, "__") {
|
||||
return nil
|
||||
}
|
||||
|
||||
startLine := int(node.StartPoint().Row) + 1
|
||||
endLine := int(node.EndPoint().Row) + 1
|
||||
|
||||
chunk := &chunking.Chunk{
|
||||
FilePath: filePath,
|
||||
Language: chunking.LanguagePython,
|
||||
Name: name,
|
||||
ParentName: parentName,
|
||||
StartLine: startLine,
|
||||
EndLine: endLine,
|
||||
Content: c.extractLines(sourceLines, startLine, endLine),
|
||||
}
|
||||
|
||||
// Determine if this is a method or function
|
||||
if parentName != "" {
|
||||
chunk.Type = chunking.ChunkTypeMethod
|
||||
} else {
|
||||
chunk.Type = chunking.ChunkTypeFunction
|
||||
}
|
||||
|
||||
// Extract signature (def line)
|
||||
chunk.Signature = c.extractFunctionSignature(node, source, sourceLines)
|
||||
|
||||
// Extract docstring as doc comment
|
||||
if c.options.IncludeDocComments {
|
||||
chunk.DocComment = c.extractDocstring(node, source)
|
||||
}
|
||||
|
||||
return chunk
|
||||
}
|
||||
|
||||
// extractClass extracts a class definition chunk.
|
||||
func (c *Chunker) extractClass(node *sitter.Node, source []byte, sourceLines []string, filePath string) *chunking.Chunk {
|
||||
// Find class name
|
||||
var nameNode *sitter.Node
|
||||
for i := 0; i < int(node.ChildCount()); i++ {
|
||||
child := node.Child(i)
|
||||
if child.Type() == "identifier" {
|
||||
nameNode = child
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
if nameNode == nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
name := nameNode.Content(source)
|
||||
|
||||
// Skip private classes if configured
|
||||
if !c.options.IncludePrivate && strings.HasPrefix(name, "_") && !strings.HasPrefix(name, "__") {
|
||||
return nil
|
||||
}
|
||||
|
||||
startLine := int(node.StartPoint().Row) + 1
|
||||
endLine := int(node.EndPoint().Row) + 1
|
||||
|
||||
chunk := &chunking.Chunk{
|
||||
FilePath: filePath,
|
||||
Language: chunking.LanguagePython,
|
||||
Type: chunking.ChunkTypeClass,
|
||||
Name: name,
|
||||
StartLine: startLine,
|
||||
EndLine: endLine,
|
||||
Content: c.extractLines(sourceLines, startLine, endLine),
|
||||
}
|
||||
|
||||
// Extract class signature (class line)
|
||||
chunk.Signature = c.extractClassSignature(node, source, sourceLines)
|
||||
|
||||
// Extract docstring as doc comment
|
||||
if c.options.IncludeDocComments {
|
||||
chunk.DocComment = c.extractDocstring(node, source)
|
||||
}
|
||||
|
||||
return chunk
|
||||
}
|
||||
|
||||
// extractFunctionSignature extracts the function definition line.
|
||||
func (c *Chunker) extractFunctionSignature(node *sitter.Node, source []byte, sourceLines []string) string {
|
||||
startLine := int(node.StartPoint().Row) + 1
|
||||
|
||||
// Find the colon that ends the signature
|
||||
for i := 0; i < int(node.ChildCount()); i++ {
|
||||
child := node.Child(i)
|
||||
if child.Type() == ":" {
|
||||
endLine := int(child.EndPoint().Row) + 1
|
||||
return strings.TrimSpace(c.extractLines(sourceLines, startLine, endLine))
|
||||
}
|
||||
}
|
||||
|
||||
// Fallback: just return first line
|
||||
return strings.TrimSpace(c.extractLines(sourceLines, startLine, startLine))
|
||||
}
|
||||
|
||||
// extractClassSignature extracts the class definition line.
|
||||
func (c *Chunker) extractClassSignature(node *sitter.Node, source []byte, sourceLines []string) string {
|
||||
startLine := int(node.StartPoint().Row) + 1
|
||||
|
||||
// Find the colon that ends the signature
|
||||
for i := 0; i < int(node.ChildCount()); i++ {
|
||||
child := node.Child(i)
|
||||
if child.Type() == ":" {
|
||||
endLine := int(child.EndPoint().Row) + 1
|
||||
return strings.TrimSpace(c.extractLines(sourceLines, startLine, endLine))
|
||||
}
|
||||
}
|
||||
|
||||
// Fallback: just return first line
|
||||
return strings.TrimSpace(c.extractLines(sourceLines, startLine, startLine))
|
||||
}
|
||||
|
||||
// extractDocstring extracts the docstring from a function or class.
|
||||
func (c *Chunker) extractDocstring(node *sitter.Node, source []byte) string {
|
||||
// Find the block
|
||||
var blockNode *sitter.Node
|
||||
for i := 0; i < int(node.ChildCount()); i++ {
|
||||
child := node.Child(i)
|
||||
if child.Type() == "block" {
|
||||
blockNode = child
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
if blockNode == nil {
|
||||
return ""
|
||||
}
|
||||
|
||||
// Check if first statement in block is a string (docstring)
|
||||
for i := 0; i < int(blockNode.ChildCount()); i++ {
|
||||
child := blockNode.Child(i)
|
||||
if child.Type() == "expression_statement" {
|
||||
// Check if it contains a string
|
||||
for j := 0; j < int(child.ChildCount()); j++ {
|
||||
grandchild := child.Child(j)
|
||||
if grandchild.Type() == "string" {
|
||||
docstring := grandchild.Content(source)
|
||||
// Remove quotes
|
||||
docstring = strings.Trim(docstring, `"'`)
|
||||
return strings.TrimSpace(docstring)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return ""
|
||||
}
|
||||
|
||||
// extractLines extracts a range of lines from source (1-indexed, inclusive).
|
||||
func (c *Chunker) extractLines(lines []string, start, end int) string {
|
||||
if start < 1 || end < start || start > len(lines) {
|
||||
return ""
|
||||
}
|
||||
|
||||
startIdx := start - 1
|
||||
endIdx := end
|
||||
if endIdx > len(lines) {
|
||||
endIdx = len(lines)
|
||||
}
|
||||
|
||||
return strings.Join(lines[startIdx:endIdx], "\n")
|
||||
}
|
||||
@@ -0,0 +1,140 @@
|
||||
// Package chunking provides AST-aware code chunking for semantic code search.
|
||||
// Chunks code files into logical units (functions, classes, methods) that preserve
|
||||
// semantic boundaries for better vector embedding and retrieval.
|
||||
package chunking
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"strings"
|
||||
)
|
||||
|
||||
// ChunkType represents the type of code chunk.
|
||||
type ChunkType string
|
||||
|
||||
const (
|
||||
// ChunkTypeFunction represents a standalone function.
|
||||
ChunkTypeFunction ChunkType = "function"
|
||||
// ChunkTypeMethod represents a method on a class/struct/type.
|
||||
ChunkTypeMethod ChunkType = "method"
|
||||
// ChunkTypeClass represents a class or struct definition.
|
||||
ChunkTypeClass ChunkType = "class"
|
||||
// ChunkTypeInterface represents an interface definition.
|
||||
ChunkTypeInterface ChunkType = "interface"
|
||||
// ChunkTypeType represents a type alias or type definition.
|
||||
ChunkTypeType ChunkType = "type"
|
||||
// ChunkTypeConst represents constant declarations.
|
||||
ChunkTypeConst ChunkType = "const"
|
||||
// ChunkTypeVar represents variable declarations.
|
||||
ChunkTypeVar ChunkType = "var"
|
||||
)
|
||||
|
||||
// Language represents a programming language.
|
||||
type Language string
|
||||
|
||||
const (
|
||||
// LanguageGo represents the Go programming language.
|
||||
LanguageGo Language = "go"
|
||||
// LanguagePython represents the Python programming language.
|
||||
LanguagePython Language = "python"
|
||||
// LanguageTypeScript represents the TypeScript programming language.
|
||||
LanguageTypeScript Language = "typescript"
|
||||
// LanguageJavaScript represents the JavaScript programming language.
|
||||
LanguageJavaScript Language = "javascript"
|
||||
)
|
||||
|
||||
// Chunk represents a semantic code chunk with AST-derived boundaries.
|
||||
type Chunk struct {
|
||||
Metadata map[string]interface{}
|
||||
FilePath string
|
||||
Language Language
|
||||
Type ChunkType
|
||||
Name string
|
||||
ParentName string
|
||||
Content string
|
||||
Signature string
|
||||
DocComment string
|
||||
StartLine int
|
||||
EndLine int
|
||||
}
|
||||
|
||||
// Identifier returns a human-readable identifier for this chunk.
|
||||
// Format: "ParentName.Name" for methods, "Name" for top-level.
|
||||
func (c *Chunk) Identifier() string {
|
||||
if c.ParentName != "" {
|
||||
return fmt.Sprintf("%s.%s", c.ParentName, c.Name)
|
||||
}
|
||||
return c.Name
|
||||
}
|
||||
|
||||
// LineRange returns a human-readable line range.
|
||||
// Format: "L123-L456"
|
||||
func (c *Chunk) LineRange() string {
|
||||
return fmt.Sprintf("L%d-L%d", c.StartLine, c.EndLine)
|
||||
}
|
||||
|
||||
// SearchableContent returns content optimized for semantic search.
|
||||
// Combines signature, doc comment, and content in a structured format.
|
||||
func (c *Chunk) SearchableContent() string {
|
||||
var parts []string
|
||||
|
||||
// Include signature for functions/methods
|
||||
if c.Signature != "" {
|
||||
parts = append(parts, c.Signature)
|
||||
}
|
||||
|
||||
// Include doc comment
|
||||
if c.DocComment != "" {
|
||||
parts = append(parts, c.DocComment)
|
||||
}
|
||||
|
||||
// Include actual content
|
||||
if c.Content != "" {
|
||||
parts = append(parts, c.Content)
|
||||
}
|
||||
|
||||
return strings.Join(parts, "\n\n")
|
||||
}
|
||||
|
||||
// Chunker is the interface for language-specific code chunkers.
|
||||
type Chunker interface {
|
||||
// Chunk parses a source file and returns semantic code chunks.
|
||||
// Returns an error if the file cannot be parsed or read.
|
||||
Chunk(ctx context.Context, filePath string) ([]Chunk, error)
|
||||
|
||||
// Language returns the language this chunker supports.
|
||||
Language() Language
|
||||
|
||||
// SupportedExtensions returns file extensions this chunker handles.
|
||||
// Example: []string{".go"} for Go chunker
|
||||
SupportedExtensions() []string
|
||||
}
|
||||
|
||||
// ChunkOptions provides options for chunking behavior.
|
||||
type ChunkOptions struct {
|
||||
// MaxChunkSize is the maximum size of a chunk in bytes.
|
||||
// Chunks larger than this will be split (respecting boundaries where possible).
|
||||
// 0 means no limit.
|
||||
MaxChunkSize int
|
||||
|
||||
// IncludeDocComments controls whether to include documentation comments.
|
||||
IncludeDocComments bool
|
||||
|
||||
// IncludePrivate controls whether to include private/unexported symbols.
|
||||
IncludePrivate bool
|
||||
|
||||
// MinLines is the minimum number of lines for a chunk to be included.
|
||||
// Chunks smaller than this will be skipped.
|
||||
// 0 means no minimum.
|
||||
MinLines int
|
||||
}
|
||||
|
||||
// DefaultChunkOptions returns sensible default options.
|
||||
func DefaultChunkOptions() ChunkOptions {
|
||||
return ChunkOptions{
|
||||
MaxChunkSize: 8192, // ~8KB per chunk (well under token limit)
|
||||
IncludeDocComments: true,
|
||||
IncludePrivate: true, // Include all symbols for comprehensive search
|
||||
MinLines: 0, // No minimum - include even single-line functions
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,403 @@
|
||||
// Package typescript provides AST-aware chunking for TypeScript and JavaScript source files using tree-sitter.
|
||||
package typescript
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"os"
|
||||
"strings"
|
||||
|
||||
sitter "github.com/smacker/go-tree-sitter"
|
||||
"github.com/smacker/go-tree-sitter/typescript/typescript"
|
||||
|
||||
"github.com/lukaszraczylo/claude-mnemonic/internal/chunking"
|
||||
)
|
||||
|
||||
// Chunker implements AST-aware chunking for TypeScript/JavaScript files.
|
||||
type Chunker struct {
|
||||
parser *sitter.Parser
|
||||
options chunking.ChunkOptions
|
||||
}
|
||||
|
||||
// NewChunker creates a new TypeScript chunker.
|
||||
func NewChunker(options chunking.ChunkOptions) *Chunker {
|
||||
parser := sitter.NewParser()
|
||||
parser.SetLanguage(typescript.GetLanguage())
|
||||
|
||||
return &Chunker{
|
||||
options: options,
|
||||
parser: parser,
|
||||
}
|
||||
}
|
||||
|
||||
// Language returns the language this chunker supports.
|
||||
func (c *Chunker) Language() chunking.Language {
|
||||
return chunking.LanguageTypeScript
|
||||
}
|
||||
|
||||
// SupportedExtensions returns the file extensions this chunker handles.
|
||||
func (c *Chunker) SupportedExtensions() []string {
|
||||
return []string{".ts", ".tsx", ".js", ".jsx"}
|
||||
}
|
||||
|
||||
// Chunk parses a TypeScript/JavaScript source file and returns semantic code chunks.
|
||||
func (c *Chunker) Chunk(ctx context.Context, filePath string) ([]chunking.Chunk, error) {
|
||||
// Read file content
|
||||
content, err := os.ReadFile(filePath)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("read file: %w", err)
|
||||
}
|
||||
|
||||
// Parse the file
|
||||
tree, err := c.parser.ParseCtx(ctx, nil, content)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("parse TypeScript file: %w", err)
|
||||
}
|
||||
defer tree.Close()
|
||||
|
||||
sourceLines := strings.Split(string(content), "\n")
|
||||
chunks := make([]chunking.Chunk, 0)
|
||||
|
||||
// Walk the AST and extract chunks
|
||||
c.walkNode(tree.RootNode(), content, sourceLines, filePath, "", &chunks)
|
||||
|
||||
return chunks, nil
|
||||
}
|
||||
|
||||
// walkNode recursively walks the tree-sitter AST and extracts chunks.
|
||||
func (c *Chunker) walkNode(node *sitter.Node, source []byte, sourceLines []string, filePath string, parentName string, chunks *[]chunking.Chunk) {
|
||||
nodeType := node.Type()
|
||||
|
||||
switch nodeType {
|
||||
case "function_declaration":
|
||||
chunk := c.extractFunction(node, source, sourceLines, filePath, parentName)
|
||||
if chunk != nil {
|
||||
*chunks = append(*chunks, *chunk)
|
||||
}
|
||||
|
||||
case "method_definition":
|
||||
chunk := c.extractMethod(node, source, sourceLines, filePath, parentName)
|
||||
if chunk != nil {
|
||||
*chunks = append(*chunks, *chunk)
|
||||
}
|
||||
|
||||
case "arrow_function", "function_expression":
|
||||
// Handle arrow functions and function expressions assigned to variables
|
||||
chunk := c.extractFunctionExpression(node, source, sourceLines, filePath, parentName)
|
||||
if chunk != nil {
|
||||
*chunks = append(*chunks, *chunk)
|
||||
}
|
||||
|
||||
case "class_declaration":
|
||||
chunk := c.extractClass(node, source, sourceLines, filePath)
|
||||
if chunk != nil {
|
||||
*chunks = append(*chunks, *chunk)
|
||||
|
||||
// Walk class body to find methods
|
||||
for i := 0; i < int(node.ChildCount()); i++ {
|
||||
child := node.Child(i)
|
||||
if child.Type() == "class_body" {
|
||||
c.walkNode(child, source, sourceLines, filePath, chunk.Name, chunks)
|
||||
}
|
||||
}
|
||||
}
|
||||
return // Don't walk children again
|
||||
|
||||
case "interface_declaration":
|
||||
chunk := c.extractInterface(node, source, sourceLines, filePath)
|
||||
if chunk != nil {
|
||||
*chunks = append(*chunks, *chunk)
|
||||
}
|
||||
|
||||
case "type_alias_declaration":
|
||||
chunk := c.extractTypeAlias(node, source, sourceLines, filePath)
|
||||
if chunk != nil {
|
||||
*chunks = append(*chunks, *chunk)
|
||||
}
|
||||
}
|
||||
|
||||
// Walk all children
|
||||
for i := 0; i < int(node.ChildCount()); i++ {
|
||||
c.walkNode(node.Child(i), source, sourceLines, filePath, parentName, chunks)
|
||||
}
|
||||
}
|
||||
|
||||
// extractFunction extracts a function declaration.
|
||||
func (c *Chunker) extractFunction(node *sitter.Node, source []byte, sourceLines []string, filePath string, parentName string) *chunking.Chunk {
|
||||
name := c.findChildContent(node, "identifier", source)
|
||||
if name == "" {
|
||||
return nil
|
||||
}
|
||||
|
||||
startLine := int(node.StartPoint().Row) + 1
|
||||
endLine := int(node.EndPoint().Row) + 1
|
||||
|
||||
chunk := &chunking.Chunk{
|
||||
FilePath: filePath,
|
||||
Language: chunking.LanguageTypeScript,
|
||||
Type: chunking.ChunkTypeFunction,
|
||||
Name: name,
|
||||
ParentName: parentName,
|
||||
StartLine: startLine,
|
||||
EndLine: endLine,
|
||||
Content: c.extractLines(sourceLines, startLine, endLine),
|
||||
Signature: c.extractFunctionSignature(node, source, sourceLines),
|
||||
}
|
||||
|
||||
// Extract JSDoc comment
|
||||
if c.options.IncludeDocComments {
|
||||
chunk.DocComment = c.extractComment(node, source)
|
||||
}
|
||||
|
||||
return chunk
|
||||
}
|
||||
|
||||
// extractMethod extracts a method definition from a class.
|
||||
func (c *Chunker) extractMethod(node *sitter.Node, source []byte, sourceLines []string, filePath string, parentName string) *chunking.Chunk {
|
||||
name := c.findChildContent(node, "property_identifier", source)
|
||||
if name == "" {
|
||||
return nil
|
||||
}
|
||||
|
||||
// Skip private methods if configured
|
||||
if !c.options.IncludePrivate && strings.HasPrefix(name, "_") {
|
||||
return nil
|
||||
}
|
||||
|
||||
startLine := int(node.StartPoint().Row) + 1
|
||||
endLine := int(node.EndPoint().Row) + 1
|
||||
|
||||
chunk := &chunking.Chunk{
|
||||
FilePath: filePath,
|
||||
Language: chunking.LanguageTypeScript,
|
||||
Type: chunking.ChunkTypeMethod,
|
||||
Name: name,
|
||||
ParentName: parentName,
|
||||
StartLine: startLine,
|
||||
EndLine: endLine,
|
||||
Content: c.extractLines(sourceLines, startLine, endLine),
|
||||
Signature: c.extractMethodSignature(node, source, sourceLines),
|
||||
}
|
||||
|
||||
// Extract JSDoc comment
|
||||
if c.options.IncludeDocComments {
|
||||
chunk.DocComment = c.extractComment(node, source)
|
||||
}
|
||||
|
||||
return chunk
|
||||
}
|
||||
|
||||
// extractFunctionExpression extracts arrow functions and function expressions.
|
||||
func (c *Chunker) extractFunctionExpression(node *sitter.Node, source []byte, sourceLines []string, filePath string, parentName string) *chunking.Chunk {
|
||||
// Try to find the variable name from parent
|
||||
parent := node.Parent()
|
||||
if parent == nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
var name string
|
||||
if parent.Type() == "variable_declarator" {
|
||||
name = c.findChildContent(parent, "identifier", source)
|
||||
} else if parent.Type() == "assignment_expression" {
|
||||
// Handle const foo = () => {}
|
||||
for i := 0; i < int(parent.ChildCount()); i++ {
|
||||
child := parent.Child(i)
|
||||
if child.Type() == "identifier" || child.Type() == "member_expression" {
|
||||
name = child.Content(source)
|
||||
break
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if name == "" {
|
||||
return nil // Anonymous function, skip
|
||||
}
|
||||
|
||||
startLine := int(node.StartPoint().Row) + 1
|
||||
endLine := int(node.EndPoint().Row) + 1
|
||||
|
||||
chunk := &chunking.Chunk{
|
||||
FilePath: filePath,
|
||||
Language: chunking.LanguageTypeScript,
|
||||
Type: chunking.ChunkTypeFunction,
|
||||
Name: name,
|
||||
ParentName: parentName,
|
||||
StartLine: startLine,
|
||||
EndLine: endLine,
|
||||
Content: c.extractLines(sourceLines, startLine, endLine),
|
||||
}
|
||||
|
||||
return chunk
|
||||
}
|
||||
|
||||
// extractClass extracts a class declaration.
|
||||
func (c *Chunker) extractClass(node *sitter.Node, source []byte, sourceLines []string, filePath string) *chunking.Chunk {
|
||||
name := c.findChildContent(node, "type_identifier", source)
|
||||
if name == "" {
|
||||
return nil
|
||||
}
|
||||
|
||||
startLine := int(node.StartPoint().Row) + 1
|
||||
endLine := int(node.EndPoint().Row) + 1
|
||||
|
||||
chunk := &chunking.Chunk{
|
||||
FilePath: filePath,
|
||||
Language: chunking.LanguageTypeScript,
|
||||
Type: chunking.ChunkTypeClass,
|
||||
Name: name,
|
||||
StartLine: startLine,
|
||||
EndLine: endLine,
|
||||
Content: c.extractLines(sourceLines, startLine, endLine),
|
||||
Signature: c.extractClassSignature(node, source, sourceLines),
|
||||
}
|
||||
|
||||
// Extract JSDoc comment
|
||||
if c.options.IncludeDocComments {
|
||||
chunk.DocComment = c.extractComment(node, source)
|
||||
}
|
||||
|
||||
return chunk
|
||||
}
|
||||
|
||||
// extractInterface extracts an interface declaration.
|
||||
func (c *Chunker) extractInterface(node *sitter.Node, source []byte, sourceLines []string, filePath string) *chunking.Chunk {
|
||||
name := c.findChildContent(node, "type_identifier", source)
|
||||
if name == "" {
|
||||
return nil
|
||||
}
|
||||
|
||||
startLine := int(node.StartPoint().Row) + 1
|
||||
endLine := int(node.EndPoint().Row) + 1
|
||||
|
||||
chunk := &chunking.Chunk{
|
||||
FilePath: filePath,
|
||||
Language: chunking.LanguageTypeScript,
|
||||
Type: chunking.ChunkTypeInterface,
|
||||
Name: name,
|
||||
StartLine: startLine,
|
||||
EndLine: endLine,
|
||||
Content: c.extractLines(sourceLines, startLine, endLine),
|
||||
}
|
||||
|
||||
// Extract JSDoc comment
|
||||
if c.options.IncludeDocComments {
|
||||
chunk.DocComment = c.extractComment(node, source)
|
||||
}
|
||||
|
||||
return chunk
|
||||
}
|
||||
|
||||
// extractTypeAlias extracts a type alias declaration.
|
||||
func (c *Chunker) extractTypeAlias(node *sitter.Node, source []byte, sourceLines []string, filePath string) *chunking.Chunk {
|
||||
name := c.findChildContent(node, "type_identifier", source)
|
||||
if name == "" {
|
||||
return nil
|
||||
}
|
||||
|
||||
startLine := int(node.StartPoint().Row) + 1
|
||||
endLine := int(node.EndPoint().Row) + 1
|
||||
|
||||
chunk := &chunking.Chunk{
|
||||
FilePath: filePath,
|
||||
Language: chunking.LanguageTypeScript,
|
||||
Type: chunking.ChunkTypeType,
|
||||
Name: name,
|
||||
StartLine: startLine,
|
||||
EndLine: endLine,
|
||||
Content: c.extractLines(sourceLines, startLine, endLine),
|
||||
}
|
||||
|
||||
return chunk
|
||||
}
|
||||
|
||||
// findChildContent finds the first child of the given type and returns its content.
|
||||
func (c *Chunker) findChildContent(node *sitter.Node, childType string, source []byte) string {
|
||||
for i := 0; i < int(node.ChildCount()); i++ {
|
||||
child := node.Child(i)
|
||||
if child.Type() == childType {
|
||||
return child.Content(source)
|
||||
}
|
||||
}
|
||||
return ""
|
||||
}
|
||||
|
||||
// extractFunctionSignature extracts the function signature.
|
||||
func (c *Chunker) extractFunctionSignature(node *sitter.Node, source []byte, sourceLines []string) string {
|
||||
startLine := int(node.StartPoint().Row) + 1
|
||||
|
||||
// Find the opening brace of the body
|
||||
for i := 0; i < int(node.ChildCount()); i++ {
|
||||
child := node.Child(i)
|
||||
if child.Type() == "statement_block" {
|
||||
endLine := int(child.StartPoint().Row) + 1
|
||||
return strings.TrimSpace(c.extractLines(sourceLines, startLine, endLine-1))
|
||||
}
|
||||
}
|
||||
|
||||
// Fallback: just return first line
|
||||
return strings.TrimSpace(c.extractLines(sourceLines, startLine, startLine))
|
||||
}
|
||||
|
||||
// extractMethodSignature extracts the method signature.
|
||||
func (c *Chunker) extractMethodSignature(node *sitter.Node, source []byte, sourceLines []string) string {
|
||||
startLine := int(node.StartPoint().Row) + 1
|
||||
|
||||
// Find the opening brace of the body
|
||||
for i := 0; i < int(node.ChildCount()); i++ {
|
||||
child := node.Child(i)
|
||||
if child.Type() == "statement_block" {
|
||||
endLine := int(child.StartPoint().Row) + 1
|
||||
return strings.TrimSpace(c.extractLines(sourceLines, startLine, endLine-1))
|
||||
}
|
||||
}
|
||||
|
||||
return strings.TrimSpace(c.extractLines(sourceLines, startLine, startLine))
|
||||
}
|
||||
|
||||
// extractClassSignature extracts the class declaration line.
|
||||
func (c *Chunker) extractClassSignature(node *sitter.Node, source []byte, sourceLines []string) string {
|
||||
startLine := int(node.StartPoint().Row) + 1
|
||||
|
||||
// Find the opening brace of the class body
|
||||
for i := 0; i < int(node.ChildCount()); i++ {
|
||||
child := node.Child(i)
|
||||
if child.Type() == "class_body" {
|
||||
endLine := int(child.StartPoint().Row) + 1
|
||||
return strings.TrimSpace(c.extractLines(sourceLines, startLine, endLine-1))
|
||||
}
|
||||
}
|
||||
|
||||
return strings.TrimSpace(c.extractLines(sourceLines, startLine, startLine))
|
||||
}
|
||||
|
||||
// extractComment extracts JSDoc or other comments from a node.
|
||||
func (c *Chunker) extractComment(node *sitter.Node, source []byte) string {
|
||||
// Check previous sibling for comment
|
||||
prevSibling := node.PrevSibling()
|
||||
if prevSibling != nil && prevSibling.Type() == "comment" {
|
||||
comment := prevSibling.Content(source)
|
||||
// Remove comment markers
|
||||
comment = strings.TrimPrefix(comment, "/**")
|
||||
comment = strings.TrimPrefix(comment, "/*")
|
||||
comment = strings.TrimSuffix(comment, "*/")
|
||||
comment = strings.TrimPrefix(comment, "//")
|
||||
return strings.TrimSpace(comment)
|
||||
}
|
||||
|
||||
return ""
|
||||
}
|
||||
|
||||
// extractLines extracts a range of lines from source (1-indexed, inclusive).
|
||||
func (c *Chunker) extractLines(lines []string, start, end int) string {
|
||||
if start < 1 || end < start || start > len(lines) {
|
||||
return ""
|
||||
}
|
||||
|
||||
startIdx := start - 1
|
||||
endIdx := end
|
||||
if endIdx > len(lines) {
|
||||
endIdx = len(lines)
|
||||
}
|
||||
|
||||
return strings.Join(lines[startIdx:endIdx], "\n")
|
||||
}
|
||||
+62
-35
@@ -36,41 +36,38 @@ var CriticalConcepts = []string{
|
||||
}
|
||||
|
||||
// Config holds the application configuration.
|
||||
// Field order optimized for memory alignment (fieldalignment).
|
||||
type Config struct {
|
||||
// Worker settings
|
||||
WorkerPort int `json:"worker_port"`
|
||||
|
||||
// Database settings
|
||||
DBPath string `json:"db_path"`
|
||||
MaxConns int `json:"max_conns"`
|
||||
|
||||
// SDK Agent settings
|
||||
Model string `json:"model"`
|
||||
ClaudeCodePath string `json:"claude_code_path"`
|
||||
|
||||
// Embedding settings
|
||||
EmbeddingModel string `json:"embedding_model"` // e.g., "bge-v1.5"
|
||||
|
||||
// Reranking settings (cross-encoder)
|
||||
RerankingEnabled bool `json:"reranking_enabled"` // Enable cross-encoder reranking
|
||||
RerankingCandidates int `json:"reranking_candidates"` // Number of candidates to retrieve before reranking (default 100)
|
||||
RerankingResults int `json:"reranking_results"` // Number of results to return after reranking (default 10)
|
||||
RerankingAlpha float64 `json:"reranking_alpha"` // Weight for combining scores: alpha*rerank + (1-alpha)*original (default 0.7)
|
||||
RerankingMinImprovement float64 `json:"reranking_min_improvement"` // Minimum rank improvement to trigger reranking (default 0, always rerank)
|
||||
RerankingPureMode bool `json:"reranking_pure_mode"` // Use pure cross-encoder scores without combining with bi-encoder (default false)
|
||||
|
||||
// Context injection settings
|
||||
ContextFullField string `json:"context_full_field"`
|
||||
DBPath string `json:"db_path"`
|
||||
Model string `json:"model"`
|
||||
ClaudeCodePath string `json:"claude_code_path"`
|
||||
EmbeddingModel string `json:"embedding_model"`
|
||||
VectorStorageStrategy string `json:"vector_storage_strategy"`
|
||||
ContextObsConcepts []string `json:"context_obs_concepts"`
|
||||
ContextObsTypes []string `json:"context_obs_types"`
|
||||
ContextMaxPromptResults int `json:"context_max_prompt_results"`
|
||||
RerankingResults int `json:"reranking_results"`
|
||||
GraphEdgeWeight float64 `json:"graph_edge_weight"`
|
||||
ContextRelevanceThreshold float64 `json:"context_relevance_threshold"`
|
||||
RerankingCandidates int `json:"reranking_candidates"`
|
||||
WorkerPort int `json:"worker_port"`
|
||||
RerankingMinImprovement float64 `json:"reranking_min_improvement"`
|
||||
ContextObservations int `json:"context_observations"`
|
||||
ContextFullCount int `json:"context_full_count"`
|
||||
ContextSessionCount int `json:"context_session_count"`
|
||||
ContextShowReadTokens bool `json:"context_show_read_tokens"`
|
||||
ContextShowWorkTokens bool `json:"context_show_work_tokens"`
|
||||
ContextFullField string `json:"context_full_field"`
|
||||
MaxConns int `json:"max_conns"`
|
||||
RerankingAlpha float64 `json:"reranking_alpha"`
|
||||
GraphMaxHops int `json:"graph_max_hops"`
|
||||
GraphBranchFactor int `json:"graph_branch_factor"`
|
||||
GraphRebuildIntervalMin int `json:"graph_rebuild_interval_min"`
|
||||
HubThreshold int `json:"hub_threshold"`
|
||||
ContextShowLastSummary bool `json:"context_show_last_summary"`
|
||||
ContextObsTypes []string `json:"context_obs_types"`
|
||||
ContextObsConcepts []string `json:"context_obs_concepts"`
|
||||
ContextRelevanceThreshold float64 `json:"context_relevance_threshold"` // 0.0-1.0, minimum similarity for inclusion
|
||||
ContextMaxPromptResults int `json:"context_max_prompt_results"` // Max results per prompt (0 = threshold only)
|
||||
RerankingEnabled bool `json:"reranking_enabled"`
|
||||
ContextShowWorkTokens bool `json:"context_show_work_tokens"`
|
||||
ContextShowReadTokens bool `json:"context_show_read_tokens"`
|
||||
RerankingPureMode bool `json:"reranking_pure_mode"`
|
||||
GraphEnabled bool `json:"graph_enabled"`
|
||||
}
|
||||
|
||||
var (
|
||||
@@ -143,11 +140,18 @@ func Default() *Config {
|
||||
MaxConns: 4,
|
||||
Model: DefaultModel,
|
||||
EmbeddingModel: DefaultEmbeddingModel,
|
||||
RerankingEnabled: true, // Enable by default for improved relevance
|
||||
RerankingCandidates: 100, // Retrieve top 100 candidates
|
||||
RerankingResults: 10, // Return top 10 after reranking
|
||||
RerankingAlpha: 0.7, // Favor cross-encoder score
|
||||
RerankingMinImprovement: 0, // Always apply reranking
|
||||
RerankingEnabled: true, // Enable by default for improved relevance
|
||||
RerankingCandidates: 100, // Retrieve top 100 candidates
|
||||
RerankingResults: 10, // Return top 10 after reranking
|
||||
RerankingAlpha: 0.7, // Favor cross-encoder score
|
||||
RerankingMinImprovement: 0, // Always apply reranking
|
||||
GraphEnabled: true, // Enable graph-aware search by default
|
||||
GraphMaxHops: 2, // Two-hop traversal
|
||||
GraphBranchFactor: 5, // Expand top 5 neighbors per node
|
||||
GraphEdgeWeight: 0.3, // Minimum edge weight to follow
|
||||
GraphRebuildIntervalMin: 60, // Rebuild graph every 60 minutes
|
||||
VectorStorageStrategy: "hub", // Hub storage strategy (LEANN-inspired)
|
||||
HubThreshold: 5, // Require 5+ accesses to store embedding
|
||||
ContextObservations: 100,
|
||||
ContextFullCount: 25,
|
||||
ContextSessionCount: 10,
|
||||
@@ -233,6 +237,29 @@ func Load() (*Config, error) {
|
||||
if v, ok := settings["CLAUDE_MNEMONIC_CONTEXT_MAX_PROMPT_RESULTS"].(float64); ok && v >= 0 {
|
||||
cfg.ContextMaxPromptResults = int(v)
|
||||
}
|
||||
// Graph settings
|
||||
if v, ok := settings["CLAUDE_MNEMONIC_GRAPH_ENABLED"].(bool); ok {
|
||||
cfg.GraphEnabled = v
|
||||
}
|
||||
if v, ok := settings["CLAUDE_MNEMONIC_GRAPH_MAX_HOPS"].(float64); ok && v > 0 {
|
||||
cfg.GraphMaxHops = int(v)
|
||||
}
|
||||
if v, ok := settings["CLAUDE_MNEMONIC_GRAPH_BRANCH_FACTOR"].(float64); ok && v > 0 {
|
||||
cfg.GraphBranchFactor = int(v)
|
||||
}
|
||||
if v, ok := settings["CLAUDE_MNEMONIC_GRAPH_EDGE_WEIGHT"].(float64); ok && v >= 0 && v <= 1 {
|
||||
cfg.GraphEdgeWeight = v
|
||||
}
|
||||
if v, ok := settings["CLAUDE_MNEMONIC_GRAPH_REBUILD_INTERVAL_MIN"].(float64); ok && v > 0 {
|
||||
cfg.GraphRebuildIntervalMin = int(v)
|
||||
}
|
||||
// Vector storage settings (LEANN Phase 2)
|
||||
if v, ok := settings["CLAUDE_MNEMONIC_VECTOR_STORAGE_STRATEGY"].(string); ok && v != "" {
|
||||
cfg.VectorStorageStrategy = v
|
||||
}
|
||||
if v, ok := settings["CLAUDE_MNEMONIC_HUB_THRESHOLD"].(float64); ok && v > 0 {
|
||||
cfg.HubThreshold = int(v)
|
||||
}
|
||||
|
||||
return cfg, nil
|
||||
}
|
||||
|
||||
@@ -121,8 +121,8 @@ func (s *ConfigSuite) TestLoad_TableDriven() {
|
||||
tests := []struct {
|
||||
name string
|
||||
settingsJSON string
|
||||
expectedPort int
|
||||
expectedModel string
|
||||
expectedPort int
|
||||
expectedObsObs int
|
||||
}{
|
||||
{
|
||||
@@ -183,12 +183,12 @@ func (s *ConfigSuite) TestLoad_TableDriven() {
|
||||
s.Require().NoError(err)
|
||||
|
||||
if tt.settingsJSON != "" {
|
||||
err := os.WriteFile(
|
||||
writeErr := os.WriteFile(
|
||||
filepath.Join(tempDir, ".claude-mnemonic", "settings.json"),
|
||||
[]byte(tt.settingsJSON),
|
||||
0600,
|
||||
)
|
||||
s.Require().NoError(err)
|
||||
s.Require().NoError(writeErr)
|
||||
}
|
||||
|
||||
cfg, err := Load()
|
||||
|
||||
@@ -214,9 +214,9 @@ func (s *ConflictStore) CleanupSupersededObservations(ctx context.Context, proje
|
||||
// GetConflictsWithDetails retrieves all conflicts with observation titles for display.
|
||||
func (s *ConflictStore) GetConflictsWithDetails(ctx context.Context, project string, limit int) ([]*ConflictWithDetails, error) {
|
||||
var results []struct {
|
||||
ObservationConflict
|
||||
NewerTitle sql.NullString `gorm:"column:newer_title"`
|
||||
OlderTitle sql.NullString `gorm:"column:older_title"`
|
||||
ObservationConflict
|
||||
}
|
||||
|
||||
err := s.db.WithContext(ctx).
|
||||
|
||||
+60
-70
@@ -17,18 +17,18 @@ import (
|
||||
|
||||
// SDKSession represents a Claude Code session.
|
||||
type SDKSession struct {
|
||||
ID int64 `gorm:"primaryKey;autoIncrement"`
|
||||
ClaudeSessionID string `gorm:"uniqueIndex;not null"`
|
||||
SDKSessionID sql.NullString `gorm:"uniqueIndex"`
|
||||
Project string `gorm:"index;not null"`
|
||||
Status string `gorm:"type:text;check:status IN ('active', 'completed', 'failed');default:'active';index"`
|
||||
StartedAt string `gorm:"not null"`
|
||||
SDKSessionID sql.NullString `gorm:"uniqueIndex"`
|
||||
UserPrompt sql.NullString
|
||||
WorkerPort sql.NullInt64
|
||||
PromptCounter int `gorm:"default:0"`
|
||||
Status string `gorm:"type:text;check:status IN ('active', 'completed', 'failed');default:'active';index"`
|
||||
StartedAt string `gorm:"not null"`
|
||||
StartedAtEpoch int64 `gorm:"index:idx_sessions_started,sort:desc;not null"`
|
||||
CompletedAt sql.NullString
|
||||
WorkerPort sql.NullInt64
|
||||
CompletedAtEpoch sql.NullInt64
|
||||
ID int64 `gorm:"primaryKey;autoIncrement"`
|
||||
PromptCounter int `gorm:"default:0"`
|
||||
StartedAtEpoch int64 `gorm:"index:idx_sessions_started,sort:desc;not null"`
|
||||
}
|
||||
|
||||
func (SDKSession) TableName() string { return "sdk_sessions" }
|
||||
@@ -46,34 +46,28 @@ func (s *SDKSession) BeforeCreate(tx *gorm.DB) error {
|
||||
|
||||
// Observation represents a stored observation (learning).
|
||||
type Observation struct {
|
||||
ID int64 `gorm:"primaryKey;autoIncrement"`
|
||||
SDKSessionID string `gorm:"index;not null"`
|
||||
Project string `gorm:"index;not null"`
|
||||
Scope models.ObservationScope `gorm:"type:text;default:'project';check:scope IN ('project', 'global');index:idx_observations_scope;index:idx_observations_project_scope,priority:2"`
|
||||
Type models.ObservationType `gorm:"type:text;check:type IN ('decision', 'bugfix', 'feature', 'refactor', 'discovery', 'change');index;not null"`
|
||||
|
||||
// Content fields
|
||||
Title sql.NullString `gorm:"type:text"`
|
||||
Subtitle sql.NullString `gorm:"type:text"`
|
||||
Facts models.JSONStringArray `gorm:"type:text"` // JSON array
|
||||
Narrative sql.NullString `gorm:"type:text"`
|
||||
Concepts models.JSONStringArray `gorm:"type:text"` // JSON array
|
||||
FilesRead models.JSONStringArray `gorm:"type:text"` // JSON array
|
||||
FilesModified models.JSONStringArray `gorm:"type:text"` // JSON array
|
||||
FileMtimes models.JSONInt64Map `gorm:"type:text"` // JSON object
|
||||
|
||||
// Metadata
|
||||
FileMtimes models.JSONInt64Map `gorm:"type:text"`
|
||||
SDKSessionID string `gorm:"index;not null"`
|
||||
Project string `gorm:"index;not null"`
|
||||
Scope models.ObservationScope `gorm:"type:text;default:'project';check:scope IN ('project', 'global');index:idx_observations_scope;index:idx_observations_project_scope,priority:2"`
|
||||
Type models.ObservationType `gorm:"type:text;check:type IN ('decision', 'bugfix', 'feature', 'refactor', 'discovery', 'change');index;not null"`
|
||||
CreatedAt string `gorm:"not null"`
|
||||
Title sql.NullString `gorm:"type:text"`
|
||||
Narrative sql.NullString `gorm:"type:text"`
|
||||
Concepts models.JSONStringArray `gorm:"type:text"`
|
||||
FilesRead models.JSONStringArray `gorm:"type:text"`
|
||||
FilesModified models.JSONStringArray `gorm:"type:text"`
|
||||
Subtitle sql.NullString `gorm:"type:text"`
|
||||
Facts models.JSONStringArray `gorm:"type:text"`
|
||||
LastRetrievedAt sql.NullInt64 `gorm:"column:last_retrieved_at_epoch"`
|
||||
PromptNumber sql.NullInt64
|
||||
DiscoveryTokens int64 `gorm:"default:0"`
|
||||
CreatedAt string `gorm:"not null"`
|
||||
CreatedAtEpoch int64 `gorm:"index:idx_observations_created,sort:desc;not null"`
|
||||
|
||||
// Importance scoring fields
|
||||
ScoreUpdatedAt sql.NullInt64 `gorm:"column:score_updated_at_epoch;index:idx_observations_score_updated"`
|
||||
ID int64 `gorm:"primaryKey;autoIncrement"`
|
||||
ImportanceScore float64 `gorm:"type:real;default:1.0;index:idx_observations_importance,priority:1,sort:desc"`
|
||||
UserFeedback int `gorm:"default:0"`
|
||||
RetrievalCount int `gorm:"default:0"`
|
||||
LastRetrievedAt sql.NullInt64 `gorm:"column:last_retrieved_at_epoch"`
|
||||
ScoreUpdatedAt sql.NullInt64 `gorm:"column:score_updated_at_epoch;index:idx_observations_score_updated"`
|
||||
CreatedAtEpoch int64 `gorm:"index:idx_observations_created,sort:desc;not null"`
|
||||
DiscoveryTokens int64 `gorm:"default:0"`
|
||||
IsSuperseded int `gorm:"default:0;index:idx_observations_superseded,priority:1"`
|
||||
}
|
||||
|
||||
@@ -95,23 +89,19 @@ func (o *Observation) BeforeCreate(tx *gorm.DB) error {
|
||||
|
||||
// SessionSummary represents a session summary.
|
||||
type SessionSummary struct {
|
||||
ID int64 `gorm:"primaryKey;autoIncrement"`
|
||||
SDKSessionID string `gorm:"index;not null"`
|
||||
Project string `gorm:"index;not null"`
|
||||
|
||||
// Summary fields (nullable TEXT)
|
||||
Request sql.NullString
|
||||
Investigated sql.NullString
|
||||
Learned sql.NullString
|
||||
Completed sql.NullString
|
||||
NextSteps sql.NullString `gorm:"column:next_steps"`
|
||||
Notes sql.NullString
|
||||
|
||||
// Metadata
|
||||
PromptNumber sql.NullInt64
|
||||
DiscoveryTokens int64 `gorm:"default:0"`
|
||||
CreatedAt string `gorm:"not null"`
|
||||
CreatedAtEpoch int64 `gorm:"index:idx_summaries_created,sort:desc;not null"`
|
||||
SDKSessionID string `gorm:"index;not null"`
|
||||
Project string `gorm:"index;not null"`
|
||||
Completed sql.NullString
|
||||
Investigated sql.NullString
|
||||
Learned sql.NullString
|
||||
NextSteps sql.NullString `gorm:"column:next_steps"`
|
||||
Notes sql.NullString
|
||||
Request sql.NullString
|
||||
PromptNumber sql.NullInt64
|
||||
ID int64 `gorm:"primaryKey;autoIncrement"`
|
||||
DiscoveryTokens int64 `gorm:"default:0"`
|
||||
CreatedAtEpoch int64 `gorm:"index:idx_summaries_created,sort:desc;not null"`
|
||||
}
|
||||
|
||||
func (SessionSummary) TableName() string { return "session_summaries" }
|
||||
@@ -129,12 +119,12 @@ func (s *SessionSummary) BeforeCreate(tx *gorm.DB) error {
|
||||
|
||||
// UserPrompt represents a user prompt.
|
||||
type UserPrompt struct {
|
||||
ID int64 `gorm:"primaryKey;autoIncrement"`
|
||||
ClaudeSessionID string `gorm:"index;not null;uniqueIndex:idx_user_prompts_session_number_unique,priority:1"`
|
||||
PromptNumber int `gorm:"index;not null;uniqueIndex:idx_user_prompts_session_number_unique,priority:2"`
|
||||
PromptText string `gorm:"type:text;not null"`
|
||||
MatchedObservations int `gorm:"default:0"`
|
||||
CreatedAt string `gorm:"not null"`
|
||||
ID int64 `gorm:"primaryKey;autoIncrement"`
|
||||
PromptNumber int `gorm:"index;not null;uniqueIndex:idx_user_prompts_session_number_unique,priority:2"`
|
||||
MatchedObservations int `gorm:"default:0"`
|
||||
CreatedAtEpoch int64 `gorm:"index:idx_prompts_created,sort:desc;not null"`
|
||||
}
|
||||
|
||||
@@ -153,16 +143,16 @@ func (p *UserPrompt) BeforeCreate(tx *gorm.DB) error {
|
||||
|
||||
// ObservationConflict tracks conflicts between observations.
|
||||
type ObservationConflict struct {
|
||||
ID int64 `gorm:"primaryKey;autoIncrement"`
|
||||
NewerObsID int64 `gorm:"index:idx_conflicts_newer;not null"`
|
||||
OlderObsID int64 `gorm:"index:idx_conflicts_older;not null"`
|
||||
ConflictType models.ConflictType `gorm:"type:text;check:conflict_type IN ('superseded', 'contradicts', 'outdated_pattern');not null"`
|
||||
Resolution models.ConflictResolution `gorm:"type:text;check:resolution IN ('prefer_newer', 'prefer_older', 'manual');not null"`
|
||||
Reason sql.NullString `gorm:"type:text"`
|
||||
DetectedAt string `gorm:"not null"`
|
||||
DetectedAtEpoch int64 `gorm:"index:idx_conflicts_unresolved,priority:2,sort:desc;not null"`
|
||||
Resolved int `gorm:"default:0;index:idx_conflicts_unresolved,priority:1"`
|
||||
Reason sql.NullString `gorm:"type:text"`
|
||||
ResolvedAt sql.NullString
|
||||
ID int64 `gorm:"primaryKey;autoIncrement"`
|
||||
NewerObsID int64 `gorm:"index:idx_conflicts_newer;not null"`
|
||||
OlderObsID int64 `gorm:"index:idx_conflicts_older;not null"`
|
||||
DetectedAtEpoch int64 `gorm:"index:idx_conflicts_unresolved,priority:2,sort:desc;not null"`
|
||||
Resolved int `gorm:"default:0;index:idx_conflicts_unresolved,priority:1"`
|
||||
}
|
||||
|
||||
func (ObservationConflict) TableName() string { return "observation_conflicts" }
|
||||
@@ -180,14 +170,14 @@ func (c *ObservationConflict) BeforeCreate(tx *gorm.DB) error {
|
||||
|
||||
// ObservationRelation tracks relationships between observations.
|
||||
type ObservationRelation struct {
|
||||
RelationType models.RelationType `gorm:"type:text;check:relation_type IN ('causes', 'fixes', 'supersedes', 'depends_on', 'relates_to', 'evolves_from');index:idx_relations_type;uniqueIndex:idx_relations_unique,priority:3;not null"`
|
||||
DetectionSource models.RelationDetectionSource `gorm:"type:text;check:detection_source IN ('file_overlap', 'embedding_similarity', 'temporal_proximity', 'narrative_mention', 'concept_overlap', 'type_progression');not null"`
|
||||
CreatedAt string `gorm:"not null"`
|
||||
Reason sql.NullString `gorm:"type:text"`
|
||||
ID int64 `gorm:"primaryKey;autoIncrement"`
|
||||
SourceID int64 `gorm:"index:idx_relations_source;index:idx_relations_both,priority:1;uniqueIndex:idx_relations_unique,priority:1;not null"`
|
||||
TargetID int64 `gorm:"index:idx_relations_target;index:idx_relations_both,priority:2;uniqueIndex:idx_relations_unique,priority:2;not null"`
|
||||
RelationType models.RelationType `gorm:"type:text;check:relation_type IN ('causes', 'fixes', 'supersedes', 'depends_on', 'relates_to', 'evolves_from');index:idx_relations_type;uniqueIndex:idx_relations_unique,priority:3;not null"`
|
||||
Confidence float64 `gorm:"type:real;default:0.5;index:idx_relations_confidence,sort:desc;not null"`
|
||||
DetectionSource models.RelationDetectionSource `gorm:"type:text;check:detection_source IN ('file_overlap', 'embedding_similarity', 'temporal_proximity', 'narrative_mention', 'concept_overlap', 'type_progression');not null"`
|
||||
Reason sql.NullString `gorm:"type:text"`
|
||||
CreatedAt string `gorm:"not null"`
|
||||
CreatedAtEpoch int64 `gorm:"not null"`
|
||||
}
|
||||
|
||||
@@ -209,21 +199,21 @@ func (r *ObservationRelation) BeforeCreate(tx *gorm.DB) error {
|
||||
|
||||
// Pattern represents a detected recurring pattern.
|
||||
type Pattern struct {
|
||||
ID int64 `gorm:"primaryKey;autoIncrement"`
|
||||
Status models.PatternStatus `gorm:"type:text;default:'active';check:status IN ('active', 'deprecated', 'merged');index"`
|
||||
Name string `gorm:"type:text;not null"`
|
||||
Type models.PatternType `gorm:"type:text;check:type IN ('bug', 'refactor', 'architecture', 'anti-pattern', 'best-practice');index;not null"`
|
||||
Description sql.NullString `gorm:"type:text"`
|
||||
Signature models.JSONStringArray `gorm:"type:text"` // JSON array of keywords
|
||||
CreatedAt string `gorm:"not null"`
|
||||
LastSeenAt string `gorm:"not null"`
|
||||
Signature models.JSONStringArray `gorm:"type:text"`
|
||||
Projects models.JSONStringArray `gorm:"type:text"`
|
||||
ObservationIDs models.JSONInt64Array `gorm:"type:text"`
|
||||
Recommendation sql.NullString `gorm:"type:text"`
|
||||
Frequency int `gorm:"default:1;index:idx_patterns_frequency,sort:desc"`
|
||||
Projects models.JSONStringArray `gorm:"type:text"` // JSON array
|
||||
ObservationIDs models.JSONInt64Array `gorm:"type:text"` // JSON array
|
||||
Status models.PatternStatus `gorm:"type:text;default:'active';check:status IN ('active', 'deprecated', 'merged');index"`
|
||||
Description sql.NullString `gorm:"type:text"`
|
||||
MergedIntoID sql.NullInt64
|
||||
Frequency int `gorm:"default:1;index:idx_patterns_frequency,sort:desc"`
|
||||
Confidence float64 `gorm:"type:real;default:0.5;index:idx_patterns_confidence,sort:desc"`
|
||||
LastSeenAt string `gorm:"not null"`
|
||||
ID int64 `gorm:"primaryKey;autoIncrement"`
|
||||
LastSeenAtEpoch int64 `gorm:"index:idx_patterns_last_seen,sort:desc;not null"`
|
||||
CreatedAt string `gorm:"not null"`
|
||||
CreatedAtEpoch int64 `gorm:"not null"`
|
||||
}
|
||||
|
||||
@@ -256,8 +246,8 @@ func (p *Pattern) BeforeCreate(tx *gorm.DB) error {
|
||||
// ConceptWeight stores configurable weights for importance scoring.
|
||||
type ConceptWeight struct {
|
||||
Concept string `gorm:"primaryKey;type:text"`
|
||||
Weight float64 `gorm:"type:real;not null;default:0.1"`
|
||||
UpdatedAt string `gorm:"not null"`
|
||||
Weight float64 `gorm:"type:real;not null;default:0.1"`
|
||||
}
|
||||
|
||||
func (ConceptWeight) TableName() string { return "concept_weights" }
|
||||
|
||||
@@ -145,9 +145,9 @@ func (s *PromptStore) GetPromptsByIDs(ctx context.Context, ids []int64, orderBy
|
||||
}
|
||||
|
||||
var results []struct {
|
||||
UserPrompt
|
||||
Project sql.NullString `gorm:"column:project"`
|
||||
SDKSessionID sql.NullString `gorm:"column:sdk_session_id"`
|
||||
UserPrompt
|
||||
}
|
||||
|
||||
query := s.db.WithContext(ctx).
|
||||
@@ -184,9 +184,9 @@ func (s *PromptStore) GetPromptsByIDs(ctx context.Context, ids []int64, orderBy
|
||||
// GetAllRecentUserPrompts retrieves recent user prompts across all projects.
|
||||
func (s *PromptStore) GetAllRecentUserPrompts(ctx context.Context, limit int) ([]*models.UserPromptWithSession, error) {
|
||||
var results []struct {
|
||||
UserPrompt
|
||||
Project sql.NullString `gorm:"column:project"`
|
||||
SDKSessionID sql.NullString `gorm:"column:sdk_session_id"`
|
||||
UserPrompt
|
||||
}
|
||||
|
||||
query := s.db.WithContext(ctx).
|
||||
@@ -211,9 +211,9 @@ func (s *PromptStore) GetAllRecentUserPrompts(ctx context.Context, limit int) ([
|
||||
// GetAllPrompts retrieves all user prompts (for vector rebuild).
|
||||
func (s *PromptStore) GetAllPrompts(ctx context.Context) ([]*models.UserPromptWithSession, error) {
|
||||
var results []struct {
|
||||
UserPrompt
|
||||
Project sql.NullString `gorm:"column:project"`
|
||||
SDKSessionID sql.NullString `gorm:"column:sdk_session_id"`
|
||||
UserPrompt
|
||||
}
|
||||
|
||||
query := s.db.WithContext(ctx).
|
||||
@@ -256,9 +256,9 @@ func (s *PromptStore) FindRecentPromptByText(ctx context.Context, claudeSessionI
|
||||
// GetRecentUserPromptsByProject retrieves recent user prompts for a specific project.
|
||||
func (s *PromptStore) GetRecentUserPromptsByProject(ctx context.Context, project string, limit int) ([]*models.UserPromptWithSession, error) {
|
||||
var results []struct {
|
||||
UserPrompt
|
||||
Project sql.NullString `gorm:"column:project"`
|
||||
SDKSessionID sql.NullString `gorm:"column:sdk_session_id"`
|
||||
UserPrompt
|
||||
}
|
||||
|
||||
query := s.db.WithContext(ctx).
|
||||
@@ -283,9 +283,9 @@ func (s *PromptStore) GetRecentUserPromptsByProject(ctx context.Context, project
|
||||
|
||||
// toModelUserPromptsWithSession converts query results to pkg/models.UserPromptWithSession.
|
||||
func toModelUserPromptsWithSession(results []struct {
|
||||
UserPrompt
|
||||
Project sql.NullString `gorm:"column:project"`
|
||||
SDKSessionID sql.NullString `gorm:"column:sdk_session_id"`
|
||||
UserPrompt
|
||||
}) []*models.UserPromptWithSession {
|
||||
prompts := make([]*models.UserPromptWithSession, len(results))
|
||||
for i, r := range results {
|
||||
|
||||
@@ -171,11 +171,11 @@ func (s *RelationStore) GetRelationsByType(ctx context.Context, relationType mod
|
||||
// GetRelationsWithDetails retrieves relations with observation titles for display.
|
||||
func (s *RelationStore) GetRelationsWithDetails(ctx context.Context, obsID int64) ([]*models.RelationWithDetails, error) {
|
||||
var results []struct {
|
||||
ObservationRelation
|
||||
SourceTitle sql.NullString `gorm:"column:source_title"`
|
||||
TargetTitle sql.NullString `gorm:"column:target_title"`
|
||||
SourceType string `gorm:"column:source_type"`
|
||||
TargetType string `gorm:"column:target_type"`
|
||||
SourceTitle sql.NullString `gorm:"column:source_title"`
|
||||
TargetTitle sql.NullString `gorm:"column:target_title"`
|
||||
ObservationRelation
|
||||
}
|
||||
|
||||
err := s.db.WithContext(ctx).
|
||||
|
||||
@@ -21,15 +21,10 @@ const (
|
||||
// ONNXConfig describes ONNX-specific model configuration.
|
||||
// This allows different models to specify their tensor names and pooling needs.
|
||||
type ONNXConfig struct {
|
||||
// InputNames are the ONNX input tensor names in order.
|
||||
InputNames []string
|
||||
// OutputNames are the ONNX output tensor names.
|
||||
Pooling PoolingStrategy
|
||||
InputNames []string
|
||||
OutputNames []string
|
||||
// Pooling specifies how to convert token embeddings to sentence embeddings.
|
||||
// If PoolingNone, the model outputs sentence embeddings directly.
|
||||
Pooling PoolingStrategy
|
||||
// HiddenSize is the embedding dimension (used for pooling calculations).
|
||||
HiddenSize int
|
||||
HiddenSize int
|
||||
}
|
||||
|
||||
// EmbeddingModel represents a text embedding model.
|
||||
@@ -62,11 +57,11 @@ type ONNXConfigurer interface {
|
||||
|
||||
// ModelMetadata describes an embedding model for UI/config.
|
||||
type ModelMetadata struct {
|
||||
Name string `json:"name"` // Human-readable name
|
||||
Version string `json:"version"` // Short ID for DB storage
|
||||
Dimensions int `json:"dimensions"` // Vector size
|
||||
Description string `json:"description"` // Brief description
|
||||
Default bool `json:"default"` // Is this the default model?
|
||||
Name string `json:"name"`
|
||||
Version string `json:"version"`
|
||||
Description string `json:"description"`
|
||||
Dimensions int `json:"dimensions"`
|
||||
Default bool `json:"default"`
|
||||
}
|
||||
|
||||
// ModelFactory creates a new instance of an embedding model.
|
||||
@@ -74,10 +69,10 @@ type ModelFactory func() (EmbeddingModel, error)
|
||||
|
||||
// ModelRegistry provides model lookup by version.
|
||||
type ModelRegistry struct {
|
||||
mu sync.RWMutex
|
||||
models map[string]ModelFactory
|
||||
metadata map[string]ModelMetadata
|
||||
defaultModel string
|
||||
mu sync.RWMutex
|
||||
}
|
||||
|
||||
// NewModelRegistry creates a new model registry.
|
||||
|
||||
@@ -46,9 +46,9 @@ var bgeONNXConfig = ONNXConfig{
|
||||
type bgeModel struct {
|
||||
tk *tokenizer.Tokenizer
|
||||
session *ort.DynamicAdvancedSession
|
||||
libDir string
|
||||
config ONNXConfig
|
||||
mu sync.Mutex
|
||||
libDir string // temp directory containing extracted libraries
|
||||
config ONNXConfig // ONNX configuration for this model
|
||||
}
|
||||
|
||||
// Compile-time check that bgeModel implements EmbeddingModel
|
||||
|
||||
@@ -0,0 +1,417 @@
|
||||
package graph
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"math"
|
||||
|
||||
"github.com/lukaszraczylo/claude-mnemonic/pkg/models"
|
||||
"github.com/rs/zerolog/log"
|
||||
)
|
||||
|
||||
const (
|
||||
// SemanticSimilarityThreshold for creating semantic edges
|
||||
SemanticSimilarityThreshold = 0.85
|
||||
|
||||
// MinFileOverlapForEdge minimum file overlap ratio to create edge
|
||||
MinFileOverlapForEdge = 0.3
|
||||
|
||||
// MaxEdgesPerNode prevents creating too many edges
|
||||
MaxEdgesPerNode = 20
|
||||
)
|
||||
|
||||
// DetectEdges identifies relationships between observations
|
||||
func DetectEdges(ctx context.Context, observations []*models.Observation) ([]Edge, error) {
|
||||
if len(observations) < 2 {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
edges := make([]Edge, 0)
|
||||
|
||||
// Build lookup maps for efficient detection
|
||||
sessionMap := buildSessionMap(observations)
|
||||
conceptMap := buildConceptMap(observations)
|
||||
fileMap := buildFileMap(observations)
|
||||
|
||||
log.Info().
|
||||
Int("observations", len(observations)).
|
||||
Int("sessions", len(sessionMap)).
|
||||
Int("concepts", len(conceptMap)).
|
||||
Msg("Starting edge detection")
|
||||
|
||||
// Detect temporal edges (same session)
|
||||
temporalEdges := detectTemporalEdges(sessionMap)
|
||||
edges = append(edges, temporalEdges...)
|
||||
|
||||
// Detect concept edges (shared tags)
|
||||
conceptEdges := detectConceptEdges(conceptMap)
|
||||
edges = append(edges, conceptEdges...)
|
||||
|
||||
// Detect file overlap edges
|
||||
fileEdges := detectFileOverlapEdges(fileMap, observations)
|
||||
edges = append(edges, fileEdges...)
|
||||
|
||||
// Prune excessive edges per node
|
||||
edges = pruneEdges(edges, MaxEdgesPerNode)
|
||||
|
||||
log.Info().
|
||||
Int("temporal_edges", len(temporalEdges)).
|
||||
Int("concept_edges", len(conceptEdges)).
|
||||
Int("file_edges", len(fileEdges)).
|
||||
Int("total_edges", len(edges)).
|
||||
Msg("Edge detection complete")
|
||||
|
||||
return edges, nil
|
||||
}
|
||||
|
||||
// buildSessionMap groups observations by SDK session
|
||||
func buildSessionMap(observations []*models.Observation) map[string][]int64 {
|
||||
sessionMap := make(map[string][]int64)
|
||||
|
||||
for _, obs := range observations {
|
||||
if obs.SDKSessionID != "" {
|
||||
sessionMap[obs.SDKSessionID] = append(sessionMap[obs.SDKSessionID], obs.ID)
|
||||
}
|
||||
}
|
||||
|
||||
return sessionMap
|
||||
}
|
||||
|
||||
// buildConceptMap groups observations by concept tags
|
||||
func buildConceptMap(observations []*models.Observation) map[string][]int64 {
|
||||
conceptMap := make(map[string][]int64)
|
||||
|
||||
for _, obs := range observations {
|
||||
for _, concept := range obs.Concepts {
|
||||
conceptMap[concept] = append(conceptMap[concept], obs.ID)
|
||||
}
|
||||
}
|
||||
|
||||
return conceptMap
|
||||
}
|
||||
|
||||
// buildFileMap maps files to observations (from both FilesRead and FilesModified)
|
||||
func buildFileMap(observations []*models.Observation) map[string][]int64 {
|
||||
fileMap := make(map[string][]int64)
|
||||
|
||||
for _, obs := range observations {
|
||||
// Add files from FilesRead
|
||||
for _, file := range obs.FilesRead {
|
||||
fileMap[file] = append(fileMap[file], obs.ID)
|
||||
}
|
||||
// Add files from FilesModified
|
||||
for _, file := range obs.FilesModified {
|
||||
fileMap[file] = append(fileMap[file], obs.ID)
|
||||
}
|
||||
}
|
||||
|
||||
return fileMap
|
||||
}
|
||||
|
||||
// detectTemporalEdges creates edges between observations in the same session
|
||||
func detectTemporalEdges(sessionMap map[string][]int64) []Edge {
|
||||
edges := make([]Edge, 0)
|
||||
|
||||
for _, obsIDs := range sessionMap {
|
||||
if len(obsIDs) < 2 {
|
||||
continue
|
||||
}
|
||||
|
||||
// Create edges between consecutive observations in session
|
||||
for i := 0; i < len(obsIDs)-1; i++ {
|
||||
edges = append(edges, Edge{
|
||||
FromID: obsIDs[i],
|
||||
ToID: obsIDs[i+1],
|
||||
Relation: RelationTemporal,
|
||||
Weight: 0.8, // High weight for temporal proximity
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
return edges
|
||||
}
|
||||
|
||||
// detectConceptEdges creates edges between observations sharing concepts
|
||||
func detectConceptEdges(conceptMap map[string][]int64) []Edge {
|
||||
edges := make([]Edge, 0)
|
||||
seen := make(map[string]bool)
|
||||
|
||||
for concept, obsIDs := range conceptMap {
|
||||
if len(obsIDs) < 2 {
|
||||
continue
|
||||
}
|
||||
|
||||
// Create edges between all observations sharing this concept
|
||||
for i := 0; i < len(obsIDs); i++ {
|
||||
for j := i + 1; j < len(obsIDs); j++ {
|
||||
// Use sorted pair as key to avoid duplicates
|
||||
pairKey := edgeKey(obsIDs[i], obsIDs[j])
|
||||
if seen[pairKey] {
|
||||
continue
|
||||
}
|
||||
seen[pairKey] = true
|
||||
|
||||
// Weight based on concept specificity (longer = more specific)
|
||||
weight := float32(0.5 + 0.3*math.Min(1.0, float64(len(concept))/20.0))
|
||||
|
||||
edges = append(edges, Edge{
|
||||
FromID: obsIDs[i],
|
||||
ToID: obsIDs[j],
|
||||
Relation: RelationConcept,
|
||||
Weight: weight,
|
||||
})
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return edges
|
||||
}
|
||||
|
||||
// detectFileOverlapEdges creates edges based on file references
|
||||
func detectFileOverlapEdges(fileMap map[string][]int64, observations []*models.Observation) []Edge {
|
||||
edges := make([]Edge, 0)
|
||||
seen := make(map[string]bool)
|
||||
|
||||
// Build observation ID to observation map for quick lookup
|
||||
obsMap := make(map[int64]*models.Observation)
|
||||
for _, obs := range observations {
|
||||
obsMap[obs.ID] = obs
|
||||
}
|
||||
|
||||
for _, obsIDs := range fileMap {
|
||||
if len(obsIDs) < 2 {
|
||||
continue
|
||||
}
|
||||
|
||||
// Create edges between observations referencing same files
|
||||
for i := 0; i < len(obsIDs); i++ {
|
||||
for j := i + 1; j < len(obsIDs); j++ {
|
||||
pairKey := edgeKey(obsIDs[i], obsIDs[j])
|
||||
if seen[pairKey] {
|
||||
continue
|
||||
}
|
||||
seen[pairKey] = true
|
||||
|
||||
// Calculate file overlap ratio
|
||||
obs1, ok1 := obsMap[obsIDs[i]]
|
||||
obs2, ok2 := obsMap[obsIDs[j]]
|
||||
|
||||
if !ok1 || !ok2 {
|
||||
continue
|
||||
}
|
||||
|
||||
// Merge FilesRead and FilesModified for both observations
|
||||
files1 := append([]string{}, obs1.FilesRead...)
|
||||
files1 = append(files1, obs1.FilesModified...)
|
||||
files2 := append([]string{}, obs2.FilesRead...)
|
||||
files2 = append(files2, obs2.FilesModified...)
|
||||
|
||||
overlap := calculateFileOverlap(files1, files2)
|
||||
if overlap < MinFileOverlapForEdge {
|
||||
continue
|
||||
}
|
||||
|
||||
edges = append(edges, Edge{
|
||||
FromID: obsIDs[i],
|
||||
ToID: obsIDs[j],
|
||||
Relation: RelationFileOverlap,
|
||||
Weight: overlap,
|
||||
})
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return edges
|
||||
}
|
||||
|
||||
// calculateFileOverlap computes Jaccard similarity of file sets
|
||||
func calculateFileOverlap(files1, files2 []string) float32 {
|
||||
if len(files1) == 0 || len(files2) == 0 {
|
||||
return 0.0
|
||||
}
|
||||
|
||||
// Convert to sets
|
||||
set1 := make(map[string]bool)
|
||||
for _, f := range files1 {
|
||||
set1[f] = true
|
||||
}
|
||||
|
||||
set2 := make(map[string]bool)
|
||||
for _, f := range files2 {
|
||||
set2[f] = true
|
||||
}
|
||||
|
||||
// Count intersection
|
||||
intersection := 0
|
||||
for f := range set1 {
|
||||
if set2[f] {
|
||||
intersection++
|
||||
}
|
||||
}
|
||||
|
||||
// Jaccard similarity = intersection / union
|
||||
union := len(set1) + len(set2) - intersection
|
||||
if union == 0 {
|
||||
return 0.0
|
||||
}
|
||||
|
||||
return float32(intersection) / float32(union)
|
||||
}
|
||||
|
||||
// pruneEdges limits edges per node to prevent graph explosion
|
||||
func pruneEdges(edges []Edge, maxPerNode int) []Edge {
|
||||
if maxPerNode <= 0 {
|
||||
return edges
|
||||
}
|
||||
|
||||
// Count edges per node
|
||||
outEdges := make(map[int64][]Edge)
|
||||
inEdges := make(map[int64][]Edge)
|
||||
|
||||
for _, edge := range edges {
|
||||
outEdges[edge.FromID] = append(outEdges[edge.FromID], edge)
|
||||
inEdges[edge.ToID] = append(inEdges[edge.ToID], edge)
|
||||
}
|
||||
|
||||
// Prune low-weight edges if node has too many
|
||||
pruned := make([]Edge, 0, len(edges))
|
||||
processed := make(map[string]bool)
|
||||
|
||||
for _, edge := range edges {
|
||||
pairKey := edgeKey(edge.FromID, edge.ToID)
|
||||
if processed[pairKey] {
|
||||
continue
|
||||
}
|
||||
processed[pairKey] = true
|
||||
|
||||
// Check if either node has too many edges
|
||||
fromCount := len(outEdges[edge.FromID])
|
||||
toCount := len(inEdges[edge.ToID])
|
||||
|
||||
if fromCount <= maxPerNode && toCount <= maxPerNode {
|
||||
pruned = append(pruned, edge)
|
||||
continue
|
||||
}
|
||||
|
||||
// Keep edge if it's high-weight (top edges for this node)
|
||||
if shouldKeepEdge(edge, outEdges[edge.FromID], maxPerNode) {
|
||||
pruned = append(pruned, edge)
|
||||
}
|
||||
}
|
||||
|
||||
if len(pruned) < len(edges) {
|
||||
log.Debug().
|
||||
Int("original", len(edges)).
|
||||
Int("pruned", len(pruned)).
|
||||
Int("removed", len(edges)-len(pruned)).
|
||||
Msg("Pruned excessive edges")
|
||||
}
|
||||
|
||||
return pruned
|
||||
}
|
||||
|
||||
// shouldKeepEdge determines if edge should be kept during pruning
|
||||
func shouldKeepEdge(edge Edge, nodeEdges []Edge, maxPerNode int) bool {
|
||||
// Sort node's edges by weight descending
|
||||
sortedEdges := make([]Edge, len(nodeEdges))
|
||||
copy(sortedEdges, nodeEdges)
|
||||
|
||||
sortEdgesByWeight(sortedEdges)
|
||||
|
||||
// Keep edge if it's in top maxPerNode
|
||||
for i := 0; i < maxPerNode && i < len(sortedEdges); i++ {
|
||||
if sortedEdges[i].FromID == edge.FromID && sortedEdges[i].ToID == edge.ToID {
|
||||
return true
|
||||
}
|
||||
}
|
||||
|
||||
return false
|
||||
}
|
||||
|
||||
// sortEdgesByWeight sorts edges by weight descending
|
||||
func sortEdgesByWeight(edges []Edge) {
|
||||
// Simple bubble sort (edges are typically small per node)
|
||||
n := len(edges)
|
||||
for i := 0; i < n-1; i++ {
|
||||
for j := 0; j < n-i-1; j++ {
|
||||
if edges[j].Weight < edges[j+1].Weight {
|
||||
edges[j], edges[j+1] = edges[j+1], edges[j]
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// edgeKey creates a unique key for an edge pair (sorted)
|
||||
func edgeKey(id1, id2 int64) string {
|
||||
if id1 < id2 {
|
||||
return fmt.Sprintf("%d-%d", id1, id2)
|
||||
}
|
||||
return fmt.Sprintf("%d-%d", id2, id1)
|
||||
}
|
||||
|
||||
// DetectSemanticEdges creates edges based on semantic similarity
|
||||
// This requires embeddings and is called separately when available
|
||||
func DetectSemanticEdges(ctx context.Context, observations []*models.Observation, embeddings map[int64][]float32) []Edge {
|
||||
edges := make([]Edge, 0)
|
||||
seen := make(map[string]bool)
|
||||
|
||||
// Compare all pairs (expensive, but necessary for semantic similarity)
|
||||
for i := 0; i < len(observations); i++ {
|
||||
emb1, ok1 := embeddings[observations[i].ID]
|
||||
if !ok1 {
|
||||
continue
|
||||
}
|
||||
|
||||
for j := i + 1; j < len(observations); j++ {
|
||||
emb2, ok2 := embeddings[observations[j].ID]
|
||||
if !ok2 {
|
||||
continue
|
||||
}
|
||||
|
||||
similarity := cosineSimilarity(emb1, emb2)
|
||||
if similarity < SemanticSimilarityThreshold {
|
||||
continue
|
||||
}
|
||||
|
||||
pairKey := edgeKey(observations[i].ID, observations[j].ID)
|
||||
if seen[pairKey] {
|
||||
continue
|
||||
}
|
||||
seen[pairKey] = true
|
||||
|
||||
edges = append(edges, Edge{
|
||||
FromID: observations[i].ID,
|
||||
ToID: observations[j].ID,
|
||||
Relation: RelationSemantic,
|
||||
Weight: similarity,
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
log.Info().
|
||||
Int("semantic_edges", len(edges)).
|
||||
Float32("threshold", SemanticSimilarityThreshold).
|
||||
Msg("Detected semantic edges")
|
||||
|
||||
return edges
|
||||
}
|
||||
|
||||
// cosineSimilarity computes cosine similarity between two vectors
|
||||
func cosineSimilarity(a, b []float32) float32 {
|
||||
if len(a) != len(b) {
|
||||
return 0.0
|
||||
}
|
||||
|
||||
var dotProduct, normA, normB float32
|
||||
for i := range a {
|
||||
dotProduct += a[i] * b[i]
|
||||
normA += a[i] * a[i]
|
||||
normB += b[i] * b[i]
|
||||
}
|
||||
|
||||
if normA == 0 || normB == 0 {
|
||||
return 0.0
|
||||
}
|
||||
|
||||
return dotProduct / float32(math.Sqrt(float64(normA))*math.Sqrt(float64(normB)))
|
||||
}
|
||||
@@ -0,0 +1,423 @@
|
||||
// Package graph provides observation relationship graphs for LEANN Phase 2.
|
||||
//
|
||||
// This package implements graph-based selective recomputation where observation
|
||||
// relationships (file overlap, semantic similarity, temporal proximity) form a
|
||||
// graph structure. Hub nodes (high-degree observations) store embeddings, while
|
||||
// leaf nodes recompute on-demand.
|
||||
package graph
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"math"
|
||||
"sort"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/lukaszraczylo/claude-mnemonic/pkg/models"
|
||||
"github.com/rs/zerolog/log"
|
||||
)
|
||||
|
||||
// RelationType defines the type of relationship between observations
|
||||
type RelationType int
|
||||
|
||||
const (
|
||||
// RelationFileOverlap indicates observations reference overlapping files
|
||||
RelationFileOverlap RelationType = iota
|
||||
// RelationSemantic indicates high semantic similarity (cosine > 0.85)
|
||||
RelationSemantic
|
||||
// RelationTemporal indicates observations from same session
|
||||
RelationTemporal
|
||||
// RelationConcept indicates shared concept tags
|
||||
RelationConcept
|
||||
)
|
||||
|
||||
// Edge represents a relationship between two observations
|
||||
type Edge struct {
|
||||
FromID int64
|
||||
ToID int64
|
||||
Relation RelationType
|
||||
Weight float32 // 0.0-1.0, higher = stronger relationship
|
||||
}
|
||||
|
||||
// Node represents an observation in the graph
|
||||
type Node struct {
|
||||
Metadata NodeMetadata
|
||||
LastAccess time.Time
|
||||
StoredEmb []float32 // Nil if recomputed on-demand
|
||||
ID int64
|
||||
Degree int // Number of edges (hub detection)
|
||||
AccessCount int
|
||||
}
|
||||
|
||||
// NodeMetadata contains observation metadata
|
||||
type NodeMetadata struct {
|
||||
CreatedAt time.Time
|
||||
Project string
|
||||
Type string
|
||||
Title string
|
||||
IsSuperseded bool
|
||||
}
|
||||
|
||||
// CSRGraph represents a graph in Compressed Sparse Row format for memory efficiency
|
||||
type CSRGraph struct {
|
||||
RowPtr []int32 // Node adjacency list pointers
|
||||
ColIdx []int32 // Edge destination IDs
|
||||
Weights []float32 // Edge weights
|
||||
mu sync.RWMutex
|
||||
}
|
||||
|
||||
// ObservationGraph manages the observation relationship graph
|
||||
type ObservationGraph struct {
|
||||
nodes map[int64]*Node
|
||||
csr *CSRGraph
|
||||
edges []Edge
|
||||
nodesMu sync.RWMutex
|
||||
edgesMu sync.RWMutex
|
||||
}
|
||||
|
||||
// NewObservationGraph creates a new empty observation graph
|
||||
func NewObservationGraph() *ObservationGraph {
|
||||
return &ObservationGraph{
|
||||
nodes: make(map[int64]*Node),
|
||||
edges: make([]Edge, 0),
|
||||
csr: &CSRGraph{},
|
||||
}
|
||||
}
|
||||
|
||||
// AddNode adds or updates a node in the graph
|
||||
func (g *ObservationGraph) AddNode(node *Node) {
|
||||
g.nodesMu.Lock()
|
||||
defer g.nodesMu.Unlock()
|
||||
|
||||
g.nodes[node.ID] = node
|
||||
}
|
||||
|
||||
// AddEdge adds an edge to the graph
|
||||
func (g *ObservationGraph) AddEdge(edge Edge) {
|
||||
g.edgesMu.Lock()
|
||||
defer g.edgesMu.Unlock()
|
||||
|
||||
g.edges = append(g.edges, edge)
|
||||
|
||||
// Update degree counts
|
||||
g.nodesMu.Lock()
|
||||
if fromNode, ok := g.nodes[edge.FromID]; ok {
|
||||
fromNode.Degree++
|
||||
}
|
||||
if toNode, ok := g.nodes[edge.ToID]; ok {
|
||||
toNode.Degree++
|
||||
}
|
||||
g.nodesMu.Unlock()
|
||||
}
|
||||
|
||||
// BuildCSR converts edge list to CSR format for efficient traversal
|
||||
func (g *ObservationGraph) BuildCSR() error {
|
||||
g.edgesMu.RLock()
|
||||
g.nodesMu.RLock()
|
||||
defer g.edgesMu.RUnlock()
|
||||
defer g.nodesMu.RUnlock()
|
||||
|
||||
if len(g.nodes) == 0 {
|
||||
return fmt.Errorf("no nodes in graph")
|
||||
}
|
||||
|
||||
// Create node ID to index mapping
|
||||
nodeIDs := make([]int64, 0, len(g.nodes))
|
||||
for id := range g.nodes {
|
||||
nodeIDs = append(nodeIDs, id)
|
||||
}
|
||||
sort.Slice(nodeIDs, func(i, j int) bool {
|
||||
return nodeIDs[i] < nodeIDs[j]
|
||||
})
|
||||
|
||||
idToIdx := make(map[int64]int32)
|
||||
for idx, id := range nodeIDs {
|
||||
// #nosec G115 - observation count will never exceed int32 max (2.1B) in practice
|
||||
idToIdx[id] = int32(idx)
|
||||
}
|
||||
|
||||
// Count edges per node
|
||||
edgeCounts := make([]int, len(nodeIDs))
|
||||
for _, edge := range g.edges {
|
||||
if fromIdx, ok := idToIdx[edge.FromID]; ok {
|
||||
edgeCounts[fromIdx]++
|
||||
}
|
||||
}
|
||||
|
||||
// Build row pointers
|
||||
rowPtr := make([]int32, len(nodeIDs)+1)
|
||||
rowPtr[0] = 0
|
||||
for i := 0; i < len(nodeIDs); i++ {
|
||||
// #nosec G115 - edge counts per node will not exceed int32 max
|
||||
rowPtr[i+1] = rowPtr[i] + int32(edgeCounts[i])
|
||||
}
|
||||
|
||||
// Build column indices and weights
|
||||
totalEdges := rowPtr[len(nodeIDs)]
|
||||
colIdx := make([]int32, totalEdges)
|
||||
weights := make([]float32, totalEdges)
|
||||
|
||||
// Temporary counter for filling CSR
|
||||
currentPos := make([]int32, len(nodeIDs))
|
||||
copy(currentPos, rowPtr[:len(nodeIDs)])
|
||||
|
||||
for _, edge := range g.edges {
|
||||
fromIdx, fromOk := idToIdx[edge.FromID]
|
||||
toIdx, toOk := idToIdx[edge.ToID]
|
||||
|
||||
if fromOk && toOk {
|
||||
pos := currentPos[fromIdx]
|
||||
colIdx[pos] = toIdx
|
||||
weights[pos] = edge.Weight
|
||||
currentPos[fromIdx]++
|
||||
}
|
||||
}
|
||||
|
||||
g.csr.mu.Lock()
|
||||
g.csr.RowPtr = rowPtr
|
||||
g.csr.ColIdx = colIdx
|
||||
g.csr.Weights = weights
|
||||
g.csr.mu.Unlock()
|
||||
|
||||
log.Info().
|
||||
Int("nodes", len(nodeIDs)).
|
||||
Int("edges", int(totalEdges)).
|
||||
Msg("Built CSR graph representation")
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// GetNeighbors returns neighboring nodes and their edge weights
|
||||
func (g *ObservationGraph) GetNeighbors(nodeID int64) ([]int64, []float32, error) {
|
||||
g.csr.mu.RLock()
|
||||
defer g.csr.mu.RUnlock()
|
||||
|
||||
// Find node index in CSR
|
||||
g.nodesMu.RLock()
|
||||
nodeIDs := make([]int64, 0, len(g.nodes))
|
||||
for id := range g.nodes {
|
||||
nodeIDs = append(nodeIDs, id)
|
||||
}
|
||||
g.nodesMu.RUnlock()
|
||||
|
||||
sort.Slice(nodeIDs, func(i, j int) bool {
|
||||
return nodeIDs[i] < nodeIDs[j]
|
||||
})
|
||||
|
||||
nodeIdx := sort.Search(len(nodeIDs), func(i int) bool {
|
||||
return nodeIDs[i] >= nodeID
|
||||
})
|
||||
|
||||
if nodeIdx >= len(nodeIDs) || nodeIDs[nodeIdx] != nodeID {
|
||||
return nil, nil, fmt.Errorf("node %d not found", nodeID)
|
||||
}
|
||||
|
||||
// Extract neighbors from CSR
|
||||
startIdx := g.csr.RowPtr[nodeIdx]
|
||||
endIdx := g.csr.RowPtr[nodeIdx+1]
|
||||
|
||||
neighborCount := endIdx - startIdx
|
||||
neighbors := make([]int64, neighborCount)
|
||||
weights := make([]float32, neighborCount)
|
||||
|
||||
for i := int32(0); i < neighborCount; i++ {
|
||||
neighborIdx := g.csr.ColIdx[startIdx+i]
|
||||
neighbors[i] = nodeIDs[neighborIdx]
|
||||
weights[i] = g.csr.Weights[startIdx+i]
|
||||
}
|
||||
|
||||
return neighbors, weights, nil
|
||||
}
|
||||
|
||||
// GetNode retrieves a node by ID
|
||||
func (g *ObservationGraph) GetNode(nodeID int64) (*Node, error) {
|
||||
g.nodesMu.RLock()
|
||||
defer g.nodesMu.RUnlock()
|
||||
|
||||
node, ok := g.nodes[nodeID]
|
||||
if !ok {
|
||||
return nil, fmt.Errorf("node %d not found", nodeID)
|
||||
}
|
||||
|
||||
return node, nil
|
||||
}
|
||||
|
||||
// FindHubs identifies hub nodes (high degree) in the graph
|
||||
func (g *ObservationGraph) FindHubs(percentile float64) []int64 {
|
||||
g.nodesMu.RLock()
|
||||
defer g.nodesMu.RUnlock()
|
||||
|
||||
if len(g.nodes) == 0 {
|
||||
return nil
|
||||
}
|
||||
|
||||
// Collect all degrees
|
||||
degrees := make([]int, 0, len(g.nodes))
|
||||
nodeIDs := make([]int64, 0, len(g.nodes))
|
||||
|
||||
for id, node := range g.nodes {
|
||||
degrees = append(degrees, node.Degree)
|
||||
nodeIDs = append(nodeIDs, id)
|
||||
}
|
||||
|
||||
// Sort by degree
|
||||
type nodeDegree struct {
|
||||
ID int64
|
||||
Degree int
|
||||
}
|
||||
|
||||
nodeDegrees := make([]nodeDegree, len(nodeIDs))
|
||||
for i := range nodeIDs {
|
||||
nodeDegrees[i] = nodeDegree{
|
||||
ID: nodeIDs[i],
|
||||
Degree: degrees[i],
|
||||
}
|
||||
}
|
||||
|
||||
sort.Slice(nodeDegrees, func(i, j int) bool {
|
||||
return nodeDegrees[i].Degree > nodeDegrees[j].Degree
|
||||
})
|
||||
|
||||
// Return top percentile
|
||||
cutoff := int(math.Ceil(float64(len(nodeDegrees)) * (1.0 - percentile)))
|
||||
if cutoff > len(nodeDegrees) {
|
||||
cutoff = len(nodeDegrees)
|
||||
}
|
||||
|
||||
hubs := make([]int64, cutoff)
|
||||
for i := 0; i < cutoff; i++ {
|
||||
hubs[i] = nodeDegrees[i].ID
|
||||
}
|
||||
|
||||
log.Info().
|
||||
Int("total_nodes", len(g.nodes)).
|
||||
Int("hubs", len(hubs)).
|
||||
Float64("percentile", percentile).
|
||||
Msg("Identified hub nodes")
|
||||
|
||||
return hubs
|
||||
}
|
||||
|
||||
// Stats returns graph statistics
|
||||
func (g *ObservationGraph) Stats() GraphStats {
|
||||
g.nodesMu.RLock()
|
||||
g.edgesMu.RLock()
|
||||
defer g.nodesMu.RUnlock()
|
||||
defer g.edgesMu.RUnlock()
|
||||
|
||||
stats := GraphStats{
|
||||
NodeCount: len(g.nodes),
|
||||
EdgeCount: len(g.edges),
|
||||
}
|
||||
|
||||
if len(g.nodes) > 0 {
|
||||
degrees := make([]int, 0, len(g.nodes))
|
||||
for _, node := range g.nodes {
|
||||
degrees = append(degrees, node.Degree)
|
||||
}
|
||||
|
||||
sort.Ints(degrees)
|
||||
stats.AvgDegree = float64(sum(degrees)) / float64(len(degrees))
|
||||
stats.MaxDegree = degrees[len(degrees)-1]
|
||||
stats.MinDegree = degrees[0]
|
||||
|
||||
// Median
|
||||
mid := len(degrees) / 2
|
||||
if len(degrees)%2 == 0 {
|
||||
stats.MedianDegree = float64(degrees[mid-1]+degrees[mid]) / 2.0
|
||||
} else {
|
||||
stats.MedianDegree = float64(degrees[mid])
|
||||
}
|
||||
}
|
||||
|
||||
// Count edge types
|
||||
stats.EdgeTypes = make(map[RelationType]int)
|
||||
for _, edge := range g.edges {
|
||||
stats.EdgeTypes[edge.Relation]++
|
||||
}
|
||||
|
||||
return stats
|
||||
}
|
||||
|
||||
// GraphStats contains graph statistics
|
||||
type GraphStats struct {
|
||||
EdgeTypes map[RelationType]int
|
||||
AvgDegree float64
|
||||
MedianDegree float64
|
||||
NodeCount int
|
||||
EdgeCount int
|
||||
MaxDegree int
|
||||
MinDegree int
|
||||
}
|
||||
|
||||
// BuildFromObservations constructs a graph from a list of observations
|
||||
func BuildFromObservations(ctx context.Context, observations []*models.Observation) (*ObservationGraph, error) {
|
||||
graph := NewObservationGraph()
|
||||
|
||||
// Add nodes
|
||||
for _, obs := range observations {
|
||||
// Extract title from sql.NullString
|
||||
title := ""
|
||||
if obs.Title.Valid {
|
||||
title = obs.Title.String
|
||||
}
|
||||
|
||||
node := &Node{
|
||||
ID: obs.ID,
|
||||
Degree: 0,
|
||||
Metadata: NodeMetadata{
|
||||
Project: obs.Project,
|
||||
Type: string(obs.Type),
|
||||
Title: title,
|
||||
CreatedAt: time.UnixMilli(obs.CreatedAtEpoch),
|
||||
IsSuperseded: obs.IsSuperseded,
|
||||
},
|
||||
LastAccess: time.Now(),
|
||||
AccessCount: 0,
|
||||
}
|
||||
graph.AddNode(node)
|
||||
}
|
||||
|
||||
// Detect edges (will be implemented in edge_detector.go)
|
||||
edges, err := DetectEdges(ctx, observations)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("detect edges: %w", err)
|
||||
}
|
||||
|
||||
for _, edge := range edges {
|
||||
graph.AddEdge(edge)
|
||||
}
|
||||
|
||||
// Build CSR representation
|
||||
if err := graph.BuildCSR(); err != nil {
|
||||
return nil, fmt.Errorf("build CSR: %w", err)
|
||||
}
|
||||
|
||||
return graph, nil
|
||||
}
|
||||
|
||||
// Helper function to sum integers
|
||||
func sum(values []int) int {
|
||||
total := 0
|
||||
for _, v := range values {
|
||||
total += v
|
||||
}
|
||||
return total
|
||||
}
|
||||
|
||||
// String returns a human-readable representation of RelationType
|
||||
func (r RelationType) String() string {
|
||||
switch r {
|
||||
case RelationFileOverlap:
|
||||
return "file_overlap"
|
||||
case RelationSemantic:
|
||||
return "semantic"
|
||||
case RelationTemporal:
|
||||
return "temporal"
|
||||
case RelationConcept:
|
||||
return "concept"
|
||||
default:
|
||||
return "unknown"
|
||||
}
|
||||
}
|
||||
+12
-14
@@ -19,12 +19,9 @@ import (
|
||||
|
||||
// Server is the MCP server that exposes search tools.
|
||||
type Server struct {
|
||||
searchMgr *search.Manager
|
||||
version string
|
||||
stdin io.Reader
|
||||
stdout io.Writer
|
||||
|
||||
// Store dependencies for enhanced tools
|
||||
stdin io.Reader
|
||||
stdout io.Writer
|
||||
searchMgr *search.Manager
|
||||
observationStore *gorm.ObservationStore
|
||||
patternStore *gorm.PatternStore
|
||||
relationStore *gorm.RelationStore
|
||||
@@ -32,6 +29,7 @@ type Server struct {
|
||||
vectorClient *sqlitevec.Client
|
||||
scoreCalculator *scoring.Calculator
|
||||
recalculator *scoring.Recalculator
|
||||
version string
|
||||
}
|
||||
|
||||
// NewServer creates a new MCP server.
|
||||
@@ -71,17 +69,17 @@ type Request struct {
|
||||
|
||||
// Response represents a JSON-RPC response.
|
||||
type Response struct {
|
||||
JSONRPC string `json:"jsonrpc"`
|
||||
ID any `json:"id"`
|
||||
Result any `json:"result,omitempty"`
|
||||
Error *Error `json:"error,omitempty"`
|
||||
JSONRPC string `json:"jsonrpc"`
|
||||
}
|
||||
|
||||
// Error represents a JSON-RPC error.
|
||||
type Error struct {
|
||||
Code int `json:"code"`
|
||||
Message string `json:"message"`
|
||||
Data any `json:"data,omitempty"`
|
||||
Message string `json:"message"`
|
||||
Code int `json:"code"`
|
||||
}
|
||||
|
||||
// ToolCallParams represents parameters for tools/call method.
|
||||
@@ -92,9 +90,9 @@ type ToolCallParams struct {
|
||||
|
||||
// Tool represents an MCP tool definition.
|
||||
type Tool struct {
|
||||
InputSchema map[string]any `json:"inputSchema"`
|
||||
Name string `json:"name"`
|
||||
Description string `json:"description"`
|
||||
InputSchema map[string]any `json:"inputSchema"`
|
||||
}
|
||||
|
||||
// Run starts the MCP server loop.
|
||||
@@ -489,17 +487,17 @@ func (s *Server) callTool(ctx context.Context, name string, args json.RawMessage
|
||||
|
||||
// TimelineParams represents parameters for timeline operations.
|
||||
type TimelineParams struct {
|
||||
AnchorID int64 `json:"anchor_id"`
|
||||
Query string `json:"query"`
|
||||
Before int `json:"before"`
|
||||
After int `json:"after"`
|
||||
Project string `json:"project"`
|
||||
ObsType string `json:"obs_type"`
|
||||
Concepts string `json:"concepts"`
|
||||
Files string `json:"files"`
|
||||
Format string `json:"format"`
|
||||
AnchorID int64 `json:"anchor_id"`
|
||||
Before int `json:"before"`
|
||||
After int `json:"after"`
|
||||
DateStart int64 `json:"dateStart"`
|
||||
DateEnd int64 `json:"dateEnd"`
|
||||
Format string `json:"format"`
|
||||
}
|
||||
|
||||
// handleTimeline handles timeline requests.
|
||||
|
||||
+10
-10
@@ -34,8 +34,8 @@ func (s *ServerSuite) TestNewServer() {
|
||||
func TestRequest(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
req Request
|
||||
expected string
|
||||
req Request
|
||||
}{
|
||||
{
|
||||
name: "initialize request",
|
||||
@@ -138,9 +138,9 @@ func TestResponse(t *testing.T) {
|
||||
// TestError tests Error struct.
|
||||
func TestError(t *testing.T) {
|
||||
tests := []struct {
|
||||
expected string
|
||||
name string
|
||||
err Error
|
||||
expected string
|
||||
}{
|
||||
{
|
||||
name: "parse error",
|
||||
@@ -365,11 +365,11 @@ func TestHandleRequest(t *testing.T) {
|
||||
ctx := context.Background()
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
req *Request
|
||||
expectError bool
|
||||
errorCode int
|
||||
name string
|
||||
errorMessage string
|
||||
errorCode int
|
||||
expectError bool
|
||||
}{
|
||||
{
|
||||
name: "initialize method",
|
||||
@@ -753,13 +753,13 @@ func TestServerStdinStdoutConfig(t *testing.T) {
|
||||
// TestResponseIDTypes tests that response IDs can be various types.
|
||||
func TestResponseIDTypes(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
id any
|
||||
name string
|
||||
}{
|
||||
{"integer id", 1},
|
||||
{"string id", "abc-123"},
|
||||
{"float id", 1.5},
|
||||
{"null id", nil},
|
||||
{name: "integer id", id: 1},
|
||||
{name: "string id", id: "abc-123"},
|
||||
{name: "float id", id: 1.5},
|
||||
{name: "null id", id: nil},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
|
||||
@@ -38,21 +38,15 @@ type PatternSyncFunc func(pattern *models.Pattern)
|
||||
|
||||
// Detector detects and tracks recurring patterns across observations.
|
||||
type Detector struct {
|
||||
config DetectorConfig
|
||||
ctx context.Context
|
||||
patternStore *gorm.PatternStore
|
||||
observationStore *gorm.ObservationStore
|
||||
|
||||
// Vector sync callback
|
||||
syncFunc PatternSyncFunc
|
||||
|
||||
// Candidate tracking (patterns not yet confirmed)
|
||||
candidates map[string]*candidatePattern
|
||||
candidatesMu sync.RWMutex
|
||||
|
||||
// Background analysis
|
||||
ctx context.Context
|
||||
cancel context.CancelFunc
|
||||
wg sync.WaitGroup
|
||||
syncFunc PatternSyncFunc
|
||||
candidates map[string]*candidatePattern
|
||||
cancel context.CancelFunc
|
||||
config DetectorConfig
|
||||
wg sync.WaitGroup
|
||||
candidatesMu sync.RWMutex
|
||||
}
|
||||
|
||||
// SetSyncFunc sets the callback for syncing patterns to vector store.
|
||||
@@ -62,11 +56,11 @@ func (d *Detector) SetSyncFunc(fn PatternSyncFunc) {
|
||||
|
||||
// candidatePattern tracks a potential pattern before it reaches frequency threshold.
|
||||
type candidatePattern struct {
|
||||
patternType models.PatternType
|
||||
title string
|
||||
signature []string
|
||||
observationIDs []int64
|
||||
projects []string
|
||||
patternType models.PatternType
|
||||
title string
|
||||
lastSeenEpoch int64
|
||||
}
|
||||
|
||||
|
||||
@@ -331,16 +331,16 @@ func TestDefaultConfig(t *testing.T) {
|
||||
func TestGeneratePatternName(t *testing.T) {
|
||||
tests := []struct {
|
||||
patternType models.PatternType
|
||||
signature []string
|
||||
title string
|
||||
wantPrefix string
|
||||
signature []string
|
||||
}{
|
||||
{models.PatternTypeBug, []string{"nil", "error"}, "", "Bug Pattern:"},
|
||||
{models.PatternTypeRefactor, []string{"extract"}, "", "Refactor Pattern:"},
|
||||
{models.PatternTypeArchitecture, []string{"service"}, "", "Architecture Pattern:"},
|
||||
{models.PatternTypeAntiPattern, []string{"god-class"}, "", "Anti-Pattern:"},
|
||||
{models.PatternTypeBestPractice, []string{"testing"}, "", "Best Practice:"},
|
||||
{models.PatternTypeBug, []string{}, "Short Title", "Short Title"}, // Use title directly
|
||||
{patternType: models.PatternTypeBug, title: "", wantPrefix: "Bug Pattern:", signature: []string{"nil", "error"}},
|
||||
{patternType: models.PatternTypeRefactor, title: "", wantPrefix: "Refactor Pattern:", signature: []string{"extract"}},
|
||||
{patternType: models.PatternTypeArchitecture, title: "", wantPrefix: "Architecture Pattern:", signature: []string{"service"}},
|
||||
{patternType: models.PatternTypeAntiPattern, title: "", wantPrefix: "Anti-Pattern:", signature: []string{"god-class"}},
|
||||
{patternType: models.PatternTypeBestPractice, title: "", wantPrefix: "Best Practice:", signature: []string{"testing"}},
|
||||
{patternType: models.PatternTypeBug, title: "Short Title", wantPrefix: "Short Title", signature: []string{}}, // Use title directly
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
|
||||
@@ -30,24 +30,24 @@ const (
|
||||
|
||||
// Candidate represents a search result candidate for reranking.
|
||||
type Candidate struct {
|
||||
ID string // Document ID
|
||||
Content string // Document text content for scoring
|
||||
Score float64 // Original bi-encoder similarity score
|
||||
Metadata map[string]any // Preserved metadata
|
||||
RerankInfo map[string]float64 // Reranking debug info (optional)
|
||||
Metadata map[string]any
|
||||
RerankInfo map[string]float64
|
||||
ID string
|
||||
Content string
|
||||
Score float64
|
||||
}
|
||||
|
||||
// RerankResult represents a reranked search result.
|
||||
type RerankResult struct {
|
||||
ID string // Document ID
|
||||
Content string // Document text content
|
||||
OriginalScore float64 // Original bi-encoder score
|
||||
RerankScore float64 // Cross-encoder relevance score
|
||||
CombinedScore float64 // Weighted combination of scores
|
||||
Metadata map[string]any // Preserved metadata
|
||||
OriginalRank int // Position before reranking (1-indexed)
|
||||
RerankRank int // Position after reranking (1-indexed)
|
||||
RankImprovement int // How much the rank improved (positive = moved up)
|
||||
Metadata map[string]any
|
||||
ID string
|
||||
Content string
|
||||
OriginalScore float64
|
||||
RerankScore float64
|
||||
CombinedScore float64
|
||||
OriginalRank int
|
||||
RerankRank int
|
||||
RankImprovement int
|
||||
}
|
||||
|
||||
// Service provides cross-encoder reranking functionality.
|
||||
|
||||
@@ -21,13 +21,13 @@ type ObservationStore interface {
|
||||
|
||||
// Recalculator periodically recalculates importance scores for observations.
|
||||
type Recalculator struct {
|
||||
log zerolog.Logger
|
||||
store ObservationStore
|
||||
calculator *Calculator
|
||||
log zerolog.Logger
|
||||
interval time.Duration
|
||||
batchSize int
|
||||
stopCh chan struct{}
|
||||
doneCh chan struct{}
|
||||
interval time.Duration
|
||||
batchSize int
|
||||
mu sync.Mutex
|
||||
running bool
|
||||
}
|
||||
|
||||
@@ -16,14 +16,14 @@ import (
|
||||
|
||||
// MockObservationStore is a mock implementation of ObservationStore for testing.
|
||||
type MockObservationStore struct {
|
||||
mu sync.Mutex
|
||||
observations []*models.Observation
|
||||
scores map[int64]float64
|
||||
conceptWeights map[string]float64
|
||||
updateErr error
|
||||
getErr error
|
||||
getConceptsErr error
|
||||
scores map[int64]float64
|
||||
conceptWeights map[string]float64
|
||||
observations []*models.Observation
|
||||
updateScoresCalls int
|
||||
mu sync.Mutex
|
||||
}
|
||||
|
||||
func NewMockObservationStore() *MockObservationStore {
|
||||
|
||||
@@ -30,25 +30,25 @@ const (
|
||||
// ExpandedQuery represents a query variant with metadata.
|
||||
type ExpandedQuery struct {
|
||||
Query string `json:"query"`
|
||||
Weight float64 `json:"weight"` // Weight for result merging (0.0-1.0)
|
||||
Source string `json:"source"` // Where this expansion came from
|
||||
Intent QueryIntent `json:"intent"` // Detected intent
|
||||
Source string `json:"source"`
|
||||
Intent QueryIntent `json:"intent"`
|
||||
Weight float64 `json:"weight"`
|
||||
}
|
||||
|
||||
// Expander provides context-aware query expansion.
|
||||
type Expander struct {
|
||||
embedSvc *embedding.Service
|
||||
vocabulary []VocabEntry // Known vocabulary from observations
|
||||
vocabVectors [][]float32 // Embeddings for vocabulary entries
|
||||
vocabMu sync.RWMutex // Protects vocabulary
|
||||
intentPatterns map[QueryIntent][]*regexp.Regexp
|
||||
vocabulary []VocabEntry
|
||||
vocabVectors [][]float32
|
||||
vocabMu sync.RWMutex
|
||||
}
|
||||
|
||||
// VocabEntry represents a vocabulary term from observations.
|
||||
type VocabEntry struct {
|
||||
Term string // The term itself
|
||||
Weight float64 // How common/important this term is (0.0-1.0)
|
||||
Source string // Where it came from (title, concept, narrative)
|
||||
Term string
|
||||
Source string
|
||||
Weight float64
|
||||
}
|
||||
|
||||
// Config holds expander configuration.
|
||||
|
||||
@@ -88,16 +88,16 @@ func (s *ExpanderSuite) TestExpand() {
|
||||
tests := []struct {
|
||||
name string
|
||||
query string
|
||||
expectedIntent QueryIntent
|
||||
minExpansions int
|
||||
hasOriginal bool
|
||||
expectedIntent QueryIntent
|
||||
}{
|
||||
{"question", "how do I implement auth", 1, true, IntentQuestion},
|
||||
{"error", "fix the bug in login", 1, true, IntentError},
|
||||
{"implementation", "implement user handler", 1, true, IntentImplementation},
|
||||
{"architecture", "architecture design", 1, true, IntentArchitecture},
|
||||
{"general", "database connection", 1, true, IntentGeneral},
|
||||
{"empty", "", 0, false, IntentGeneral},
|
||||
{name: "question", query: "how do I implement auth", expectedIntent: IntentQuestion, minExpansions: 1, hasOriginal: true},
|
||||
{name: "error", query: "fix the bug in login", expectedIntent: IntentError, minExpansions: 1, hasOriginal: true},
|
||||
{name: "implementation", query: "implement user handler", expectedIntent: IntentImplementation, minExpansions: 1, hasOriginal: true},
|
||||
{name: "architecture", query: "architecture design", expectedIntent: IntentArchitecture, minExpansions: 1, hasOriginal: true},
|
||||
{name: "general", query: "database connection", expectedIntent: IntentGeneral, minExpansions: 1, hasOriginal: true},
|
||||
{name: "empty", query: "", expectedIntent: IntentGeneral, minExpansions: 0, hasOriginal: false},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
@@ -392,13 +392,13 @@ func TestTruncate(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
input string
|
||||
maxLen int
|
||||
expected string
|
||||
maxLen int
|
||||
}{
|
||||
{"short", "hello", 10, "hello"},
|
||||
{"exact", "hello", 5, "hello"},
|
||||
{"long", "hello world", 5, "hello..."},
|
||||
{"empty", "", 10, ""},
|
||||
{name: "short", input: "hello", expected: "hello", maxLen: 10},
|
||||
{name: "exact", input: "hello", expected: "hello", maxLen: 5},
|
||||
{name: "long", input: "hello world", expected: "hello...", maxLen: 5},
|
||||
{name: "empty", input: "", expected: "", maxLen: 10},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
|
||||
@@ -516,16 +516,16 @@ func TestTruncate_TableDriven(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
input string
|
||||
maxLen int
|
||||
expected string
|
||||
maxLen int
|
||||
}{
|
||||
{"short_string", "hello", 10, "hello"},
|
||||
{"exact_length", "hello", 5, "hello"},
|
||||
{"long_string", "hello world", 5, "hello..."},
|
||||
{"empty_string", "", 10, ""},
|
||||
{"whitespace_only", " ", 10, ""},
|
||||
{"with_leading_space", " hello ", 10, "hello"},
|
||||
{"very_long", "this is a very long string that should be truncated", 20, "this is a very long ..."},
|
||||
{name: "short_string", input: "hello", expected: "hello", maxLen: 10},
|
||||
{name: "exact_length", input: "hello", expected: "hello", maxLen: 5},
|
||||
{name: "long_string", input: "hello world", expected: "hello...", maxLen: 5},
|
||||
{name: "empty_string", input: "", expected: "", maxLen: 10},
|
||||
{name: "whitespace_only", input: " ", expected: "", maxLen: 10},
|
||||
{name: "with_leading_space", input: " hello ", expected: "hello", maxLen: 10},
|
||||
{name: "very_long", input: "this is a very long string that should be truncated", expected: "this is a very long ...", maxLen: 20},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
|
||||
+15
-15
@@ -35,41 +35,41 @@ func NewManager(
|
||||
|
||||
// SearchParams contains parameters for unified search.
|
||||
type SearchParams struct {
|
||||
Query string
|
||||
Type string // "observations", "sessions", "prompts", or empty for all
|
||||
Format string
|
||||
Type string
|
||||
Project string
|
||||
ObsType string // Observation type filter
|
||||
ObsType string
|
||||
Concepts string
|
||||
Files string
|
||||
Query string
|
||||
Scope string
|
||||
OrderBy string
|
||||
DateStart int64
|
||||
DateEnd int64
|
||||
OrderBy string // "relevance", "date_desc", "date_asc"
|
||||
Limit int
|
||||
Offset int
|
||||
Format string // "index" or "full"
|
||||
Scope string // "project", "global", or empty for project+global
|
||||
IncludeGlobal bool // If true, include global observations along with project-scoped
|
||||
ExcludeSuperseded bool // If true, exclude observations that have been superseded
|
||||
Limit int
|
||||
DateEnd int64
|
||||
IncludeGlobal bool
|
||||
ExcludeSuperseded bool
|
||||
}
|
||||
|
||||
// SearchResult represents a unified search result.
|
||||
type SearchResult struct {
|
||||
Type string `json:"type"` // "observation", "session", "prompt"
|
||||
ID int64 `json:"id"`
|
||||
Metadata map[string]interface{} `json:"metadata,omitempty"`
|
||||
Type string `json:"type"`
|
||||
Title string `json:"title,omitempty"`
|
||||
Content string `json:"content,omitempty"`
|
||||
Project string `json:"project"`
|
||||
Scope string `json:"scope,omitempty"` // "project" or "global"
|
||||
Scope string `json:"scope,omitempty"`
|
||||
ID int64 `json:"id"`
|
||||
CreatedAt int64 `json:"created_at_epoch"`
|
||||
Score float64 `json:"score,omitempty"`
|
||||
Metadata map[string]interface{} `json:"metadata,omitempty"`
|
||||
}
|
||||
|
||||
// UnifiedSearchResult contains the combined search results.
|
||||
type UnifiedSearchResult struct {
|
||||
Query string `json:"query,omitempty"`
|
||||
Results []SearchResult `json:"results"`
|
||||
TotalCount int `json:"total_count"`
|
||||
Query string `json:"query,omitempty"`
|
||||
}
|
||||
|
||||
// UnifiedSearch performs a unified search across all document types.
|
||||
|
||||
@@ -94,8 +94,8 @@ func TestTruncate(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
input string
|
||||
maxLen int
|
||||
expected string
|
||||
maxLen int
|
||||
}{
|
||||
{
|
||||
name: "short string no truncation",
|
||||
@@ -148,8 +148,8 @@ func TestObservationToResult(t *testing.T) {
|
||||
m := NewManager(nil, nil, nil, nil)
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
obs *models.Observation
|
||||
name string
|
||||
format string
|
||||
expected SearchResult
|
||||
}{
|
||||
@@ -240,8 +240,8 @@ func TestSummaryToResult(t *testing.T) {
|
||||
m := NewManager(nil, nil, nil, nil)
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
summary *models.SessionSummary
|
||||
name string
|
||||
format string
|
||||
expected SearchResult
|
||||
}{
|
||||
@@ -322,8 +322,8 @@ func TestPromptToResult(t *testing.T) {
|
||||
m := NewManager(nil, nil, nil, nil)
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
prompt *models.UserPromptWithSession
|
||||
name string
|
||||
format string
|
||||
expected SearchResult
|
||||
}{
|
||||
@@ -406,9 +406,9 @@ func TestPromptToResult(t *testing.T) {
|
||||
func TestSearchParamsValidation(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
expectedOrder string
|
||||
params SearchParams
|
||||
expectedLimit int
|
||||
expectedOrder string
|
||||
}{
|
||||
{
|
||||
name: "default limit applied",
|
||||
@@ -731,16 +731,16 @@ func TestPromptToResultFormats(t *testing.T) {
|
||||
func TestSearchParamsDefaults(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
initialLimit int
|
||||
initialOrder string
|
||||
expectedLimit int
|
||||
expectedOrder string
|
||||
initialLimit int
|
||||
expectedLimit int
|
||||
}{
|
||||
{"zero_limit", 0, "", 20, "date_desc"},
|
||||
{"negative_limit", -5, "", 20, "date_desc"},
|
||||
{"over_100_limit", 150, "", 100, "date_desc"},
|
||||
{"valid_limit_50", 50, "relevance", 50, "relevance"},
|
||||
{"custom_order", 30, "date_asc", 30, "date_asc"},
|
||||
{name: "zero_limit", initialOrder: "", expectedOrder: "date_desc", initialLimit: 0, expectedLimit: 20},
|
||||
{name: "negative_limit", initialOrder: "", expectedOrder: "date_desc", initialLimit: -5, expectedLimit: 20},
|
||||
{name: "over_100_limit", initialOrder: "", expectedOrder: "date_desc", initialLimit: 150, expectedLimit: 100},
|
||||
{name: "valid_limit_50", initialOrder: "relevance", expectedOrder: "relevance", initialLimit: 50, expectedLimit: 50},
|
||||
{name: "custom_order", initialOrder: "date_asc", expectedOrder: "date_asc", initialLimit: 30, expectedLimit: 30},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
@@ -774,18 +774,18 @@ func TestTruncateEdgeCases(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
input string
|
||||
maxLen int
|
||||
expected string
|
||||
maxLen int
|
||||
}{
|
||||
// Unicode strings - uses byte length so ensure maxLen accommodates full string
|
||||
{"unicode_string_no_truncate", "日本語テスト", 20, "日本語テスト"},
|
||||
{"mixed_unicode_no_truncate", "Hello世界", 15, "Hello世界"},
|
||||
{name: "unicode_string_no_truncate", input: "日本語テスト", expected: "日本語テスト", maxLen: 20},
|
||||
{name: "mixed_unicode_no_truncate", input: "Hello世界", expected: "Hello世界", maxLen: 15},
|
||||
// ASCII truncation
|
||||
{"ascii_truncate", "Hello World", 5, "Hello..."},
|
||||
{"only_whitespace", " ", 10, ""},
|
||||
{"tabs_and_newlines", "\t\n \t", 10, ""},
|
||||
{"newlines_with_content", "\n\nhello\n\n", 10, "hello"},
|
||||
{"zero_max_len", "hello", 0, "..."},
|
||||
{name: "ascii_truncate", input: "Hello World", expected: "Hello...", maxLen: 5},
|
||||
{name: "only_whitespace", input: " ", expected: "", maxLen: 10},
|
||||
{name: "tabs_and_newlines", input: "\t\n \t", expected: "", maxLen: 10},
|
||||
{name: "newlines_with_content", input: "\n\nhello\n\n", expected: "hello", maxLen: 10},
|
||||
{name: "zero_max_len", input: "hello", expected: "...", maxLen: 0},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
|
||||
+13
-14
@@ -33,11 +33,11 @@ const (
|
||||
|
||||
// Release represents a GitHub release.
|
||||
type Release struct {
|
||||
PublishedAt time.Time `json:"published_at"`
|
||||
TagName string `json:"tag_name"`
|
||||
Name string `json:"name"`
|
||||
PublishedAt time.Time `json:"published_at"`
|
||||
Assets []Asset `json:"assets"`
|
||||
Body string `json:"body"`
|
||||
Assets []Asset `json:"assets"`
|
||||
}
|
||||
|
||||
// Asset represents a release asset.
|
||||
@@ -49,15 +49,15 @@ type Asset struct {
|
||||
|
||||
// UpdateInfo contains information about an available update.
|
||||
type UpdateInfo struct {
|
||||
Available bool `json:"available"`
|
||||
PublishedAt time.Time `json:"published_at,omitempty"`
|
||||
CurrentVersion string `json:"current_version"`
|
||||
LatestVersion string `json:"latest_version"`
|
||||
ReleaseNotes string `json:"release_notes,omitempty"`
|
||||
PublishedAt time.Time `json:"published_at,omitempty"`
|
||||
DownloadURL string `json:"download_url,omitempty"`
|
||||
ChecksumsURL string `json:"checksums_url,omitempty"`
|
||||
BundleURL string `json:"bundle_url,omitempty"` // Sigstore bundle (.sigstore.json)
|
||||
BundleURL string `json:"bundle_url,omitempty"`
|
||||
ManualUpdateCommand string `json:"manual_update_command,omitempty"`
|
||||
Available bool `json:"available"`
|
||||
}
|
||||
|
||||
// InstallScriptURL is the URL to the remote installation script.
|
||||
@@ -74,23 +74,22 @@ func GetManualUpdateCommand(version string) string {
|
||||
|
||||
// UpdateStatus represents the current update status.
|
||||
type UpdateStatus struct {
|
||||
State string `json:"state"` // "idle", "checking", "downloading", "verifying", "applying", "done", "error"
|
||||
Progress float64 `json:"progress"`
|
||||
State string `json:"state"`
|
||||
Message string `json:"message"`
|
||||
Error string `json:"error,omitempty"`
|
||||
ManualUpdateCommand string `json:"manual_update_command,omitempty"` // Shown when update fails
|
||||
ManualUpdateCommand string `json:"manual_update_command,omitempty"`
|
||||
Progress float64 `json:"progress"`
|
||||
}
|
||||
|
||||
// Updater handles self-updates.
|
||||
type Updater struct {
|
||||
lastCheck time.Time
|
||||
httpClient *http.Client
|
||||
cachedUpdate *UpdateInfo
|
||||
currentVersion string
|
||||
installDir string
|
||||
httpClient *http.Client
|
||||
|
||||
mu sync.RWMutex
|
||||
status UpdateStatus
|
||||
lastCheck time.Time
|
||||
cachedUpdate *UpdateInfo
|
||||
status UpdateStatus
|
||||
mu sync.RWMutex
|
||||
}
|
||||
|
||||
// New creates a new Updater.
|
||||
|
||||
@@ -0,0 +1,309 @@
|
||||
package hybrid
|
||||
|
||||
import (
|
||||
"context"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/lukaszraczylo/claude-mnemonic/internal/vector/sqlitevec"
|
||||
"github.com/rs/zerolog/log"
|
||||
)
|
||||
|
||||
// AutoTuner dynamically adjusts hub threshold based on query performance
|
||||
type AutoTuner struct {
|
||||
ctx context.Context
|
||||
client *Client
|
||||
cancel context.CancelFunc
|
||||
latencies []time.Duration
|
||||
wg sync.WaitGroup
|
||||
queries int64
|
||||
targetLatency time.Duration
|
||||
adjustPeriod time.Duration
|
||||
minThreshold int
|
||||
maxThreshold int
|
||||
adjustments int
|
||||
latenciesMu sync.Mutex
|
||||
}
|
||||
|
||||
// AutoTunerConfig configures the auto-tuner
|
||||
type AutoTunerConfig struct {
|
||||
TargetLatency time.Duration // Target p95 latency (default: 50ms)
|
||||
MinThreshold int // Min hub threshold (default: 2)
|
||||
MaxThreshold int // Max hub threshold (default: 20)
|
||||
AdjustPeriod time.Duration // Adjustment frequency (default: 5min)
|
||||
}
|
||||
|
||||
// DefaultAutoTunerConfig returns sensible defaults
|
||||
func DefaultAutoTunerConfig() AutoTunerConfig {
|
||||
return AutoTunerConfig{
|
||||
TargetLatency: 50 * time.Millisecond,
|
||||
MinThreshold: 2,
|
||||
MaxThreshold: 20,
|
||||
AdjustPeriod: 5 * time.Minute,
|
||||
}
|
||||
}
|
||||
|
||||
// NewAutoTuner creates a new auto-tuner for the hybrid client
|
||||
func NewAutoTuner(client *Client, cfg AutoTunerConfig) *AutoTuner {
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
|
||||
tuner := &AutoTuner{
|
||||
client: client,
|
||||
targetLatency: cfg.TargetLatency,
|
||||
minThreshold: cfg.MinThreshold,
|
||||
maxThreshold: cfg.MaxThreshold,
|
||||
adjustPeriod: cfg.AdjustPeriod,
|
||||
latencies: make([]time.Duration, 0, 1000),
|
||||
ctx: ctx,
|
||||
cancel: cancel,
|
||||
}
|
||||
|
||||
return tuner
|
||||
}
|
||||
|
||||
// Start begins auto-tuning in the background
|
||||
func (a *AutoTuner) Start() {
|
||||
a.wg.Add(1)
|
||||
go a.tuningLoop()
|
||||
|
||||
log.Info().
|
||||
Dur("target_latency", a.targetLatency).
|
||||
Int("min_threshold", a.minThreshold).
|
||||
Int("max_threshold", a.maxThreshold).
|
||||
Dur("adjust_period", a.adjustPeriod).
|
||||
Msg("Auto-tuner started")
|
||||
}
|
||||
|
||||
// Stop stops the auto-tuner
|
||||
func (a *AutoTuner) Stop() {
|
||||
a.cancel()
|
||||
a.wg.Wait()
|
||||
log.Info().Msg("Auto-tuner stopped")
|
||||
}
|
||||
|
||||
// RecordQuery records a query latency for analysis
|
||||
func (a *AutoTuner) RecordQuery(latency time.Duration) {
|
||||
a.latenciesMu.Lock()
|
||||
defer a.latenciesMu.Unlock()
|
||||
|
||||
a.queries++
|
||||
a.latencies = append(a.latencies, latency)
|
||||
|
||||
// Keep only recent queries (last 1000)
|
||||
if len(a.latencies) > 1000 {
|
||||
a.latencies = a.latencies[len(a.latencies)-1000:]
|
||||
}
|
||||
}
|
||||
|
||||
// tuningLoop periodically adjusts hub threshold
|
||||
func (a *AutoTuner) tuningLoop() {
|
||||
defer a.wg.Done()
|
||||
|
||||
ticker := time.NewTicker(a.adjustPeriod)
|
||||
defer ticker.Stop()
|
||||
|
||||
for {
|
||||
select {
|
||||
case <-a.ctx.Done():
|
||||
return
|
||||
|
||||
case <-ticker.C:
|
||||
a.adjustThreshold()
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// adjustThreshold analyzes recent queries and adjusts hub threshold
|
||||
func (a *AutoTuner) adjustThreshold() {
|
||||
a.latenciesMu.Lock()
|
||||
defer a.latenciesMu.Unlock()
|
||||
|
||||
if len(a.latencies) < 10 {
|
||||
// Not enough data yet
|
||||
return
|
||||
}
|
||||
|
||||
// Calculate p95 latency
|
||||
p95 := calculateP95(a.latencies)
|
||||
|
||||
currentThreshold := a.client.hubThreshold
|
||||
|
||||
log.Debug().
|
||||
Dur("p95_latency", p95).
|
||||
Dur("target_latency", a.targetLatency).
|
||||
Int("current_threshold", currentThreshold).
|
||||
Int("queries", len(a.latencies)).
|
||||
Msg("Auto-tuner evaluating performance")
|
||||
|
||||
// Determine adjustment direction
|
||||
var newThreshold int
|
||||
|
||||
if p95 > a.targetLatency {
|
||||
// Too slow - lower threshold (more hubs = faster queries)
|
||||
adjustment := calculateAdjustment(p95, a.targetLatency)
|
||||
newThreshold = currentThreshold - adjustment
|
||||
|
||||
if newThreshold < a.minThreshold {
|
||||
newThreshold = a.minThreshold
|
||||
}
|
||||
|
||||
log.Info().
|
||||
Dur("p95", p95).
|
||||
Int("old_threshold", currentThreshold).
|
||||
Int("new_threshold", newThreshold).
|
||||
Msg("Auto-tuner: Lowering hub threshold (too slow)")
|
||||
|
||||
} else if p95 < a.targetLatency*8/10 {
|
||||
// Too fast - raise threshold (fewer hubs = more savings)
|
||||
// Only adjust if significantly faster (20% margin)
|
||||
adjustment := calculateAdjustment(a.targetLatency, p95)
|
||||
newThreshold = currentThreshold + adjustment
|
||||
|
||||
if newThreshold > a.maxThreshold {
|
||||
newThreshold = a.maxThreshold
|
||||
}
|
||||
|
||||
log.Info().
|
||||
Dur("p95", p95).
|
||||
Int("old_threshold", currentThreshold).
|
||||
Int("new_threshold", newThreshold).
|
||||
Msg("Auto-tuner: Raising hub threshold (room for savings)")
|
||||
|
||||
} else {
|
||||
// Within acceptable range, no adjustment needed
|
||||
log.Debug().
|
||||
Dur("p95", p95).
|
||||
Int("threshold", currentThreshold).
|
||||
Msg("Auto-tuner: Performance acceptable, no adjustment")
|
||||
return
|
||||
}
|
||||
|
||||
// Apply adjustment
|
||||
if newThreshold != currentThreshold {
|
||||
a.client.hubThreshold = newThreshold
|
||||
a.adjustments++
|
||||
|
||||
// Clear latency history after adjustment
|
||||
a.latencies = make([]time.Duration, 0, 1000)
|
||||
|
||||
log.Info().
|
||||
Int("threshold", newThreshold).
|
||||
Int("total_adjustments", a.adjustments).
|
||||
Msg("Hub threshold adjusted by auto-tuner")
|
||||
}
|
||||
}
|
||||
|
||||
// calculateP95 computes the 95th percentile latency
|
||||
func calculateP95(latencies []time.Duration) time.Duration {
|
||||
if len(latencies) == 0 {
|
||||
return 0
|
||||
}
|
||||
|
||||
// Sort latencies
|
||||
sorted := make([]time.Duration, len(latencies))
|
||||
copy(sorted, latencies)
|
||||
|
||||
// Simple bubble sort (small dataset)
|
||||
n := len(sorted)
|
||||
for i := 0; i < n-1; i++ {
|
||||
for j := 0; j < n-i-1; j++ {
|
||||
if sorted[j] > sorted[j+1] {
|
||||
sorted[j], sorted[j+1] = sorted[j+1], sorted[j]
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Return 95th percentile
|
||||
idx := int(float64(len(sorted)) * 0.95)
|
||||
if idx >= len(sorted) {
|
||||
idx = len(sorted) - 1
|
||||
}
|
||||
|
||||
return sorted[idx]
|
||||
}
|
||||
|
||||
// calculateAdjustment determines how much to adjust threshold
|
||||
func calculateAdjustment(actual, target time.Duration) int {
|
||||
// Calculate percentage difference
|
||||
diff := float64(actual-target) / float64(target)
|
||||
|
||||
// Adjust more aggressively for larger differences
|
||||
if diff > 0.5 || diff < -0.5 {
|
||||
return 3 // Large adjustment
|
||||
} else if diff > 0.2 || diff < -0.2 {
|
||||
return 2 // Medium adjustment
|
||||
}
|
||||
|
||||
return 1 // Small adjustment
|
||||
}
|
||||
|
||||
// GetStats returns auto-tuner statistics
|
||||
func (a *AutoTuner) GetStats() AutoTunerStats {
|
||||
a.latenciesMu.Lock()
|
||||
defer a.latenciesMu.Unlock()
|
||||
|
||||
stats := AutoTunerStats{
|
||||
CurrentThreshold: a.client.hubThreshold,
|
||||
TargetLatency: a.targetLatency,
|
||||
TotalQueries: a.queries,
|
||||
TotalAdjustments: a.adjustments,
|
||||
RecentQueries: len(a.latencies),
|
||||
}
|
||||
|
||||
if len(a.latencies) > 0 {
|
||||
stats.P95Latency = calculateP95(a.latencies)
|
||||
|
||||
// Calculate average
|
||||
var total time.Duration
|
||||
for _, lat := range a.latencies {
|
||||
total += lat
|
||||
}
|
||||
stats.AvgLatency = total / time.Duration(len(a.latencies))
|
||||
}
|
||||
|
||||
return stats
|
||||
}
|
||||
|
||||
// AutoTunerStats contains auto-tuner statistics
|
||||
type AutoTunerStats struct {
|
||||
CurrentThreshold int
|
||||
TargetLatency time.Duration
|
||||
P95Latency time.Duration
|
||||
AvgLatency time.Duration
|
||||
TotalQueries int64
|
||||
TotalAdjustments int
|
||||
RecentQueries int
|
||||
}
|
||||
|
||||
// AutoTunedClient wraps Client with automatic performance tuning
|
||||
type AutoTunedClient struct {
|
||||
*Client
|
||||
tuner *AutoTuner
|
||||
}
|
||||
|
||||
// Query wraps the underlying Query call with latency tracking
|
||||
func (a *AutoTunedClient) Query(ctx context.Context, query string, limit int, where map[string]any) ([]sqlitevec.QueryResult, error) {
|
||||
start := time.Now()
|
||||
results, err := a.Client.Query(ctx, query, limit, where)
|
||||
latency := time.Since(start)
|
||||
|
||||
a.tuner.RecordQuery(latency)
|
||||
|
||||
return results, err
|
||||
}
|
||||
|
||||
// WithAutoTuning wraps a hybrid client with auto-tuning enabled
|
||||
func WithAutoTuning(client *Client, cfg AutoTunerConfig) *AutoTunedClient {
|
||||
tuner := NewAutoTuner(client, cfg)
|
||||
tuner.Start()
|
||||
|
||||
return &AutoTunedClient{
|
||||
Client: client,
|
||||
tuner: tuner,
|
||||
}
|
||||
}
|
||||
|
||||
// Stop stops the auto-tuner
|
||||
func (a *AutoTunedClient) StopTuning() {
|
||||
a.tuner.Stop()
|
||||
}
|
||||
@@ -0,0 +1,515 @@
|
||||
// Package hybrid provides LEANN-inspired selective vector storage for claude-mnemonic.
|
||||
//
|
||||
// This package implements a hybrid storage strategy where frequently-accessed
|
||||
// observations ("hubs") have their embeddings stored, while infrequently-accessed
|
||||
// observations have their embeddings recomputed on-demand during search.
|
||||
//
|
||||
// This approach reduces storage by 60-80% with minimal impact on search latency (<50ms).
|
||||
package hybrid
|
||||
|
||||
import (
|
||||
"context"
|
||||
"database/sql"
|
||||
"fmt"
|
||||
"math"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/lukaszraczylo/claude-mnemonic/internal/embedding"
|
||||
"github.com/lukaszraczylo/claude-mnemonic/internal/vector/sqlitevec"
|
||||
"github.com/rs/zerolog/log"
|
||||
)
|
||||
|
||||
// VectorStorageStrategy defines how embeddings are stored/computed
|
||||
type VectorStorageStrategy int
|
||||
|
||||
const (
|
||||
// StorageAlways stores all embeddings (current behavior, backwards compatible)
|
||||
StorageAlways VectorStorageStrategy = iota
|
||||
// StorageHub stores only frequently-accessed "hub" embeddings (recommended)
|
||||
StorageHub
|
||||
// StorageOnDemand recomputes all embeddings during search (maximum savings)
|
||||
StorageOnDemand
|
||||
)
|
||||
|
||||
// Client wraps sqlitevec.Client with selective storage logic
|
||||
type Client struct {
|
||||
base *sqlitevec.Client
|
||||
db *sql.DB
|
||||
embedSvc *embedding.Service
|
||||
accessCount map[string]int
|
||||
lastAccess map[string]time.Time
|
||||
contentCache map[string]string
|
||||
strategy VectorStorageStrategy
|
||||
hubThreshold int
|
||||
mu sync.RWMutex
|
||||
cacheMu sync.RWMutex
|
||||
}
|
||||
|
||||
// Config for hybrid client
|
||||
type Config struct {
|
||||
BaseClient *sqlitevec.Client
|
||||
DB *sql.DB
|
||||
EmbedSvc *embedding.Service
|
||||
Strategy VectorStorageStrategy
|
||||
HubThreshold int // Default: 5 accesses
|
||||
}
|
||||
|
||||
// NewClient creates a new hybrid vector client
|
||||
func NewClient(cfg Config) *Client {
|
||||
if cfg.HubThreshold <= 0 {
|
||||
cfg.HubThreshold = 5
|
||||
}
|
||||
|
||||
log.Info().
|
||||
Str("strategy", strategyToString(cfg.Strategy)).
|
||||
Int("hub_threshold", cfg.HubThreshold).
|
||||
Msg("Initializing LEANN hybrid vector client")
|
||||
|
||||
return &Client{
|
||||
base: cfg.BaseClient,
|
||||
db: cfg.DB,
|
||||
embedSvc: cfg.EmbedSvc,
|
||||
strategy: cfg.Strategy,
|
||||
hubThreshold: cfg.HubThreshold,
|
||||
accessCount: make(map[string]int),
|
||||
lastAccess: make(map[string]time.Time),
|
||||
contentCache: make(map[string]string),
|
||||
}
|
||||
}
|
||||
|
||||
// AddDocuments implements selective storage based on strategy
|
||||
func (c *Client) AddDocuments(ctx context.Context, docs []sqlitevec.Document) error {
|
||||
if len(docs) == 0 {
|
||||
return nil
|
||||
}
|
||||
|
||||
switch c.strategy {
|
||||
case StorageAlways:
|
||||
// Use existing implementation - store all embeddings
|
||||
return c.base.AddDocuments(ctx, docs)
|
||||
|
||||
case StorageHub:
|
||||
// Store only hub candidates
|
||||
return c.addDocumentsSelective(ctx, docs)
|
||||
|
||||
case StorageOnDemand:
|
||||
// Don't store embeddings, only cache content
|
||||
return c.cacheDocuments(ctx, docs)
|
||||
|
||||
default:
|
||||
return c.base.AddDocuments(ctx, docs)
|
||||
}
|
||||
}
|
||||
|
||||
// addDocumentsSelective stores embeddings only for hub-qualified documents
|
||||
func (c *Client) addDocumentsSelective(ctx context.Context, docs []sqlitevec.Document) error {
|
||||
// Always cache content for potential recomputation
|
||||
if err := c.cacheDocuments(ctx, docs); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// Filter to hub documents
|
||||
hubDocs := make([]sqlitevec.Document, 0, len(docs))
|
||||
for _, doc := range docs {
|
||||
if c.isHub(doc.ID) {
|
||||
hubDocs = append(hubDocs, doc)
|
||||
}
|
||||
}
|
||||
|
||||
// Store only hub embeddings
|
||||
if len(hubDocs) > 0 {
|
||||
log.Debug().
|
||||
Int("total", len(docs)).
|
||||
Int("hubs", len(hubDocs)).
|
||||
Msg("Storing selective embeddings")
|
||||
return c.base.AddDocuments(ctx, hubDocs)
|
||||
}
|
||||
|
||||
log.Debug().Int("total", len(docs)).Msg("All documents cached, no hubs to store")
|
||||
return nil
|
||||
}
|
||||
|
||||
// cacheDocuments stores content for later recomputation
|
||||
func (c *Client) cacheDocuments(ctx context.Context, docs []sqlitevec.Document) error {
|
||||
c.cacheMu.Lock()
|
||||
defer c.cacheMu.Unlock()
|
||||
|
||||
for _, doc := range docs {
|
||||
c.contentCache[doc.ID] = doc.Content
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// DeleteDocuments removes documents by their IDs
|
||||
func (c *Client) DeleteDocuments(ctx context.Context, ids []string) error {
|
||||
// Remove from base storage
|
||||
if err := c.base.DeleteDocuments(ctx, ids); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// Clean up caches
|
||||
c.mu.Lock()
|
||||
for _, id := range ids {
|
||||
delete(c.accessCount, id)
|
||||
delete(c.lastAccess, id)
|
||||
}
|
||||
c.mu.Unlock()
|
||||
|
||||
c.cacheMu.Lock()
|
||||
for _, id := range ids {
|
||||
delete(c.contentCache, id)
|
||||
}
|
||||
c.cacheMu.Unlock()
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// Query performs search with dynamic recomputation
|
||||
func (c *Client) Query(ctx context.Context, query string, limit int, where map[string]any) ([]sqlitevec.QueryResult, error) {
|
||||
switch c.strategy {
|
||||
case StorageAlways:
|
||||
// Use existing implementation
|
||||
return c.queryAndTrack(ctx, query, limit, where)
|
||||
|
||||
case StorageHub:
|
||||
// Search hubs, then expand with recomputation
|
||||
return c.queryHybrid(ctx, query, limit, where)
|
||||
|
||||
case StorageOnDemand:
|
||||
// Fully dynamic search
|
||||
return c.queryDynamic(ctx, query, limit, where)
|
||||
|
||||
default:
|
||||
return c.queryAndTrack(ctx, query, limit, where)
|
||||
}
|
||||
}
|
||||
|
||||
// queryAndTrack wraps base Query with access tracking
|
||||
func (c *Client) queryAndTrack(ctx context.Context, query string, limit int, where map[string]any) ([]sqlitevec.QueryResult, error) {
|
||||
results, err := c.base.Query(ctx, query, limit, where)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// Track access for hub detection
|
||||
c.trackAccess(results)
|
||||
|
||||
return results, nil
|
||||
}
|
||||
|
||||
// queryHybrid searches stored hubs and recomputes non-hubs
|
||||
func (c *Client) queryHybrid(ctx context.Context, query string, limit int, where map[string]any) ([]sqlitevec.QueryResult, error) {
|
||||
startTime := time.Now()
|
||||
|
||||
// 1. Query stored hub embeddings (limit * 2 for expansion)
|
||||
hubResults, err := c.base.Query(ctx, query, limit*2, where)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// 2. Track access
|
||||
c.trackAccess(hubResults)
|
||||
|
||||
// 3. Get candidate non-hub IDs (from content cache)
|
||||
candidates := c.getCandidateNonHubs(where, limit*2)
|
||||
|
||||
// 4. Recompute embeddings for candidates if we have any
|
||||
var recomputedResults []sqlitevec.QueryResult
|
||||
if len(candidates) > 0 {
|
||||
recomputedResults, err = c.recomputeAndScore(ctx, query, candidates)
|
||||
if err != nil {
|
||||
// Log but don't fail - use hub results only
|
||||
log.Warn().Err(err).Msg("Failed to recompute embeddings, using hub results only")
|
||||
recomputedResults = nil
|
||||
}
|
||||
}
|
||||
|
||||
// 5. Merge and rank
|
||||
allResults := append(hubResults, recomputedResults...)
|
||||
sortBySimilarity(allResults)
|
||||
|
||||
// 6. Return top K
|
||||
if len(allResults) > limit {
|
||||
allResults = allResults[:limit]
|
||||
}
|
||||
|
||||
duration := time.Since(startTime)
|
||||
log.Debug().
|
||||
Dur("duration_ms", duration).
|
||||
Int("hubs", len(hubResults)).
|
||||
Int("recomputed", len(recomputedResults)).
|
||||
Int("results", len(allResults)).
|
||||
Msg("Hybrid search completed")
|
||||
|
||||
return allResults, nil
|
||||
}
|
||||
|
||||
// queryDynamic recomputes all embeddings on-the-fly
|
||||
func (c *Client) queryDynamic(ctx context.Context, query string, limit int, where map[string]any) ([]sqlitevec.QueryResult, error) {
|
||||
startTime := time.Now()
|
||||
|
||||
// Get all candidate IDs from content cache
|
||||
candidates := c.getCandidateNonHubs(where, limit*5)
|
||||
|
||||
// Recompute and score all
|
||||
results, err := c.recomputeAndScore(ctx, query, candidates)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// Track access
|
||||
c.trackAccess(results)
|
||||
|
||||
// Return top K
|
||||
if len(results) > limit {
|
||||
results = results[:limit]
|
||||
}
|
||||
|
||||
duration := time.Since(startTime)
|
||||
log.Debug().
|
||||
Dur("duration_ms", duration).
|
||||
Int("recomputed", len(candidates)).
|
||||
Int("results", len(results)).
|
||||
Msg("Dynamic search completed")
|
||||
|
||||
return results, nil
|
||||
}
|
||||
|
||||
// recomputeAndScore generates embeddings and computes similarities
|
||||
func (c *Client) recomputeAndScore(ctx context.Context, query string, candidateIDs []string) ([]sqlitevec.QueryResult, error) {
|
||||
if len(candidateIDs) == 0 {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
// Generate query embedding
|
||||
queryEmb, err := c.embedSvc.Embed(query)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("embed query: %w", err)
|
||||
}
|
||||
|
||||
// Get content for candidates
|
||||
c.cacheMu.RLock()
|
||||
texts := make([]string, 0, len(candidateIDs))
|
||||
validIDs := make([]string, 0, len(candidateIDs))
|
||||
for _, id := range candidateIDs {
|
||||
if content, ok := c.contentCache[id]; ok && content != "" {
|
||||
texts = append(texts, content)
|
||||
validIDs = append(validIDs, id)
|
||||
}
|
||||
}
|
||||
c.cacheMu.RUnlock()
|
||||
|
||||
if len(texts) == 0 {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
// Batch generate embeddings
|
||||
embeddings, err := c.embedSvc.EmbedBatch(texts)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("batch embed: %w", err)
|
||||
}
|
||||
|
||||
// Compute similarities
|
||||
results := make([]sqlitevec.QueryResult, len(embeddings))
|
||||
for i, emb := range embeddings {
|
||||
similarity := cosineSimilarity(queryEmb, emb)
|
||||
distance := 1.0 - similarity // Convert to distance
|
||||
|
||||
results[i] = sqlitevec.QueryResult{
|
||||
ID: validIDs[i],
|
||||
Distance: float64(distance),
|
||||
Similarity: float64(similarity),
|
||||
Metadata: make(map[string]any),
|
||||
}
|
||||
}
|
||||
|
||||
return results, nil
|
||||
}
|
||||
|
||||
// trackAccess records document access for hub detection
|
||||
func (c *Client) trackAccess(results []sqlitevec.QueryResult) {
|
||||
if len(results) == 0 {
|
||||
return
|
||||
}
|
||||
|
||||
c.mu.Lock()
|
||||
defer c.mu.Unlock()
|
||||
|
||||
now := time.Now()
|
||||
for _, r := range results {
|
||||
c.accessCount[r.ID]++
|
||||
c.lastAccess[r.ID] = now
|
||||
}
|
||||
}
|
||||
|
||||
// isHub checks if a document qualifies as a hub
|
||||
func (c *Client) isHub(docID string) bool {
|
||||
c.mu.RLock()
|
||||
defer c.mu.RUnlock()
|
||||
|
||||
count := c.accessCount[docID]
|
||||
return count >= c.hubThreshold
|
||||
}
|
||||
|
||||
// getCandidateNonHubs returns IDs of non-hub documents matching filter
|
||||
func (c *Client) getCandidateNonHubs(where map[string]any, limit int) []string {
|
||||
c.cacheMu.RLock()
|
||||
defer c.cacheMu.RUnlock()
|
||||
|
||||
candidates := make([]string, 0, limit)
|
||||
for id := range c.contentCache {
|
||||
if !c.isHub(id) {
|
||||
candidates = append(candidates, id)
|
||||
if len(candidates) >= limit {
|
||||
break
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return candidates
|
||||
}
|
||||
|
||||
// IsConnected always returns true (wraps base client)
|
||||
func (c *Client) IsConnected() bool {
|
||||
return c.base.IsConnected()
|
||||
}
|
||||
|
||||
// Close releases resources
|
||||
func (c *Client) Close() error {
|
||||
return c.base.Close()
|
||||
}
|
||||
|
||||
// Count returns the total number of vectors in the store
|
||||
func (c *Client) Count(ctx context.Context) (int64, error) {
|
||||
return c.base.Count(ctx)
|
||||
}
|
||||
|
||||
// ModelVersion returns the current embedding model version
|
||||
func (c *Client) ModelVersion() string {
|
||||
return c.base.ModelVersion()
|
||||
}
|
||||
|
||||
// NeedsRebuild checks if vectors need to be rebuilt due to model version change
|
||||
func (c *Client) NeedsRebuild(ctx context.Context) (bool, string) {
|
||||
return c.base.NeedsRebuild(ctx)
|
||||
}
|
||||
|
||||
// GetStaleVectors returns doc_ids of vectors with mismatched or null model versions
|
||||
func (c *Client) GetStaleVectors(ctx context.Context) ([]sqlitevec.StaleVectorInfo, error) {
|
||||
return c.base.GetStaleVectors(ctx)
|
||||
}
|
||||
|
||||
// DeleteVectorsByDocIDs removes vectors by their doc_ids
|
||||
func (c *Client) DeleteVectorsByDocIDs(ctx context.Context, docIDs []string) error {
|
||||
return c.base.DeleteVectorsByDocIDs(ctx, docIDs)
|
||||
}
|
||||
|
||||
// GetStorageStats returns storage efficiency metrics
|
||||
func (c *Client) GetStorageStats(ctx context.Context) (StorageStats, error) {
|
||||
c.mu.RLock()
|
||||
c.cacheMu.RLock()
|
||||
defer c.mu.RUnlock()
|
||||
defer c.cacheMu.RUnlock()
|
||||
|
||||
totalDocs := len(c.contentCache)
|
||||
hubCount := 0
|
||||
for id := range c.contentCache {
|
||||
if c.accessCount[id] >= c.hubThreshold {
|
||||
hubCount++
|
||||
}
|
||||
}
|
||||
|
||||
storedCount := hubCount
|
||||
if c.strategy == StorageAlways {
|
||||
// Get actual count from database
|
||||
if count, err := c.base.Count(ctx); err == nil {
|
||||
storedCount = int(count)
|
||||
}
|
||||
} else if c.strategy == StorageOnDemand {
|
||||
storedCount = 0
|
||||
}
|
||||
|
||||
embeddingSize := 384 * 4 // 384 dims × 4 bytes (float32)
|
||||
storedBytes := storedCount * embeddingSize
|
||||
potentialBytes := totalDocs * embeddingSize
|
||||
|
||||
savingsPercent := 0.0
|
||||
if potentialBytes > 0 {
|
||||
savingsPercent = (1.0 - float64(storedBytes)/float64(potentialBytes)) * 100
|
||||
}
|
||||
|
||||
return StorageStats{
|
||||
TotalDocuments: totalDocs,
|
||||
HubDocuments: hubCount,
|
||||
StoredEmbeddings: storedCount,
|
||||
StorageBytes: storedBytes,
|
||||
SavingsPercent: savingsPercent,
|
||||
Strategy: c.strategy,
|
||||
}, nil
|
||||
}
|
||||
|
||||
// StorageStats contains storage efficiency metrics
|
||||
type StorageStats struct {
|
||||
TotalDocuments int
|
||||
HubDocuments int
|
||||
StoredEmbeddings int
|
||||
StorageBytes int
|
||||
SavingsPercent float64
|
||||
Strategy VectorStorageStrategy
|
||||
}
|
||||
|
||||
// Helper functions
|
||||
|
||||
func cosineSimilarity(a, b []float32) float32 {
|
||||
var dotProduct, normA, normB float32
|
||||
for i := range a {
|
||||
dotProduct += a[i] * b[i]
|
||||
normA += a[i] * a[i]
|
||||
normB += b[i] * b[i]
|
||||
}
|
||||
if normA == 0 || normB == 0 {
|
||||
return 0
|
||||
}
|
||||
return dotProduct / float32(math.Sqrt(float64(normA))*math.Sqrt(float64(normB)))
|
||||
}
|
||||
|
||||
func sortBySimilarity(results []sqlitevec.QueryResult) {
|
||||
// Use a simple but efficient sorting algorithm
|
||||
n := len(results)
|
||||
for i := 0; i < n-1; i++ {
|
||||
for j := 0; j < n-i-1; j++ {
|
||||
if results[j].Similarity < results[j+1].Similarity {
|
||||
results[j], results[j+1] = results[j+1], results[j]
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func strategyToString(s VectorStorageStrategy) string {
|
||||
switch s {
|
||||
case StorageAlways:
|
||||
return "always"
|
||||
case StorageHub:
|
||||
return "hub"
|
||||
case StorageOnDemand:
|
||||
return "on_demand"
|
||||
default:
|
||||
return "unknown"
|
||||
}
|
||||
}
|
||||
|
||||
// ParseStrategy converts a string to VectorStorageStrategy
|
||||
func ParseStrategy(s string) VectorStorageStrategy {
|
||||
switch s {
|
||||
case "hub":
|
||||
return StorageHub
|
||||
case "on_demand":
|
||||
return StorageOnDemand
|
||||
case "always":
|
||||
return StorageAlways
|
||||
default:
|
||||
return StorageHub // Default to hub strategy
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,186 @@
|
||||
package hybrid
|
||||
|
||||
import (
|
||||
"testing"
|
||||
|
||||
"github.com/lukaszraczylo/claude-mnemonic/internal/vector/sqlitevec"
|
||||
"github.com/stretchr/testify/assert"
|
||||
)
|
||||
|
||||
func TestParseStrategy(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
input string
|
||||
expected VectorStorageStrategy
|
||||
}{
|
||||
{"hub_strategy", "hub", StorageHub},
|
||||
{"on_demand_strategy", "on_demand", StorageOnDemand},
|
||||
{"always_strategy", "always", StorageAlways},
|
||||
{"invalid_defaults_to_hub", "invalid", StorageHub},
|
||||
{"empty_defaults_to_hub", "", StorageHub},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
result := ParseStrategy(tt.input)
|
||||
assert.Equal(t, tt.expected, result)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestStrategyToString(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
expected string
|
||||
input VectorStorageStrategy
|
||||
}{
|
||||
{"hub_to_string", "hub", StorageHub},
|
||||
{"on_demand_to_string", "on_demand", StorageOnDemand},
|
||||
{"always_to_string", "always", StorageAlways},
|
||||
{"invalid_to_unknown", "unknown", VectorStorageStrategy(99)},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
result := strategyToString(tt.input)
|
||||
assert.Equal(t, tt.expected, result)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestCosineSimilarity(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
a []float32
|
||||
b []float32
|
||||
expected float32
|
||||
}{
|
||||
{
|
||||
name: "identical_vectors",
|
||||
a: []float32{1, 0, 0},
|
||||
b: []float32{1, 0, 0},
|
||||
expected: 1.0,
|
||||
},
|
||||
{
|
||||
name: "orthogonal_vectors",
|
||||
a: []float32{1, 0, 0},
|
||||
b: []float32{0, 1, 0},
|
||||
expected: 0.0,
|
||||
},
|
||||
{
|
||||
name: "opposite_vectors",
|
||||
a: []float32{1, 0, 0},
|
||||
b: []float32{-1, 0, 0},
|
||||
expected: -1.0,
|
||||
},
|
||||
{
|
||||
name: "zero_vector",
|
||||
a: []float32{0, 0, 0},
|
||||
b: []float32{1, 1, 1},
|
||||
expected: 0.0,
|
||||
},
|
||||
{
|
||||
name: "parallel_vectors",
|
||||
a: []float32{2, 0, 0},
|
||||
b: []float32{4, 0, 0},
|
||||
expected: 1.0,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
result := cosineSimilarity(tt.a, tt.b)
|
||||
assert.InDelta(t, tt.expected, result, 0.001)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestSortBySimilarity(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
input []sqlitevec.QueryResult
|
||||
expected []string // Expected order of IDs
|
||||
}{
|
||||
{
|
||||
name: "already_sorted",
|
||||
input: []sqlitevec.QueryResult{
|
||||
{ID: "doc1", Similarity: 0.9},
|
||||
{ID: "doc2", Similarity: 0.7},
|
||||
{ID: "doc3", Similarity: 0.5},
|
||||
},
|
||||
expected: []string{"doc1", "doc2", "doc3"},
|
||||
},
|
||||
{
|
||||
name: "reverse_sorted",
|
||||
input: []sqlitevec.QueryResult{
|
||||
{ID: "doc1", Similarity: 0.3},
|
||||
{ID: "doc2", Similarity: 0.7},
|
||||
{ID: "doc3", Similarity: 0.9},
|
||||
},
|
||||
expected: []string{"doc3", "doc2", "doc1"},
|
||||
},
|
||||
{
|
||||
name: "random_order",
|
||||
input: []sqlitevec.QueryResult{
|
||||
{ID: "doc1", Similarity: 0.5},
|
||||
{ID: "doc2", Similarity: 0.9},
|
||||
{ID: "doc3", Similarity: 0.3},
|
||||
{ID: "doc4", Similarity: 0.7},
|
||||
},
|
||||
expected: []string{"doc2", "doc4", "doc1", "doc3"},
|
||||
},
|
||||
{
|
||||
name: "identical_similarities",
|
||||
input: []sqlitevec.QueryResult{
|
||||
{ID: "doc1", Similarity: 0.5},
|
||||
{ID: "doc2", Similarity: 0.5},
|
||||
{ID: "doc3", Similarity: 0.5},
|
||||
},
|
||||
expected: []string{"doc1", "doc2", "doc3"},
|
||||
},
|
||||
{
|
||||
name: "empty_list",
|
||||
input: []sqlitevec.QueryResult{},
|
||||
expected: []string{},
|
||||
},
|
||||
{
|
||||
name: "single_element",
|
||||
input: []sqlitevec.QueryResult{
|
||||
{ID: "doc1", Similarity: 0.5},
|
||||
},
|
||||
expected: []string{"doc1"},
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
sortBySimilarity(tt.input)
|
||||
|
||||
actual := make([]string, len(tt.input))
|
||||
for i, r := range tt.input {
|
||||
actual[i] = r.ID
|
||||
}
|
||||
|
||||
assert.Equal(t, tt.expected, actual)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestSortBySimilarity_PreserveOtherFields(t *testing.T) {
|
||||
input := []sqlitevec.QueryResult{
|
||||
{ID: "doc1", Similarity: 0.3, Distance: 0.7, Metadata: map[string]any{"key": "val1"}},
|
||||
{ID: "doc2", Similarity: 0.9, Distance: 0.1, Metadata: map[string]any{"key": "val2"}},
|
||||
}
|
||||
|
||||
sortBySimilarity(input)
|
||||
|
||||
assert.Equal(t, "doc2", input[0].ID)
|
||||
assert.InDelta(t, 0.9, input[0].Similarity, 0.001)
|
||||
assert.InDelta(t, 0.1, input[0].Distance, 0.001)
|
||||
assert.Equal(t, "val2", input[0].Metadata["key"])
|
||||
|
||||
assert.Equal(t, "doc1", input[1].ID)
|
||||
assert.InDelta(t, 0.3, input[1].Similarity, 0.001)
|
||||
assert.InDelta(t, 0.7, input[1].Distance, 0.001)
|
||||
assert.Equal(t, "val1", input[1].Metadata["key"])
|
||||
}
|
||||
@@ -0,0 +1,62 @@
|
||||
package hybrid
|
||||
|
||||
import (
|
||||
"os"
|
||||
"strconv"
|
||||
|
||||
"github.com/rs/zerolog/log"
|
||||
)
|
||||
|
||||
// GetStrategyFromEnv reads CLAUDE_MNEMONIC_VECTOR_STRATEGY from environment
|
||||
func GetStrategyFromEnv() VectorStorageStrategy {
|
||||
strategyStr := os.Getenv("CLAUDE_MNEMONIC_VECTOR_STRATEGY")
|
||||
if strategyStr == "" {
|
||||
// Default to hub strategy for optimal balance
|
||||
return StorageHub
|
||||
}
|
||||
|
||||
strategy := ParseStrategy(strategyStr)
|
||||
log.Info().
|
||||
Str("env_value", strategyStr).
|
||||
Str("strategy", strategyToString(strategy)).
|
||||
Msg("Vector storage strategy from environment")
|
||||
|
||||
return strategy
|
||||
}
|
||||
|
||||
// GetHubThresholdFromEnv reads CLAUDE_MNEMONIC_HUB_THRESHOLD from environment
|
||||
func GetHubThresholdFromEnv() int {
|
||||
thresholdStr := os.Getenv("CLAUDE_MNEMONIC_HUB_THRESHOLD")
|
||||
if thresholdStr == "" {
|
||||
return 5 // Default threshold
|
||||
}
|
||||
|
||||
threshold, err := strconv.Atoi(thresholdStr)
|
||||
if err != nil {
|
||||
log.Warn().
|
||||
Err(err).
|
||||
Str("env_value", thresholdStr).
|
||||
Msg("Invalid hub threshold in environment, using default")
|
||||
return 5
|
||||
}
|
||||
|
||||
if threshold < 1 {
|
||||
log.Warn().
|
||||
Int("env_value", threshold).
|
||||
Msg("Hub threshold too low, using minimum of 1")
|
||||
return 1
|
||||
}
|
||||
|
||||
log.Info().
|
||||
Int("threshold", threshold).
|
||||
Msg("Hub threshold from environment")
|
||||
|
||||
return threshold
|
||||
}
|
||||
|
||||
// IsHybridEnabled checks if hybrid storage should be used
|
||||
// Returns false if CLAUDE_MNEMONIC_VECTOR_STRATEGY=always (backwards compat)
|
||||
func IsHybridEnabled() bool {
|
||||
strategy := GetStrategyFromEnv()
|
||||
return strategy != StorageAlways
|
||||
}
|
||||
@@ -0,0 +1,308 @@
|
||||
package hybrid
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"sort"
|
||||
"time"
|
||||
|
||||
"github.com/lukaszraczylo/claude-mnemonic/internal/graph"
|
||||
"github.com/lukaszraczylo/claude-mnemonic/internal/vector/sqlitevec"
|
||||
"github.com/lukaszraczylo/claude-mnemonic/pkg/models"
|
||||
"github.com/rs/zerolog/log"
|
||||
)
|
||||
|
||||
// GraphConfig configures graph-aware search
|
||||
type GraphConfig struct {
|
||||
Enabled bool
|
||||
MaxHops int // Maximum graph traversal depth (default: 2)
|
||||
BranchFactor int // Number of neighbors to expand per node (default: 5)
|
||||
EdgeWeight float64 // Minimum edge weight to follow (default: 0.3)
|
||||
}
|
||||
|
||||
// DefaultGraphConfig returns sensible defaults for graph search
|
||||
func DefaultGraphConfig() GraphConfig {
|
||||
return GraphConfig{
|
||||
Enabled: true,
|
||||
MaxHops: 2,
|
||||
BranchFactor: 5,
|
||||
EdgeWeight: 0.3,
|
||||
}
|
||||
}
|
||||
|
||||
// GraphSearchClient wraps hybrid.Client with graph-aware search
|
||||
type GraphSearchClient struct {
|
||||
*Client
|
||||
graph *graph.ObservationGraph
|
||||
graphConfig GraphConfig
|
||||
}
|
||||
|
||||
// NewGraphSearchClient creates a graph-enhanced hybrid client
|
||||
func NewGraphSearchClient(baseClient *Client, observationGraph *graph.ObservationGraph, cfg GraphConfig) *GraphSearchClient {
|
||||
return &GraphSearchClient{
|
||||
Client: baseClient,
|
||||
graph: observationGraph,
|
||||
graphConfig: cfg,
|
||||
}
|
||||
}
|
||||
|
||||
// Query performs graph-aware vector search with two-level traversal
|
||||
func (g *GraphSearchClient) Query(ctx context.Context, query string, limit int, where map[string]any) ([]sqlitevec.QueryResult, error) {
|
||||
if !g.graphConfig.Enabled || g.graph == nil {
|
||||
// Fall back to standard hybrid search
|
||||
return g.Client.Query(ctx, query, limit, where)
|
||||
}
|
||||
|
||||
startTime := time.Now()
|
||||
|
||||
// 1. Generate query embedding
|
||||
queryEmb, err := g.embedSvc.Embed(query)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("embed query: %w", err)
|
||||
}
|
||||
|
||||
// 2. Search hub nodes (stored embeddings)
|
||||
hubResults, err := g.base.Query(ctx, query, limit*2, where)
|
||||
if err != nil {
|
||||
// Fall back to standard search on error
|
||||
log.Warn().Err(err).Msg("Hub search failed, falling back to hybrid search")
|
||||
return g.Client.Query(ctx, query, limit, where)
|
||||
}
|
||||
|
||||
// 3. Track hub access
|
||||
g.trackAccess(hubResults)
|
||||
|
||||
// 4. Expand via graph traversal
|
||||
expandedIDs := g.expandFromHubs(hubResults, limit*4)
|
||||
|
||||
// 5. Filter to non-hubs that need recomputation
|
||||
nonHubIDs := make([]string, 0)
|
||||
for _, id := range expandedIDs {
|
||||
if !g.isHub(id) {
|
||||
nonHubIDs = append(nonHubIDs, id)
|
||||
}
|
||||
}
|
||||
|
||||
// 6. Batch recompute non-hub embeddings
|
||||
recomputedResults, err := g.recomputeAndScore(ctx, query, nonHubIDs)
|
||||
if err != nil {
|
||||
log.Warn().Err(err).Msg("Recomputation failed, using hub results only")
|
||||
recomputedResults = nil
|
||||
}
|
||||
|
||||
// 7. Apply graph-based ranking boost
|
||||
allResults := g.mergeAndRankWithGraph(hubResults, recomputedResults, queryEmb)
|
||||
|
||||
// 8. Return top K
|
||||
if len(allResults) > limit {
|
||||
allResults = allResults[:limit]
|
||||
}
|
||||
|
||||
duration := time.Since(startTime)
|
||||
log.Debug().
|
||||
Dur("duration_ms", duration).
|
||||
Int("hubs", len(hubResults)).
|
||||
Int("expanded", len(expandedIDs)).
|
||||
Int("recomputed", len(recomputedResults)).
|
||||
Int("results", len(allResults)).
|
||||
Msg("Graph search completed")
|
||||
|
||||
return allResults, nil
|
||||
}
|
||||
|
||||
// expandFromHubs traverses graph from hub nodes to find promising candidates
|
||||
func (g *GraphSearchClient) expandFromHubs(hubResults []sqlitevec.QueryResult, maxCandidates int) []string {
|
||||
if g.graph == nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
expanded := make(map[string]float64) // doc_id -> relevance score
|
||||
visited := make(map[int64]bool)
|
||||
|
||||
// Start from top hub results
|
||||
for i, result := range hubResults {
|
||||
if i >= g.graphConfig.BranchFactor*2 {
|
||||
break // Limit starting points
|
||||
}
|
||||
|
||||
// Parse observation ID from doc_id
|
||||
obsID := parseObservationID(result.ID)
|
||||
if obsID == 0 {
|
||||
continue
|
||||
}
|
||||
|
||||
// Mark as visited with high relevance (direct match)
|
||||
visited[obsID] = true
|
||||
expanded[result.ID] = result.Similarity
|
||||
|
||||
// Traverse graph from this hub
|
||||
g.traverseGraph(obsID, result.Similarity, 0, expanded, visited)
|
||||
}
|
||||
|
||||
// Convert to sorted list
|
||||
type candidate struct {
|
||||
ID string
|
||||
Relevance float64
|
||||
}
|
||||
|
||||
candidates := make([]candidate, 0, len(expanded))
|
||||
for id, rel := range expanded {
|
||||
candidates = append(candidates, candidate{ID: id, Relevance: rel})
|
||||
}
|
||||
|
||||
// Sort by relevance descending
|
||||
sort.Slice(candidates, func(i, j int) bool {
|
||||
return candidates[i].Relevance > candidates[j].Relevance
|
||||
})
|
||||
|
||||
// Return top candidates
|
||||
if len(candidates) > maxCandidates {
|
||||
candidates = candidates[:maxCandidates]
|
||||
}
|
||||
|
||||
result := make([]string, len(candidates))
|
||||
for i, c := range candidates {
|
||||
result[i] = c.ID
|
||||
}
|
||||
|
||||
return result
|
||||
}
|
||||
|
||||
// traverseGraph performs depth-limited graph traversal
|
||||
func (g *GraphSearchClient) traverseGraph(nodeID int64, baseRelevance float64, depth int, expanded map[string]float64, visited map[int64]bool) {
|
||||
if depth >= g.graphConfig.MaxHops {
|
||||
return // Max depth reached
|
||||
}
|
||||
|
||||
// Get neighbors from graph
|
||||
neighbors, weights, err := g.graph.GetNeighbors(nodeID)
|
||||
if err != nil {
|
||||
return // No neighbors or error
|
||||
}
|
||||
|
||||
// Traverse top neighbors by weight
|
||||
type neighborWeight struct {
|
||||
ID int64
|
||||
Weight float32
|
||||
}
|
||||
|
||||
neighborList := make([]neighborWeight, len(neighbors))
|
||||
for i := range neighbors {
|
||||
neighborList[i] = neighborWeight{
|
||||
ID: neighbors[i],
|
||||
Weight: weights[i],
|
||||
}
|
||||
}
|
||||
|
||||
// Sort by weight descending
|
||||
sort.Slice(neighborList, func(i, j int) bool {
|
||||
return neighborList[i].Weight > neighborList[j].Weight
|
||||
})
|
||||
|
||||
// Expand top branch_factor neighbors
|
||||
expanded_count := 0
|
||||
for _, nw := range neighborList {
|
||||
if expanded_count >= g.graphConfig.BranchFactor {
|
||||
break
|
||||
}
|
||||
|
||||
// Skip if edge weight too low
|
||||
if float64(nw.Weight) < g.graphConfig.EdgeWeight {
|
||||
continue
|
||||
}
|
||||
|
||||
// Skip if already visited
|
||||
if visited[nw.ID] {
|
||||
continue
|
||||
}
|
||||
visited[nw.ID] = true
|
||||
|
||||
// Calculate propagated relevance (decays with distance)
|
||||
decay := 0.7 // 30% decay per hop
|
||||
propagatedRelevance := baseRelevance * float64(nw.Weight) * decay
|
||||
|
||||
// Add to expanded set
|
||||
docID := formatObservationDocID(nw.ID)
|
||||
if existing, ok := expanded[docID]; !ok || propagatedRelevance > existing {
|
||||
expanded[docID] = propagatedRelevance
|
||||
}
|
||||
|
||||
// Recursively traverse
|
||||
g.traverseGraph(nw.ID, propagatedRelevance, depth+1, expanded, visited)
|
||||
expanded_count++
|
||||
}
|
||||
}
|
||||
|
||||
// mergeAndRankWithGraph combines hub and recomputed results with graph-based ranking
|
||||
func (g *GraphSearchClient) mergeAndRankWithGraph(hubResults, recomputedResults []sqlitevec.QueryResult, queryEmb []float32) []sqlitevec.QueryResult {
|
||||
// Merge results
|
||||
allResults := append(hubResults, recomputedResults...)
|
||||
|
||||
// Apply graph-based re-ranking
|
||||
if g.graph != nil {
|
||||
for i := range allResults {
|
||||
obsID := parseObservationID(allResults[i].ID)
|
||||
if obsID == 0 {
|
||||
continue
|
||||
}
|
||||
|
||||
// Boost score based on node degree (hubs are more important)
|
||||
node, err := g.graph.GetNode(obsID)
|
||||
if err == nil && node.Degree > 0 {
|
||||
// Degree boost: up to 10% increase for high-degree nodes
|
||||
degreeBoost := 1.0 + (0.1 * float64(node.Degree) / 20.0)
|
||||
if degreeBoost > 1.1 {
|
||||
degreeBoost = 1.1
|
||||
}
|
||||
allResults[i].Similarity *= degreeBoost
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Sort by adjusted similarity
|
||||
sortBySimilarity(allResults)
|
||||
|
||||
return allResults
|
||||
}
|
||||
|
||||
// parseObservationID extracts observation ID from doc_id
|
||||
// Format: "obs-{id}-{field}"
|
||||
func parseObservationID(docID string) int64 {
|
||||
var obsID int64
|
||||
// Ignore error - returns 0 on parse failure, which callers handle
|
||||
_, _ = fmt.Sscanf(docID, "obs-%d-", &obsID)
|
||||
return obsID
|
||||
}
|
||||
|
||||
// formatObservationDocID creates a doc_id for an observation
|
||||
func formatObservationDocID(obsID int64) string {
|
||||
return fmt.Sprintf("obs-%d-combined", obsID)
|
||||
}
|
||||
|
||||
// GetGraphStats returns statistics about the observation graph
|
||||
func (g *GraphSearchClient) GetGraphStats() graph.GraphStats {
|
||||
if g.graph == nil {
|
||||
return graph.GraphStats{}
|
||||
}
|
||||
return g.graph.Stats()
|
||||
}
|
||||
|
||||
// RebuildGraph rebuilds the observation graph from current observations
|
||||
// This should be called periodically or when observations change significantly
|
||||
func (g *GraphSearchClient) RebuildGraph(ctx context.Context, observations []*models.Observation) error {
|
||||
log.Info().Int("observations", len(observations)).Msg("Rebuilding observation graph")
|
||||
|
||||
newGraph, err := graph.BuildFromObservations(ctx, observations)
|
||||
if err != nil {
|
||||
return fmt.Errorf("build graph: %w", err)
|
||||
}
|
||||
|
||||
g.graph = newGraph
|
||||
|
||||
log.Info().
|
||||
Int("nodes", newGraph.Stats().NodeCount).
|
||||
Int("edges", newGraph.Stats().EdgeCount).
|
||||
Msg("Graph rebuilt successfully")
|
||||
|
||||
return nil
|
||||
}
|
||||
@@ -0,0 +1,16 @@
|
||||
package hybrid
|
||||
|
||||
import (
|
||||
"testing"
|
||||
|
||||
"github.com/lukaszraczylo/claude-mnemonic/internal/vector"
|
||||
)
|
||||
|
||||
// TestInterfaceImplementation verifies that hybrid clients implement vector.Client interface
|
||||
func TestInterfaceImplementation(t *testing.T) {
|
||||
// Compile-time check that Client implements vector.Client
|
||||
var _ vector.Client = (*Client)(nil)
|
||||
|
||||
// Compile-time check that GraphSearchClient implements vector.Client
|
||||
var _ vector.Client = (*GraphSearchClient)(nil)
|
||||
}
|
||||
@@ -0,0 +1,272 @@
|
||||
package hybrid
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"sync"
|
||||
"sync/atomic"
|
||||
"time"
|
||||
)
|
||||
|
||||
// Metrics tracks performance and usage statistics for hybrid vector storage
|
||||
type Metrics struct {
|
||||
startTime time.Time
|
||||
recentLatencies []time.Duration
|
||||
latenciesMu sync.Mutex
|
||||
totalQueries atomic.Int64
|
||||
hubOnlyQueries atomic.Int64
|
||||
hybridQueries atomic.Int64
|
||||
onDemandQueries atomic.Int64
|
||||
graphQueries atomic.Int64
|
||||
totalLatency atomic.Int64 // Sum in microseconds
|
||||
hubLatency atomic.Int64
|
||||
recomputeLatency atomic.Int64
|
||||
totalDocuments atomic.Int64
|
||||
hubDocuments atomic.Int64
|
||||
storedEmbeddings atomic.Int64
|
||||
recomputedCount atomic.Int64
|
||||
cacheHits atomic.Int64
|
||||
cacheMisses atomic.Int64
|
||||
graphTraversals atomic.Int64
|
||||
avgTraversalDepth atomic.Int64
|
||||
}
|
||||
|
||||
// NewMetrics creates a new metrics tracker
|
||||
func NewMetrics() *Metrics {
|
||||
return &Metrics{
|
||||
recentLatencies: make([]time.Duration, 0, 1000),
|
||||
startTime: time.Now(),
|
||||
}
|
||||
}
|
||||
|
||||
// RecordQuery records a query execution
|
||||
func (m *Metrics) RecordQuery(queryType string, latency time.Duration, recomputed int) {
|
||||
m.totalQueries.Add(1)
|
||||
m.totalLatency.Add(latency.Microseconds())
|
||||
|
||||
switch queryType {
|
||||
case "hub_only":
|
||||
m.hubOnlyQueries.Add(1)
|
||||
case "hybrid":
|
||||
m.hybridQueries.Add(1)
|
||||
case "on_demand":
|
||||
m.onDemandQueries.Add(1)
|
||||
case "graph":
|
||||
m.graphQueries.Add(1)
|
||||
}
|
||||
|
||||
if recomputed > 0 {
|
||||
m.recomputedCount.Add(int64(recomputed))
|
||||
}
|
||||
|
||||
// Track recent latencies
|
||||
m.latenciesMu.Lock()
|
||||
m.recentLatencies = append(m.recentLatencies, latency)
|
||||
if len(m.recentLatencies) > 1000 {
|
||||
m.recentLatencies = m.recentLatencies[len(m.recentLatencies)-1000:]
|
||||
}
|
||||
m.latenciesMu.Unlock()
|
||||
}
|
||||
|
||||
// RecordHubLatency records time spent in hub search
|
||||
func (m *Metrics) RecordHubLatency(latency time.Duration) {
|
||||
m.hubLatency.Add(latency.Microseconds())
|
||||
}
|
||||
|
||||
// RecordRecomputeLatency records time spent recomputing embeddings
|
||||
func (m *Metrics) RecordRecomputeLatency(latency time.Duration) {
|
||||
m.recomputeLatency.Add(latency.Microseconds())
|
||||
}
|
||||
|
||||
// RecordCacheHit records a content cache hit
|
||||
func (m *Metrics) RecordCacheHit() {
|
||||
m.cacheHits.Add(1)
|
||||
}
|
||||
|
||||
// RecordCacheMiss records a content cache miss
|
||||
func (m *Metrics) RecordCacheMiss() {
|
||||
m.cacheMisses.Add(1)
|
||||
}
|
||||
|
||||
// RecordGraphTraversal records a graph traversal operation
|
||||
func (m *Metrics) RecordGraphTraversal(depth int) {
|
||||
m.graphTraversals.Add(1)
|
||||
m.avgTraversalDepth.Add(int64(depth))
|
||||
}
|
||||
|
||||
// UpdateStorageStats updates current storage statistics
|
||||
func (m *Metrics) UpdateStorageStats(total, hubs, stored int) {
|
||||
m.totalDocuments.Store(int64(total))
|
||||
m.hubDocuments.Store(int64(hubs))
|
||||
m.storedEmbeddings.Store(int64(stored))
|
||||
}
|
||||
|
||||
// GetSnapshot returns current metrics snapshot
|
||||
func (m *Metrics) GetSnapshot() MetricsSnapshot {
|
||||
m.latenciesMu.Lock()
|
||||
defer m.latenciesMu.Unlock()
|
||||
|
||||
totalQueries := m.totalQueries.Load()
|
||||
|
||||
snapshot := MetricsSnapshot{
|
||||
// Query counts
|
||||
TotalQueries: totalQueries,
|
||||
HubOnlyQueries: m.hubOnlyQueries.Load(),
|
||||
HybridQueries: m.hybridQueries.Load(),
|
||||
OnDemandQueries: m.onDemandQueries.Load(),
|
||||
GraphQueries: m.graphQueries.Load(),
|
||||
|
||||
// Storage
|
||||
TotalDocuments: int(m.totalDocuments.Load()),
|
||||
HubDocuments: int(m.hubDocuments.Load()),
|
||||
StoredEmbeddings: int(m.storedEmbeddings.Load()),
|
||||
RecomputedTotal: m.recomputedCount.Load(),
|
||||
|
||||
// Cache
|
||||
CacheHits: m.cacheHits.Load(),
|
||||
CacheMisses: m.cacheMisses.Load(),
|
||||
|
||||
// Graph
|
||||
GraphTraversals: m.graphTraversals.Load(),
|
||||
|
||||
// Runtime
|
||||
Uptime: time.Since(m.startTime),
|
||||
}
|
||||
|
||||
// Calculate latencies
|
||||
if totalQueries > 0 {
|
||||
snapshot.AvgLatency = time.Duration(m.totalLatency.Load()/totalQueries) * time.Microsecond
|
||||
snapshot.AvgHubLatency = time.Duration(m.hubLatency.Load()/totalQueries) * time.Microsecond
|
||||
}
|
||||
|
||||
if m.recomputedCount.Load() > 0 {
|
||||
snapshot.AvgRecomputeLatency = time.Duration(m.recomputeLatency.Load()/m.recomputedCount.Load()) * time.Microsecond
|
||||
}
|
||||
|
||||
// Calculate percentiles
|
||||
if len(m.recentLatencies) > 0 {
|
||||
sorted := make([]time.Duration, len(m.recentLatencies))
|
||||
copy(sorted, m.recentLatencies)
|
||||
sortDurations(sorted)
|
||||
|
||||
snapshot.P50Latency = percentile(sorted, 0.50)
|
||||
snapshot.P95Latency = percentile(sorted, 0.95)
|
||||
snapshot.P99Latency = percentile(sorted, 0.99)
|
||||
}
|
||||
|
||||
// Calculate cache hit rate
|
||||
totalCacheOps := snapshot.CacheHits + snapshot.CacheMisses
|
||||
if totalCacheOps > 0 {
|
||||
snapshot.CacheHitRate = float64(snapshot.CacheHits) / float64(totalCacheOps)
|
||||
}
|
||||
|
||||
// Calculate storage savings
|
||||
if snapshot.TotalDocuments > 0 {
|
||||
embeddingSize := 384 * 4 // 384 dims × 4 bytes
|
||||
fullStorage := snapshot.TotalDocuments * embeddingSize
|
||||
actualStorage := snapshot.StoredEmbeddings * embeddingSize
|
||||
|
||||
if fullStorage > 0 {
|
||||
snapshot.StorageSavingsPercent = (1.0 - float64(actualStorage)/float64(fullStorage)) * 100
|
||||
}
|
||||
}
|
||||
|
||||
// Calculate avg traversal depth
|
||||
if snapshot.GraphTraversals > 0 {
|
||||
snapshot.AvgTraversalDepth = float64(m.avgTraversalDepth.Load()) / float64(snapshot.GraphTraversals)
|
||||
}
|
||||
|
||||
return snapshot
|
||||
}
|
||||
|
||||
// MetricsSnapshot represents a point-in-time metrics snapshot
|
||||
type MetricsSnapshot struct {
|
||||
// Query metrics
|
||||
TotalQueries int64
|
||||
HubOnlyQueries int64
|
||||
HybridQueries int64
|
||||
OnDemandQueries int64
|
||||
GraphQueries int64
|
||||
|
||||
// Latency metrics
|
||||
AvgLatency time.Duration
|
||||
P50Latency time.Duration
|
||||
P95Latency time.Duration
|
||||
P99Latency time.Duration
|
||||
AvgHubLatency time.Duration
|
||||
AvgRecomputeLatency time.Duration
|
||||
|
||||
// Storage metrics
|
||||
TotalDocuments int
|
||||
HubDocuments int
|
||||
StoredEmbeddings int
|
||||
StorageSavingsPercent float64
|
||||
RecomputedTotal int64
|
||||
|
||||
// Cache metrics
|
||||
CacheHits int64
|
||||
CacheMisses int64
|
||||
CacheHitRate float64
|
||||
|
||||
// Graph metrics
|
||||
GraphTraversals int64
|
||||
AvgTraversalDepth float64
|
||||
|
||||
// Runtime
|
||||
Uptime time.Duration
|
||||
}
|
||||
|
||||
// sortDurations sorts a slice of durations in ascending order
|
||||
func sortDurations(durations []time.Duration) {
|
||||
n := len(durations)
|
||||
for i := 0; i < n-1; i++ {
|
||||
for j := 0; j < n-i-1; j++ {
|
||||
if durations[j] > durations[j+1] {
|
||||
durations[j], durations[j+1] = durations[j+1], durations[j]
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// percentile calculates the Nth percentile from a sorted slice
|
||||
func percentile(sorted []time.Duration, p float64) time.Duration {
|
||||
if len(sorted) == 0 {
|
||||
return 0
|
||||
}
|
||||
|
||||
idx := int(float64(len(sorted)) * p)
|
||||
if idx >= len(sorted) {
|
||||
idx = len(sorted) - 1
|
||||
}
|
||||
|
||||
return sorted[idx]
|
||||
}
|
||||
|
||||
// String returns a human-readable representation of metrics
|
||||
func (s MetricsSnapshot) String() string {
|
||||
return fmt.Sprintf(`Hybrid Vector Storage Metrics:
|
||||
Queries:
|
||||
Total: %d (Hub: %d, Hybrid: %d, OnDemand: %d, Graph: %d)
|
||||
Avg Latency: %v (p50: %v, p95: %v, p99: %v)
|
||||
Hub Latency: %v, Recompute Latency: %v
|
||||
Storage:
|
||||
Documents: %d (Hubs: %d, %.1f%%)
|
||||
Stored Embeddings: %d
|
||||
Savings: %.1f%%
|
||||
Total Recomputed: %d
|
||||
Cache:
|
||||
Hits: %d, Misses: %d (Hit Rate: %.1f%%)
|
||||
Graph:
|
||||
Traversals: %d (Avg Depth: %.2f)
|
||||
Runtime: %v`,
|
||||
s.TotalQueries, s.HubOnlyQueries, s.HybridQueries, s.OnDemandQueries, s.GraphQueries,
|
||||
s.AvgLatency, s.P50Latency, s.P95Latency, s.P99Latency,
|
||||
s.AvgHubLatency, s.AvgRecomputeLatency,
|
||||
s.TotalDocuments, s.HubDocuments, float64(s.HubDocuments)/float64(s.TotalDocuments)*100,
|
||||
s.StoredEmbeddings,
|
||||
s.StorageSavingsPercent,
|
||||
s.RecomputedTotal,
|
||||
s.CacheHits, s.CacheMisses, s.CacheHitRate*100,
|
||||
s.GraphTraversals, s.AvgTraversalDepth,
|
||||
s.Uptime,
|
||||
)
|
||||
}
|
||||
@@ -0,0 +1,42 @@
|
||||
// Package vector provides common interfaces for vector storage implementations
|
||||
package vector
|
||||
|
||||
import (
|
||||
"context"
|
||||
|
||||
"github.com/lukaszraczylo/claude-mnemonic/internal/vector/sqlitevec"
|
||||
)
|
||||
|
||||
// Client defines the interface for vector storage operations.
|
||||
// Both sqlitevec.Client and hybrid.Client implement this interface.
|
||||
type Client interface {
|
||||
// AddDocuments adds documents with their embeddings to the vector store
|
||||
AddDocuments(ctx context.Context, docs []sqlitevec.Document) error
|
||||
|
||||
// DeleteDocuments removes documents by their IDs
|
||||
DeleteDocuments(ctx context.Context, ids []string) error
|
||||
|
||||
// Query performs a vector similarity search
|
||||
Query(ctx context.Context, query string, limit int, where map[string]any) ([]sqlitevec.QueryResult, error)
|
||||
|
||||
// IsConnected checks if the vector store is available
|
||||
IsConnected() bool
|
||||
|
||||
// Close releases resources
|
||||
Close() error
|
||||
|
||||
// Count returns the total number of vectors in the store
|
||||
Count(ctx context.Context) (int64, error)
|
||||
|
||||
// ModelVersion returns the current embedding model version
|
||||
ModelVersion() string
|
||||
|
||||
// NeedsRebuild checks if vectors need to be rebuilt due to model version change
|
||||
NeedsRebuild(ctx context.Context) (bool, string)
|
||||
|
||||
// GetStaleVectors returns doc_ids of vectors with mismatched or null model versions
|
||||
GetStaleVectors(ctx context.Context) ([]sqlitevec.StaleVectorInfo, error)
|
||||
|
||||
// DeleteVectorsByDocIDs removes vectors by their doc_ids
|
||||
DeleteVectorsByDocIDs(ctx context.Context, docIDs []string) error
|
||||
}
|
||||
@@ -319,11 +319,11 @@ func (c *Client) NeedsRebuild(ctx context.Context) (bool, string) {
|
||||
// StaleVectorInfo contains information about a vector that needs rebuilding.
|
||||
type StaleVectorInfo struct {
|
||||
DocID string
|
||||
SQLiteID int64
|
||||
DocType string
|
||||
FieldType string
|
||||
Project string
|
||||
Scope string
|
||||
SQLiteID int64
|
||||
}
|
||||
|
||||
// GetStaleVectors returns doc_ids of vectors with mismatched or null model versions.
|
||||
|
||||
@@ -12,17 +12,17 @@ const (
|
||||
|
||||
// Document represents a document to store with vector embedding.
|
||||
type Document struct {
|
||||
Metadata map[string]any
|
||||
ID string
|
||||
Content string
|
||||
Metadata map[string]any
|
||||
}
|
||||
|
||||
// QueryResult represents a search result from vector search.
|
||||
type QueryResult struct {
|
||||
Metadata map[string]any
|
||||
ID string
|
||||
Distance float64
|
||||
Similarity float64 // 1.0 = identical, 0.0 = opposite (derived from distance)
|
||||
Metadata map[string]any
|
||||
Similarity float64
|
||||
}
|
||||
|
||||
// DistanceToSimilarity converts sqlite-vec cosine distance to similarity score.
|
||||
|
||||
@@ -42,10 +42,10 @@ func TestQueryResult_Fields(t *testing.T) {
|
||||
|
||||
func TestBuildWhereFilter(t *testing.T) {
|
||||
tests := []struct {
|
||||
expected map[string]interface{}
|
||||
name string
|
||||
docType DocType
|
||||
project string
|
||||
expected map[string]interface{}
|
||||
}{
|
||||
{
|
||||
name: "empty_filters",
|
||||
@@ -474,9 +474,9 @@ func TestCopyMetadataMulti(t *testing.T) {
|
||||
func TestJoinStrings(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
strs []string
|
||||
sep string
|
||||
expected string
|
||||
strs []string
|
||||
}{
|
||||
{
|
||||
name: "empty_slice",
|
||||
@@ -522,8 +522,8 @@ func TestTruncateString(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
input string
|
||||
maxLen int
|
||||
expected string
|
||||
maxLen int
|
||||
}{
|
||||
{
|
||||
name: "shorter_than_max",
|
||||
@@ -577,10 +577,10 @@ func TestFilterByThreshold(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
results []QueryResult
|
||||
expectedIDs []string
|
||||
threshold float64
|
||||
maxResults int
|
||||
expectedLen int
|
||||
expectedIDs []string
|
||||
}{
|
||||
{
|
||||
name: "empty_results",
|
||||
|
||||
@@ -16,15 +16,15 @@ import (
|
||||
// Watcher monitors a file or directory for deletion and calls onDelete when removed.
|
||||
// It watches the parent directory since fsnotify cannot watch non-existent files.
|
||||
type Watcher struct {
|
||||
targetPath string // The file/directory to watch for deletion
|
||||
parentPath string // Parent directory (what we actually watch)
|
||||
onDelete func() // Callback when target is deleted
|
||||
watcher *fsnotify.Watcher
|
||||
ctx context.Context
|
||||
onDelete func()
|
||||
watcher *fsnotify.Watcher
|
||||
cancel context.CancelFunc
|
||||
targetPath string
|
||||
parentPath string
|
||||
debounce time.Duration
|
||||
mu sync.Mutex
|
||||
running bool
|
||||
debounce time.Duration
|
||||
}
|
||||
|
||||
// New creates a new Watcher for the given target path.
|
||||
|
||||
@@ -158,10 +158,10 @@ type SessionInitRequest struct {
|
||||
|
||||
// SessionInitResponse is the response for session initialization.
|
||||
type SessionInitResponse struct {
|
||||
Reason string `json:"reason,omitempty"`
|
||||
SessionDBID int64 `json:"sessionDbId"`
|
||||
PromptNumber int `json:"promptNumber"`
|
||||
Skipped bool `json:"skipped,omitempty"`
|
||||
Reason string `json:"reason,omitempty"`
|
||||
}
|
||||
|
||||
// DuplicatePromptWindowSeconds is the time window for detecting duplicate prompt submissions.
|
||||
@@ -1312,3 +1312,85 @@ func (s *Service) handleRestart(w http.ResponseWriter, r *http.Request) {
|
||||
}
|
||||
}()
|
||||
}
|
||||
|
||||
// handleGetGraphStats returns observation graph statistics.
|
||||
func (s *Service) handleGetGraphStats(w http.ResponseWriter, r *http.Request) {
|
||||
if s.graphSearchClient == nil {
|
||||
writeJSON(w, map[string]interface{}{
|
||||
"enabled": false,
|
||||
"message": "Graph search not enabled",
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
stats := s.graphSearchClient.GetGraphStats()
|
||||
|
||||
response := map[string]interface{}{
|
||||
"enabled": s.config.GraphEnabled,
|
||||
"nodeCount": stats.NodeCount,
|
||||
"edgeCount": stats.EdgeCount,
|
||||
"avgDegree": stats.AvgDegree,
|
||||
"maxDegree": stats.MaxDegree,
|
||||
"minDegree": stats.MinDegree,
|
||||
"medianDegree": stats.MedianDegree,
|
||||
"edgeTypes": stats.EdgeTypes,
|
||||
"config": map[string]interface{}{
|
||||
"maxHops": s.config.GraphMaxHops,
|
||||
"branchFactor": s.config.GraphBranchFactor,
|
||||
"edgeWeight": s.config.GraphEdgeWeight,
|
||||
"rebuildIntervalMin": s.config.GraphRebuildIntervalMin,
|
||||
},
|
||||
}
|
||||
|
||||
writeJSON(w, response)
|
||||
}
|
||||
|
||||
// handleGetVectorMetrics returns hybrid vector storage metrics.
|
||||
func (s *Service) handleGetVectorMetrics(w http.ResponseWriter, r *http.Request) {
|
||||
if s.hybridMetrics == nil {
|
||||
writeJSON(w, map[string]interface{}{
|
||||
"enabled": false,
|
||||
"message": "Vector metrics not available",
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
snapshot := s.hybridMetrics.GetSnapshot()
|
||||
|
||||
response := map[string]interface{}{
|
||||
"queries": map[string]interface{}{
|
||||
"total": snapshot.TotalQueries,
|
||||
"hubOnly": snapshot.HubOnlyQueries,
|
||||
"hybrid": snapshot.HybridQueries,
|
||||
"onDemand": snapshot.OnDemandQueries,
|
||||
"graph": snapshot.GraphQueries,
|
||||
},
|
||||
"latency": map[string]interface{}{
|
||||
"avg": snapshot.AvgLatency.String(),
|
||||
"p50": snapshot.P50Latency.String(),
|
||||
"p95": snapshot.P95Latency.String(),
|
||||
"p99": snapshot.P99Latency.String(),
|
||||
"avgHub": snapshot.AvgHubLatency.String(),
|
||||
"avgRecompute": snapshot.AvgRecomputeLatency.String(),
|
||||
},
|
||||
"storage": map[string]interface{}{
|
||||
"totalDocuments": snapshot.TotalDocuments,
|
||||
"hubDocuments": snapshot.HubDocuments,
|
||||
"storedEmbeddings": snapshot.StoredEmbeddings,
|
||||
"savingsPercent": snapshot.StorageSavingsPercent,
|
||||
"recomputedTotal": snapshot.RecomputedTotal,
|
||||
},
|
||||
"cache": map[string]interface{}{
|
||||
"hits": snapshot.CacheHits,
|
||||
"misses": snapshot.CacheMisses,
|
||||
"hitRate": snapshot.CacheHitRate,
|
||||
},
|
||||
"graph": map[string]interface{}{
|
||||
"traversals": snapshot.GraphTraversals,
|
||||
"avgDepth": snapshot.AvgTraversalDepth,
|
||||
},
|
||||
"uptime": snapshot.Uptime.String(),
|
||||
}
|
||||
|
||||
writeJSON(w, response)
|
||||
}
|
||||
|
||||
@@ -77,10 +77,10 @@ func TestParseObservations_TableDriven(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
input string
|
||||
expectedCount int
|
||||
expectedType models.ObservationType
|
||||
expectedTitle string
|
||||
checkConcepts []string
|
||||
expectedCount int
|
||||
}{
|
||||
{
|
||||
name: "valid_bugfix_observation",
|
||||
@@ -300,9 +300,9 @@ func TestParseSummary_TableDriven(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
input string
|
||||
expectedRequest string
|
||||
sessionID int64
|
||||
expectNil bool
|
||||
expectedRequest string
|
||||
}{
|
||||
{
|
||||
name: "empty_input",
|
||||
|
||||
@@ -31,15 +31,14 @@ type SyncSummaryFunc func(summary *models.SessionSummary)
|
||||
|
||||
// Processor handles SDK agent processing of observations and summaries using Claude Code CLI.
|
||||
type Processor struct {
|
||||
claudePath string
|
||||
model string
|
||||
observationStore *gorm.ObservationStore
|
||||
summaryStore *gorm.SummaryStore
|
||||
broadcastFunc BroadcastFunc
|
||||
syncObservationFunc SyncObservationFunc
|
||||
syncSummaryFunc SyncSummaryFunc
|
||||
// Semaphore to limit concurrent Claude CLI calls (prevents API overload)
|
||||
sem chan struct{}
|
||||
sem chan struct{}
|
||||
claudePath string
|
||||
model string
|
||||
}
|
||||
|
||||
// SetBroadcastFunc sets the broadcast callback for SSE events.
|
||||
|
||||
@@ -11,8 +11,8 @@ import (
|
||||
|
||||
func TestIsSelfReferentialSummary(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
summary *models.ParsedSummary
|
||||
name string
|
||||
expected bool
|
||||
}{
|
||||
{
|
||||
@@ -281,8 +281,8 @@ func TestTruncateForLog(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
input string
|
||||
maxLen int
|
||||
expected string
|
||||
maxLen int
|
||||
}{
|
||||
{
|
||||
name: "shorter_than_max",
|
||||
@@ -719,8 +719,8 @@ func TestShouldSkipTrivialOperation_EdgeCases(t *testing.T) {
|
||||
// TestIsSelfReferentialSummary_MoreCases tests additional self-referential detection cases.
|
||||
func TestIsSelfReferentialSummary_MoreCases(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
summary *models.ParsedSummary
|
||||
name string
|
||||
expected bool
|
||||
}{
|
||||
{
|
||||
|
||||
@@ -24,12 +24,12 @@ var ObservationConcepts = []string{
|
||||
|
||||
// ToolExecution represents a tool execution for observation.
|
||||
type ToolExecution struct {
|
||||
ID int64
|
||||
ToolName string
|
||||
ToolInput string
|
||||
ToolOutput string
|
||||
CreatedAtEpoch int64
|
||||
CWD string
|
||||
ID int64
|
||||
CreatedAtEpoch int64
|
||||
}
|
||||
|
||||
// BuildObservationPrompt builds a prompt for processing a tool observation.
|
||||
@@ -67,12 +67,12 @@ func BuildObservationPrompt(exec ToolExecution) string {
|
||||
|
||||
// SummaryRequest contains data for building a summary prompt.
|
||||
type SummaryRequest struct {
|
||||
SessionDBID int64
|
||||
SDKSessionID string
|
||||
Project string
|
||||
UserPrompt string
|
||||
LastUserMessage string
|
||||
LastAssistantMessage string
|
||||
SessionDBID int64
|
||||
}
|
||||
|
||||
// BuildSummaryPrompt builds a prompt requesting a session summary.
|
||||
|
||||
@@ -12,8 +12,8 @@ func TestTruncate(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
input string
|
||||
maxLen int
|
||||
expected string
|
||||
maxLen int
|
||||
}{
|
||||
{
|
||||
name: "shorter_than_max",
|
||||
@@ -60,8 +60,8 @@ func TestBuildObservationPrompt(t *testing.T) {
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
exec ToolExecution
|
||||
contains []string
|
||||
exec ToolExecution
|
||||
}{
|
||||
{
|
||||
name: "basic_read_tool",
|
||||
|
||||
+188
-74
@@ -12,6 +12,10 @@ import (
|
||||
|
||||
"github.com/go-chi/chi/v5"
|
||||
"github.com/go-chi/chi/v5/middleware"
|
||||
"github.com/lukaszraczylo/claude-mnemonic/internal/chunking"
|
||||
"github.com/lukaszraczylo/claude-mnemonic/internal/chunking/golang"
|
||||
"github.com/lukaszraczylo/claude-mnemonic/internal/chunking/python"
|
||||
"github.com/lukaszraczylo/claude-mnemonic/internal/chunking/typescript"
|
||||
"github.com/lukaszraczylo/claude-mnemonic/internal/config"
|
||||
"github.com/lukaszraczylo/claude-mnemonic/internal/db/gorm"
|
||||
"github.com/lukaszraczylo/claude-mnemonic/internal/embedding"
|
||||
@@ -20,6 +24,7 @@ import (
|
||||
"github.com/lukaszraczylo/claude-mnemonic/internal/scoring"
|
||||
"github.com/lukaszraczylo/claude-mnemonic/internal/search/expansion"
|
||||
"github.com/lukaszraczylo/claude-mnemonic/internal/update"
|
||||
"github.com/lukaszraczylo/claude-mnemonic/internal/vector/hybrid"
|
||||
"github.com/lukaszraczylo/claude-mnemonic/internal/vector/sqlitevec"
|
||||
"github.com/lukaszraczylo/claude-mnemonic/internal/watcher"
|
||||
"github.com/lukaszraczylo/claude-mnemonic/internal/worker/sdk"
|
||||
@@ -56,80 +61,53 @@ type RetrievalStats struct {
|
||||
|
||||
// Service is the main worker service orchestrator.
|
||||
type Service struct {
|
||||
// Version of the worker binary
|
||||
version string
|
||||
|
||||
// Configuration
|
||||
config *config.Config
|
||||
|
||||
// Database
|
||||
store *gorm.Store
|
||||
sessionStore *gorm.SessionStore
|
||||
observationStore *gorm.ObservationStore
|
||||
summaryStore *gorm.SummaryStore
|
||||
promptStore *gorm.PromptStore
|
||||
conflictStore *gorm.ConflictStore
|
||||
patternStore *gorm.PatternStore
|
||||
relationStore *gorm.RelationStore
|
||||
|
||||
// Pattern detection
|
||||
patternDetector *pattern.Detector
|
||||
|
||||
// Domain services
|
||||
sessionManager *session.Manager
|
||||
sseBroadcaster *sse.Broadcaster
|
||||
processor *sdk.Processor
|
||||
|
||||
// Vector database (sqlite-vec with local embeddings)
|
||||
embedSvc *embedding.Service
|
||||
vectorClient *sqlitevec.Client
|
||||
vectorSync *sqlitevec.Sync
|
||||
|
||||
// Cross-encoder reranking (for improved search relevance)
|
||||
reranker *reranking.Service
|
||||
|
||||
// Query expansion (for improved search recall)
|
||||
queryExpander *expansion.Expander
|
||||
|
||||
// Importance scoring
|
||||
scoreCalculator *scoring.Calculator
|
||||
recalculator *scoring.Recalculator
|
||||
|
||||
// HTTP server
|
||||
router *chi.Mux
|
||||
server *http.Server
|
||||
startTime time.Time
|
||||
|
||||
// Retrieval statistics (per-project)
|
||||
retrievalStats map[string]*RetrievalStats
|
||||
retrievalStatsMu sync.RWMutex
|
||||
|
||||
// Lifecycle
|
||||
ctx context.Context
|
||||
cancel context.CancelFunc
|
||||
wg sync.WaitGroup
|
||||
|
||||
// Initialization state (for deferred init)
|
||||
ready atomic.Bool
|
||||
initError error
|
||||
initMu sync.RWMutex
|
||||
|
||||
// Background verification queue for stale observations
|
||||
staleQueue chan staleVerifyRequest
|
||||
staleQueueOnce sync.Once
|
||||
|
||||
// File watchers for auto-recreation on deletion
|
||||
dbWatcher *watcher.Watcher
|
||||
configWatcher *watcher.Watcher
|
||||
|
||||
// Self-updater
|
||||
updater *update.Updater
|
||||
startTime time.Time
|
||||
initError error
|
||||
ctx context.Context
|
||||
patternDetector *pattern.Detector
|
||||
queryExpander *expansion.Expander
|
||||
summaryStore *gorm.SummaryStore
|
||||
promptStore *gorm.PromptStore
|
||||
conflictStore *gorm.ConflictStore
|
||||
patternStore *gorm.PatternStore
|
||||
relationStore *gorm.RelationStore
|
||||
updater *update.Updater
|
||||
sessionManager *session.Manager
|
||||
scoreCalculator *scoring.Calculator
|
||||
processor *sdk.Processor
|
||||
embedSvc *embedding.Service
|
||||
vectorClient *sqlitevec.Client
|
||||
vectorSync *sqlitevec.Sync
|
||||
graphSearchClient *hybrid.GraphSearchClient
|
||||
hybridMetrics *hybrid.Metrics
|
||||
graphRebuildTicker *time.Ticker
|
||||
chunkingManager *chunking.Manager
|
||||
observationStore *gorm.ObservationStore
|
||||
reranker *reranking.Service
|
||||
sseBroadcaster *sse.Broadcaster
|
||||
recalculator *scoring.Recalculator
|
||||
router *chi.Mux
|
||||
server *http.Server
|
||||
sessionStore *gorm.SessionStore
|
||||
retrievalStats map[string]*RetrievalStats
|
||||
configWatcher *watcher.Watcher
|
||||
store *gorm.Store
|
||||
cancel context.CancelFunc
|
||||
dbWatcher *watcher.Watcher
|
||||
staleQueue chan staleVerifyRequest
|
||||
config *config.Config
|
||||
version string
|
||||
wg sync.WaitGroup
|
||||
initMu sync.RWMutex
|
||||
retrievalStatsMu sync.RWMutex
|
||||
staleQueueOnce sync.Once
|
||||
ready atomic.Bool
|
||||
}
|
||||
|
||||
// staleVerifyRequest represents a request to verify a stale observation in background
|
||||
type staleVerifyRequest struct {
|
||||
observationID int64
|
||||
cwd string
|
||||
observationID int64
|
||||
}
|
||||
|
||||
// NewService creates a new worker service with deferred initialization.
|
||||
@@ -210,6 +188,9 @@ func (s *Service) initializeAsync() {
|
||||
var embedSvc *embedding.Service
|
||||
var vectorClient *sqlitevec.Client
|
||||
var vectorSync *sqlitevec.Sync
|
||||
var graphSearchClient *hybrid.GraphSearchClient
|
||||
var hybridMetrics *hybrid.Metrics
|
||||
var chunkingManager *chunking.Manager
|
||||
|
||||
var reranker *reranking.Service
|
||||
|
||||
@@ -218,18 +199,51 @@ func (s *Service) initializeAsync() {
|
||||
log.Warn().Err(err).Msg("Embedding service creation failed - vector search disabled")
|
||||
} else {
|
||||
embedSvc = emb
|
||||
// Create sqlite-vec client using the same DB connection
|
||||
client, err := sqlitevec.NewClient(sqlitevec.Config{
|
||||
// Create base sqlite-vec client using the same DB connection
|
||||
baseClient, err := sqlitevec.NewClient(sqlitevec.Config{
|
||||
DB: store.GetRawDB(),
|
||||
}, embedSvc)
|
||||
if err != nil {
|
||||
log.Warn().Err(err).Msg("sqlite-vec client creation failed - vector search disabled")
|
||||
} else {
|
||||
vectorClient = client
|
||||
vectorSync = sqlitevec.NewSync(client)
|
||||
vectorClient = baseClient
|
||||
|
||||
// Wrap with LEANN hybrid storage client
|
||||
strategy := hybrid.ParseStrategy(s.config.VectorStorageStrategy)
|
||||
hybridClient := hybrid.NewClient(hybrid.Config{
|
||||
BaseClient: baseClient,
|
||||
DB: store.GetRawDB(),
|
||||
EmbedSvc: embedSvc,
|
||||
Strategy: strategy,
|
||||
HubThreshold: s.config.HubThreshold,
|
||||
})
|
||||
|
||||
// Wrap with graph-aware search client
|
||||
graphConfig := hybrid.GraphConfig{
|
||||
Enabled: s.config.GraphEnabled,
|
||||
MaxHops: s.config.GraphMaxHops,
|
||||
BranchFactor: s.config.GraphBranchFactor,
|
||||
EdgeWeight: s.config.GraphEdgeWeight,
|
||||
}
|
||||
graphSearchClient = hybrid.NewGraphSearchClient(hybridClient, nil, graphConfig)
|
||||
hybridMetrics = hybrid.NewMetrics()
|
||||
|
||||
vectorSync = sqlitevec.NewSync(baseClient)
|
||||
|
||||
// Initialize AST-aware code chunking
|
||||
chunkOpts := chunking.DefaultChunkOptions()
|
||||
chunkers := []chunking.Chunker{
|
||||
golang.NewChunker(chunkOpts),
|
||||
python.NewChunker(chunkOpts),
|
||||
typescript.NewChunker(chunkOpts),
|
||||
}
|
||||
chunkingManager = chunking.NewManager(chunkers, chunkOpts)
|
||||
|
||||
log.Info().
|
||||
Str("model", embedSvc.Version()).
|
||||
Msg("sqlite-vec vector search enabled")
|
||||
Str("storage_strategy", s.config.VectorStorageStrategy).
|
||||
Bool("graph_enabled", s.config.GraphEnabled).
|
||||
Msg("LEANN hybrid vector storage and graph search enabled")
|
||||
}
|
||||
|
||||
// Create cross-encoder reranking service if enabled
|
||||
@@ -284,6 +298,9 @@ func (s *Service) initializeAsync() {
|
||||
s.embedSvc = embedSvc
|
||||
s.vectorClient = vectorClient
|
||||
s.vectorSync = vectorSync
|
||||
s.graphSearchClient = graphSearchClient
|
||||
s.hybridMetrics = hybridMetrics
|
||||
s.chunkingManager = chunkingManager
|
||||
s.reranker = reranker
|
||||
s.initMu.Unlock()
|
||||
|
||||
@@ -411,6 +428,18 @@ func (s *Service) initializeAsync() {
|
||||
s.ready.Store(true)
|
||||
log.Info().Msg("Async initialization complete - service ready")
|
||||
|
||||
// Build initial observation graph if graph search is enabled
|
||||
if graphSearchClient != nil && s.config.GraphEnabled {
|
||||
s.wg.Add(1)
|
||||
go s.buildInitialGraph(observationStore)
|
||||
|
||||
// Start periodic graph rebuild timer
|
||||
if s.config.GraphRebuildIntervalMin > 0 {
|
||||
s.wg.Add(1)
|
||||
go s.startGraphRebuildTimer(observationStore)
|
||||
}
|
||||
}
|
||||
|
||||
// Start queue processor if SDK processor is available
|
||||
if processor != nil {
|
||||
s.wg.Add(1)
|
||||
@@ -1136,6 +1165,10 @@ func (s *Service) setupRoutes() {
|
||||
r.Get("/api/observations/{id}/relations", s.handleGetRelations)
|
||||
r.Get("/api/observations/{id}/graph", s.handleGetRelationGraph)
|
||||
r.Get("/api/observations/{id}/related", s.handleGetRelatedObservations)
|
||||
|
||||
// LEANN Phase 2: Graph-based search and hybrid vector storage
|
||||
r.Get("/api/graph/stats", s.handleGetGraphStats)
|
||||
r.Get("/api/vector/metrics", s.handleGetVectorMetrics)
|
||||
})
|
||||
}
|
||||
|
||||
@@ -1346,6 +1379,87 @@ func (s *Service) processAllSessions() {
|
||||
s.broadcastProcessingStatus()
|
||||
}
|
||||
|
||||
// buildInitialGraph builds the observation relationship graph in the background.
|
||||
func (s *Service) buildInitialGraph(observationStore *gorm.ObservationStore) {
|
||||
defer s.wg.Done()
|
||||
|
||||
log.Info().Msg("Building initial observation graph...")
|
||||
start := time.Now()
|
||||
|
||||
// Fetch all observations
|
||||
observations, err := observationStore.GetAllObservations(s.ctx)
|
||||
if err != nil {
|
||||
log.Error().Err(err).Msg("Failed to fetch observations for graph building")
|
||||
return
|
||||
}
|
||||
|
||||
if len(observations) == 0 {
|
||||
log.Info().Msg("No observations to build graph from")
|
||||
return
|
||||
}
|
||||
|
||||
// Build graph using RebuildGraph method
|
||||
if err := s.graphSearchClient.RebuildGraph(s.ctx, observations); err != nil {
|
||||
log.Error().Err(err).Msg("Failed to build observation graph")
|
||||
return
|
||||
}
|
||||
|
||||
elapsed := time.Since(start)
|
||||
stats := s.graphSearchClient.GetGraphStats()
|
||||
|
||||
log.Info().
|
||||
Int("observations", len(observations)).
|
||||
Int("nodes", stats.NodeCount).
|
||||
Int("edges", stats.EdgeCount).
|
||||
Float64("avg_degree", stats.AvgDegree).
|
||||
Int("max_degree", stats.MaxDegree).
|
||||
Dur("elapsed", elapsed).
|
||||
Msg("Initial observation graph built successfully")
|
||||
}
|
||||
|
||||
// startGraphRebuildTimer starts a periodic ticker to rebuild the observation graph.
|
||||
func (s *Service) startGraphRebuildTimer(observationStore *gorm.ObservationStore) {
|
||||
defer s.wg.Done()
|
||||
|
||||
interval := time.Duration(s.config.GraphRebuildIntervalMin) * time.Minute
|
||||
s.graphRebuildTicker = time.NewTicker(interval)
|
||||
|
||||
log.Info().
|
||||
Dur("interval", interval).
|
||||
Msg("Started periodic graph rebuild timer")
|
||||
|
||||
for {
|
||||
select {
|
||||
case <-s.ctx.Done():
|
||||
s.graphRebuildTicker.Stop()
|
||||
log.Info().Msg("Stopped graph rebuild timer")
|
||||
return
|
||||
|
||||
case <-s.graphRebuildTicker.C:
|
||||
log.Info().Msg("Periodic graph rebuild triggered")
|
||||
start := time.Now()
|
||||
|
||||
observations, err := observationStore.GetAllObservations(s.ctx)
|
||||
if err != nil {
|
||||
log.Error().Err(err).Msg("Failed to fetch observations for graph rebuild")
|
||||
continue
|
||||
}
|
||||
|
||||
if err := s.graphSearchClient.RebuildGraph(s.ctx, observations); err != nil {
|
||||
log.Error().Err(err).Msg("Failed to rebuild observation graph")
|
||||
continue
|
||||
}
|
||||
|
||||
stats := s.graphSearchClient.GetGraphStats()
|
||||
log.Info().
|
||||
Int("nodes", stats.NodeCount).
|
||||
Int("edges", stats.EdgeCount).
|
||||
Dur("elapsed", time.Since(start)).
|
||||
Msg("Periodic graph rebuild complete")
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Shutdown gracefully shuts down the service.
|
||||
func (s *Service) Shutdown(ctx context.Context) error {
|
||||
s.cancel()
|
||||
|
||||
@@ -21,11 +21,11 @@ const (
|
||||
|
||||
// ObservationData contains data for a tool observation.
|
||||
type ObservationData struct {
|
||||
ToolName string
|
||||
ToolInput interface{}
|
||||
ToolResponse interface{}
|
||||
PromptNumber int
|
||||
ToolName string
|
||||
CWD string
|
||||
PromptNumber int
|
||||
}
|
||||
|
||||
// SummarizeData contains data for a summarize request.
|
||||
@@ -36,30 +36,28 @@ type SummarizeData struct {
|
||||
|
||||
// PendingMessage represents a message queued for SDK processing.
|
||||
type PendingMessage struct {
|
||||
Type MessageType
|
||||
Observation *ObservationData
|
||||
Summarize *SummarizeData
|
||||
Type MessageType
|
||||
}
|
||||
|
||||
// ActiveSession represents an in-memory active session being processed.
|
||||
type ActiveSession struct {
|
||||
SessionDBID int64
|
||||
ClaudeSessionID string
|
||||
SDKSessionID string
|
||||
StartTime time.Time
|
||||
ctx context.Context
|
||||
cancel context.CancelFunc
|
||||
notify chan struct{}
|
||||
Project string
|
||||
UserPrompt string
|
||||
SDKSessionID string
|
||||
ClaudeSessionID string
|
||||
pendingMessages []PendingMessage
|
||||
LastPromptNumber int
|
||||
StartTime time.Time
|
||||
CumulativeInputTokens int64
|
||||
CumulativeOutputTokens int64
|
||||
|
||||
// Concurrency control
|
||||
pendingMessages []PendingMessage
|
||||
messageMu sync.Mutex
|
||||
notify chan struct{}
|
||||
ctx context.Context
|
||||
cancel context.CancelFunc
|
||||
generatorActive atomic.Bool
|
||||
SessionDBID int64
|
||||
messageMu sync.Mutex
|
||||
generatorActive atomic.Bool
|
||||
}
|
||||
|
||||
// SessionTimeout is how long an inactive session can exist before cleanup.
|
||||
@@ -70,15 +68,14 @@ const CleanupInterval = 5 * time.Minute
|
||||
|
||||
// Manager manages active session lifecycles.
|
||||
type Manager struct {
|
||||
sessionStore *gorm.SessionStore
|
||||
sessions map[int64]*ActiveSession
|
||||
mu sync.RWMutex
|
||||
onCreated func(int64)
|
||||
onDeleted func(int64)
|
||||
ctx context.Context
|
||||
cancel context.CancelFunc
|
||||
// Global notification channel for immediate processing
|
||||
ctx context.Context
|
||||
sessionStore *gorm.SessionStore
|
||||
sessions map[int64]*ActiveSession
|
||||
onCreated func(int64)
|
||||
onDeleted func(int64)
|
||||
cancel context.CancelFunc
|
||||
ProcessNotify chan struct{}
|
||||
mu sync.RWMutex
|
||||
}
|
||||
|
||||
// NewManager creates a new session manager.
|
||||
|
||||
@@ -669,16 +669,16 @@ func TestActiveSessionCWD(t *testing.T) {
|
||||
// TestToolInputResponse tests various tool input/response types.
|
||||
func TestToolInputResponse(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
input interface{}
|
||||
response interface{}
|
||||
name string
|
||||
}{
|
||||
{"nil_values", nil, nil},
|
||||
{"string_values", "input string", "response string"},
|
||||
{"map_values", map[string]string{"key": "value"}, map[string]interface{}{"result": true}},
|
||||
{"slice_values", []string{"a", "b"}, []int{1, 2, 3}},
|
||||
{"int_values", 42, 100},
|
||||
{"bool_values", true, false},
|
||||
{name: "nil_values", input: nil, response: nil},
|
||||
{name: "string_values", input: "input string", response: "response string"},
|
||||
{name: "map_values", input: map[string]string{"key": "value"}, response: map[string]interface{}{"result": true}},
|
||||
{name: "slice_values", input: []string{"a", "b"}, response: []int{1, 2, 3}},
|
||||
{name: "int_values", input: 42, response: 100},
|
||||
{name: "bool_values", input: true, response: false},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
|
||||
@@ -19,10 +19,10 @@ const (
|
||||
|
||||
// Client represents a connected SSE client.
|
||||
type Client struct {
|
||||
ID string
|
||||
Writer http.ResponseWriter
|
||||
Flusher http.Flusher
|
||||
Done chan struct{}
|
||||
ID string
|
||||
}
|
||||
|
||||
// Broadcaster manages SSE client connections and message broadcasting.
|
||||
|
||||
@@ -256,8 +256,8 @@ func TestHandleSSE(t *testing.T) {
|
||||
// TestBroadcastJSON tests broadcasting various JSON types.
|
||||
func TestBroadcastJSON(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
data interface{}
|
||||
name string
|
||||
wantErr bool
|
||||
}{
|
||||
{
|
||||
|
||||
@@ -62,11 +62,11 @@ type BaseInput struct {
|
||||
// HookContext provides common context for hook handlers.
|
||||
type HookContext struct {
|
||||
HookName string
|
||||
Port int
|
||||
Project string
|
||||
SessionID string
|
||||
CWD string
|
||||
RawInput []byte
|
||||
Port int
|
||||
}
|
||||
|
||||
// HookHandler is a function that handles hook-specific logic.
|
||||
|
||||
@@ -320,11 +320,11 @@ func TestExtractBaseVersion(t *testing.T) {
|
||||
// TestPOST tests the POST function with a mock server.
|
||||
func TestPOST(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
serverHandler func(w http.ResponseWriter, r *http.Request)
|
||||
body interface{}
|
||||
expectError bool
|
||||
serverHandler func(w http.ResponseWriter, r *http.Request)
|
||||
expectedResult map[string]interface{}
|
||||
name string
|
||||
expectError bool
|
||||
}{
|
||||
{
|
||||
name: "successful POST with JSON response",
|
||||
@@ -393,10 +393,10 @@ func TestPOST(t *testing.T) {
|
||||
// TestGET tests the GET function with a mock server.
|
||||
func TestGET(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
serverHandler func(w http.ResponseWriter, r *http.Request)
|
||||
expectError bool
|
||||
expectedResult map[string]interface{}
|
||||
name string
|
||||
expectError bool
|
||||
}{
|
||||
{
|
||||
name: "successful GET with JSON response",
|
||||
@@ -532,8 +532,8 @@ func TestExitCodes(t *testing.T) {
|
||||
func TestHookResponse(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
response HookResponse
|
||||
expected string
|
||||
response HookResponse
|
||||
}{
|
||||
{
|
||||
name: "continue true",
|
||||
@@ -597,8 +597,8 @@ func TestHookContext(t *testing.T) {
|
||||
// TestIsWorkerRunning_WithServer tests IsWorkerRunning with actual server.
|
||||
func TestIsWorkerRunning_WithServer(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
serverHandler func(w http.ResponseWriter, r *http.Request)
|
||||
name string
|
||||
expectedResult bool
|
||||
}{
|
||||
{
|
||||
@@ -828,8 +828,8 @@ func TestBaseInput_PartialFields(t *testing.T) {
|
||||
func TestHookResponse_Marshal(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
response HookResponse
|
||||
contains []string
|
||||
response HookResponse
|
||||
}{
|
||||
{
|
||||
name: "continue true",
|
||||
|
||||
@@ -33,25 +33,25 @@ const (
|
||||
|
||||
// ObservationConflict tracks conflicting observations.
|
||||
type ObservationConflict struct {
|
||||
ID int64 `db:"id" json:"id"`
|
||||
NewerObsID int64 `db:"newer_obs_id" json:"newer_obs_id"`
|
||||
OlderObsID int64 `db:"older_obs_id" json:"older_obs_id"`
|
||||
ResolvedAt *string `db:"resolved_at" json:"resolved_at,omitempty"`
|
||||
ConflictType ConflictType `db:"conflict_type" json:"conflict_type"`
|
||||
Resolution ConflictResolution `db:"resolution" json:"resolution"`
|
||||
Reason string `db:"reason" json:"reason"`
|
||||
DetectedAt string `db:"detected_at" json:"detected_at"`
|
||||
ID int64 `db:"id" json:"id"`
|
||||
NewerObsID int64 `db:"newer_obs_id" json:"newer_obs_id"`
|
||||
OlderObsID int64 `db:"older_obs_id" json:"older_obs_id"`
|
||||
DetectedAtEpoch int64 `db:"detected_at_epoch" json:"detected_at_epoch"`
|
||||
Resolved bool `db:"resolved" json:"resolved"`
|
||||
ResolvedAt *string `db:"resolved_at" json:"resolved_at,omitempty"`
|
||||
}
|
||||
|
||||
// ConflictDetectionResult contains the result of conflict detection.
|
||||
type ConflictDetectionResult struct {
|
||||
HasConflict bool
|
||||
Type ConflictType
|
||||
Resolution ConflictResolution
|
||||
Reason string
|
||||
OlderObsIDs []int64 // IDs of observations that conflict with the new one
|
||||
OlderObsIDs []int64
|
||||
HasConflict bool
|
||||
}
|
||||
|
||||
// NewObservationConflict creates a new conflict record.
|
||||
|
||||
@@ -51,8 +51,8 @@ func (s *ConflictSuite) TestDetectExplicitCorrection_TableDriven() {
|
||||
tests := []struct {
|
||||
name string
|
||||
text string
|
||||
expectMatch bool
|
||||
expectPattern string
|
||||
expectMatch bool
|
||||
}{
|
||||
{
|
||||
name: "actually that was wrong",
|
||||
@@ -128,9 +128,9 @@ func (s *ConflictSuite) TestDetectExplicitCorrection_TableDriven() {
|
||||
// TestDetectOpposingFileChanges_TableDriven tests opposing file change detection.
|
||||
func (s *ConflictSuite) TestDetectOpposingFileChanges_TableDriven() {
|
||||
tests := []struct {
|
||||
name string
|
||||
newerObs *Observation
|
||||
olderObs *Observation
|
||||
name string
|
||||
expectConflict bool
|
||||
}{
|
||||
{
|
||||
@@ -202,9 +202,9 @@ func (s *ConflictSuite) TestDetectOpposingFileChanges_TableDriven() {
|
||||
// TestDetectConceptTagMismatch_TableDriven tests concept tag mismatch detection.
|
||||
func (s *ConflictSuite) TestDetectConceptTagMismatch_TableDriven() {
|
||||
tests := []struct {
|
||||
name string
|
||||
newerObs *Observation
|
||||
olderObs *Observation
|
||||
name string
|
||||
expectConflict bool
|
||||
}{
|
||||
{
|
||||
|
||||
+28
-36
@@ -121,48 +121,44 @@ func (j JSONInt64Map) Value() (driver.Value, error) {
|
||||
|
||||
// Observation represents a learning extracted from a Claude Code session.
|
||||
type Observation struct {
|
||||
ID int64 `db:"id" json:"id"`
|
||||
FileMtimes JSONInt64Map `db:"file_mtimes" json:"file_mtimes,omitempty"`
|
||||
SDKSessionID string `db:"sdk_session_id" json:"sdk_session_id"`
|
||||
Project string `db:"project" json:"project"`
|
||||
Scope ObservationScope `db:"scope" json:"scope"`
|
||||
Type ObservationType `db:"type" json:"type"`
|
||||
Title sql.NullString `db:"title" json:"title,omitempty"`
|
||||
CreatedAt string `db:"created_at" json:"created_at"`
|
||||
Subtitle sql.NullString `db:"subtitle" json:"subtitle,omitempty"`
|
||||
Facts JSONStringArray `db:"facts" json:"facts,omitempty"`
|
||||
Title sql.NullString `db:"title" json:"title,omitempty"`
|
||||
Narrative sql.NullString `db:"narrative" json:"narrative,omitempty"`
|
||||
Concepts JSONStringArray `db:"concepts" json:"concepts,omitempty"`
|
||||
FilesRead JSONStringArray `db:"files_read" json:"files_read,omitempty"`
|
||||
FilesModified JSONStringArray `db:"files_modified" json:"files_modified,omitempty"`
|
||||
FileMtimes JSONInt64Map `db:"file_mtimes" json:"file_mtimes,omitempty"`
|
||||
Facts JSONStringArray `db:"facts" json:"facts,omitempty"`
|
||||
PromptNumber sql.NullInt64 `db:"prompt_number" json:"prompt_number,omitempty"`
|
||||
LastRetrievedAt sql.NullInt64 `db:"last_retrieved_at_epoch" json:"last_retrieved_at_epoch,omitempty"`
|
||||
ScoreUpdatedAt sql.NullInt64 `db:"score_updated_at_epoch" json:"score_updated_at_epoch,omitempty"`
|
||||
DiscoveryTokens int64 `db:"discovery_tokens" json:"discovery_tokens"`
|
||||
CreatedAt string `db:"created_at" json:"created_at"`
|
||||
ID int64 `db:"id" json:"id"`
|
||||
CreatedAtEpoch int64 `db:"created_at_epoch" json:"created_at_epoch"`
|
||||
ImportanceScore float64 `db:"importance_score" json:"importance_score"`
|
||||
UserFeedback int `db:"user_feedback" json:"user_feedback"`
|
||||
RetrievalCount int `db:"retrieval_count" json:"retrieval_count"`
|
||||
IsStale bool `db:"-" json:"is_stale,omitempty"`
|
||||
|
||||
// Importance scoring fields
|
||||
ImportanceScore float64 `db:"importance_score" json:"importance_score"`
|
||||
UserFeedback int `db:"user_feedback" json:"user_feedback"`
|
||||
RetrievalCount int `db:"retrieval_count" json:"retrieval_count"`
|
||||
LastRetrievedAt sql.NullInt64 `db:"last_retrieved_at_epoch" json:"last_retrieved_at_epoch,omitempty"`
|
||||
ScoreUpdatedAt sql.NullInt64 `db:"score_updated_at_epoch" json:"score_updated_at_epoch,omitempty"`
|
||||
|
||||
// Conflict detection fields
|
||||
IsSuperseded bool `db:"is_superseded" json:"is_superseded,omitempty"`
|
||||
IsSuperseded bool `db:"is_superseded" json:"is_superseded,omitempty"`
|
||||
}
|
||||
|
||||
// ParsedObservation represents an observation parsed from SDK response XML.
|
||||
type ParsedObservation struct {
|
||||
FileMtimes map[string]int64
|
||||
Type ObservationType
|
||||
Title string
|
||||
Subtitle string
|
||||
Facts []string
|
||||
Narrative string
|
||||
Scope ObservationScope
|
||||
Facts []string
|
||||
Concepts []string
|
||||
FilesRead []string
|
||||
FilesModified []string
|
||||
FileMtimes map[string]int64 // File path -> mtime epoch ms
|
||||
Scope ObservationScope // Optional: if empty, will be auto-determined
|
||||
}
|
||||
|
||||
// ToStoredObservation converts a ParsedObservation to the stored Observation format.
|
||||
@@ -197,34 +193,30 @@ func DetermineScope(concepts []string) ObservationScope {
|
||||
// ObservationJSON is a JSON-friendly representation of Observation.
|
||||
// It converts sql.NullString to plain strings for clean JSON output.
|
||||
type ObservationJSON struct {
|
||||
ID int64 `json:"id"`
|
||||
FileMtimes map[string]int64 `json:"file_mtimes,omitempty"`
|
||||
Subtitle string `json:"subtitle,omitempty"`
|
||||
SDKSessionID string `json:"sdk_session_id"`
|
||||
Project string `json:"project"`
|
||||
Scope ObservationScope `json:"scope"`
|
||||
Type ObservationType `json:"type"`
|
||||
Title string `json:"title,omitempty"`
|
||||
Subtitle string `json:"subtitle,omitempty"`
|
||||
Facts []string `json:"facts,omitempty"`
|
||||
CreatedAt string `json:"created_at"`
|
||||
Narrative string `json:"narrative,omitempty"`
|
||||
Project string `json:"project"`
|
||||
Concepts []string `json:"concepts,omitempty"`
|
||||
Facts []string `json:"facts,omitempty"`
|
||||
FilesRead []string `json:"files_read,omitempty"`
|
||||
FilesModified []string `json:"files_modified,omitempty"`
|
||||
FileMtimes map[string]int64 `json:"file_mtimes,omitempty"`
|
||||
PromptNumber int64 `json:"prompt_number,omitempty"`
|
||||
DiscoveryTokens int64 `json:"discovery_tokens"`
|
||||
CreatedAt string `json:"created_at"`
|
||||
CreatedAtEpoch int64 `json:"created_at_epoch"`
|
||||
DiscoveryTokens int64 `json:"discovery_tokens"`
|
||||
ID int64 `json:"id"`
|
||||
PromptNumber int64 `json:"prompt_number,omitempty"`
|
||||
ImportanceScore float64 `json:"importance_score"`
|
||||
UserFeedback int `json:"user_feedback"`
|
||||
RetrievalCount int `json:"retrieval_count"`
|
||||
LastRetrievedAt int64 `json:"last_retrieved_at_epoch,omitempty"`
|
||||
ScoreUpdatedAt int64 `json:"score_updated_at_epoch,omitempty"`
|
||||
IsStale bool `json:"is_stale,omitempty"`
|
||||
|
||||
// Importance scoring fields
|
||||
ImportanceScore float64 `json:"importance_score"`
|
||||
UserFeedback int `json:"user_feedback"`
|
||||
RetrievalCount int `json:"retrieval_count"`
|
||||
LastRetrievedAt int64 `json:"last_retrieved_at_epoch,omitempty"`
|
||||
ScoreUpdatedAt int64 `json:"score_updated_at_epoch,omitempty"`
|
||||
|
||||
// Conflict detection fields
|
||||
IsSuperseded bool `json:"is_superseded,omitempty"`
|
||||
IsSuperseded bool `json:"is_superseded,omitempty"`
|
||||
}
|
||||
|
||||
// MarshalJSON implements json.Marshaler for Observation.
|
||||
|
||||
@@ -50,8 +50,8 @@ func (s *ObservationSuite) TestGlobalizableConcepts() {
|
||||
func (s *ObservationSuite) TestDetermineScope_TableDriven() {
|
||||
tests := []struct {
|
||||
name string
|
||||
concepts []string
|
||||
expected ObservationScope
|
||||
concepts []string
|
||||
}{
|
||||
{
|
||||
name: "empty concepts - project scope",
|
||||
@@ -121,9 +121,9 @@ func (s *ObservationSuite) TestParsedObservation_FileMtimesJSON() {
|
||||
// TestObservation_CheckStaleness_TableDriven tests staleness checking.
|
||||
func (s *ObservationSuite) TestObservation_CheckStaleness_TableDriven() {
|
||||
tests := []struct {
|
||||
name string
|
||||
storedMtimes map[string]int64
|
||||
currentMtimes map[string]int64
|
||||
name string
|
||||
expectedStale bool
|
||||
}{
|
||||
{
|
||||
@@ -300,10 +300,10 @@ func TestParsedObservation_ToStoredObservation(t *testing.T) {
|
||||
// TestJSONStringArray tests JSONStringArray scanning.
|
||||
func TestJSONStringArray(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
input interface{}
|
||||
wantErr bool
|
||||
name string
|
||||
expected JSONStringArray
|
||||
wantErr bool
|
||||
}{
|
||||
{
|
||||
name: "nil input",
|
||||
@@ -348,10 +348,10 @@ func TestJSONStringArray(t *testing.T) {
|
||||
// TestJSONInt64Map tests JSONInt64Map scanning.
|
||||
func TestJSONInt64Map(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
input interface{}
|
||||
wantErr bool
|
||||
expected JSONInt64Map
|
||||
name string
|
||||
wantErr bool
|
||||
}{
|
||||
{
|
||||
name: "nil input",
|
||||
|
||||
+25
-25
@@ -39,21 +39,21 @@ const (
|
||||
// Pattern represents a recurring pattern detected across observations.
|
||||
// This enables Claude to reference historical insights: "I've encountered this pattern 12 times."
|
||||
type Pattern struct {
|
||||
ID int64 `db:"id" json:"id"`
|
||||
Name string `db:"name" json:"name"` // e.g., "State Management Anti-Pattern"
|
||||
Type PatternType `db:"type" json:"type"` // bug, refactor, architecture, etc.
|
||||
Description sql.NullString `db:"description" json:"description"` // Detailed description
|
||||
Signature JSONStringArray `db:"signature" json:"signature"` // Keyword clusters for detection
|
||||
Recommendation sql.NullString `db:"recommendation" json:"recommendation"` // What works for this pattern
|
||||
Frequency int `db:"frequency" json:"frequency"` // How many times encountered
|
||||
Projects JSONStringArray `db:"projects" json:"projects"` // Projects where this pattern was seen
|
||||
ObservationIDs JSONInt64Array `db:"observation_ids" json:"observation_ids"` // Source observation IDs
|
||||
Status PatternStatus `db:"status" json:"status"` // active, deprecated, merged
|
||||
MergedIntoID sql.NullInt64 `db:"merged_into_id" json:"merged_into_id,omitempty"`
|
||||
Confidence float64 `db:"confidence" json:"confidence"` // Detection confidence (0.0-1.0)
|
||||
LastSeenAt string `db:"last_seen_at" json:"last_seen_at"` // Last time pattern was detected
|
||||
LastSeenEpoch int64 `db:"last_seen_at_epoch" json:"last_seen_at_epoch"`
|
||||
Status PatternStatus `db:"status" json:"status"`
|
||||
Name string `db:"name" json:"name"`
|
||||
Type PatternType `db:"type" json:"type"`
|
||||
CreatedAt string `db:"created_at" json:"created_at"`
|
||||
LastSeenAt string `db:"last_seen_at" json:"last_seen_at"`
|
||||
Signature JSONStringArray `db:"signature" json:"signature"`
|
||||
Projects JSONStringArray `db:"projects" json:"projects"`
|
||||
ObservationIDs JSONInt64Array `db:"observation_ids" json:"observation_ids"`
|
||||
Recommendation sql.NullString `db:"recommendation" json:"recommendation"`
|
||||
Description sql.NullString `db:"description" json:"description"`
|
||||
MergedIntoID sql.NullInt64 `db:"merged_into_id" json:"merged_into_id,omitempty"`
|
||||
Frequency int `db:"frequency" json:"frequency"`
|
||||
Confidence float64 `db:"confidence" json:"confidence"`
|
||||
ID int64 `db:"id" json:"id"`
|
||||
LastSeenEpoch int64 `db:"last_seen_at_epoch" json:"last_seen_at_epoch"`
|
||||
CreatedAtEpoch int64 `db:"created_at_epoch" json:"created_at_epoch"`
|
||||
}
|
||||
|
||||
@@ -95,21 +95,21 @@ func (j JSONInt64Array) Value() (driver.Value, error) {
|
||||
|
||||
// PatternJSON is a JSON-friendly representation of Pattern.
|
||||
type PatternJSON struct {
|
||||
ID int64 `json:"id"`
|
||||
Status PatternStatus `json:"status"`
|
||||
Name string `json:"name"`
|
||||
Type PatternType `json:"type"`
|
||||
Description string `json:"description,omitempty"`
|
||||
Signature []string `json:"signature,omitempty"`
|
||||
CreatedAt string `json:"created_at"`
|
||||
Recommendation string `json:"recommendation,omitempty"`
|
||||
Frequency int `json:"frequency"`
|
||||
Projects []string `json:"projects,omitempty"`
|
||||
LastSeenAt string `json:"last_seen_at"`
|
||||
Signature []string `json:"signature,omitempty"`
|
||||
ObservationIDs []int64 `json:"observation_ids,omitempty"`
|
||||
Status PatternStatus `json:"status"`
|
||||
Projects []string `json:"projects,omitempty"`
|
||||
MergedIntoID int64 `json:"merged_into_id,omitempty"`
|
||||
Confidence float64 `json:"confidence"`
|
||||
LastSeenAt string `json:"last_seen_at"`
|
||||
Frequency int `json:"frequency"`
|
||||
LastSeenEpoch int64 `json:"last_seen_at_epoch"`
|
||||
CreatedAt string `json:"created_at"`
|
||||
ID int64 `json:"id"`
|
||||
CreatedAtEpoch int64 `json:"created_at_epoch"`
|
||||
}
|
||||
|
||||
@@ -214,11 +214,11 @@ func (p *Pattern) updateConfidence() {
|
||||
|
||||
// PatternMatch represents a match between an observation and a potential pattern.
|
||||
type PatternMatch struct {
|
||||
PatternID int64 `json:"pattern_id"`
|
||||
Score float64 `json:"score"` // Match score (0.0-1.0)
|
||||
MatchedOn string `json:"matched_on"` // What triggered the match (concept, keyword, type, etc.)
|
||||
IsNew bool `json:"is_new"` // Whether this would create a new pattern
|
||||
MatchedOn string `json:"matched_on"`
|
||||
SuggestedName string `json:"suggested_name,omitempty"`
|
||||
PatternID int64 `json:"pattern_id"`
|
||||
Score float64 `json:"score"`
|
||||
IsNew bool `json:"is_new"`
|
||||
}
|
||||
|
||||
// PatternSignatureKeywords are common keywords used in pattern detection.
|
||||
|
||||
@@ -116,18 +116,18 @@ func TestPattern_ConfidenceCalculation(t *testing.T) {
|
||||
|
||||
func TestPatternType_Detection(t *testing.T) {
|
||||
tests := []struct {
|
||||
concepts []string
|
||||
title string
|
||||
narrative string
|
||||
expected PatternType
|
||||
concepts []string
|
||||
}{
|
||||
{[]string{"anti-pattern"}, "", "", PatternTypeAntiPattern},
|
||||
{[]string{"best-practice"}, "", "", PatternTypeBestPractice},
|
||||
{[]string{"architecture"}, "", "", PatternTypeArchitecture},
|
||||
{[]string{"refactor"}, "", "", PatternTypeRefactor},
|
||||
{[]string{}, "nil pointer bug", "", PatternTypeBug},
|
||||
{[]string{}, "Deadlock in concurrent code", "", PatternTypeBug},
|
||||
{[]string{}, "Extract interface", "", PatternTypeRefactor},
|
||||
{title: "", narrative: "", expected: PatternTypeAntiPattern, concepts: []string{"anti-pattern"}},
|
||||
{title: "", narrative: "", expected: PatternTypeBestPractice, concepts: []string{"best-practice"}},
|
||||
{title: "", narrative: "", expected: PatternTypeArchitecture, concepts: []string{"architecture"}},
|
||||
{title: "", narrative: "", expected: PatternTypeRefactor, concepts: []string{"refactor"}},
|
||||
{title: "nil pointer bug", narrative: "", expected: PatternTypeBug, concepts: []string{}},
|
||||
{title: "Deadlock in concurrent code", narrative: "", expected: PatternTypeBug, concepts: []string{}},
|
||||
{title: "Extract interface", narrative: "", expected: PatternTypeRefactor, concepts: []string{}},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
|
||||
@@ -3,18 +3,18 @@ package models
|
||||
|
||||
// UserPrompt represents a user prompt captured during a session.
|
||||
type UserPrompt struct {
|
||||
ID int64 `db:"id" json:"id"`
|
||||
ClaudeSessionID string `db:"claude_session_id" json:"claude_session_id"`
|
||||
PromptNumber int `db:"prompt_number" json:"prompt_number"`
|
||||
PromptText string `db:"prompt_text" json:"prompt_text"`
|
||||
MatchedObservations int `db:"matched_observations" json:"matched_observations"`
|
||||
CreatedAt string `db:"created_at" json:"created_at"`
|
||||
ID int64 `db:"id" json:"id"`
|
||||
PromptNumber int `db:"prompt_number" json:"prompt_number"`
|
||||
MatchedObservations int `db:"matched_observations" json:"matched_observations"`
|
||||
CreatedAtEpoch int64 `db:"created_at_epoch" json:"created_at_epoch"`
|
||||
}
|
||||
|
||||
// UserPromptWithSession includes session context for search results.
|
||||
type UserPromptWithSession struct {
|
||||
UserPrompt
|
||||
Project string `db:"project" json:"project"`
|
||||
SDKSessionID string `db:"sdk_session_id" json:"sdk_session_id"`
|
||||
UserPrompt
|
||||
}
|
||||
|
||||
@@ -60,14 +60,14 @@ const (
|
||||
|
||||
// ObservationRelation represents a directed relationship between two observations.
|
||||
type ObservationRelation struct {
|
||||
ID int64 `db:"id" json:"id"`
|
||||
SourceID int64 `db:"source_id" json:"source_id"`
|
||||
TargetID int64 `db:"target_id" json:"target_id"`
|
||||
RelationType RelationType `db:"relation_type" json:"relation_type"`
|
||||
Confidence float64 `db:"confidence" json:"confidence"`
|
||||
DetectionSource RelationDetectionSource `db:"detection_source" json:"detection_source"`
|
||||
Reason string `db:"reason" json:"reason,omitempty"`
|
||||
CreatedAt string `db:"created_at" json:"created_at"`
|
||||
ID int64 `db:"id" json:"id"`
|
||||
SourceID int64 `db:"source_id" json:"source_id"`
|
||||
TargetID int64 `db:"target_id" json:"target_id"`
|
||||
Confidence float64 `db:"confidence" json:"confidence"`
|
||||
CreatedAtEpoch int64 `db:"created_at_epoch" json:"created_at_epoch"`
|
||||
}
|
||||
|
||||
@@ -88,12 +88,12 @@ func NewObservationRelation(sourceID, targetID int64, relType RelationType, conf
|
||||
|
||||
// RelationDetectionResult contains the result of relation detection.
|
||||
type RelationDetectionResult struct {
|
||||
SourceID int64
|
||||
TargetID int64
|
||||
RelationType RelationType
|
||||
Confidence float64
|
||||
DetectionSource RelationDetectionSource
|
||||
Reason string
|
||||
SourceID int64
|
||||
TargetID int64
|
||||
Confidence float64
|
||||
}
|
||||
|
||||
// DetectFileOverlapRelation checks if observations share file references and determines relationship type.
|
||||
@@ -484,6 +484,6 @@ type RelationWithDetails struct {
|
||||
|
||||
// RelationGraph represents a graph of related observations.
|
||||
type RelationGraph struct {
|
||||
CenterID int64 `json:"center_id"`
|
||||
Relations []*RelationWithDetails `json:"relations"`
|
||||
CenterID int64 `json:"center_id"`
|
||||
}
|
||||
|
||||
@@ -8,12 +8,12 @@ import (
|
||||
|
||||
func TestDetectFileOverlapRelation(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
newer *Observation
|
||||
older *Observation
|
||||
wantRelation bool
|
||||
name string
|
||||
wantRelType RelationType
|
||||
wantMinConfid float64
|
||||
wantRelation bool
|
||||
}{
|
||||
{
|
||||
name: "no file overlap",
|
||||
@@ -105,11 +105,11 @@ func TestDetectFileOverlapRelation(t *testing.T) {
|
||||
|
||||
func TestDetectConceptOverlapRelation(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
newer *Observation
|
||||
older *Observation
|
||||
wantRelation bool
|
||||
name string
|
||||
wantMinConfid float64
|
||||
wantRelation bool
|
||||
}{
|
||||
{
|
||||
name: "no concept overlap",
|
||||
@@ -179,8 +179,8 @@ func TestDetectTypeProgressionRelation(t *testing.T) {
|
||||
name string
|
||||
newerType ObservationType
|
||||
olderType ObservationType
|
||||
wantRelation bool
|
||||
wantRelType RelationType
|
||||
wantRelation bool
|
||||
}{
|
||||
{
|
||||
name: "bugfix fixes discovery",
|
||||
@@ -314,8 +314,8 @@ func TestDetectNarrativeMentionRelation(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
narrative string
|
||||
wantRelation bool
|
||||
wantRelType RelationType
|
||||
wantRelation bool
|
||||
}{
|
||||
{
|
||||
name: "fixes language",
|
||||
|
||||
+7
-23
@@ -4,8 +4,8 @@ package models
|
||||
// ConceptWeight represents a configurable weight for a concept.
|
||||
type ConceptWeight struct {
|
||||
Concept string `db:"concept" json:"concept"`
|
||||
Weight float64 `db:"weight" json:"weight"`
|
||||
UpdatedAt string `db:"updated_at" json:"updated_at"`
|
||||
Weight float64 `db:"weight" json:"weight"`
|
||||
}
|
||||
|
||||
// UserFeedbackType represents the type of user feedback.
|
||||
@@ -62,28 +62,12 @@ var TypeBaseScores = map[ObservationType]float64{
|
||||
|
||||
// ScoringConfig contains all scoring weights and parameters.
|
||||
type ScoringConfig struct {
|
||||
// RecencyHalfLifeDays is the number of days for the importance score to halve.
|
||||
// With 7 days, a 7-day old observation has 50% of a new observation's recency score.
|
||||
RecencyHalfLifeDays float64 `json:"recency_half_life_days"`
|
||||
|
||||
// FeedbackWeight scales the user feedback contribution to final score.
|
||||
// With 0.30, a thumbs up adds 0.30 to the score, thumbs down subtracts 0.30.
|
||||
FeedbackWeight float64 `json:"feedback_weight"`
|
||||
|
||||
// ConceptWeight scales the concept boost contribution.
|
||||
// The sum of matching concept weights is multiplied by this.
|
||||
ConceptWeight float64 `json:"concept_weight"`
|
||||
|
||||
// RetrievalWeight scales the retrieval boost contribution.
|
||||
// Popular observations get a logarithmic bonus.
|
||||
RetrievalWeight float64 `json:"retrieval_weight"`
|
||||
|
||||
// ConceptWeights maps concept names to their importance weights.
|
||||
ConceptWeights map[string]float64 `json:"concept_weights"`
|
||||
|
||||
// MinScore is the minimum allowed importance score.
|
||||
// Prevents observations from completely disappearing.
|
||||
MinScore float64 `json:"min_score"`
|
||||
ConceptWeights map[string]float64 `json:"concept_weights"`
|
||||
RecencyHalfLifeDays float64 `json:"recency_half_life_days"`
|
||||
FeedbackWeight float64 `json:"feedback_weight"`
|
||||
ConceptWeight float64 `json:"concept_weight"`
|
||||
RetrievalWeight float64 `json:"retrieval_weight"`
|
||||
MinScore float64 `json:"min_score"`
|
||||
}
|
||||
|
||||
// DefaultScoringConfig returns the default scoring configuration.
|
||||
|
||||
@@ -17,29 +17,29 @@ const (
|
||||
|
||||
// SDKSession represents a Claude Code session tracked by the memory system.
|
||||
type SDKSession struct {
|
||||
ID int64 `db:"id" json:"id"`
|
||||
ClaudeSessionID string `db:"claude_session_id" json:"claude_session_id"`
|
||||
SDKSessionID sql.NullString `db:"sdk_session_id" json:"sdk_session_id,omitempty"`
|
||||
Project string `db:"project" json:"project"`
|
||||
UserPrompt sql.NullString `db:"user_prompt" json:"user_prompt,omitempty"`
|
||||
WorkerPort sql.NullInt64 `db:"worker_port" json:"worker_port,omitempty"`
|
||||
PromptCounter int64 `db:"prompt_counter" json:"prompt_counter"`
|
||||
Status SessionStatus `db:"status" json:"status"`
|
||||
StartedAt string `db:"started_at" json:"started_at"`
|
||||
StartedAtEpoch int64 `db:"started_at_epoch" json:"started_at_epoch"`
|
||||
SDKSessionID sql.NullString `db:"sdk_session_id" json:"sdk_session_id,omitempty"`
|
||||
UserPrompt sql.NullString `db:"user_prompt" json:"user_prompt,omitempty"`
|
||||
CompletedAt sql.NullString `db:"completed_at" json:"completed_at,omitempty"`
|
||||
WorkerPort sql.NullInt64 `db:"worker_port" json:"worker_port,omitempty"`
|
||||
CompletedAtEpoch sql.NullInt64 `db:"completed_at_epoch" json:"completed_at_epoch,omitempty"`
|
||||
ID int64 `db:"id" json:"id"`
|
||||
PromptCounter int64 `db:"prompt_counter" json:"prompt_counter"`
|
||||
StartedAtEpoch int64 `db:"started_at_epoch" json:"started_at_epoch"`
|
||||
}
|
||||
|
||||
// ActiveSession represents an in-memory active session being processed.
|
||||
type ActiveSession struct {
|
||||
SessionDBID int64
|
||||
StartTime time.Time
|
||||
ClaudeSessionID string
|
||||
SDKSessionID string
|
||||
Project string
|
||||
UserPrompt string
|
||||
SessionDBID int64
|
||||
LastPromptNumber int
|
||||
StartTime time.Time
|
||||
CumulativeInputTokens int64
|
||||
CumulativeOutputTokens int64
|
||||
}
|
||||
|
||||
@@ -9,18 +9,18 @@ import (
|
||||
|
||||
// SessionSummary represents a summary of a Claude Code session.
|
||||
type SessionSummary struct {
|
||||
ID int64 `db:"id" json:"id"`
|
||||
CreatedAt string `db:"created_at" json:"created_at"`
|
||||
SDKSessionID string `db:"sdk_session_id" json:"sdk_session_id"`
|
||||
Project string `db:"project" json:"project"`
|
||||
Request sql.NullString `db:"request" json:"request,omitempty"`
|
||||
Completed sql.NullString `db:"completed" json:"completed,omitempty"`
|
||||
Investigated sql.NullString `db:"investigated" json:"investigated,omitempty"`
|
||||
Learned sql.NullString `db:"learned" json:"learned,omitempty"`
|
||||
Completed sql.NullString `db:"completed" json:"completed,omitempty"`
|
||||
NextSteps sql.NullString `db:"next_steps" json:"next_steps,omitempty"`
|
||||
Notes sql.NullString `db:"notes" json:"notes,omitempty"`
|
||||
Request sql.NullString `db:"request" json:"request,omitempty"`
|
||||
PromptNumber sql.NullInt64 `db:"prompt_number" json:"prompt_number,omitempty"`
|
||||
ID int64 `db:"id" json:"id"`
|
||||
DiscoveryTokens int64 `db:"discovery_tokens" json:"discovery_tokens"`
|
||||
CreatedAt string `db:"created_at" json:"created_at"`
|
||||
CreatedAtEpoch int64 `db:"created_at_epoch" json:"created_at_epoch"`
|
||||
}
|
||||
|
||||
@@ -56,18 +56,18 @@ func NewSessionSummary(sdkSessionID, project string, parsed *ParsedSummary, prom
|
||||
// SessionSummaryJSON is a JSON-friendly representation of SessionSummary.
|
||||
// It converts sql.NullString to plain strings for clean JSON output.
|
||||
type SessionSummaryJSON struct {
|
||||
ID int64 `json:"id"`
|
||||
Completed string `json:"completed,omitempty"`
|
||||
SDKSessionID string `json:"sdk_session_id"`
|
||||
Project string `json:"project"`
|
||||
Request string `json:"request,omitempty"`
|
||||
Investigated string `json:"investigated,omitempty"`
|
||||
Learned string `json:"learned,omitempty"`
|
||||
Completed string `json:"completed,omitempty"`
|
||||
NextSteps string `json:"next_steps,omitempty"`
|
||||
Notes string `json:"notes,omitempty"`
|
||||
CreatedAt string `json:"created_at"`
|
||||
ID int64 `json:"id"`
|
||||
PromptNumber int64 `json:"prompt_number,omitempty"`
|
||||
DiscoveryTokens int64 `json:"discovery_tokens"`
|
||||
CreatedAt string `json:"created_at"`
|
||||
CreatedAtEpoch int64 `json:"created_at_epoch"`
|
||||
}
|
||||
|
||||
|
||||
@@ -12,9 +12,9 @@ import (
|
||||
|
||||
func TestJaccardSimilarity(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
set1 map[string]bool
|
||||
set2 map[string]bool
|
||||
name string
|
||||
expected float64
|
||||
}{
|
||||
{
|
||||
|
||||
Generated
+2
-2
@@ -1,12 +1,12 @@
|
||||
{
|
||||
"name": "claude-mnemonic-dashboard",
|
||||
"version": "8fe9ea5-dirty",
|
||||
"version": "v0.10.5-1-g7ab4b07-dirty",
|
||||
"lockfileVersion": 3,
|
||||
"requires": true,
|
||||
"packages": {
|
||||
"": {
|
||||
"name": "claude-mnemonic-dashboard",
|
||||
"version": "8fe9ea5-dirty",
|
||||
"version": "v0.10.5-1-g7ab4b07-dirty",
|
||||
"dependencies": {
|
||||
"vis-data": "^7.1.9",
|
||||
"vis-network": "^9.1.9",
|
||||
|
||||
+1
-1
@@ -1,6 +1,6 @@
|
||||
{
|
||||
"name": "claude-mnemonic-dashboard",
|
||||
"version": "8fe9ea5-dirty",
|
||||
"version": "v0.10.5-1-g7ab4b07-dirty",
|
||||
"private": true,
|
||||
"type": "module",
|
||||
"scripts": {
|
||||
|
||||
@@ -2,6 +2,7 @@
|
||||
import { ref, computed } from 'vue'
|
||||
import type { Stats, SelfCheckResponse } from '@/types'
|
||||
import ProjectFilter from './ProjectFilter.vue'
|
||||
import { useGraphMetrics } from '@/composables'
|
||||
|
||||
const props = defineProps<{
|
||||
stats: Stats | null
|
||||
@@ -18,12 +19,21 @@ defineEmits<{
|
||||
|
||||
// Collapse state - persisted in localStorage
|
||||
const isCollapsed = ref(localStorage.getItem('sidebar-collapsed') === 'true')
|
||||
const metricsExpanded = ref(localStorage.getItem('metrics-expanded') === 'true')
|
||||
|
||||
// Graph metrics composable
|
||||
const { graphStats, vectorMetrics, loading: metricsLoading, refresh: refreshMetrics } = useGraphMetrics()
|
||||
|
||||
function toggleCollapse() {
|
||||
isCollapsed.value = !isCollapsed.value
|
||||
localStorage.setItem('sidebar-collapsed', String(isCollapsed.value))
|
||||
}
|
||||
|
||||
function toggleMetrics() {
|
||||
metricsExpanded.value = !metricsExpanded.value
|
||||
localStorage.setItem('metrics-expanded', String(metricsExpanded.value))
|
||||
}
|
||||
|
||||
function formatNumber(n: number): string {
|
||||
if (n >= 1000000) return (n / 1000000).toFixed(1) + 'M'
|
||||
if (n >= 1000) return (n / 1000).toFixed(1) + 'K'
|
||||
@@ -205,6 +215,99 @@ function getStatusColor(status: string): string {
|
||||
</div>
|
||||
</div>
|
||||
|
||||
<!-- Advanced Metrics -->
|
||||
<div class="bg-slate-800/50 rounded-lg border border-slate-700/50">
|
||||
<button
|
||||
@click="toggleMetrics"
|
||||
class="w-full flex items-center justify-between p-4 hover:bg-slate-700/30 transition-colors rounded-lg"
|
||||
>
|
||||
<div class="flex items-center gap-2">
|
||||
<i class="fas fa-chart-line text-violet-400" />
|
||||
<h3 class="text-sm font-semibold text-white">Advanced Metrics</h3>
|
||||
</div>
|
||||
<i
|
||||
:class="[
|
||||
'fas text-slate-400 transition-transform duration-200',
|
||||
metricsExpanded ? 'fa-chevron-up' : 'fa-chevron-down'
|
||||
]"
|
||||
/>
|
||||
</button>
|
||||
|
||||
<Transition name="expand">
|
||||
<div v-show="metricsExpanded" class="px-4 pb-4 space-y-4">
|
||||
<!-- Loading State -->
|
||||
<div v-if="metricsLoading" class="text-center py-4">
|
||||
<i class="fas fa-spinner fa-spin text-slate-400" />
|
||||
<p class="text-slate-500 text-sm mt-2">Loading metrics...</p>
|
||||
</div>
|
||||
|
||||
<!-- Graph Stats -->
|
||||
<div v-else-if="graphStats?.enabled">
|
||||
<div class="flex items-center justify-between mb-2">
|
||||
<span class="text-xs text-slate-400 uppercase tracking-wide">Graph</span>
|
||||
<button
|
||||
@click="refreshMetrics"
|
||||
class="text-xs text-violet-400 hover:text-violet-300 transition-colors"
|
||||
title="Refresh metrics"
|
||||
>
|
||||
<i class="fas fa-sync-alt" />
|
||||
</button>
|
||||
</div>
|
||||
<div class="space-y-2">
|
||||
<div class="flex items-center justify-between">
|
||||
<span class="text-slate-400 text-sm">Nodes</span>
|
||||
<span class="text-white font-medium">{{ formatNumber(graphStats.nodeCount) }}</span>
|
||||
</div>
|
||||
<div class="flex items-center justify-between">
|
||||
<span class="text-slate-400 text-sm">Edges</span>
|
||||
<span class="text-white font-medium">{{ formatNumber(graphStats.edgeCount) }}</span>
|
||||
</div>
|
||||
<div class="flex items-center justify-between">
|
||||
<span class="text-slate-400 text-sm">Avg Degree</span>
|
||||
<span class="text-white font-medium">{{ graphStats.avgDegree.toFixed(1) }}</span>
|
||||
</div>
|
||||
<div class="flex items-center justify-between">
|
||||
<span class="text-slate-400 text-sm">Max Degree</span>
|
||||
<span class="text-white font-medium">{{ graphStats.maxDegree }}</span>
|
||||
</div>
|
||||
</div>
|
||||
|
||||
<!-- Vector Metrics -->
|
||||
<div v-if="vectorMetrics?.enabled" class="mt-4 pt-4 border-t border-slate-700/50">
|
||||
<div class="text-xs text-slate-400 uppercase tracking-wide mb-2">Vector Storage</div>
|
||||
<div class="space-y-2">
|
||||
<div class="flex items-center justify-between">
|
||||
<span class="text-slate-400 text-sm">Savings</span>
|
||||
<span class="text-green-400 font-medium">
|
||||
{{ vectorMetrics.storage.savingsPercent.toFixed(1) }}%
|
||||
</span>
|
||||
</div>
|
||||
<div class="flex items-center justify-between">
|
||||
<span class="text-slate-400 text-sm">Queries</span>
|
||||
<span class="text-white font-medium">{{ formatNumber(vectorMetrics.queries.total) }}</span>
|
||||
</div>
|
||||
<div class="flex items-center justify-between">
|
||||
<span class="text-slate-400 text-sm">Cache Hit</span>
|
||||
<span class="text-cyan-400 font-medium">
|
||||
{{ (vectorMetrics.cache.hitRate * 100).toFixed(1) }}%
|
||||
</span>
|
||||
</div>
|
||||
<div class="flex items-center justify-between">
|
||||
<span class="text-slate-400 text-sm">Avg Latency</span>
|
||||
<span class="text-white font-medium text-xs">{{ vectorMetrics.latency.avg }}</span>
|
||||
</div>
|
||||
</div>
|
||||
</div>
|
||||
</div>
|
||||
|
||||
<!-- Disabled State -->
|
||||
<div v-else class="text-slate-500 text-sm py-2">
|
||||
{{ graphStats?.message || 'Metrics not available' }}
|
||||
</div>
|
||||
</div>
|
||||
</Transition>
|
||||
</div>
|
||||
|
||||
<!-- Session Info -->
|
||||
<div v-if="stats" class="bg-slate-800/50 rounded-lg p-4 border border-slate-700/50">
|
||||
<div class="flex items-center gap-2 mb-3">
|
||||
@@ -260,6 +363,30 @@ function getStatusColor(status: string): string {
|
||||
>
|
||||
<i class="fas fa-search text-cyan-400" />
|
||||
</div>
|
||||
|
||||
<!-- Metrics indicator -->
|
||||
<div
|
||||
v-if="graphStats?.enabled"
|
||||
class="bg-slate-800/50 rounded-lg p-3 border border-slate-700/50 flex justify-center"
|
||||
:title="`${graphStats.nodeCount} nodes, ${graphStats.edgeCount} edges`"
|
||||
>
|
||||
<i class="fas fa-chart-line text-violet-400" />
|
||||
</div>
|
||||
</div>
|
||||
</aside>
|
||||
</template>
|
||||
|
||||
<style scoped>
|
||||
.expand-enter-active,
|
||||
.expand-leave-active {
|
||||
transition: all 0.3s ease;
|
||||
overflow: hidden;
|
||||
max-height: 500px;
|
||||
}
|
||||
|
||||
.expand-enter-from,
|
||||
.expand-leave-to {
|
||||
max-height: 0;
|
||||
opacity: 0;
|
||||
}
|
||||
</style>
|
||||
|
||||
@@ -3,3 +3,4 @@ export { useStats } from './useStats'
|
||||
export { useTimeline } from './useTimeline'
|
||||
export { useUpdate } from './useUpdate'
|
||||
export { useHealth } from './useHealth'
|
||||
export { useGraphMetrics } from './useGraphMetrics'
|
||||
|
||||
@@ -0,0 +1,43 @@
|
||||
import { ref, onMounted } from 'vue'
|
||||
import type { GraphStats, VectorMetrics } from '@/types'
|
||||
import { fetchGraphStats, fetchVectorMetrics } from '@/utils/api'
|
||||
|
||||
export function useGraphMetrics() {
|
||||
const graphStats = ref<GraphStats | null>(null)
|
||||
const vectorMetrics = ref<VectorMetrics | null>(null)
|
||||
const loading = ref(false)
|
||||
const error = ref<string | null>(null)
|
||||
|
||||
const refresh = async () => {
|
||||
loading.value = true
|
||||
error.value = null
|
||||
|
||||
try {
|
||||
// Fetch both in parallel
|
||||
const [graph, vector] = await Promise.all([
|
||||
fetchGraphStats(),
|
||||
fetchVectorMetrics()
|
||||
])
|
||||
|
||||
graphStats.value = graph
|
||||
vectorMetrics.value = vector
|
||||
} catch (err) {
|
||||
error.value = err instanceof Error ? err.message : 'Failed to fetch metrics'
|
||||
console.error('[GraphMetrics] Error:', err)
|
||||
} finally {
|
||||
loading.value = false
|
||||
}
|
||||
}
|
||||
|
||||
onMounted(() => {
|
||||
refresh()
|
||||
})
|
||||
|
||||
return {
|
||||
graphStats,
|
||||
vectorMetrics,
|
||||
loading,
|
||||
error,
|
||||
refresh
|
||||
}
|
||||
}
|
||||
@@ -63,3 +63,58 @@ export interface SelfCheckResponse {
|
||||
uptime: string
|
||||
components: ComponentHealth[]
|
||||
}
|
||||
|
||||
export interface GraphStats {
|
||||
enabled: boolean
|
||||
nodeCount: number
|
||||
edgeCount: number
|
||||
avgDegree: number
|
||||
maxDegree: number
|
||||
minDegree: number
|
||||
medianDegree: number
|
||||
edgeTypes: Record<string, number>
|
||||
config: {
|
||||
maxHops: number
|
||||
branchFactor: number
|
||||
edgeWeight: number
|
||||
rebuildIntervalMin: number
|
||||
}
|
||||
message?: string
|
||||
}
|
||||
|
||||
export interface VectorMetrics {
|
||||
enabled: boolean
|
||||
queries: {
|
||||
total: number
|
||||
hubOnly: number
|
||||
hybrid: number
|
||||
onDemand: number
|
||||
graph: number
|
||||
}
|
||||
latency: {
|
||||
avg: string
|
||||
p50: string
|
||||
p95: string
|
||||
p99: string
|
||||
avgHub: string
|
||||
avgRecompute: string
|
||||
}
|
||||
storage: {
|
||||
totalDocuments: number
|
||||
hubDocuments: number
|
||||
storedEmbeddings: number
|
||||
savingsPercent: number
|
||||
recomputedTotal: number
|
||||
}
|
||||
cache: {
|
||||
hits: number
|
||||
misses: number
|
||||
hitRate: number
|
||||
}
|
||||
graph: {
|
||||
traversals: number
|
||||
avgDepth: number
|
||||
}
|
||||
uptime: string
|
||||
message?: string
|
||||
}
|
||||
|
||||
+9
-1
@@ -1,4 +1,4 @@
|
||||
import type { Observation, UserPrompt, SessionSummary, Stats, FeedItem, ObservationFeedItem, PromptFeedItem, SummaryFeedItem, RelationWithDetails, RelationGraph, RelationStats } from '@/types'
|
||||
import type { Observation, UserPrompt, SessionSummary, Stats, FeedItem, ObservationFeedItem, PromptFeedItem, SummaryFeedItem, RelationWithDetails, RelationGraph, RelationStats, GraphStats, VectorMetrics } from '@/types'
|
||||
|
||||
const API_BASE = '/api'
|
||||
const DEFAULT_TIMEOUT = 10000 // 10 seconds
|
||||
@@ -164,3 +164,11 @@ export async function fetchRelatedObservations(observationId: number, minConfide
|
||||
export async function fetchRelationStats(signal?: AbortSignal): Promise<RelationStats> {
|
||||
return fetchWithRetry<RelationStats>(`${API_BASE}/relations/stats`, { signal })
|
||||
}
|
||||
|
||||
export async function fetchGraphStats(signal?: AbortSignal): Promise<GraphStats> {
|
||||
return fetchWithRetry<GraphStats>(`${API_BASE}/graph/stats`, { signal })
|
||||
}
|
||||
|
||||
export async function fetchVectorMetrics(signal?: AbortSignal): Promise<VectorMetrics> {
|
||||
return fetchWithRetry<VectorMetrics>(`${API_BASE}/vector/metrics`, { signal })
|
||||
}
|
||||
|
||||
@@ -1 +1 @@
|
||||
{"root":["./src/main.ts","./src/vite-env.d.ts","./src/components/index.ts","./src/composables/index.ts","./src/composables/usehealth.ts","./src/composables/usesse.ts","./src/composables/usestats.ts","./src/composables/usetimeline.ts","./src/composables/usetypes.ts","./src/composables/useupdate.ts","./src/types/api.ts","./src/types/index.ts","./src/types/observation.ts","./src/types/prompt.ts","./src/types/relation.ts","./src/types/summary.ts","./src/utils/api.ts","./src/utils/formatters.ts","./src/app.vue","./src/components/badge.vue","./src/components/card.vue","./src/components/filtertabs.vue","./src/components/header.vue","./src/components/iconbox.vue","./src/components/observationcard.vue","./src/components/projectfilter.vue","./src/components/promptcard.vue","./src/components/relationgraph.vue","./src/components/sidebar.vue","./src/components/statscards.vue","./src/components/summarycard.vue","./src/components/timeline.vue"],"version":"5.7.3"}
|
||||
{"root":["./src/main.ts","./src/vite-env.d.ts","./src/components/index.ts","./src/composables/index.ts","./src/composables/usegraphmetrics.ts","./src/composables/usehealth.ts","./src/composables/usesse.ts","./src/composables/usestats.ts","./src/composables/usetimeline.ts","./src/composables/usetypes.ts","./src/composables/useupdate.ts","./src/types/api.ts","./src/types/index.ts","./src/types/observation.ts","./src/types/prompt.ts","./src/types/relation.ts","./src/types/summary.ts","./src/utils/api.ts","./src/utils/formatters.ts","./src/app.vue","./src/components/badge.vue","./src/components/card.vue","./src/components/filtertabs.vue","./src/components/header.vue","./src/components/iconbox.vue","./src/components/observationcard.vue","./src/components/projectfilter.vue","./src/components/promptcard.vue","./src/components/relationgraph.vue","./src/components/sidebar.vue","./src/components/statscards.vue","./src/components/summarycard.vue","./src/components/timeline.vue"],"version":"5.7.3"}
|
||||
Reference in New Issue
Block a user