-
+
Go
Single binary. Fast startup, low memory. Zero runtime dependencies.
@@ -315,12 +315,20 @@
sqlite-vec
-
Embedded vector database. No external services required.
+
Hybrid vector storage with LEANN-inspired selective embeddings. 60-80% storage reduction.
BGE
Two-stage retrieval: bi-encoder embeddings + cross-encoder reranking for high accuracy.
+
+
Tree-sitter
+
AST-aware code chunking respects function boundaries for Go, Python, and TypeScript.
+
+
+
CSR Graph
+
Memory-efficient observation relationship graph with edge detection and hub identification.
+
@@ -417,9 +425,12 @@ const activeTab = ref('macos')
const features = [
{ icon: 'fas fa-brain', title: 'Learns as you work', description: 'Every bug fix, every architecture decision, every "aha moment" - captured automatically without breaking your flow.' },
{ icon: 'fas fa-search', title: 'Two-stage retrieval', description: 'Cross-encoder reranking delivers highly relevant results. Finds what you need even with vague queries like "that auth thing".' },
- { icon: 'fas fa-project-diagram', title: 'Knowledge graph', description: 'Automatically discovers relationships between memories. See how concepts connect in the visual graph dashboard.' },
+ { icon: 'fas fa-project-diagram', title: 'Graph-based search', description: 'LEANN Phase 2: Graph relationships between observations (file overlap, semantic similarity, temporal proximity) for smarter context retrieval.' },
+ { icon: 'fas fa-microchip', title: 'AST-aware chunking', description: 'Intelligent code splitting respects function boundaries. Go, Python, and TypeScript code is chunked at semantic boundaries, not arbitrary line counts.' },
+ { icon: 'fas fa-database', title: 'Hybrid vector storage', description: 'LEANN-inspired selective storage: frequently-accessed "hub" observations store embeddings, others recompute on-demand. 60-80% storage savings with <50ms latency.' },
{ icon: 'fas fa-folder-tree', title: 'Project-aware context', description: 'Your React knowledge stays in React projects. Your Go patterns stay in Go projects. No context pollution.' },
{ icon: 'fas fa-chart-line', title: 'Smart scoring', description: 'Importance decay, pattern detection, and conflict resolution ensure the most valuable memories surface first.' },
+ { icon: 'fas fa-gauge-high', title: 'Auto-tuning', description: 'Dynamic hub threshold adjustment based on query performance. Automatically balances storage efficiency with search latency for your workload.' },
{ icon: 'fas fa-lock', title: '100% private', description: 'Your code context never leaves your machine. No telemetry. No cloud sync. Your memories are yours.' },
]
@@ -447,6 +458,10 @@ const configOptions = [
{ name: 'CLAUDE_MNEMONIC_CONTEXT_OBSERVATIONS', description: 'Maximum observations injected per session (default: 100)', icon: 'fas fa-layer-group' },
{ name: 'CLAUDE_MNEMONIC_RERANKING_ENABLED', description: 'Enable cross-encoder reranking for improved search relevance (default: true)', icon: 'fas fa-sort-amount-down' },
{ name: 'CLAUDE_MNEMONIC_CONTEXT_RELEVANCE_THRESHOLD', description: 'Minimum similarity score for inclusion, 0.0-1.0 (default: 0.3)', icon: 'fas fa-filter' },
+ { name: 'CLAUDE_MNEMONIC_VECTOR_STORAGE_STRATEGY', description: 'Storage strategy: "hub" (default), "always", or "on_demand"', icon: 'fas fa-database' },
+ { name: 'CLAUDE_MNEMONIC_GRAPH_ENABLED', description: 'Enable graph-based search with observation relationships (default: true)', icon: 'fas fa-project-diagram' },
+ { name: 'CLAUDE_MNEMONIC_GRAPH_MAX_HOPS', description: 'Maximum graph traversal depth for search expansion (default: 2)', icon: 'fas fa-route' },
+ { name: 'CLAUDE_MNEMONIC_GRAPH_REBUILD_INTERVAL_MIN', description: 'How often to rebuild the observation graph in minutes (default: 60)', icon: 'fas fa-clock' },
]
const requiredDeps = [
@@ -457,9 +472,11 @@ const requiredDeps = [
const faqs = [
{ question: 'Will it confuse Claude with wrong context?', answer: 'No. Mnemonic uses project isolation and semantic relevance scoring. Only memories from the current project (or global best practices) are injected, and only when they\'re actually relevant to your prompt.' },
{ question: 'What exactly gets saved?', answer: 'Bug fixes with context ("Fixed race condition by adding mutex"), architecture decisions ("Using repository pattern for data access"), conventions ("All API routes prefixed with /api/v1"), and learnings you want to preserve.' },
- { question: 'Can I delete or edit memories?', answer: 'Yes. The web dashboard at localhost:37777 lets you browse, search, edit, and delete any memory. You\'re always in control.' },
+ { question: 'How does hybrid vector storage work?', answer: 'LEANN-inspired selective storage: frequently-accessed "hub" observations (identified by access patterns and graph centrality) store embeddings. Infrequently-accessed observations recompute embeddings on-demand during search. This reduces storage by 60-80% with minimal latency impact (<50ms).' },
+ { question: 'Can I delete or edit memories?', answer: 'Yes. The web dashboard at localhost:37777 lets you browse, search, edit, and delete any memory. You can also view graph relationships, storage metrics, and performance analytics. You\'re always in control.' },
{ question: 'Does it work with my existing Claude Code setup?', answer: 'Yes. Mnemonic installs as a Claude Code plugin with hooks. Your existing workflows, settings, and shortcuts remain unchanged.' },
{ question: 'What if I switch between projects frequently?', answer: 'That\'s the point. Each project has isolated memories. Switch from your Python ML project to your TypeScript app - context switches automatically.' },
- { question: 'Is there a performance impact?', answer: 'Minimal. The Go worker is lightweight (typically under 30MB RAM). Context injection at session start takes milliseconds for most projects.' },
+ { question: 'Is there a performance impact?', answer: 'Minimal. The Go worker is lightweight (typically under 30MB RAM). Hybrid storage and auto-tuning optimize for your workload. Context injection at session start takes milliseconds for most projects.' },
+ { question: 'What is AST-aware chunking?', answer: 'When processing code observations, Mnemonic uses Tree-sitter parsers to respect function and class boundaries instead of arbitrary line limits. Go, Python, and TypeScript code is chunked at semantic boundaries for better search accuracy.' },
]
diff --git a/go.mod b/go.mod
index 28de2f9..703046a 100644
--- a/go.mod
+++ b/go.mod
@@ -12,6 +12,7 @@ require (
github.com/goccy/go-json v0.10.5
github.com/mattn/go-sqlite3 v1.14.33
github.com/rs/zerolog v1.34.0
+ github.com/smacker/go-tree-sitter v0.0.0-20240827094217-dd81d9e9be82
github.com/stretchr/testify v1.11.1
github.com/sugarme/tokenizer v0.3.0
github.com/yalue/onnxruntime_go v1.25.0
diff --git a/go.sum b/go.sum
index 19ddd7a..496a385 100644
--- a/go.sum
+++ b/go.sum
@@ -44,6 +44,8 @@ github.com/rs/zerolog v1.34.0 h1:k43nTLIwcTVQAncfCw4KZ2VY6ukYoZaBPNOE8txlOeY=
github.com/rs/zerolog v1.34.0/go.mod h1:bJsvje4Z08ROH4Nhs5iH600c3IkWhwp44iRc54W6wYQ=
github.com/schollz/progressbar/v2 v2.15.0 h1:dVzHQ8fHRmtPjD3K10jT3Qgn/+H+92jhPrhmxIJfDz8=
github.com/schollz/progressbar/v2 v2.15.0/go.mod h1:UdPq3prGkfQ7MOzZKlDRpYKcFqEMczbD7YmbPgpzKMI=
+github.com/smacker/go-tree-sitter v0.0.0-20240827094217-dd81d9e9be82 h1:6C8qej6f1bStuePVkLSFxoU22XBS165D3klxlzRg8F4=
+github.com/smacker/go-tree-sitter v0.0.0-20240827094217-dd81d9e9be82/go.mod h1:xe4pgH49k4SsmkQq5OT8abwhWmnzkhpgnXeekbx2efw=
github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME=
github.com/stretchr/testify v1.3.0/go.mod h1:M5WIy9Dh21IEIfnGCwXGc5bZfKNJtfHm1UVUgZn+9EI=
github.com/stretchr/testify v1.11.1 h1:7s2iGBzp5EwR7/aIZr8ao5+dra3wiQyKjjFuvgVKu7U=
diff --git a/internal/chunking/golang/chunker.go b/internal/chunking/golang/chunker.go
new file mode 100644
index 0000000..c267cf5
--- /dev/null
+++ b/internal/chunking/golang/chunker.go
@@ -0,0 +1,285 @@
+// Package golang provides AST-aware chunking for Go source files.
+package golang
+
+import (
+ "context"
+ "fmt"
+ "go/ast"
+ "go/parser"
+ "go/token"
+ "os"
+ "strings"
+
+ "github.com/lukaszraczylo/claude-mnemonic/internal/chunking"
+)
+
+// Chunker implements AST-aware chunking for Go files.
+type Chunker struct {
+ options chunking.ChunkOptions
+}
+
+// NewChunker creates a new Go chunker.
+func NewChunker(options chunking.ChunkOptions) *Chunker {
+ return &Chunker{options: options}
+}
+
+// Language returns the language this chunker supports.
+func (c *Chunker) Language() chunking.Language {
+ return chunking.LanguageGo
+}
+
+// SupportedExtensions returns the file extensions this chunker handles.
+func (c *Chunker) SupportedExtensions() []string {
+ return []string{".go"}
+}
+
+// Chunk parses a Go source file and returns semantic code chunks.
+func (c *Chunker) Chunk(ctx context.Context, filePath string) ([]chunking.Chunk, error) {
+ // Read file content
+ content, err := os.ReadFile(filePath)
+ if err != nil {
+ return nil, fmt.Errorf("read file: %w", err)
+ }
+
+ // Parse the Go file
+ fset := token.NewFileSet()
+ file, err := parser.ParseFile(fset, filePath, content, parser.ParseComments)
+ if err != nil {
+ return nil, fmt.Errorf("parse Go file: %w", err)
+ }
+
+ chunks := make([]chunking.Chunk, 0)
+ sourceLines := strings.Split(string(content), "\n")
+
+ // Extract chunks from declarations
+ for _, decl := range file.Decls {
+ switch d := decl.(type) {
+ case *ast.FuncDecl:
+ chunk := c.extractFunction(fset, d, sourceLines, filePath)
+ if chunk != nil {
+ chunks = append(chunks, *chunk)
+ }
+ case *ast.GenDecl:
+ extracted := c.extractGenDecl(fset, d, sourceLines, filePath)
+ chunks = append(chunks, extracted...)
+ }
+ }
+
+ return chunks, nil
+}
+
+// extractFunction extracts a function or method declaration as a chunk.
+func (c *Chunker) extractFunction(fset *token.FileSet, fn *ast.FuncDecl, sourceLines []string, filePath string) *chunking.Chunk {
+ // Skip unexported if configured
+ if !c.options.IncludePrivate && !fn.Name.IsExported() {
+ return nil
+ }
+
+ startPos := fset.Position(fn.Pos())
+ endPos := fset.Position(fn.End())
+
+ chunk := &chunking.Chunk{
+ FilePath: filePath,
+ Language: chunking.LanguageGo,
+ Name: fn.Name.Name,
+ StartLine: startPos.Line,
+ EndLine: endPos.Line,
+ }
+
+ // Determine if this is a method or a function
+ if fn.Recv != nil && len(fn.Recv.List) > 0 {
+ chunk.Type = chunking.ChunkTypeMethod
+ chunk.ParentName = c.extractReceiverType(fn.Recv)
+ } else {
+ chunk.Type = chunking.ChunkTypeFunction
+ }
+
+ // Extract content
+ chunk.Content = c.extractLines(sourceLines, startPos.Line, endPos.Line)
+
+ // Extract signature (function declaration without body)
+ chunk.Signature = c.extractFunctionSignature(fn, fset, sourceLines)
+
+ // Extract doc comment
+ if c.options.IncludeDocComments && fn.Doc != nil {
+ chunk.DocComment = strings.TrimSpace(fn.Doc.Text())
+ }
+
+ return chunk
+}
+
+// extractGenDecl extracts general declarations (type, const, var).
+func (c *Chunker) extractGenDecl(fset *token.FileSet, gd *ast.GenDecl, sourceLines []string, filePath string) []chunking.Chunk {
+ var chunks []chunking.Chunk
+
+ for _, spec := range gd.Specs {
+ switch s := spec.(type) {
+ case *ast.TypeSpec:
+ chunk := c.extractTypeSpec(fset, gd, s, sourceLines, filePath)
+ if chunk != nil {
+ chunks = append(chunks, *chunk)
+ }
+ case *ast.ValueSpec:
+ // Handle const and var declarations
+ chunk := c.extractValueSpec(fset, gd, s, sourceLines, filePath)
+ if chunk != nil {
+ chunks = append(chunks, *chunk)
+ }
+ }
+ }
+
+ return chunks
+}
+
+// extractTypeSpec extracts a type declaration (struct, interface, type alias).
+func (c *Chunker) extractTypeSpec(fset *token.FileSet, gd *ast.GenDecl, ts *ast.TypeSpec, sourceLines []string, filePath string) *chunking.Chunk {
+ // Skip unexported if configured
+ if !c.options.IncludePrivate && !ts.Name.IsExported() {
+ return nil
+ }
+
+ startPos := fset.Position(gd.Pos())
+ endPos := fset.Position(gd.End())
+
+ chunk := &chunking.Chunk{
+ FilePath: filePath,
+ Language: chunking.LanguageGo,
+ Name: ts.Name.Name,
+ StartLine: startPos.Line,
+ EndLine: endPos.Line,
+ Content: c.extractLines(sourceLines, startPos.Line, endPos.Line),
+ }
+
+ // Determine chunk type based on type expression
+ switch ts.Type.(type) {
+ case *ast.StructType:
+ chunk.Type = chunking.ChunkTypeClass // Treat struct as class
+ case *ast.InterfaceType:
+ chunk.Type = chunking.ChunkTypeInterface
+ default:
+ chunk.Type = chunking.ChunkTypeType
+ }
+
+ // Extract doc comment
+ if c.options.IncludeDocComments && gd.Doc != nil {
+ chunk.DocComment = strings.TrimSpace(gd.Doc.Text())
+ }
+
+ return chunk
+}
+
+// extractValueSpec extracts const or var declarations.
+func (c *Chunker) extractValueSpec(fset *token.FileSet, gd *ast.GenDecl, vs *ast.ValueSpec, sourceLines []string, filePath string) *chunking.Chunk {
+ // Skip if all names are unexported and we're excluding private
+ if !c.options.IncludePrivate {
+ allUnexported := true
+ for _, name := range vs.Names {
+ if name.IsExported() {
+ allUnexported = false
+ break
+ }
+ }
+ if allUnexported {
+ return nil
+ }
+ }
+
+ startPos := fset.Position(gd.Pos())
+ endPos := fset.Position(gd.End())
+
+ // Use first name as the chunk name, join multiple if present
+ names := make([]string, len(vs.Names))
+ for i, name := range vs.Names {
+ names[i] = name.Name
+ }
+
+ chunk := &chunking.Chunk{
+ FilePath: filePath,
+ Language: chunking.LanguageGo,
+ Name: strings.Join(names, ", "),
+ StartLine: startPos.Line,
+ EndLine: endPos.Line,
+ Content: c.extractLines(sourceLines, startPos.Line, endPos.Line),
+ }
+
+ // Set type based on token
+ if gd.Tok == token.CONST {
+ chunk.Type = chunking.ChunkTypeConst
+ } else {
+ chunk.Type = chunking.ChunkTypeVar
+ }
+
+ // Extract doc comment
+ if c.options.IncludeDocComments && gd.Doc != nil {
+ chunk.DocComment = strings.TrimSpace(gd.Doc.Text())
+ }
+
+ return chunk
+}
+
+// extractReceiverType extracts the receiver type name from a method.
+func (c *Chunker) extractReceiverType(recv *ast.FieldList) string {
+ if len(recv.List) == 0 {
+ return ""
+ }
+
+ field := recv.List[0]
+ switch t := field.Type.(type) {
+ case *ast.Ident:
+ return t.Name
+ case *ast.StarExpr:
+ if ident, ok := t.X.(*ast.Ident); ok {
+ return ident.Name
+ }
+ }
+
+ return ""
+}
+
+// extractFunctionSignature extracts the function signature without the body.
+func (c *Chunker) extractFunctionSignature(fn *ast.FuncDecl, fset *token.FileSet, sourceLines []string) string {
+ if fn.Body == nil {
+ // No body, return entire declaration
+ startPos := fset.Position(fn.Pos())
+ endPos := fset.Position(fn.End())
+ return c.extractLines(sourceLines, startPos.Line, endPos.Line)
+ }
+
+ // Extract from start of function to just before body
+ startPos := fset.Position(fn.Pos())
+ bodyPos := fset.Position(fn.Body.Pos())
+
+ // If body is on the same line, extract just that line up to the opening brace
+ if startPos.Line == bodyPos.Line {
+ line := sourceLines[startPos.Line-1]
+ // Find the opening brace position
+ if idx := strings.Index(line[startPos.Column-1:], "{"); idx >= 0 {
+ return strings.TrimSpace(line[startPos.Column-1 : startPos.Column-1+idx])
+ }
+ return strings.TrimSpace(line[startPos.Column-1:])
+ }
+
+ // Get lines from start to the line containing the opening brace
+ sig := c.extractLines(sourceLines, startPos.Line, bodyPos.Line)
+ // Remove the opening brace and anything after it
+ if idx := strings.Index(sig, "{"); idx >= 0 {
+ sig = sig[:idx]
+ }
+ return strings.TrimSpace(sig)
+}
+
+// extractLines extracts a range of lines from source (1-indexed, inclusive).
+func (c *Chunker) extractLines(lines []string, start, end int) string {
+ if start < 1 || end < start || start > len(lines) {
+ return ""
+ }
+
+ // Adjust for 0-indexed array (start and end are 1-indexed)
+ startIdx := start - 1
+ endIdx := end
+ if endIdx > len(lines) {
+ endIdx = len(lines)
+ }
+
+ return strings.Join(lines[startIdx:endIdx], "\n")
+}
diff --git a/internal/chunking/golang/chunker_test.go b/internal/chunking/golang/chunker_test.go
new file mode 100644
index 0000000..f09adc9
--- /dev/null
+++ b/internal/chunking/golang/chunker_test.go
@@ -0,0 +1,214 @@
+package golang
+
+import (
+ "context"
+ "os"
+ "path/filepath"
+ "testing"
+
+ "github.com/lukaszraczylo/claude-mnemonic/internal/chunking"
+)
+
+func TestGoChunker_BasicFunctions(t *testing.T) {
+ // Create temp test file
+ tmpDir := t.TempDir()
+ testFile := filepath.Join(tmpDir, "test.go")
+
+ testCode := `package main
+
+import "fmt"
+
+// Greet prints a greeting message
+func Greet(name string) {
+ fmt.Printf("Hello, %s!\n", name)
+}
+
+// Add adds two numbers
+func Add(a, b int) int {
+ return a + b
+}
+
+// unexported function should be included by default
+func helper() string {
+ return "helper"
+}
+`
+
+ if err := os.WriteFile(testFile, []byte(testCode), 0600); err != nil {
+ t.Fatalf("Failed to create test file: %v", err)
+ }
+
+ // Create chunker with default options
+ chunker := NewChunker(chunking.DefaultChunkOptions())
+
+ // Chunk the file
+ chunks, err := chunker.Chunk(context.Background(), testFile)
+ if err != nil {
+ t.Fatalf("Chunk() failed: %v", err)
+ }
+
+ // Verify we got all functions
+ if len(chunks) != 3 {
+ t.Errorf("Expected 3 chunks (Greet, Add, helper), got %d", len(chunks))
+ }
+
+ // Verify chunk details
+ expectedNames := map[string]bool{
+ "Greet": false,
+ "Add": false,
+ "helper": false,
+ }
+
+ for _, chunk := range chunks {
+ if chunk.Type != chunking.ChunkTypeFunction {
+ t.Errorf("Expected chunk type 'function', got '%s'", chunk.Type)
+ }
+
+ if chunk.Language != chunking.LanguageGo {
+ t.Errorf("Expected language 'go', got '%s'", chunk.Language)
+ }
+
+ if _, ok := expectedNames[chunk.Name]; !ok {
+ t.Errorf("Unexpected function name: %s", chunk.Name)
+ } else {
+ expectedNames[chunk.Name] = true
+ }
+
+ // Verify content is non-empty
+ if chunk.Content == "" {
+ t.Errorf("Chunk %s has empty content", chunk.Name)
+ }
+
+ // Verify signature is present for functions
+ if chunk.Signature == "" {
+ t.Errorf("Chunk %s has empty signature", chunk.Name)
+ }
+ }
+
+ // Verify all expected functions were found
+ for name, found := range expectedNames {
+ if !found {
+ t.Errorf("Expected function %s not found", name)
+ }
+ }
+}
+
+func TestGoChunker_StructsAndMethods(t *testing.T) {
+ tmpDir := t.TempDir()
+ testFile := filepath.Join(tmpDir, "test.go")
+
+ testCode := `package main
+
+// User represents a user
+type User struct {
+ ID int
+ Name string
+}
+
+// GetName returns the user's name
+func (u *User) GetName() string {
+ return u.Name
+}
+
+// SetName sets the user's name
+func (u *User) SetName(name string) {
+ u.Name = name
+}
+`
+
+ if err := os.WriteFile(testFile, []byte(testCode), 0600); err != nil {
+ t.Fatalf("Failed to create test file: %v", err)
+ }
+
+ chunker := NewChunker(chunking.DefaultChunkOptions())
+ chunks, err := chunker.Chunk(context.Background(), testFile)
+ if err != nil {
+ t.Fatalf("Chunk() failed: %v", err)
+ }
+
+ // Should have 1 struct + 2 methods = 3 chunks
+ if len(chunks) != 3 {
+ t.Errorf("Expected 3 chunks (User struct, GetName, SetName), got %d", len(chunks))
+ }
+
+ // Find the struct and methods
+ var structChunk, getNameChunk, setNameChunk *chunking.Chunk
+ for i := range chunks {
+ switch chunks[i].Name {
+ case "User":
+ structChunk = &chunks[i]
+ case "GetName":
+ getNameChunk = &chunks[i]
+ case "SetName":
+ setNameChunk = &chunks[i]
+ }
+ }
+
+ // Verify struct
+ if structChunk == nil {
+ t.Fatal("User struct not found")
+ }
+ if structChunk.Type != chunking.ChunkTypeClass {
+ t.Errorf("Expected User to be ChunkTypeClass, got %s", structChunk.Type)
+ }
+
+ // Verify methods
+ if getNameChunk == nil {
+ t.Fatal("GetName method not found")
+ }
+ if getNameChunk.Type != chunking.ChunkTypeMethod {
+ t.Errorf("Expected GetName to be ChunkTypeMethod, got %s", getNameChunk.Type)
+ }
+ if getNameChunk.ParentName != "User" {
+ t.Errorf("Expected GetName parent to be 'User', got '%s'", getNameChunk.ParentName)
+ }
+
+ if setNameChunk == nil {
+ t.Fatal("SetName method not found")
+ }
+ if setNameChunk.Type != chunking.ChunkTypeMethod {
+ t.Errorf("Expected SetName to be ChunkTypeMethod, got %s", setNameChunk.Type)
+ }
+ if setNameChunk.ParentName != "User" {
+ t.Errorf("Expected SetName parent to be 'User', got '%s'", setNameChunk.ParentName)
+ }
+}
+
+func TestGoChunker_DocComments(t *testing.T) {
+ tmpDir := t.TempDir()
+ testFile := filepath.Join(tmpDir, "test.go")
+
+ testCode := `package main
+
+// Calculate performs a calculation.
+// It takes two integers and returns their sum.
+func Calculate(a, b int) int {
+ return a + b
+}
+`
+
+ if err := os.WriteFile(testFile, []byte(testCode), 0600); err != nil {
+ t.Fatalf("Failed to create test file: %v", err)
+ }
+
+ chunker := NewChunker(chunking.DefaultChunkOptions())
+ chunks, err := chunker.Chunk(context.Background(), testFile)
+ if err != nil {
+ t.Fatalf("Chunk() failed: %v", err)
+ }
+
+ if len(chunks) != 1 {
+ t.Fatalf("Expected 1 chunk, got %d", len(chunks))
+ }
+
+ chunk := chunks[0]
+ if chunk.DocComment == "" {
+ t.Error("Expected doc comment to be present")
+ }
+
+ // Doc comment should contain the comment text
+ expectedComment := "Calculate performs a calculation.\nIt takes two integers and returns their sum."
+ if chunk.DocComment != expectedComment {
+ t.Errorf("Expected doc comment '%s', got '%s'", expectedComment, chunk.DocComment)
+ }
+}
diff --git a/internal/chunking/manager.go b/internal/chunking/manager.go
new file mode 100644
index 0000000..1115a54
--- /dev/null
+++ b/internal/chunking/manager.go
@@ -0,0 +1,106 @@
+package chunking
+
+import (
+ "context"
+ "fmt"
+ "path/filepath"
+ "strings"
+)
+
+// Manager dispatches files to appropriate language-specific chunkers.
+type Manager struct {
+ chunkers map[string]Chunker // extension -> chunker
+ options ChunkOptions
+}
+
+// NewManager creates a new chunking manager with the given chunkers.
+func NewManager(chunkers []Chunker, options ChunkOptions) *Manager {
+ m := &Manager{
+ chunkers: make(map[string]Chunker),
+ options: options,
+ }
+
+ // Register chunkers by their supported extensions
+ for _, chunker := range chunkers {
+ for _, ext := range chunker.SupportedExtensions() {
+ m.chunkers[ext] = chunker
+ }
+ }
+
+ return m
+}
+
+// ChunkFile chunks a single file using the appropriate language chunker.
+// Returns an error if no chunker is found for the file extension.
+func (m *Manager) ChunkFile(ctx context.Context, filePath string) ([]Chunk, error) {
+ ext := strings.ToLower(filepath.Ext(filePath))
+ chunker, ok := m.chunkers[ext]
+ if !ok {
+ return nil, fmt.Errorf("no chunker for extension %s", ext)
+ }
+
+ chunks, err := chunker.Chunk(ctx, filePath)
+ if err != nil {
+ return nil, fmt.Errorf("chunk %s: %w", filePath, err)
+ }
+
+ // Apply options-based filtering
+ filtered := make([]Chunk, 0, len(chunks))
+ for _, chunk := range chunks {
+ // Filter by minimum lines
+ if m.options.MinLines > 0 {
+ lineCount := chunk.EndLine - chunk.StartLine + 1
+ if lineCount < m.options.MinLines {
+ continue
+ }
+ }
+
+ // Filter by maximum chunk size
+ if m.options.MaxChunkSize > 0 && len(chunk.Content) > m.options.MaxChunkSize {
+ // TODO: Consider splitting large chunks intelligently
+ // For now, skip chunks that are too large
+ continue
+ }
+
+ filtered = append(filtered, chunk)
+ }
+
+ return filtered, nil
+}
+
+// ChunkFiles chunks multiple files in parallel.
+// Returns a map of file path to chunks, and any errors encountered.
+// Errors for individual files do not stop processing of other files.
+func (m *Manager) ChunkFiles(ctx context.Context, filePaths []string) (map[string][]Chunk, []error) {
+ results := make(map[string][]Chunk)
+ var errors []error
+
+ for _, filePath := range filePaths {
+ chunks, err := m.ChunkFile(ctx, filePath)
+ if err != nil {
+ errors = append(errors, fmt.Errorf("%s: %w", filePath, err))
+ continue
+ }
+ if len(chunks) > 0 {
+ results[filePath] = chunks
+ }
+ }
+
+ return results, errors
+}
+
+// SupportsFile checks if the manager can chunk the given file based on extension.
+func (m *Manager) SupportsFile(filePath string) bool {
+ ext := strings.ToLower(filepath.Ext(filePath))
+ _, ok := m.chunkers[ext]
+ return ok
+}
+
+// SupportedExtensions returns all file extensions supported by registered chunkers.
+func (m *Manager) SupportedExtensions() []string {
+ exts := make([]string, 0, len(m.chunkers))
+ for ext := range m.chunkers {
+ exts = append(exts, ext)
+ }
+ return exts
+}
diff --git a/internal/chunking/manager_test.go b/internal/chunking/manager_test.go
new file mode 100644
index 0000000..6c8e867
--- /dev/null
+++ b/internal/chunking/manager_test.go
@@ -0,0 +1,162 @@
+package chunking
+
+import (
+ "context"
+ "os"
+ "path/filepath"
+ "testing"
+)
+
+// mockChunker is a test chunker that returns dummy chunks
+type mockChunker struct{}
+
+func (m *mockChunker) Chunk(ctx context.Context, filePath string) ([]Chunk, error) {
+ // Just return an empty chunk for testing
+ return []Chunk{
+ {
+ FilePath: filePath,
+ Language: LanguageGo,
+ Type: ChunkTypeFunction,
+ Name: "TestFunc",
+ StartLine: 1,
+ EndLine: 1,
+ Content: "test",
+ },
+ }, nil
+}
+
+func (m *mockChunker) Language() Language {
+ return LanguageGo
+}
+
+func (m *mockChunker) SupportedExtensions() []string {
+ return []string{".go", ".py", ".ts"}
+}
+
+func TestManager_ChunkMultipleFiles(t *testing.T) {
+ tmpDir := t.TempDir()
+
+ // Create a Go file
+ goFile := filepath.Join(tmpDir, "test.go")
+ goCode := `package main
+
+func Hello() string {
+ return "hello"
+}
+`
+ if err := os.WriteFile(goFile, []byte(goCode), 0600); err != nil {
+ t.Fatalf("Failed to create Go file: %v", err)
+ }
+
+ // Create a Python file
+ pyFile := filepath.Join(tmpDir, "test.py")
+ pyCode := `def greet(name):
+ return f"Hello, {name}!"
+
+class User:
+ def __init__(self, name):
+ self.name = name
+`
+ if err := os.WriteFile(pyFile, []byte(pyCode), 0600); err != nil {
+ t.Fatalf("Failed to create Python file: %v", err)
+ }
+
+ // Create a TypeScript file
+ tsFile := filepath.Join(tmpDir, "test.ts")
+ tsCode := `function add(a: number, b: number): number {
+ return a + b;
+}
+
+class Calculator {
+ multiply(a: number, b: number): number {
+ return a * b;
+ }
+}
+`
+ if err := os.WriteFile(tsFile, []byte(tsCode), 0600); err != nil {
+ t.Fatalf("Failed to create TypeScript file: %v", err)
+ }
+
+ // Create manager
+ manager := NewManager([]Chunker{&mockChunker{}}, DefaultChunkOptions())
+
+ // Test SupportsFile
+ if !manager.SupportsFile(goFile) {
+ t.Error("Manager should support .go files")
+ }
+ if !manager.SupportsFile(pyFile) {
+ t.Error("Manager should support .py files")
+ }
+ if !manager.SupportsFile(tsFile) {
+ t.Error("Manager should support .ts files")
+ }
+
+ unsupportedFile := filepath.Join(tmpDir, "test.txt")
+ if manager.SupportsFile(unsupportedFile) {
+ t.Error("Manager should not support .txt files")
+ }
+
+ // Test ChunkFiles
+ results, errs := manager.ChunkFiles(context.Background(), []string{goFile, pyFile, tsFile})
+ if len(errs) > 0 {
+ t.Errorf("ChunkFiles returned errors: %v", errs)
+ }
+
+ if len(results) != 3 {
+ t.Errorf("Expected results for 3 files, got %d", len(results))
+ }
+
+ // Verify each file has chunks
+ for _, file := range []string{goFile, pyFile, tsFile} {
+ if chunks, ok := results[file]; !ok || len(chunks) == 0 {
+ t.Errorf("No chunks found for file %s", file)
+ }
+ }
+}
+
+// mockChunkerWithExts is a test chunker with configurable extensions
+type mockChunkerWithExts struct {
+ exts []string
+}
+
+func (m *mockChunkerWithExts) Chunk(ctx context.Context, filePath string) ([]Chunk, error) {
+ return nil, nil
+}
+
+func (m *mockChunkerWithExts) Language() Language {
+ return LanguageGo
+}
+
+func (m *mockChunkerWithExts) SupportedExtensions() []string {
+ return m.exts
+}
+
+func TestManager_SupportedExtensions(t *testing.T) {
+
+ // Create manager with mock chunkers
+ manager := NewManager([]Chunker{
+ &mockChunkerWithExts{exts: []string{".go"}},
+ &mockChunkerWithExts{exts: []string{".py", ".pyw"}},
+ }, DefaultChunkOptions())
+
+ exts := manager.SupportedExtensions()
+ expectedExts := map[string]bool{
+ ".go": false,
+ ".py": false,
+ ".pyw": false,
+ }
+
+ for _, ext := range exts {
+ if _, ok := expectedExts[ext]; ok {
+ expectedExts[ext] = true
+ } else {
+ t.Errorf("Unexpected extension: %s", ext)
+ }
+ }
+
+ for ext, found := range expectedExts {
+ if !found {
+ t.Errorf("Expected extension %s not found", ext)
+ }
+ }
+}
diff --git a/internal/chunking/python/chunker.go b/internal/chunking/python/chunker.go
new file mode 100644
index 0000000..c4906f9
--- /dev/null
+++ b/internal/chunking/python/chunker.go
@@ -0,0 +1,291 @@
+// Package python provides AST-aware chunking for Python source files using tree-sitter.
+package python
+
+import (
+ "context"
+ "fmt"
+ "os"
+ "strings"
+
+ sitter "github.com/smacker/go-tree-sitter"
+ "github.com/smacker/go-tree-sitter/python"
+
+ "github.com/lukaszraczylo/claude-mnemonic/internal/chunking"
+)
+
+// Chunker implements AST-aware chunking for Python files.
+type Chunker struct {
+ parser *sitter.Parser
+ options chunking.ChunkOptions
+}
+
+// NewChunker creates a new Python chunker.
+func NewChunker(options chunking.ChunkOptions) *Chunker {
+ parser := sitter.NewParser()
+ parser.SetLanguage(python.GetLanguage())
+
+ return &Chunker{
+ options: options,
+ parser: parser,
+ }
+}
+
+// Language returns the language this chunker supports.
+func (c *Chunker) Language() chunking.Language {
+ return chunking.LanguagePython
+}
+
+// SupportedExtensions returns the file extensions this chunker handles.
+func (c *Chunker) SupportedExtensions() []string {
+ return []string{".py"}
+}
+
+// Chunk parses a Python source file and returns semantic code chunks.
+func (c *Chunker) Chunk(ctx context.Context, filePath string) ([]chunking.Chunk, error) {
+ // Read file content
+ content, err := os.ReadFile(filePath)
+ if err != nil {
+ return nil, fmt.Errorf("read file: %w", err)
+ }
+
+ // Parse the Python file
+ tree, err := c.parser.ParseCtx(ctx, nil, content)
+ if err != nil {
+ return nil, fmt.Errorf("parse Python file: %w", err)
+ }
+ defer tree.Close()
+
+ sourceLines := strings.Split(string(content), "\n")
+ chunks := make([]chunking.Chunk, 0)
+
+ // Walk the AST and extract chunks
+ c.walkNode(tree.RootNode(), content, sourceLines, filePath, "", &chunks)
+
+ return chunks, nil
+}
+
+// walkNode recursively walks the tree-sitter AST and extracts chunks.
+func (c *Chunker) walkNode(node *sitter.Node, source []byte, sourceLines []string, filePath string, parentName string, chunks *[]chunking.Chunk) {
+ nodeType := node.Type()
+
+ switch nodeType {
+ case "function_definition":
+ chunk := c.extractFunction(node, source, sourceLines, filePath, parentName)
+ if chunk != nil {
+ *chunks = append(*chunks, *chunk)
+ }
+
+ case "class_definition":
+ chunk := c.extractClass(node, source, sourceLines, filePath)
+ if chunk != nil {
+ *chunks = append(*chunks, *chunk)
+
+ // Walk class body to find methods
+ for i := 0; i < int(node.ChildCount()); i++ {
+ child := node.Child(i)
+ if child.Type() == "block" {
+ c.walkNode(child, source, sourceLines, filePath, chunk.Name, chunks)
+ }
+ }
+ }
+ return // Don't walk children again
+
+ case "block":
+ // Walk statements in block
+ for i := 0; i < int(node.ChildCount()); i++ {
+ c.walkNode(node.Child(i), source, sourceLines, filePath, parentName, chunks)
+ }
+ return
+ }
+
+ // Walk all children
+ for i := 0; i < int(node.ChildCount()); i++ {
+ c.walkNode(node.Child(i), source, sourceLines, filePath, parentName, chunks)
+ }
+}
+
+// extractFunction extracts a function definition chunk.
+func (c *Chunker) extractFunction(node *sitter.Node, source []byte, sourceLines []string, filePath string, parentName string) *chunking.Chunk {
+ // Find function name
+ var nameNode *sitter.Node
+ for i := 0; i < int(node.ChildCount()); i++ {
+ child := node.Child(i)
+ if child.Type() == "identifier" {
+ nameNode = child
+ break
+ }
+ }
+
+ if nameNode == nil {
+ return nil
+ }
+
+ name := nameNode.Content(source)
+
+ // Skip private functions if configured
+ if !c.options.IncludePrivate && strings.HasPrefix(name, "_") && !strings.HasPrefix(name, "__") {
+ return nil
+ }
+
+ startLine := int(node.StartPoint().Row) + 1
+ endLine := int(node.EndPoint().Row) + 1
+
+ chunk := &chunking.Chunk{
+ FilePath: filePath,
+ Language: chunking.LanguagePython,
+ Name: name,
+ ParentName: parentName,
+ StartLine: startLine,
+ EndLine: endLine,
+ Content: c.extractLines(sourceLines, startLine, endLine),
+ }
+
+ // Determine if this is a method or function
+ if parentName != "" {
+ chunk.Type = chunking.ChunkTypeMethod
+ } else {
+ chunk.Type = chunking.ChunkTypeFunction
+ }
+
+ // Extract signature (def line)
+ chunk.Signature = c.extractFunctionSignature(node, source, sourceLines)
+
+ // Extract docstring as doc comment
+ if c.options.IncludeDocComments {
+ chunk.DocComment = c.extractDocstring(node, source)
+ }
+
+ return chunk
+}
+
+// extractClass extracts a class definition chunk.
+func (c *Chunker) extractClass(node *sitter.Node, source []byte, sourceLines []string, filePath string) *chunking.Chunk {
+ // Find class name
+ var nameNode *sitter.Node
+ for i := 0; i < int(node.ChildCount()); i++ {
+ child := node.Child(i)
+ if child.Type() == "identifier" {
+ nameNode = child
+ break
+ }
+ }
+
+ if nameNode == nil {
+ return nil
+ }
+
+ name := nameNode.Content(source)
+
+ // Skip private classes if configured
+ if !c.options.IncludePrivate && strings.HasPrefix(name, "_") && !strings.HasPrefix(name, "__") {
+ return nil
+ }
+
+ startLine := int(node.StartPoint().Row) + 1
+ endLine := int(node.EndPoint().Row) + 1
+
+ chunk := &chunking.Chunk{
+ FilePath: filePath,
+ Language: chunking.LanguagePython,
+ Type: chunking.ChunkTypeClass,
+ Name: name,
+ StartLine: startLine,
+ EndLine: endLine,
+ Content: c.extractLines(sourceLines, startLine, endLine),
+ }
+
+ // Extract class signature (class line)
+ chunk.Signature = c.extractClassSignature(node, source, sourceLines)
+
+ // Extract docstring as doc comment
+ if c.options.IncludeDocComments {
+ chunk.DocComment = c.extractDocstring(node, source)
+ }
+
+ return chunk
+}
+
+// extractFunctionSignature extracts the function definition line.
+func (c *Chunker) extractFunctionSignature(node *sitter.Node, source []byte, sourceLines []string) string {
+ startLine := int(node.StartPoint().Row) + 1
+
+ // Find the colon that ends the signature
+ for i := 0; i < int(node.ChildCount()); i++ {
+ child := node.Child(i)
+ if child.Type() == ":" {
+ endLine := int(child.EndPoint().Row) + 1
+ return strings.TrimSpace(c.extractLines(sourceLines, startLine, endLine))
+ }
+ }
+
+ // Fallback: just return first line
+ return strings.TrimSpace(c.extractLines(sourceLines, startLine, startLine))
+}
+
+// extractClassSignature extracts the class definition line.
+func (c *Chunker) extractClassSignature(node *sitter.Node, source []byte, sourceLines []string) string {
+ startLine := int(node.StartPoint().Row) + 1
+
+ // Find the colon that ends the signature
+ for i := 0; i < int(node.ChildCount()); i++ {
+ child := node.Child(i)
+ if child.Type() == ":" {
+ endLine := int(child.EndPoint().Row) + 1
+ return strings.TrimSpace(c.extractLines(sourceLines, startLine, endLine))
+ }
+ }
+
+ // Fallback: just return first line
+ return strings.TrimSpace(c.extractLines(sourceLines, startLine, startLine))
+}
+
+// extractDocstring extracts the docstring from a function or class.
+func (c *Chunker) extractDocstring(node *sitter.Node, source []byte) string {
+ // Find the block
+ var blockNode *sitter.Node
+ for i := 0; i < int(node.ChildCount()); i++ {
+ child := node.Child(i)
+ if child.Type() == "block" {
+ blockNode = child
+ break
+ }
+ }
+
+ if blockNode == nil {
+ return ""
+ }
+
+ // Check if first statement in block is a string (docstring)
+ for i := 0; i < int(blockNode.ChildCount()); i++ {
+ child := blockNode.Child(i)
+ if child.Type() == "expression_statement" {
+ // Check if it contains a string
+ for j := 0; j < int(child.ChildCount()); j++ {
+ grandchild := child.Child(j)
+ if grandchild.Type() == "string" {
+ docstring := grandchild.Content(source)
+ // Remove quotes
+ docstring = strings.Trim(docstring, `"'`)
+ return strings.TrimSpace(docstring)
+ }
+ }
+ }
+ }
+
+ return ""
+}
+
+// extractLines extracts a range of lines from source (1-indexed, inclusive).
+func (c *Chunker) extractLines(lines []string, start, end int) string {
+ if start < 1 || end < start || start > len(lines) {
+ return ""
+ }
+
+ startIdx := start - 1
+ endIdx := end
+ if endIdx > len(lines) {
+ endIdx = len(lines)
+ }
+
+ return strings.Join(lines[startIdx:endIdx], "\n")
+}
diff --git a/internal/chunking/types.go b/internal/chunking/types.go
new file mode 100644
index 0000000..ce639b1
--- /dev/null
+++ b/internal/chunking/types.go
@@ -0,0 +1,140 @@
+// Package chunking provides AST-aware code chunking for semantic code search.
+// Chunks code files into logical units (functions, classes, methods) that preserve
+// semantic boundaries for better vector embedding and retrieval.
+package chunking
+
+import (
+ "context"
+ "fmt"
+ "strings"
+)
+
+// ChunkType represents the type of code chunk.
+type ChunkType string
+
+const (
+ // ChunkTypeFunction represents a standalone function.
+ ChunkTypeFunction ChunkType = "function"
+ // ChunkTypeMethod represents a method on a class/struct/type.
+ ChunkTypeMethod ChunkType = "method"
+ // ChunkTypeClass represents a class or struct definition.
+ ChunkTypeClass ChunkType = "class"
+ // ChunkTypeInterface represents an interface definition.
+ ChunkTypeInterface ChunkType = "interface"
+ // ChunkTypeType represents a type alias or type definition.
+ ChunkTypeType ChunkType = "type"
+ // ChunkTypeConst represents constant declarations.
+ ChunkTypeConst ChunkType = "const"
+ // ChunkTypeVar represents variable declarations.
+ ChunkTypeVar ChunkType = "var"
+)
+
+// Language represents a programming language.
+type Language string
+
+const (
+ // LanguageGo represents the Go programming language.
+ LanguageGo Language = "go"
+ // LanguagePython represents the Python programming language.
+ LanguagePython Language = "python"
+ // LanguageTypeScript represents the TypeScript programming language.
+ LanguageTypeScript Language = "typescript"
+ // LanguageJavaScript represents the JavaScript programming language.
+ LanguageJavaScript Language = "javascript"
+)
+
+// Chunk represents a semantic code chunk with AST-derived boundaries.
+type Chunk struct {
+ Metadata map[string]interface{}
+ FilePath string
+ Language Language
+ Type ChunkType
+ Name string
+ ParentName string
+ Content string
+ Signature string
+ DocComment string
+ StartLine int
+ EndLine int
+}
+
+// Identifier returns a human-readable identifier for this chunk.
+// Format: "ParentName.Name" for methods, "Name" for top-level.
+func (c *Chunk) Identifier() string {
+ if c.ParentName != "" {
+ return fmt.Sprintf("%s.%s", c.ParentName, c.Name)
+ }
+ return c.Name
+}
+
+// LineRange returns a human-readable line range.
+// Format: "L123-L456"
+func (c *Chunk) LineRange() string {
+ return fmt.Sprintf("L%d-L%d", c.StartLine, c.EndLine)
+}
+
+// SearchableContent returns content optimized for semantic search.
+// Combines signature, doc comment, and content in a structured format.
+func (c *Chunk) SearchableContent() string {
+ var parts []string
+
+ // Include signature for functions/methods
+ if c.Signature != "" {
+ parts = append(parts, c.Signature)
+ }
+
+ // Include doc comment
+ if c.DocComment != "" {
+ parts = append(parts, c.DocComment)
+ }
+
+ // Include actual content
+ if c.Content != "" {
+ parts = append(parts, c.Content)
+ }
+
+ return strings.Join(parts, "\n\n")
+}
+
+// Chunker is the interface for language-specific code chunkers.
+type Chunker interface {
+ // Chunk parses a source file and returns semantic code chunks.
+ // Returns an error if the file cannot be parsed or read.
+ Chunk(ctx context.Context, filePath string) ([]Chunk, error)
+
+ // Language returns the language this chunker supports.
+ Language() Language
+
+ // SupportedExtensions returns file extensions this chunker handles.
+ // Example: []string{".go"} for Go chunker
+ SupportedExtensions() []string
+}
+
+// ChunkOptions provides options for chunking behavior.
+type ChunkOptions struct {
+ // MaxChunkSize is the maximum size of a chunk in bytes.
+ // Chunks larger than this will be split (respecting boundaries where possible).
+ // 0 means no limit.
+ MaxChunkSize int
+
+ // IncludeDocComments controls whether to include documentation comments.
+ IncludeDocComments bool
+
+ // IncludePrivate controls whether to include private/unexported symbols.
+ IncludePrivate bool
+
+ // MinLines is the minimum number of lines for a chunk to be included.
+ // Chunks smaller than this will be skipped.
+ // 0 means no minimum.
+ MinLines int
+}
+
+// DefaultChunkOptions returns sensible default options.
+func DefaultChunkOptions() ChunkOptions {
+ return ChunkOptions{
+ MaxChunkSize: 8192, // ~8KB per chunk (well under token limit)
+ IncludeDocComments: true,
+ IncludePrivate: true, // Include all symbols for comprehensive search
+ MinLines: 0, // No minimum - include even single-line functions
+ }
+}
diff --git a/internal/chunking/typescript/chunker.go b/internal/chunking/typescript/chunker.go
new file mode 100644
index 0000000..44029cd
--- /dev/null
+++ b/internal/chunking/typescript/chunker.go
@@ -0,0 +1,403 @@
+// Package typescript provides AST-aware chunking for TypeScript and JavaScript source files using tree-sitter.
+package typescript
+
+import (
+ "context"
+ "fmt"
+ "os"
+ "strings"
+
+ sitter "github.com/smacker/go-tree-sitter"
+ "github.com/smacker/go-tree-sitter/typescript/typescript"
+
+ "github.com/lukaszraczylo/claude-mnemonic/internal/chunking"
+)
+
+// Chunker implements AST-aware chunking for TypeScript/JavaScript files.
+type Chunker struct {
+ parser *sitter.Parser
+ options chunking.ChunkOptions
+}
+
+// NewChunker creates a new TypeScript chunker.
+func NewChunker(options chunking.ChunkOptions) *Chunker {
+ parser := sitter.NewParser()
+ parser.SetLanguage(typescript.GetLanguage())
+
+ return &Chunker{
+ options: options,
+ parser: parser,
+ }
+}
+
+// Language returns the language this chunker supports.
+func (c *Chunker) Language() chunking.Language {
+ return chunking.LanguageTypeScript
+}
+
+// SupportedExtensions returns the file extensions this chunker handles.
+func (c *Chunker) SupportedExtensions() []string {
+ return []string{".ts", ".tsx", ".js", ".jsx"}
+}
+
+// Chunk parses a TypeScript/JavaScript source file and returns semantic code chunks.
+func (c *Chunker) Chunk(ctx context.Context, filePath string) ([]chunking.Chunk, error) {
+ // Read file content
+ content, err := os.ReadFile(filePath)
+ if err != nil {
+ return nil, fmt.Errorf("read file: %w", err)
+ }
+
+ // Parse the file
+ tree, err := c.parser.ParseCtx(ctx, nil, content)
+ if err != nil {
+ return nil, fmt.Errorf("parse TypeScript file: %w", err)
+ }
+ defer tree.Close()
+
+ sourceLines := strings.Split(string(content), "\n")
+ chunks := make([]chunking.Chunk, 0)
+
+ // Walk the AST and extract chunks
+ c.walkNode(tree.RootNode(), content, sourceLines, filePath, "", &chunks)
+
+ return chunks, nil
+}
+
+// walkNode recursively walks the tree-sitter AST and extracts chunks.
+func (c *Chunker) walkNode(node *sitter.Node, source []byte, sourceLines []string, filePath string, parentName string, chunks *[]chunking.Chunk) {
+ nodeType := node.Type()
+
+ switch nodeType {
+ case "function_declaration":
+ chunk := c.extractFunction(node, source, sourceLines, filePath, parentName)
+ if chunk != nil {
+ *chunks = append(*chunks, *chunk)
+ }
+
+ case "method_definition":
+ chunk := c.extractMethod(node, source, sourceLines, filePath, parentName)
+ if chunk != nil {
+ *chunks = append(*chunks, *chunk)
+ }
+
+ case "arrow_function", "function_expression":
+ // Handle arrow functions and function expressions assigned to variables
+ chunk := c.extractFunctionExpression(node, source, sourceLines, filePath, parentName)
+ if chunk != nil {
+ *chunks = append(*chunks, *chunk)
+ }
+
+ case "class_declaration":
+ chunk := c.extractClass(node, source, sourceLines, filePath)
+ if chunk != nil {
+ *chunks = append(*chunks, *chunk)
+
+ // Walk class body to find methods
+ for i := 0; i < int(node.ChildCount()); i++ {
+ child := node.Child(i)
+ if child.Type() == "class_body" {
+ c.walkNode(child, source, sourceLines, filePath, chunk.Name, chunks)
+ }
+ }
+ }
+ return // Don't walk children again
+
+ case "interface_declaration":
+ chunk := c.extractInterface(node, source, sourceLines, filePath)
+ if chunk != nil {
+ *chunks = append(*chunks, *chunk)
+ }
+
+ case "type_alias_declaration":
+ chunk := c.extractTypeAlias(node, source, sourceLines, filePath)
+ if chunk != nil {
+ *chunks = append(*chunks, *chunk)
+ }
+ }
+
+ // Walk all children
+ for i := 0; i < int(node.ChildCount()); i++ {
+ c.walkNode(node.Child(i), source, sourceLines, filePath, parentName, chunks)
+ }
+}
+
+// extractFunction extracts a function declaration.
+func (c *Chunker) extractFunction(node *sitter.Node, source []byte, sourceLines []string, filePath string, parentName string) *chunking.Chunk {
+ name := c.findChildContent(node, "identifier", source)
+ if name == "" {
+ return nil
+ }
+
+ startLine := int(node.StartPoint().Row) + 1
+ endLine := int(node.EndPoint().Row) + 1
+
+ chunk := &chunking.Chunk{
+ FilePath: filePath,
+ Language: chunking.LanguageTypeScript,
+ Type: chunking.ChunkTypeFunction,
+ Name: name,
+ ParentName: parentName,
+ StartLine: startLine,
+ EndLine: endLine,
+ Content: c.extractLines(sourceLines, startLine, endLine),
+ Signature: c.extractFunctionSignature(node, source, sourceLines),
+ }
+
+ // Extract JSDoc comment
+ if c.options.IncludeDocComments {
+ chunk.DocComment = c.extractComment(node, source)
+ }
+
+ return chunk
+}
+
+// extractMethod extracts a method definition from a class.
+func (c *Chunker) extractMethod(node *sitter.Node, source []byte, sourceLines []string, filePath string, parentName string) *chunking.Chunk {
+ name := c.findChildContent(node, "property_identifier", source)
+ if name == "" {
+ return nil
+ }
+
+ // Skip private methods if configured
+ if !c.options.IncludePrivate && strings.HasPrefix(name, "_") {
+ return nil
+ }
+
+ startLine := int(node.StartPoint().Row) + 1
+ endLine := int(node.EndPoint().Row) + 1
+
+ chunk := &chunking.Chunk{
+ FilePath: filePath,
+ Language: chunking.LanguageTypeScript,
+ Type: chunking.ChunkTypeMethod,
+ Name: name,
+ ParentName: parentName,
+ StartLine: startLine,
+ EndLine: endLine,
+ Content: c.extractLines(sourceLines, startLine, endLine),
+ Signature: c.extractMethodSignature(node, source, sourceLines),
+ }
+
+ // Extract JSDoc comment
+ if c.options.IncludeDocComments {
+ chunk.DocComment = c.extractComment(node, source)
+ }
+
+ return chunk
+}
+
+// extractFunctionExpression extracts arrow functions and function expressions.
+func (c *Chunker) extractFunctionExpression(node *sitter.Node, source []byte, sourceLines []string, filePath string, parentName string) *chunking.Chunk {
+ // Try to find the variable name from parent
+ parent := node.Parent()
+ if parent == nil {
+ return nil
+ }
+
+ var name string
+ if parent.Type() == "variable_declarator" {
+ name = c.findChildContent(parent, "identifier", source)
+ } else if parent.Type() == "assignment_expression" {
+ // Handle const foo = () => {}
+ for i := 0; i < int(parent.ChildCount()); i++ {
+ child := parent.Child(i)
+ if child.Type() == "identifier" || child.Type() == "member_expression" {
+ name = child.Content(source)
+ break
+ }
+ }
+ }
+
+ if name == "" {
+ return nil // Anonymous function, skip
+ }
+
+ startLine := int(node.StartPoint().Row) + 1
+ endLine := int(node.EndPoint().Row) + 1
+
+ chunk := &chunking.Chunk{
+ FilePath: filePath,
+ Language: chunking.LanguageTypeScript,
+ Type: chunking.ChunkTypeFunction,
+ Name: name,
+ ParentName: parentName,
+ StartLine: startLine,
+ EndLine: endLine,
+ Content: c.extractLines(sourceLines, startLine, endLine),
+ }
+
+ return chunk
+}
+
+// extractClass extracts a class declaration.
+func (c *Chunker) extractClass(node *sitter.Node, source []byte, sourceLines []string, filePath string) *chunking.Chunk {
+ name := c.findChildContent(node, "type_identifier", source)
+ if name == "" {
+ return nil
+ }
+
+ startLine := int(node.StartPoint().Row) + 1
+ endLine := int(node.EndPoint().Row) + 1
+
+ chunk := &chunking.Chunk{
+ FilePath: filePath,
+ Language: chunking.LanguageTypeScript,
+ Type: chunking.ChunkTypeClass,
+ Name: name,
+ StartLine: startLine,
+ EndLine: endLine,
+ Content: c.extractLines(sourceLines, startLine, endLine),
+ Signature: c.extractClassSignature(node, source, sourceLines),
+ }
+
+ // Extract JSDoc comment
+ if c.options.IncludeDocComments {
+ chunk.DocComment = c.extractComment(node, source)
+ }
+
+ return chunk
+}
+
+// extractInterface extracts an interface declaration.
+func (c *Chunker) extractInterface(node *sitter.Node, source []byte, sourceLines []string, filePath string) *chunking.Chunk {
+ name := c.findChildContent(node, "type_identifier", source)
+ if name == "" {
+ return nil
+ }
+
+ startLine := int(node.StartPoint().Row) + 1
+ endLine := int(node.EndPoint().Row) + 1
+
+ chunk := &chunking.Chunk{
+ FilePath: filePath,
+ Language: chunking.LanguageTypeScript,
+ Type: chunking.ChunkTypeInterface,
+ Name: name,
+ StartLine: startLine,
+ EndLine: endLine,
+ Content: c.extractLines(sourceLines, startLine, endLine),
+ }
+
+ // Extract JSDoc comment
+ if c.options.IncludeDocComments {
+ chunk.DocComment = c.extractComment(node, source)
+ }
+
+ return chunk
+}
+
+// extractTypeAlias extracts a type alias declaration.
+func (c *Chunker) extractTypeAlias(node *sitter.Node, source []byte, sourceLines []string, filePath string) *chunking.Chunk {
+ name := c.findChildContent(node, "type_identifier", source)
+ if name == "" {
+ return nil
+ }
+
+ startLine := int(node.StartPoint().Row) + 1
+ endLine := int(node.EndPoint().Row) + 1
+
+ chunk := &chunking.Chunk{
+ FilePath: filePath,
+ Language: chunking.LanguageTypeScript,
+ Type: chunking.ChunkTypeType,
+ Name: name,
+ StartLine: startLine,
+ EndLine: endLine,
+ Content: c.extractLines(sourceLines, startLine, endLine),
+ }
+
+ return chunk
+}
+
+// findChildContent finds the first child of the given type and returns its content.
+func (c *Chunker) findChildContent(node *sitter.Node, childType string, source []byte) string {
+ for i := 0; i < int(node.ChildCount()); i++ {
+ child := node.Child(i)
+ if child.Type() == childType {
+ return child.Content(source)
+ }
+ }
+ return ""
+}
+
+// extractFunctionSignature extracts the function signature.
+func (c *Chunker) extractFunctionSignature(node *sitter.Node, source []byte, sourceLines []string) string {
+ startLine := int(node.StartPoint().Row) + 1
+
+ // Find the opening brace of the body
+ for i := 0; i < int(node.ChildCount()); i++ {
+ child := node.Child(i)
+ if child.Type() == "statement_block" {
+ endLine := int(child.StartPoint().Row) + 1
+ return strings.TrimSpace(c.extractLines(sourceLines, startLine, endLine-1))
+ }
+ }
+
+ // Fallback: just return first line
+ return strings.TrimSpace(c.extractLines(sourceLines, startLine, startLine))
+}
+
+// extractMethodSignature extracts the method signature.
+func (c *Chunker) extractMethodSignature(node *sitter.Node, source []byte, sourceLines []string) string {
+ startLine := int(node.StartPoint().Row) + 1
+
+ // Find the opening brace of the body
+ for i := 0; i < int(node.ChildCount()); i++ {
+ child := node.Child(i)
+ if child.Type() == "statement_block" {
+ endLine := int(child.StartPoint().Row) + 1
+ return strings.TrimSpace(c.extractLines(sourceLines, startLine, endLine-1))
+ }
+ }
+
+ return strings.TrimSpace(c.extractLines(sourceLines, startLine, startLine))
+}
+
+// extractClassSignature extracts the class declaration line.
+func (c *Chunker) extractClassSignature(node *sitter.Node, source []byte, sourceLines []string) string {
+ startLine := int(node.StartPoint().Row) + 1
+
+ // Find the opening brace of the class body
+ for i := 0; i < int(node.ChildCount()); i++ {
+ child := node.Child(i)
+ if child.Type() == "class_body" {
+ endLine := int(child.StartPoint().Row) + 1
+ return strings.TrimSpace(c.extractLines(sourceLines, startLine, endLine-1))
+ }
+ }
+
+ return strings.TrimSpace(c.extractLines(sourceLines, startLine, startLine))
+}
+
+// extractComment extracts JSDoc or other comments from a node.
+func (c *Chunker) extractComment(node *sitter.Node, source []byte) string {
+ // Check previous sibling for comment
+ prevSibling := node.PrevSibling()
+ if prevSibling != nil && prevSibling.Type() == "comment" {
+ comment := prevSibling.Content(source)
+ // Remove comment markers
+ comment = strings.TrimPrefix(comment, "/**")
+ comment = strings.TrimPrefix(comment, "/*")
+ comment = strings.TrimSuffix(comment, "*/")
+ comment = strings.TrimPrefix(comment, "//")
+ return strings.TrimSpace(comment)
+ }
+
+ return ""
+}
+
+// extractLines extracts a range of lines from source (1-indexed, inclusive).
+func (c *Chunker) extractLines(lines []string, start, end int) string {
+ if start < 1 || end < start || start > len(lines) {
+ return ""
+ }
+
+ startIdx := start - 1
+ endIdx := end
+ if endIdx > len(lines) {
+ endIdx = len(lines)
+ }
+
+ return strings.Join(lines[startIdx:endIdx], "\n")
+}
diff --git a/internal/config/config.go b/internal/config/config.go
index 13f952e..d2f8027 100644
--- a/internal/config/config.go
+++ b/internal/config/config.go
@@ -36,41 +36,38 @@ var CriticalConcepts = []string{
}
// Config holds the application configuration.
+// Field order optimized for memory alignment (fieldalignment).
type Config struct {
- // Worker settings
- WorkerPort int `json:"worker_port"`
-
- // Database settings
- DBPath string `json:"db_path"`
- MaxConns int `json:"max_conns"`
-
- // SDK Agent settings
- Model string `json:"model"`
- ClaudeCodePath string `json:"claude_code_path"`
-
- // Embedding settings
- EmbeddingModel string `json:"embedding_model"` // e.g., "bge-v1.5"
-
- // Reranking settings (cross-encoder)
- RerankingEnabled bool `json:"reranking_enabled"` // Enable cross-encoder reranking
- RerankingCandidates int `json:"reranking_candidates"` // Number of candidates to retrieve before reranking (default 100)
- RerankingResults int `json:"reranking_results"` // Number of results to return after reranking (default 10)
- RerankingAlpha float64 `json:"reranking_alpha"` // Weight for combining scores: alpha*rerank + (1-alpha)*original (default 0.7)
- RerankingMinImprovement float64 `json:"reranking_min_improvement"` // Minimum rank improvement to trigger reranking (default 0, always rerank)
- RerankingPureMode bool `json:"reranking_pure_mode"` // Use pure cross-encoder scores without combining with bi-encoder (default false)
-
- // Context injection settings
+ ContextFullField string `json:"context_full_field"`
+ DBPath string `json:"db_path"`
+ Model string `json:"model"`
+ ClaudeCodePath string `json:"claude_code_path"`
+ EmbeddingModel string `json:"embedding_model"`
+ VectorStorageStrategy string `json:"vector_storage_strategy"`
+ ContextObsConcepts []string `json:"context_obs_concepts"`
+ ContextObsTypes []string `json:"context_obs_types"`
+ ContextMaxPromptResults int `json:"context_max_prompt_results"`
+ RerankingResults int `json:"reranking_results"`
+ GraphEdgeWeight float64 `json:"graph_edge_weight"`
+ ContextRelevanceThreshold float64 `json:"context_relevance_threshold"`
+ RerankingCandidates int `json:"reranking_candidates"`
+ WorkerPort int `json:"worker_port"`
+ RerankingMinImprovement float64 `json:"reranking_min_improvement"`
ContextObservations int `json:"context_observations"`
ContextFullCount int `json:"context_full_count"`
ContextSessionCount int `json:"context_session_count"`
- ContextShowReadTokens bool `json:"context_show_read_tokens"`
- ContextShowWorkTokens bool `json:"context_show_work_tokens"`
- ContextFullField string `json:"context_full_field"`
+ MaxConns int `json:"max_conns"`
+ RerankingAlpha float64 `json:"reranking_alpha"`
+ GraphMaxHops int `json:"graph_max_hops"`
+ GraphBranchFactor int `json:"graph_branch_factor"`
+ GraphRebuildIntervalMin int `json:"graph_rebuild_interval_min"`
+ HubThreshold int `json:"hub_threshold"`
ContextShowLastSummary bool `json:"context_show_last_summary"`
- ContextObsTypes []string `json:"context_obs_types"`
- ContextObsConcepts []string `json:"context_obs_concepts"`
- ContextRelevanceThreshold float64 `json:"context_relevance_threshold"` // 0.0-1.0, minimum similarity for inclusion
- ContextMaxPromptResults int `json:"context_max_prompt_results"` // Max results per prompt (0 = threshold only)
+ RerankingEnabled bool `json:"reranking_enabled"`
+ ContextShowWorkTokens bool `json:"context_show_work_tokens"`
+ ContextShowReadTokens bool `json:"context_show_read_tokens"`
+ RerankingPureMode bool `json:"reranking_pure_mode"`
+ GraphEnabled bool `json:"graph_enabled"`
}
var (
@@ -143,11 +140,18 @@ func Default() *Config {
MaxConns: 4,
Model: DefaultModel,
EmbeddingModel: DefaultEmbeddingModel,
- RerankingEnabled: true, // Enable by default for improved relevance
- RerankingCandidates: 100, // Retrieve top 100 candidates
- RerankingResults: 10, // Return top 10 after reranking
- RerankingAlpha: 0.7, // Favor cross-encoder score
- RerankingMinImprovement: 0, // Always apply reranking
+ RerankingEnabled: true, // Enable by default for improved relevance
+ RerankingCandidates: 100, // Retrieve top 100 candidates
+ RerankingResults: 10, // Return top 10 after reranking
+ RerankingAlpha: 0.7, // Favor cross-encoder score
+ RerankingMinImprovement: 0, // Always apply reranking
+ GraphEnabled: true, // Enable graph-aware search by default
+ GraphMaxHops: 2, // Two-hop traversal
+ GraphBranchFactor: 5, // Expand top 5 neighbors per node
+ GraphEdgeWeight: 0.3, // Minimum edge weight to follow
+ GraphRebuildIntervalMin: 60, // Rebuild graph every 60 minutes
+ VectorStorageStrategy: "hub", // Hub storage strategy (LEANN-inspired)
+ HubThreshold: 5, // Require 5+ accesses to store embedding
ContextObservations: 100,
ContextFullCount: 25,
ContextSessionCount: 10,
@@ -233,6 +237,29 @@ func Load() (*Config, error) {
if v, ok := settings["CLAUDE_MNEMONIC_CONTEXT_MAX_PROMPT_RESULTS"].(float64); ok && v >= 0 {
cfg.ContextMaxPromptResults = int(v)
}
+ // Graph settings
+ if v, ok := settings["CLAUDE_MNEMONIC_GRAPH_ENABLED"].(bool); ok {
+ cfg.GraphEnabled = v
+ }
+ if v, ok := settings["CLAUDE_MNEMONIC_GRAPH_MAX_HOPS"].(float64); ok && v > 0 {
+ cfg.GraphMaxHops = int(v)
+ }
+ if v, ok := settings["CLAUDE_MNEMONIC_GRAPH_BRANCH_FACTOR"].(float64); ok && v > 0 {
+ cfg.GraphBranchFactor = int(v)
+ }
+ if v, ok := settings["CLAUDE_MNEMONIC_GRAPH_EDGE_WEIGHT"].(float64); ok && v >= 0 && v <= 1 {
+ cfg.GraphEdgeWeight = v
+ }
+ if v, ok := settings["CLAUDE_MNEMONIC_GRAPH_REBUILD_INTERVAL_MIN"].(float64); ok && v > 0 {
+ cfg.GraphRebuildIntervalMin = int(v)
+ }
+ // Vector storage settings (LEANN Phase 2)
+ if v, ok := settings["CLAUDE_MNEMONIC_VECTOR_STORAGE_STRATEGY"].(string); ok && v != "" {
+ cfg.VectorStorageStrategy = v
+ }
+ if v, ok := settings["CLAUDE_MNEMONIC_HUB_THRESHOLD"].(float64); ok && v > 0 {
+ cfg.HubThreshold = int(v)
+ }
return cfg, nil
}
diff --git a/internal/config/config_test.go b/internal/config/config_test.go
index bc8bc4e..02dfd0c 100644
--- a/internal/config/config_test.go
+++ b/internal/config/config_test.go
@@ -121,8 +121,8 @@ func (s *ConfigSuite) TestLoad_TableDriven() {
tests := []struct {
name string
settingsJSON string
- expectedPort int
expectedModel string
+ expectedPort int
expectedObsObs int
}{
{
@@ -183,12 +183,12 @@ func (s *ConfigSuite) TestLoad_TableDriven() {
s.Require().NoError(err)
if tt.settingsJSON != "" {
- err := os.WriteFile(
+ writeErr := os.WriteFile(
filepath.Join(tempDir, ".claude-mnemonic", "settings.json"),
[]byte(tt.settingsJSON),
0600,
)
- s.Require().NoError(err)
+ s.Require().NoError(writeErr)
}
cfg, err := Load()
diff --git a/internal/db/gorm/conflict_store.go b/internal/db/gorm/conflict_store.go
index cae45b5..6535f7a 100644
--- a/internal/db/gorm/conflict_store.go
+++ b/internal/db/gorm/conflict_store.go
@@ -214,9 +214,9 @@ func (s *ConflictStore) CleanupSupersededObservations(ctx context.Context, proje
// GetConflictsWithDetails retrieves all conflicts with observation titles for display.
func (s *ConflictStore) GetConflictsWithDetails(ctx context.Context, project string, limit int) ([]*ConflictWithDetails, error) {
var results []struct {
- ObservationConflict
NewerTitle sql.NullString `gorm:"column:newer_title"`
OlderTitle sql.NullString `gorm:"column:older_title"`
+ ObservationConflict
}
err := s.db.WithContext(ctx).
diff --git a/internal/db/gorm/models.go b/internal/db/gorm/models.go
index e183a7a..f331561 100644
--- a/internal/db/gorm/models.go
+++ b/internal/db/gorm/models.go
@@ -17,18 +17,18 @@ import (
// SDKSession represents a Claude Code session.
type SDKSession struct {
- ID int64 `gorm:"primaryKey;autoIncrement"`
ClaudeSessionID string `gorm:"uniqueIndex;not null"`
- SDKSessionID sql.NullString `gorm:"uniqueIndex"`
Project string `gorm:"index;not null"`
+ Status string `gorm:"type:text;check:status IN ('active', 'completed', 'failed');default:'active';index"`
+ StartedAt string `gorm:"not null"`
+ SDKSessionID sql.NullString `gorm:"uniqueIndex"`
UserPrompt sql.NullString
- WorkerPort sql.NullInt64
- PromptCounter int `gorm:"default:0"`
- Status string `gorm:"type:text;check:status IN ('active', 'completed', 'failed');default:'active';index"`
- StartedAt string `gorm:"not null"`
- StartedAtEpoch int64 `gorm:"index:idx_sessions_started,sort:desc;not null"`
CompletedAt sql.NullString
+ WorkerPort sql.NullInt64
CompletedAtEpoch sql.NullInt64
+ ID int64 `gorm:"primaryKey;autoIncrement"`
+ PromptCounter int `gorm:"default:0"`
+ StartedAtEpoch int64 `gorm:"index:idx_sessions_started,sort:desc;not null"`
}
func (SDKSession) TableName() string { return "sdk_sessions" }
@@ -46,34 +46,28 @@ func (s *SDKSession) BeforeCreate(tx *gorm.DB) error {
// Observation represents a stored observation (learning).
type Observation struct {
- ID int64 `gorm:"primaryKey;autoIncrement"`
- SDKSessionID string `gorm:"index;not null"`
- Project string `gorm:"index;not null"`
- Scope models.ObservationScope `gorm:"type:text;default:'project';check:scope IN ('project', 'global');index:idx_observations_scope;index:idx_observations_project_scope,priority:2"`
- Type models.ObservationType `gorm:"type:text;check:type IN ('decision', 'bugfix', 'feature', 'refactor', 'discovery', 'change');index;not null"`
-
- // Content fields
- Title sql.NullString `gorm:"type:text"`
- Subtitle sql.NullString `gorm:"type:text"`
- Facts models.JSONStringArray `gorm:"type:text"` // JSON array
- Narrative sql.NullString `gorm:"type:text"`
- Concepts models.JSONStringArray `gorm:"type:text"` // JSON array
- FilesRead models.JSONStringArray `gorm:"type:text"` // JSON array
- FilesModified models.JSONStringArray `gorm:"type:text"` // JSON array
- FileMtimes models.JSONInt64Map `gorm:"type:text"` // JSON object
-
- // Metadata
+ FileMtimes models.JSONInt64Map `gorm:"type:text"`
+ SDKSessionID string `gorm:"index;not null"`
+ Project string `gorm:"index;not null"`
+ Scope models.ObservationScope `gorm:"type:text;default:'project';check:scope IN ('project', 'global');index:idx_observations_scope;index:idx_observations_project_scope,priority:2"`
+ Type models.ObservationType `gorm:"type:text;check:type IN ('decision', 'bugfix', 'feature', 'refactor', 'discovery', 'change');index;not null"`
+ CreatedAt string `gorm:"not null"`
+ Title sql.NullString `gorm:"type:text"`
+ Narrative sql.NullString `gorm:"type:text"`
+ Concepts models.JSONStringArray `gorm:"type:text"`
+ FilesRead models.JSONStringArray `gorm:"type:text"`
+ FilesModified models.JSONStringArray `gorm:"type:text"`
+ Subtitle sql.NullString `gorm:"type:text"`
+ Facts models.JSONStringArray `gorm:"type:text"`
+ LastRetrievedAt sql.NullInt64 `gorm:"column:last_retrieved_at_epoch"`
PromptNumber sql.NullInt64
- DiscoveryTokens int64 `gorm:"default:0"`
- CreatedAt string `gorm:"not null"`
- CreatedAtEpoch int64 `gorm:"index:idx_observations_created,sort:desc;not null"`
-
- // Importance scoring fields
+ ScoreUpdatedAt sql.NullInt64 `gorm:"column:score_updated_at_epoch;index:idx_observations_score_updated"`
+ ID int64 `gorm:"primaryKey;autoIncrement"`
ImportanceScore float64 `gorm:"type:real;default:1.0;index:idx_observations_importance,priority:1,sort:desc"`
UserFeedback int `gorm:"default:0"`
RetrievalCount int `gorm:"default:0"`
- LastRetrievedAt sql.NullInt64 `gorm:"column:last_retrieved_at_epoch"`
- ScoreUpdatedAt sql.NullInt64 `gorm:"column:score_updated_at_epoch;index:idx_observations_score_updated"`
+ CreatedAtEpoch int64 `gorm:"index:idx_observations_created,sort:desc;not null"`
+ DiscoveryTokens int64 `gorm:"default:0"`
IsSuperseded int `gorm:"default:0;index:idx_observations_superseded,priority:1"`
}
@@ -95,23 +89,19 @@ func (o *Observation) BeforeCreate(tx *gorm.DB) error {
// SessionSummary represents a session summary.
type SessionSummary struct {
- ID int64 `gorm:"primaryKey;autoIncrement"`
- SDKSessionID string `gorm:"index;not null"`
- Project string `gorm:"index;not null"`
-
- // Summary fields (nullable TEXT)
- Request sql.NullString
- Investigated sql.NullString
- Learned sql.NullString
- Completed sql.NullString
- NextSteps sql.NullString `gorm:"column:next_steps"`
- Notes sql.NullString
-
- // Metadata
- PromptNumber sql.NullInt64
- DiscoveryTokens int64 `gorm:"default:0"`
CreatedAt string `gorm:"not null"`
- CreatedAtEpoch int64 `gorm:"index:idx_summaries_created,sort:desc;not null"`
+ SDKSessionID string `gorm:"index;not null"`
+ Project string `gorm:"index;not null"`
+ Completed sql.NullString
+ Investigated sql.NullString
+ Learned sql.NullString
+ NextSteps sql.NullString `gorm:"column:next_steps"`
+ Notes sql.NullString
+ Request sql.NullString
+ PromptNumber sql.NullInt64
+ ID int64 `gorm:"primaryKey;autoIncrement"`
+ DiscoveryTokens int64 `gorm:"default:0"`
+ CreatedAtEpoch int64 `gorm:"index:idx_summaries_created,sort:desc;not null"`
}
func (SessionSummary) TableName() string { return "session_summaries" }
@@ -129,12 +119,12 @@ func (s *SessionSummary) BeforeCreate(tx *gorm.DB) error {
// UserPrompt represents a user prompt.
type UserPrompt struct {
- ID int64 `gorm:"primaryKey;autoIncrement"`
ClaudeSessionID string `gorm:"index;not null;uniqueIndex:idx_user_prompts_session_number_unique,priority:1"`
- PromptNumber int `gorm:"index;not null;uniqueIndex:idx_user_prompts_session_number_unique,priority:2"`
PromptText string `gorm:"type:text;not null"`
- MatchedObservations int `gorm:"default:0"`
CreatedAt string `gorm:"not null"`
+ ID int64 `gorm:"primaryKey;autoIncrement"`
+ PromptNumber int `gorm:"index;not null;uniqueIndex:idx_user_prompts_session_number_unique,priority:2"`
+ MatchedObservations int `gorm:"default:0"`
CreatedAtEpoch int64 `gorm:"index:idx_prompts_created,sort:desc;not null"`
}
@@ -153,16 +143,16 @@ func (p *UserPrompt) BeforeCreate(tx *gorm.DB) error {
// ObservationConflict tracks conflicts between observations.
type ObservationConflict struct {
- ID int64 `gorm:"primaryKey;autoIncrement"`
- NewerObsID int64 `gorm:"index:idx_conflicts_newer;not null"`
- OlderObsID int64 `gorm:"index:idx_conflicts_older;not null"`
ConflictType models.ConflictType `gorm:"type:text;check:conflict_type IN ('superseded', 'contradicts', 'outdated_pattern');not null"`
Resolution models.ConflictResolution `gorm:"type:text;check:resolution IN ('prefer_newer', 'prefer_older', 'manual');not null"`
- Reason sql.NullString `gorm:"type:text"`
DetectedAt string `gorm:"not null"`
- DetectedAtEpoch int64 `gorm:"index:idx_conflicts_unresolved,priority:2,sort:desc;not null"`
- Resolved int `gorm:"default:0;index:idx_conflicts_unresolved,priority:1"`
+ Reason sql.NullString `gorm:"type:text"`
ResolvedAt sql.NullString
+ ID int64 `gorm:"primaryKey;autoIncrement"`
+ NewerObsID int64 `gorm:"index:idx_conflicts_newer;not null"`
+ OlderObsID int64 `gorm:"index:idx_conflicts_older;not null"`
+ DetectedAtEpoch int64 `gorm:"index:idx_conflicts_unresolved,priority:2,sort:desc;not null"`
+ Resolved int `gorm:"default:0;index:idx_conflicts_unresolved,priority:1"`
}
func (ObservationConflict) TableName() string { return "observation_conflicts" }
@@ -180,14 +170,14 @@ func (c *ObservationConflict) BeforeCreate(tx *gorm.DB) error {
// ObservationRelation tracks relationships between observations.
type ObservationRelation struct {
+ RelationType models.RelationType `gorm:"type:text;check:relation_type IN ('causes', 'fixes', 'supersedes', 'depends_on', 'relates_to', 'evolves_from');index:idx_relations_type;uniqueIndex:idx_relations_unique,priority:3;not null"`
+ DetectionSource models.RelationDetectionSource `gorm:"type:text;check:detection_source IN ('file_overlap', 'embedding_similarity', 'temporal_proximity', 'narrative_mention', 'concept_overlap', 'type_progression');not null"`
+ CreatedAt string `gorm:"not null"`
+ Reason sql.NullString `gorm:"type:text"`
ID int64 `gorm:"primaryKey;autoIncrement"`
SourceID int64 `gorm:"index:idx_relations_source;index:idx_relations_both,priority:1;uniqueIndex:idx_relations_unique,priority:1;not null"`
TargetID int64 `gorm:"index:idx_relations_target;index:idx_relations_both,priority:2;uniqueIndex:idx_relations_unique,priority:2;not null"`
- RelationType models.RelationType `gorm:"type:text;check:relation_type IN ('causes', 'fixes', 'supersedes', 'depends_on', 'relates_to', 'evolves_from');index:idx_relations_type;uniqueIndex:idx_relations_unique,priority:3;not null"`
Confidence float64 `gorm:"type:real;default:0.5;index:idx_relations_confidence,sort:desc;not null"`
- DetectionSource models.RelationDetectionSource `gorm:"type:text;check:detection_source IN ('file_overlap', 'embedding_similarity', 'temporal_proximity', 'narrative_mention', 'concept_overlap', 'type_progression');not null"`
- Reason sql.NullString `gorm:"type:text"`
- CreatedAt string `gorm:"not null"`
CreatedAtEpoch int64 `gorm:"not null"`
}
@@ -209,21 +199,21 @@ func (r *ObservationRelation) BeforeCreate(tx *gorm.DB) error {
// Pattern represents a detected recurring pattern.
type Pattern struct {
- ID int64 `gorm:"primaryKey;autoIncrement"`
+ Status models.PatternStatus `gorm:"type:text;default:'active';check:status IN ('active', 'deprecated', 'merged');index"`
Name string `gorm:"type:text;not null"`
Type models.PatternType `gorm:"type:text;check:type IN ('bug', 'refactor', 'architecture', 'anti-pattern', 'best-practice');index;not null"`
- Description sql.NullString `gorm:"type:text"`
- Signature models.JSONStringArray `gorm:"type:text"` // JSON array of keywords
+ CreatedAt string `gorm:"not null"`
+ LastSeenAt string `gorm:"not null"`
+ Signature models.JSONStringArray `gorm:"type:text"`
+ Projects models.JSONStringArray `gorm:"type:text"`
+ ObservationIDs models.JSONInt64Array `gorm:"type:text"`
Recommendation sql.NullString `gorm:"type:text"`
- Frequency int `gorm:"default:1;index:idx_patterns_frequency,sort:desc"`
- Projects models.JSONStringArray `gorm:"type:text"` // JSON array
- ObservationIDs models.JSONInt64Array `gorm:"type:text"` // JSON array
- Status models.PatternStatus `gorm:"type:text;default:'active';check:status IN ('active', 'deprecated', 'merged');index"`
+ Description sql.NullString `gorm:"type:text"`
MergedIntoID sql.NullInt64
+ Frequency int `gorm:"default:1;index:idx_patterns_frequency,sort:desc"`
Confidence float64 `gorm:"type:real;default:0.5;index:idx_patterns_confidence,sort:desc"`
- LastSeenAt string `gorm:"not null"`
+ ID int64 `gorm:"primaryKey;autoIncrement"`
LastSeenAtEpoch int64 `gorm:"index:idx_patterns_last_seen,sort:desc;not null"`
- CreatedAt string `gorm:"not null"`
CreatedAtEpoch int64 `gorm:"not null"`
}
@@ -256,8 +246,8 @@ func (p *Pattern) BeforeCreate(tx *gorm.DB) error {
// ConceptWeight stores configurable weights for importance scoring.
type ConceptWeight struct {
Concept string `gorm:"primaryKey;type:text"`
- Weight float64 `gorm:"type:real;not null;default:0.1"`
UpdatedAt string `gorm:"not null"`
+ Weight float64 `gorm:"type:real;not null;default:0.1"`
}
func (ConceptWeight) TableName() string { return "concept_weights" }
diff --git a/internal/db/gorm/prompt_store.go b/internal/db/gorm/prompt_store.go
index c377044..86519cf 100644
--- a/internal/db/gorm/prompt_store.go
+++ b/internal/db/gorm/prompt_store.go
@@ -145,9 +145,9 @@ func (s *PromptStore) GetPromptsByIDs(ctx context.Context, ids []int64, orderBy
}
var results []struct {
- UserPrompt
Project sql.NullString `gorm:"column:project"`
SDKSessionID sql.NullString `gorm:"column:sdk_session_id"`
+ UserPrompt
}
query := s.db.WithContext(ctx).
@@ -184,9 +184,9 @@ func (s *PromptStore) GetPromptsByIDs(ctx context.Context, ids []int64, orderBy
// GetAllRecentUserPrompts retrieves recent user prompts across all projects.
func (s *PromptStore) GetAllRecentUserPrompts(ctx context.Context, limit int) ([]*models.UserPromptWithSession, error) {
var results []struct {
- UserPrompt
Project sql.NullString `gorm:"column:project"`
SDKSessionID sql.NullString `gorm:"column:sdk_session_id"`
+ UserPrompt
}
query := s.db.WithContext(ctx).
@@ -211,9 +211,9 @@ func (s *PromptStore) GetAllRecentUserPrompts(ctx context.Context, limit int) ([
// GetAllPrompts retrieves all user prompts (for vector rebuild).
func (s *PromptStore) GetAllPrompts(ctx context.Context) ([]*models.UserPromptWithSession, error) {
var results []struct {
- UserPrompt
Project sql.NullString `gorm:"column:project"`
SDKSessionID sql.NullString `gorm:"column:sdk_session_id"`
+ UserPrompt
}
query := s.db.WithContext(ctx).
@@ -256,9 +256,9 @@ func (s *PromptStore) FindRecentPromptByText(ctx context.Context, claudeSessionI
// GetRecentUserPromptsByProject retrieves recent user prompts for a specific project.
func (s *PromptStore) GetRecentUserPromptsByProject(ctx context.Context, project string, limit int) ([]*models.UserPromptWithSession, error) {
var results []struct {
- UserPrompt
Project sql.NullString `gorm:"column:project"`
SDKSessionID sql.NullString `gorm:"column:sdk_session_id"`
+ UserPrompt
}
query := s.db.WithContext(ctx).
@@ -283,9 +283,9 @@ func (s *PromptStore) GetRecentUserPromptsByProject(ctx context.Context, project
// toModelUserPromptsWithSession converts query results to pkg/models.UserPromptWithSession.
func toModelUserPromptsWithSession(results []struct {
- UserPrompt
Project sql.NullString `gorm:"column:project"`
SDKSessionID sql.NullString `gorm:"column:sdk_session_id"`
+ UserPrompt
}) []*models.UserPromptWithSession {
prompts := make([]*models.UserPromptWithSession, len(results))
for i, r := range results {
diff --git a/internal/db/gorm/relation_store.go b/internal/db/gorm/relation_store.go
index 0f084bf..acef0b3 100644
--- a/internal/db/gorm/relation_store.go
+++ b/internal/db/gorm/relation_store.go
@@ -171,11 +171,11 @@ func (s *RelationStore) GetRelationsByType(ctx context.Context, relationType mod
// GetRelationsWithDetails retrieves relations with observation titles for display.
func (s *RelationStore) GetRelationsWithDetails(ctx context.Context, obsID int64) ([]*models.RelationWithDetails, error) {
var results []struct {
- ObservationRelation
- SourceTitle sql.NullString `gorm:"column:source_title"`
- TargetTitle sql.NullString `gorm:"column:target_title"`
SourceType string `gorm:"column:source_type"`
TargetType string `gorm:"column:target_type"`
+ SourceTitle sql.NullString `gorm:"column:source_title"`
+ TargetTitle sql.NullString `gorm:"column:target_title"`
+ ObservationRelation
}
err := s.db.WithContext(ctx).
diff --git a/internal/db/gorm/store.go b/internal/db/gorm/store.go
index aab5203..2e2dfa6 100644
--- a/internal/db/gorm/store.go
+++ b/internal/db/gorm/store.go
@@ -88,6 +88,11 @@ func NewStore(cfg Config) (*Store, error) {
if _, err := sqlDB.Exec("PRAGMA synchronous=NORMAL"); err != nil {
return nil, fmt.Errorf("set synchronous mode: %w", err)
}
+ // Set busy timeout to 5 seconds to handle concurrent writes
+ // This allows SQLite to retry when database is locked instead of failing immediately
+ if _, err := sqlDB.Exec("PRAGMA busy_timeout=5000"); err != nil {
+ return nil, fmt.Errorf("set busy timeout: %w", err)
+ }
return store, nil
}
diff --git a/internal/embedding/model.go b/internal/embedding/model.go
index 3a40838..2c1124d 100644
--- a/internal/embedding/model.go
+++ b/internal/embedding/model.go
@@ -21,15 +21,10 @@ const (
// ONNXConfig describes ONNX-specific model configuration.
// This allows different models to specify their tensor names and pooling needs.
type ONNXConfig struct {
- // InputNames are the ONNX input tensor names in order.
- InputNames []string
- // OutputNames are the ONNX output tensor names.
+ Pooling PoolingStrategy
+ InputNames []string
OutputNames []string
- // Pooling specifies how to convert token embeddings to sentence embeddings.
- // If PoolingNone, the model outputs sentence embeddings directly.
- Pooling PoolingStrategy
- // HiddenSize is the embedding dimension (used for pooling calculations).
- HiddenSize int
+ HiddenSize int
}
// EmbeddingModel represents a text embedding model.
@@ -62,11 +57,11 @@ type ONNXConfigurer interface {
// ModelMetadata describes an embedding model for UI/config.
type ModelMetadata struct {
- Name string `json:"name"` // Human-readable name
- Version string `json:"version"` // Short ID for DB storage
- Dimensions int `json:"dimensions"` // Vector size
- Description string `json:"description"` // Brief description
- Default bool `json:"default"` // Is this the default model?
+ Name string `json:"name"`
+ Version string `json:"version"`
+ Description string `json:"description"`
+ Dimensions int `json:"dimensions"`
+ Default bool `json:"default"`
}
// ModelFactory creates a new instance of an embedding model.
@@ -74,10 +69,10 @@ type ModelFactory func() (EmbeddingModel, error)
// ModelRegistry provides model lookup by version.
type ModelRegistry struct {
- mu sync.RWMutex
models map[string]ModelFactory
metadata map[string]ModelMetadata
defaultModel string
+ mu sync.RWMutex
}
// NewModelRegistry creates a new model registry.
diff --git a/internal/embedding/service.go b/internal/embedding/service.go
index f6be984..94fd51c 100644
--- a/internal/embedding/service.go
+++ b/internal/embedding/service.go
@@ -46,9 +46,9 @@ var bgeONNXConfig = ONNXConfig{
type bgeModel struct {
tk *tokenizer.Tokenizer
session *ort.DynamicAdvancedSession
+ libDir string
+ config ONNXConfig
mu sync.Mutex
- libDir string // temp directory containing extracted libraries
- config ONNXConfig // ONNX configuration for this model
}
// Compile-time check that bgeModel implements EmbeddingModel
diff --git a/internal/graph/edge_detector.go b/internal/graph/edge_detector.go
new file mode 100644
index 0000000..0770010
--- /dev/null
+++ b/internal/graph/edge_detector.go
@@ -0,0 +1,417 @@
+package graph
+
+import (
+ "context"
+ "fmt"
+ "math"
+
+ "github.com/lukaszraczylo/claude-mnemonic/pkg/models"
+ "github.com/rs/zerolog/log"
+)
+
+const (
+ // SemanticSimilarityThreshold for creating semantic edges
+ SemanticSimilarityThreshold = 0.85
+
+ // MinFileOverlapForEdge minimum file overlap ratio to create edge
+ MinFileOverlapForEdge = 0.3
+
+ // MaxEdgesPerNode prevents creating too many edges
+ MaxEdgesPerNode = 20
+)
+
+// DetectEdges identifies relationships between observations
+func DetectEdges(ctx context.Context, observations []*models.Observation) ([]Edge, error) {
+ if len(observations) < 2 {
+ return nil, nil
+ }
+
+ edges := make([]Edge, 0)
+
+ // Build lookup maps for efficient detection
+ sessionMap := buildSessionMap(observations)
+ conceptMap := buildConceptMap(observations)
+ fileMap := buildFileMap(observations)
+
+ log.Info().
+ Int("observations", len(observations)).
+ Int("sessions", len(sessionMap)).
+ Int("concepts", len(conceptMap)).
+ Msg("Starting edge detection")
+
+ // Detect temporal edges (same session)
+ temporalEdges := detectTemporalEdges(sessionMap)
+ edges = append(edges, temporalEdges...)
+
+ // Detect concept edges (shared tags)
+ conceptEdges := detectConceptEdges(conceptMap)
+ edges = append(edges, conceptEdges...)
+
+ // Detect file overlap edges
+ fileEdges := detectFileOverlapEdges(fileMap, observations)
+ edges = append(edges, fileEdges...)
+
+ // Prune excessive edges per node
+ edges = pruneEdges(edges, MaxEdgesPerNode)
+
+ log.Info().
+ Int("temporal_edges", len(temporalEdges)).
+ Int("concept_edges", len(conceptEdges)).
+ Int("file_edges", len(fileEdges)).
+ Int("total_edges", len(edges)).
+ Msg("Edge detection complete")
+
+ return edges, nil
+}
+
+// buildSessionMap groups observations by SDK session
+func buildSessionMap(observations []*models.Observation) map[string][]int64 {
+ sessionMap := make(map[string][]int64)
+
+ for _, obs := range observations {
+ if obs.SDKSessionID != "" {
+ sessionMap[obs.SDKSessionID] = append(sessionMap[obs.SDKSessionID], obs.ID)
+ }
+ }
+
+ return sessionMap
+}
+
+// buildConceptMap groups observations by concept tags
+func buildConceptMap(observations []*models.Observation) map[string][]int64 {
+ conceptMap := make(map[string][]int64)
+
+ for _, obs := range observations {
+ for _, concept := range obs.Concepts {
+ conceptMap[concept] = append(conceptMap[concept], obs.ID)
+ }
+ }
+
+ return conceptMap
+}
+
+// buildFileMap maps files to observations (from both FilesRead and FilesModified)
+func buildFileMap(observations []*models.Observation) map[string][]int64 {
+ fileMap := make(map[string][]int64)
+
+ for _, obs := range observations {
+ // Add files from FilesRead
+ for _, file := range obs.FilesRead {
+ fileMap[file] = append(fileMap[file], obs.ID)
+ }
+ // Add files from FilesModified
+ for _, file := range obs.FilesModified {
+ fileMap[file] = append(fileMap[file], obs.ID)
+ }
+ }
+
+ return fileMap
+}
+
+// detectTemporalEdges creates edges between observations in the same session
+func detectTemporalEdges(sessionMap map[string][]int64) []Edge {
+ edges := make([]Edge, 0)
+
+ for _, obsIDs := range sessionMap {
+ if len(obsIDs) < 2 {
+ continue
+ }
+
+ // Create edges between consecutive observations in session
+ for i := 0; i < len(obsIDs)-1; i++ {
+ edges = append(edges, Edge{
+ FromID: obsIDs[i],
+ ToID: obsIDs[i+1],
+ Relation: RelationTemporal,
+ Weight: 0.8, // High weight for temporal proximity
+ })
+ }
+ }
+
+ return edges
+}
+
+// detectConceptEdges creates edges between observations sharing concepts
+func detectConceptEdges(conceptMap map[string][]int64) []Edge {
+ edges := make([]Edge, 0)
+ seen := make(map[string]bool)
+
+ for concept, obsIDs := range conceptMap {
+ if len(obsIDs) < 2 {
+ continue
+ }
+
+ // Create edges between all observations sharing this concept
+ for i := 0; i < len(obsIDs); i++ {
+ for j := i + 1; j < len(obsIDs); j++ {
+ // Use sorted pair as key to avoid duplicates
+ pairKey := edgeKey(obsIDs[i], obsIDs[j])
+ if seen[pairKey] {
+ continue
+ }
+ seen[pairKey] = true
+
+ // Weight based on concept specificity (longer = more specific)
+ weight := float32(0.5 + 0.3*math.Min(1.0, float64(len(concept))/20.0))
+
+ edges = append(edges, Edge{
+ FromID: obsIDs[i],
+ ToID: obsIDs[j],
+ Relation: RelationConcept,
+ Weight: weight,
+ })
+ }
+ }
+ }
+
+ return edges
+}
+
+// detectFileOverlapEdges creates edges based on file references
+func detectFileOverlapEdges(fileMap map[string][]int64, observations []*models.Observation) []Edge {
+ edges := make([]Edge, 0)
+ seen := make(map[string]bool)
+
+ // Build observation ID to observation map for quick lookup
+ obsMap := make(map[int64]*models.Observation)
+ for _, obs := range observations {
+ obsMap[obs.ID] = obs
+ }
+
+ for _, obsIDs := range fileMap {
+ if len(obsIDs) < 2 {
+ continue
+ }
+
+ // Create edges between observations referencing same files
+ for i := 0; i < len(obsIDs); i++ {
+ for j := i + 1; j < len(obsIDs); j++ {
+ pairKey := edgeKey(obsIDs[i], obsIDs[j])
+ if seen[pairKey] {
+ continue
+ }
+ seen[pairKey] = true
+
+ // Calculate file overlap ratio
+ obs1, ok1 := obsMap[obsIDs[i]]
+ obs2, ok2 := obsMap[obsIDs[j]]
+
+ if !ok1 || !ok2 {
+ continue
+ }
+
+ // Merge FilesRead and FilesModified for both observations
+ files1 := append([]string{}, obs1.FilesRead...)
+ files1 = append(files1, obs1.FilesModified...)
+ files2 := append([]string{}, obs2.FilesRead...)
+ files2 = append(files2, obs2.FilesModified...)
+
+ overlap := calculateFileOverlap(files1, files2)
+ if overlap < MinFileOverlapForEdge {
+ continue
+ }
+
+ edges = append(edges, Edge{
+ FromID: obsIDs[i],
+ ToID: obsIDs[j],
+ Relation: RelationFileOverlap,
+ Weight: overlap,
+ })
+ }
+ }
+ }
+
+ return edges
+}
+
+// calculateFileOverlap computes Jaccard similarity of file sets
+func calculateFileOverlap(files1, files2 []string) float32 {
+ if len(files1) == 0 || len(files2) == 0 {
+ return 0.0
+ }
+
+ // Convert to sets
+ set1 := make(map[string]bool)
+ for _, f := range files1 {
+ set1[f] = true
+ }
+
+ set2 := make(map[string]bool)
+ for _, f := range files2 {
+ set2[f] = true
+ }
+
+ // Count intersection
+ intersection := 0
+ for f := range set1 {
+ if set2[f] {
+ intersection++
+ }
+ }
+
+ // Jaccard similarity = intersection / union
+ union := len(set1) + len(set2) - intersection
+ if union == 0 {
+ return 0.0
+ }
+
+ return float32(intersection) / float32(union)
+}
+
+// pruneEdges limits edges per node to prevent graph explosion
+func pruneEdges(edges []Edge, maxPerNode int) []Edge {
+ if maxPerNode <= 0 {
+ return edges
+ }
+
+ // Count edges per node
+ outEdges := make(map[int64][]Edge)
+ inEdges := make(map[int64][]Edge)
+
+ for _, edge := range edges {
+ outEdges[edge.FromID] = append(outEdges[edge.FromID], edge)
+ inEdges[edge.ToID] = append(inEdges[edge.ToID], edge)
+ }
+
+ // Prune low-weight edges if node has too many
+ pruned := make([]Edge, 0, len(edges))
+ processed := make(map[string]bool)
+
+ for _, edge := range edges {
+ pairKey := edgeKey(edge.FromID, edge.ToID)
+ if processed[pairKey] {
+ continue
+ }
+ processed[pairKey] = true
+
+ // Check if either node has too many edges
+ fromCount := len(outEdges[edge.FromID])
+ toCount := len(inEdges[edge.ToID])
+
+ if fromCount <= maxPerNode && toCount <= maxPerNode {
+ pruned = append(pruned, edge)
+ continue
+ }
+
+ // Keep edge if it's high-weight (top edges for this node)
+ if shouldKeepEdge(edge, outEdges[edge.FromID], maxPerNode) {
+ pruned = append(pruned, edge)
+ }
+ }
+
+ if len(pruned) < len(edges) {
+ log.Debug().
+ Int("original", len(edges)).
+ Int("pruned", len(pruned)).
+ Int("removed", len(edges)-len(pruned)).
+ Msg("Pruned excessive edges")
+ }
+
+ return pruned
+}
+
+// shouldKeepEdge determines if edge should be kept during pruning
+func shouldKeepEdge(edge Edge, nodeEdges []Edge, maxPerNode int) bool {
+ // Sort node's edges by weight descending
+ sortedEdges := make([]Edge, len(nodeEdges))
+ copy(sortedEdges, nodeEdges)
+
+ sortEdgesByWeight(sortedEdges)
+
+ // Keep edge if it's in top maxPerNode
+ for i := 0; i < maxPerNode && i < len(sortedEdges); i++ {
+ if sortedEdges[i].FromID == edge.FromID && sortedEdges[i].ToID == edge.ToID {
+ return true
+ }
+ }
+
+ return false
+}
+
+// sortEdgesByWeight sorts edges by weight descending
+func sortEdgesByWeight(edges []Edge) {
+ // Simple bubble sort (edges are typically small per node)
+ n := len(edges)
+ for i := 0; i < n-1; i++ {
+ for j := 0; j < n-i-1; j++ {
+ if edges[j].Weight < edges[j+1].Weight {
+ edges[j], edges[j+1] = edges[j+1], edges[j]
+ }
+ }
+ }
+}
+
+// edgeKey creates a unique key for an edge pair (sorted)
+func edgeKey(id1, id2 int64) string {
+ if id1 < id2 {
+ return fmt.Sprintf("%d-%d", id1, id2)
+ }
+ return fmt.Sprintf("%d-%d", id2, id1)
+}
+
+// DetectSemanticEdges creates edges based on semantic similarity
+// This requires embeddings and is called separately when available
+func DetectSemanticEdges(ctx context.Context, observations []*models.Observation, embeddings map[int64][]float32) []Edge {
+ edges := make([]Edge, 0)
+ seen := make(map[string]bool)
+
+ // Compare all pairs (expensive, but necessary for semantic similarity)
+ for i := 0; i < len(observations); i++ {
+ emb1, ok1 := embeddings[observations[i].ID]
+ if !ok1 {
+ continue
+ }
+
+ for j := i + 1; j < len(observations); j++ {
+ emb2, ok2 := embeddings[observations[j].ID]
+ if !ok2 {
+ continue
+ }
+
+ similarity := cosineSimilarity(emb1, emb2)
+ if similarity < SemanticSimilarityThreshold {
+ continue
+ }
+
+ pairKey := edgeKey(observations[i].ID, observations[j].ID)
+ if seen[pairKey] {
+ continue
+ }
+ seen[pairKey] = true
+
+ edges = append(edges, Edge{
+ FromID: observations[i].ID,
+ ToID: observations[j].ID,
+ Relation: RelationSemantic,
+ Weight: similarity,
+ })
+ }
+ }
+
+ log.Info().
+ Int("semantic_edges", len(edges)).
+ Float32("threshold", SemanticSimilarityThreshold).
+ Msg("Detected semantic edges")
+
+ return edges
+}
+
+// cosineSimilarity computes cosine similarity between two vectors
+func cosineSimilarity(a, b []float32) float32 {
+ if len(a) != len(b) {
+ return 0.0
+ }
+
+ var dotProduct, normA, normB float32
+ for i := range a {
+ dotProduct += a[i] * b[i]
+ normA += a[i] * a[i]
+ normB += b[i] * b[i]
+ }
+
+ if normA == 0 || normB == 0 {
+ return 0.0
+ }
+
+ return dotProduct / float32(math.Sqrt(float64(normA))*math.Sqrt(float64(normB)))
+}
diff --git a/internal/graph/observation_graph.go b/internal/graph/observation_graph.go
new file mode 100644
index 0000000..c86b6fa
--- /dev/null
+++ b/internal/graph/observation_graph.go
@@ -0,0 +1,423 @@
+// Package graph provides observation relationship graphs for LEANN Phase 2.
+//
+// This package implements graph-based selective recomputation where observation
+// relationships (file overlap, semantic similarity, temporal proximity) form a
+// graph structure. Hub nodes (high-degree observations) store embeddings, while
+// leaf nodes recompute on-demand.
+package graph
+
+import (
+ "context"
+ "fmt"
+ "math"
+ "sort"
+ "sync"
+ "time"
+
+ "github.com/lukaszraczylo/claude-mnemonic/pkg/models"
+ "github.com/rs/zerolog/log"
+)
+
+// RelationType defines the type of relationship between observations
+type RelationType int
+
+const (
+ // RelationFileOverlap indicates observations reference overlapping files
+ RelationFileOverlap RelationType = iota
+ // RelationSemantic indicates high semantic similarity (cosine > 0.85)
+ RelationSemantic
+ // RelationTemporal indicates observations from same session
+ RelationTemporal
+ // RelationConcept indicates shared concept tags
+ RelationConcept
+)
+
+// Edge represents a relationship between two observations
+type Edge struct {
+ FromID int64
+ ToID int64
+ Relation RelationType
+ Weight float32 // 0.0-1.0, higher = stronger relationship
+}
+
+// Node represents an observation in the graph
+type Node struct {
+ Metadata NodeMetadata
+ LastAccess time.Time
+ StoredEmb []float32 // Nil if recomputed on-demand
+ ID int64
+ Degree int // Number of edges (hub detection)
+ AccessCount int
+}
+
+// NodeMetadata contains observation metadata
+type NodeMetadata struct {
+ CreatedAt time.Time
+ Project string
+ Type string
+ Title string
+ IsSuperseded bool
+}
+
+// CSRGraph represents a graph in Compressed Sparse Row format for memory efficiency
+type CSRGraph struct {
+ RowPtr []int32 // Node adjacency list pointers
+ ColIdx []int32 // Edge destination IDs
+ Weights []float32 // Edge weights
+ mu sync.RWMutex
+}
+
+// ObservationGraph manages the observation relationship graph
+type ObservationGraph struct {
+ nodes map[int64]*Node
+ csr *CSRGraph
+ edges []Edge
+ nodesMu sync.RWMutex
+ edgesMu sync.RWMutex
+}
+
+// NewObservationGraph creates a new empty observation graph
+func NewObservationGraph() *ObservationGraph {
+ return &ObservationGraph{
+ nodes: make(map[int64]*Node),
+ edges: make([]Edge, 0),
+ csr: &CSRGraph{},
+ }
+}
+
+// AddNode adds or updates a node in the graph
+func (g *ObservationGraph) AddNode(node *Node) {
+ g.nodesMu.Lock()
+ defer g.nodesMu.Unlock()
+
+ g.nodes[node.ID] = node
+}
+
+// AddEdge adds an edge to the graph
+func (g *ObservationGraph) AddEdge(edge Edge) {
+ g.edgesMu.Lock()
+ defer g.edgesMu.Unlock()
+
+ g.edges = append(g.edges, edge)
+
+ // Update degree counts
+ g.nodesMu.Lock()
+ if fromNode, ok := g.nodes[edge.FromID]; ok {
+ fromNode.Degree++
+ }
+ if toNode, ok := g.nodes[edge.ToID]; ok {
+ toNode.Degree++
+ }
+ g.nodesMu.Unlock()
+}
+
+// BuildCSR converts edge list to CSR format for efficient traversal
+func (g *ObservationGraph) BuildCSR() error {
+ g.edgesMu.RLock()
+ g.nodesMu.RLock()
+ defer g.edgesMu.RUnlock()
+ defer g.nodesMu.RUnlock()
+
+ if len(g.nodes) == 0 {
+ return fmt.Errorf("no nodes in graph")
+ }
+
+ // Create node ID to index mapping
+ nodeIDs := make([]int64, 0, len(g.nodes))
+ for id := range g.nodes {
+ nodeIDs = append(nodeIDs, id)
+ }
+ sort.Slice(nodeIDs, func(i, j int) bool {
+ return nodeIDs[i] < nodeIDs[j]
+ })
+
+ idToIdx := make(map[int64]int32)
+ for idx, id := range nodeIDs {
+ // #nosec G115 - observation count will never exceed int32 max (2.1B) in practice
+ idToIdx[id] = int32(idx)
+ }
+
+ // Count edges per node
+ edgeCounts := make([]int, len(nodeIDs))
+ for _, edge := range g.edges {
+ if fromIdx, ok := idToIdx[edge.FromID]; ok {
+ edgeCounts[fromIdx]++
+ }
+ }
+
+ // Build row pointers
+ rowPtr := make([]int32, len(nodeIDs)+1)
+ rowPtr[0] = 0
+ for i := 0; i < len(nodeIDs); i++ {
+ // #nosec G115 - edge counts per node will not exceed int32 max
+ rowPtr[i+1] = rowPtr[i] + int32(edgeCounts[i])
+ }
+
+ // Build column indices and weights
+ totalEdges := rowPtr[len(nodeIDs)]
+ colIdx := make([]int32, totalEdges)
+ weights := make([]float32, totalEdges)
+
+ // Temporary counter for filling CSR
+ currentPos := make([]int32, len(nodeIDs))
+ copy(currentPos, rowPtr[:len(nodeIDs)])
+
+ for _, edge := range g.edges {
+ fromIdx, fromOk := idToIdx[edge.FromID]
+ toIdx, toOk := idToIdx[edge.ToID]
+
+ if fromOk && toOk {
+ pos := currentPos[fromIdx]
+ colIdx[pos] = toIdx
+ weights[pos] = edge.Weight
+ currentPos[fromIdx]++
+ }
+ }
+
+ g.csr.mu.Lock()
+ g.csr.RowPtr = rowPtr
+ g.csr.ColIdx = colIdx
+ g.csr.Weights = weights
+ g.csr.mu.Unlock()
+
+ log.Info().
+ Int("nodes", len(nodeIDs)).
+ Int("edges", int(totalEdges)).
+ Msg("Built CSR graph representation")
+
+ return nil
+}
+
+// GetNeighbors returns neighboring nodes and their edge weights
+func (g *ObservationGraph) GetNeighbors(nodeID int64) ([]int64, []float32, error) {
+ g.csr.mu.RLock()
+ defer g.csr.mu.RUnlock()
+
+ // Find node index in CSR
+ g.nodesMu.RLock()
+ nodeIDs := make([]int64, 0, len(g.nodes))
+ for id := range g.nodes {
+ nodeIDs = append(nodeIDs, id)
+ }
+ g.nodesMu.RUnlock()
+
+ sort.Slice(nodeIDs, func(i, j int) bool {
+ return nodeIDs[i] < nodeIDs[j]
+ })
+
+ nodeIdx := sort.Search(len(nodeIDs), func(i int) bool {
+ return nodeIDs[i] >= nodeID
+ })
+
+ if nodeIdx >= len(nodeIDs) || nodeIDs[nodeIdx] != nodeID {
+ return nil, nil, fmt.Errorf("node %d not found", nodeID)
+ }
+
+ // Extract neighbors from CSR
+ startIdx := g.csr.RowPtr[nodeIdx]
+ endIdx := g.csr.RowPtr[nodeIdx+1]
+
+ neighborCount := endIdx - startIdx
+ neighbors := make([]int64, neighborCount)
+ weights := make([]float32, neighborCount)
+
+ for i := int32(0); i < neighborCount; i++ {
+ neighborIdx := g.csr.ColIdx[startIdx+i]
+ neighbors[i] = nodeIDs[neighborIdx]
+ weights[i] = g.csr.Weights[startIdx+i]
+ }
+
+ return neighbors, weights, nil
+}
+
+// GetNode retrieves a node by ID
+func (g *ObservationGraph) GetNode(nodeID int64) (*Node, error) {
+ g.nodesMu.RLock()
+ defer g.nodesMu.RUnlock()
+
+ node, ok := g.nodes[nodeID]
+ if !ok {
+ return nil, fmt.Errorf("node %d not found", nodeID)
+ }
+
+ return node, nil
+}
+
+// FindHubs identifies hub nodes (high degree) in the graph
+func (g *ObservationGraph) FindHubs(percentile float64) []int64 {
+ g.nodesMu.RLock()
+ defer g.nodesMu.RUnlock()
+
+ if len(g.nodes) == 0 {
+ return nil
+ }
+
+ // Collect all degrees
+ degrees := make([]int, 0, len(g.nodes))
+ nodeIDs := make([]int64, 0, len(g.nodes))
+
+ for id, node := range g.nodes {
+ degrees = append(degrees, node.Degree)
+ nodeIDs = append(nodeIDs, id)
+ }
+
+ // Sort by degree
+ type nodeDegree struct {
+ ID int64
+ Degree int
+ }
+
+ nodeDegrees := make([]nodeDegree, len(nodeIDs))
+ for i := range nodeIDs {
+ nodeDegrees[i] = nodeDegree{
+ ID: nodeIDs[i],
+ Degree: degrees[i],
+ }
+ }
+
+ sort.Slice(nodeDegrees, func(i, j int) bool {
+ return nodeDegrees[i].Degree > nodeDegrees[j].Degree
+ })
+
+ // Return top percentile
+ cutoff := int(math.Ceil(float64(len(nodeDegrees)) * (1.0 - percentile)))
+ if cutoff > len(nodeDegrees) {
+ cutoff = len(nodeDegrees)
+ }
+
+ hubs := make([]int64, cutoff)
+ for i := 0; i < cutoff; i++ {
+ hubs[i] = nodeDegrees[i].ID
+ }
+
+ log.Info().
+ Int("total_nodes", len(g.nodes)).
+ Int("hubs", len(hubs)).
+ Float64("percentile", percentile).
+ Msg("Identified hub nodes")
+
+ return hubs
+}
+
+// Stats returns graph statistics
+func (g *ObservationGraph) Stats() GraphStats {
+ g.nodesMu.RLock()
+ g.edgesMu.RLock()
+ defer g.nodesMu.RUnlock()
+ defer g.edgesMu.RUnlock()
+
+ stats := GraphStats{
+ NodeCount: len(g.nodes),
+ EdgeCount: len(g.edges),
+ }
+
+ if len(g.nodes) > 0 {
+ degrees := make([]int, 0, len(g.nodes))
+ for _, node := range g.nodes {
+ degrees = append(degrees, node.Degree)
+ }
+
+ sort.Ints(degrees)
+ stats.AvgDegree = float64(sum(degrees)) / float64(len(degrees))
+ stats.MaxDegree = degrees[len(degrees)-1]
+ stats.MinDegree = degrees[0]
+
+ // Median
+ mid := len(degrees) / 2
+ if len(degrees)%2 == 0 {
+ stats.MedianDegree = float64(degrees[mid-1]+degrees[mid]) / 2.0
+ } else {
+ stats.MedianDegree = float64(degrees[mid])
+ }
+ }
+
+ // Count edge types
+ stats.EdgeTypes = make(map[RelationType]int)
+ for _, edge := range g.edges {
+ stats.EdgeTypes[edge.Relation]++
+ }
+
+ return stats
+}
+
+// GraphStats contains graph statistics
+type GraphStats struct {
+ EdgeTypes map[RelationType]int
+ AvgDegree float64
+ MedianDegree float64
+ NodeCount int
+ EdgeCount int
+ MaxDegree int
+ MinDegree int
+}
+
+// BuildFromObservations constructs a graph from a list of observations
+func BuildFromObservations(ctx context.Context, observations []*models.Observation) (*ObservationGraph, error) {
+ graph := NewObservationGraph()
+
+ // Add nodes
+ for _, obs := range observations {
+ // Extract title from sql.NullString
+ title := ""
+ if obs.Title.Valid {
+ title = obs.Title.String
+ }
+
+ node := &Node{
+ ID: obs.ID,
+ Degree: 0,
+ Metadata: NodeMetadata{
+ Project: obs.Project,
+ Type: string(obs.Type),
+ Title: title,
+ CreatedAt: time.UnixMilli(obs.CreatedAtEpoch),
+ IsSuperseded: obs.IsSuperseded,
+ },
+ LastAccess: time.Now(),
+ AccessCount: 0,
+ }
+ graph.AddNode(node)
+ }
+
+ // Detect edges (will be implemented in edge_detector.go)
+ edges, err := DetectEdges(ctx, observations)
+ if err != nil {
+ return nil, fmt.Errorf("detect edges: %w", err)
+ }
+
+ for _, edge := range edges {
+ graph.AddEdge(edge)
+ }
+
+ // Build CSR representation
+ if err := graph.BuildCSR(); err != nil {
+ return nil, fmt.Errorf("build CSR: %w", err)
+ }
+
+ return graph, nil
+}
+
+// Helper function to sum integers
+func sum(values []int) int {
+ total := 0
+ for _, v := range values {
+ total += v
+ }
+ return total
+}
+
+// String returns a human-readable representation of RelationType
+func (r RelationType) String() string {
+ switch r {
+ case RelationFileOverlap:
+ return "file_overlap"
+ case RelationSemantic:
+ return "semantic"
+ case RelationTemporal:
+ return "temporal"
+ case RelationConcept:
+ return "concept"
+ default:
+ return "unknown"
+ }
+}
diff --git a/internal/mcp/server.go b/internal/mcp/server.go
index 7deca69..c30fc5a 100644
--- a/internal/mcp/server.go
+++ b/internal/mcp/server.go
@@ -19,12 +19,9 @@ import (
// Server is the MCP server that exposes search tools.
type Server struct {
- searchMgr *search.Manager
- version string
- stdin io.Reader
- stdout io.Writer
-
- // Store dependencies for enhanced tools
+ stdin io.Reader
+ stdout io.Writer
+ searchMgr *search.Manager
observationStore *gorm.ObservationStore
patternStore *gorm.PatternStore
relationStore *gorm.RelationStore
@@ -32,6 +29,7 @@ type Server struct {
vectorClient *sqlitevec.Client
scoreCalculator *scoring.Calculator
recalculator *scoring.Recalculator
+ version string
}
// NewServer creates a new MCP server.
@@ -71,17 +69,17 @@ type Request struct {
// Response represents a JSON-RPC response.
type Response struct {
- JSONRPC string `json:"jsonrpc"`
ID any `json:"id"`
Result any `json:"result,omitempty"`
Error *Error `json:"error,omitempty"`
+ JSONRPC string `json:"jsonrpc"`
}
// Error represents a JSON-RPC error.
type Error struct {
- Code int `json:"code"`
- Message string `json:"message"`
Data any `json:"data,omitempty"`
+ Message string `json:"message"`
+ Code int `json:"code"`
}
// ToolCallParams represents parameters for tools/call method.
@@ -92,9 +90,9 @@ type ToolCallParams struct {
// Tool represents an MCP tool definition.
type Tool struct {
+ InputSchema map[string]any `json:"inputSchema"`
Name string `json:"name"`
Description string `json:"description"`
- InputSchema map[string]any `json:"inputSchema"`
}
// Run starts the MCP server loop.
@@ -489,17 +487,17 @@ func (s *Server) callTool(ctx context.Context, name string, args json.RawMessage
// TimelineParams represents parameters for timeline operations.
type TimelineParams struct {
- AnchorID int64 `json:"anchor_id"`
Query string `json:"query"`
- Before int `json:"before"`
- After int `json:"after"`
Project string `json:"project"`
ObsType string `json:"obs_type"`
Concepts string `json:"concepts"`
Files string `json:"files"`
+ Format string `json:"format"`
+ AnchorID int64 `json:"anchor_id"`
+ Before int `json:"before"`
+ After int `json:"after"`
DateStart int64 `json:"dateStart"`
DateEnd int64 `json:"dateEnd"`
- Format string `json:"format"`
}
// handleTimeline handles timeline requests.
diff --git a/internal/mcp/server_test.go b/internal/mcp/server_test.go
index 469dfc0..0d4521f 100644
--- a/internal/mcp/server_test.go
+++ b/internal/mcp/server_test.go
@@ -34,8 +34,8 @@ func (s *ServerSuite) TestNewServer() {
func TestRequest(t *testing.T) {
tests := []struct {
name string
- req Request
expected string
+ req Request
}{
{
name: "initialize request",
@@ -138,9 +138,9 @@ func TestResponse(t *testing.T) {
// TestError tests Error struct.
func TestError(t *testing.T) {
tests := []struct {
+ expected string
name string
err Error
- expected string
}{
{
name: "parse error",
@@ -365,11 +365,11 @@ func TestHandleRequest(t *testing.T) {
ctx := context.Background()
tests := []struct {
- name string
req *Request
- expectError bool
- errorCode int
+ name string
errorMessage string
+ errorCode int
+ expectError bool
}{
{
name: "initialize method",
@@ -753,13 +753,13 @@ func TestServerStdinStdoutConfig(t *testing.T) {
// TestResponseIDTypes tests that response IDs can be various types.
func TestResponseIDTypes(t *testing.T) {
tests := []struct {
- name string
id any
+ name string
}{
- {"integer id", 1},
- {"string id", "abc-123"},
- {"float id", 1.5},
- {"null id", nil},
+ {name: "integer id", id: 1},
+ {name: "string id", id: "abc-123"},
+ {name: "float id", id: 1.5},
+ {name: "null id", id: nil},
}
for _, tt := range tests {
diff --git a/internal/pattern/detector.go b/internal/pattern/detector.go
index 57997c6..e294a8d 100644
--- a/internal/pattern/detector.go
+++ b/internal/pattern/detector.go
@@ -38,21 +38,15 @@ type PatternSyncFunc func(pattern *models.Pattern)
// Detector detects and tracks recurring patterns across observations.
type Detector struct {
- config DetectorConfig
+ ctx context.Context
patternStore *gorm.PatternStore
observationStore *gorm.ObservationStore
-
- // Vector sync callback
- syncFunc PatternSyncFunc
-
- // Candidate tracking (patterns not yet confirmed)
- candidates map[string]*candidatePattern
- candidatesMu sync.RWMutex
-
- // Background analysis
- ctx context.Context
- cancel context.CancelFunc
- wg sync.WaitGroup
+ syncFunc PatternSyncFunc
+ candidates map[string]*candidatePattern
+ cancel context.CancelFunc
+ config DetectorConfig
+ wg sync.WaitGroup
+ candidatesMu sync.RWMutex
}
// SetSyncFunc sets the callback for syncing patterns to vector store.
@@ -62,11 +56,11 @@ func (d *Detector) SetSyncFunc(fn PatternSyncFunc) {
// candidatePattern tracks a potential pattern before it reaches frequency threshold.
type candidatePattern struct {
+ patternType models.PatternType
+ title string
signature []string
observationIDs []int64
projects []string
- patternType models.PatternType
- title string
lastSeenEpoch int64
}
diff --git a/internal/pattern/detector_test.go b/internal/pattern/detector_test.go
index 099341b..9b9dc8b 100644
--- a/internal/pattern/detector_test.go
+++ b/internal/pattern/detector_test.go
@@ -331,16 +331,16 @@ func TestDefaultConfig(t *testing.T) {
func TestGeneratePatternName(t *testing.T) {
tests := []struct {
patternType models.PatternType
- signature []string
title string
wantPrefix string
+ signature []string
}{
- {models.PatternTypeBug, []string{"nil", "error"}, "", "Bug Pattern:"},
- {models.PatternTypeRefactor, []string{"extract"}, "", "Refactor Pattern:"},
- {models.PatternTypeArchitecture, []string{"service"}, "", "Architecture Pattern:"},
- {models.PatternTypeAntiPattern, []string{"god-class"}, "", "Anti-Pattern:"},
- {models.PatternTypeBestPractice, []string{"testing"}, "", "Best Practice:"},
- {models.PatternTypeBug, []string{}, "Short Title", "Short Title"}, // Use title directly
+ {patternType: models.PatternTypeBug, title: "", wantPrefix: "Bug Pattern:", signature: []string{"nil", "error"}},
+ {patternType: models.PatternTypeRefactor, title: "", wantPrefix: "Refactor Pattern:", signature: []string{"extract"}},
+ {patternType: models.PatternTypeArchitecture, title: "", wantPrefix: "Architecture Pattern:", signature: []string{"service"}},
+ {patternType: models.PatternTypeAntiPattern, title: "", wantPrefix: "Anti-Pattern:", signature: []string{"god-class"}},
+ {patternType: models.PatternTypeBestPractice, title: "", wantPrefix: "Best Practice:", signature: []string{"testing"}},
+ {patternType: models.PatternTypeBug, title: "Short Title", wantPrefix: "Short Title", signature: []string{}}, // Use title directly
}
for _, tt := range tests {
diff --git a/internal/reranking/service.go b/internal/reranking/service.go
index 38fe387..f41494d 100644
--- a/internal/reranking/service.go
+++ b/internal/reranking/service.go
@@ -30,24 +30,24 @@ const (
// Candidate represents a search result candidate for reranking.
type Candidate struct {
- ID string // Document ID
- Content string // Document text content for scoring
- Score float64 // Original bi-encoder similarity score
- Metadata map[string]any // Preserved metadata
- RerankInfo map[string]float64 // Reranking debug info (optional)
+ Metadata map[string]any
+ RerankInfo map[string]float64
+ ID string
+ Content string
+ Score float64
}
// RerankResult represents a reranked search result.
type RerankResult struct {
- ID string // Document ID
- Content string // Document text content
- OriginalScore float64 // Original bi-encoder score
- RerankScore float64 // Cross-encoder relevance score
- CombinedScore float64 // Weighted combination of scores
- Metadata map[string]any // Preserved metadata
- OriginalRank int // Position before reranking (1-indexed)
- RerankRank int // Position after reranking (1-indexed)
- RankImprovement int // How much the rank improved (positive = moved up)
+ Metadata map[string]any
+ ID string
+ Content string
+ OriginalScore float64
+ RerankScore float64
+ CombinedScore float64
+ OriginalRank int
+ RerankRank int
+ RankImprovement int
}
// Service provides cross-encoder reranking functionality.
diff --git a/internal/scoring/recalculator.go b/internal/scoring/recalculator.go
index 0f7caca..e01527f 100644
--- a/internal/scoring/recalculator.go
+++ b/internal/scoring/recalculator.go
@@ -21,13 +21,13 @@ type ObservationStore interface {
// Recalculator periodically recalculates importance scores for observations.
type Recalculator struct {
+ log zerolog.Logger
store ObservationStore
calculator *Calculator
- log zerolog.Logger
- interval time.Duration
- batchSize int
stopCh chan struct{}
doneCh chan struct{}
+ interval time.Duration
+ batchSize int
mu sync.Mutex
running bool
}
diff --git a/internal/scoring/recalculator_test.go b/internal/scoring/recalculator_test.go
index 6fff695..a1bdbbf 100644
--- a/internal/scoring/recalculator_test.go
+++ b/internal/scoring/recalculator_test.go
@@ -16,14 +16,14 @@ import (
// MockObservationStore is a mock implementation of ObservationStore for testing.
type MockObservationStore struct {
- mu sync.Mutex
- observations []*models.Observation
- scores map[int64]float64
- conceptWeights map[string]float64
updateErr error
getErr error
getConceptsErr error
+ scores map[int64]float64
+ conceptWeights map[string]float64
+ observations []*models.Observation
updateScoresCalls int
+ mu sync.Mutex
}
func NewMockObservationStore() *MockObservationStore {
diff --git a/internal/search/expansion/expander.go b/internal/search/expansion/expander.go
index be1ab6e..ff6f22d 100644
--- a/internal/search/expansion/expander.go
+++ b/internal/search/expansion/expander.go
@@ -30,25 +30,25 @@ const (
// ExpandedQuery represents a query variant with metadata.
type ExpandedQuery struct {
Query string `json:"query"`
- Weight float64 `json:"weight"` // Weight for result merging (0.0-1.0)
- Source string `json:"source"` // Where this expansion came from
- Intent QueryIntent `json:"intent"` // Detected intent
+ Source string `json:"source"`
+ Intent QueryIntent `json:"intent"`
+ Weight float64 `json:"weight"`
}
// Expander provides context-aware query expansion.
type Expander struct {
embedSvc *embedding.Service
- vocabulary []VocabEntry // Known vocabulary from observations
- vocabVectors [][]float32 // Embeddings for vocabulary entries
- vocabMu sync.RWMutex // Protects vocabulary
intentPatterns map[QueryIntent][]*regexp.Regexp
+ vocabulary []VocabEntry
+ vocabVectors [][]float32
+ vocabMu sync.RWMutex
}
// VocabEntry represents a vocabulary term from observations.
type VocabEntry struct {
- Term string // The term itself
- Weight float64 // How common/important this term is (0.0-1.0)
- Source string // Where it came from (title, concept, narrative)
+ Term string
+ Source string
+ Weight float64
}
// Config holds expander configuration.
diff --git a/internal/search/expansion/expander_test.go b/internal/search/expansion/expander_test.go
index e455ad3..1317340 100644
--- a/internal/search/expansion/expander_test.go
+++ b/internal/search/expansion/expander_test.go
@@ -88,16 +88,16 @@ func (s *ExpanderSuite) TestExpand() {
tests := []struct {
name string
query string
+ expectedIntent QueryIntent
minExpansions int
hasOriginal bool
- expectedIntent QueryIntent
}{
- {"question", "how do I implement auth", 1, true, IntentQuestion},
- {"error", "fix the bug in login", 1, true, IntentError},
- {"implementation", "implement user handler", 1, true, IntentImplementation},
- {"architecture", "architecture design", 1, true, IntentArchitecture},
- {"general", "database connection", 1, true, IntentGeneral},
- {"empty", "", 0, false, IntentGeneral},
+ {name: "question", query: "how do I implement auth", expectedIntent: IntentQuestion, minExpansions: 1, hasOriginal: true},
+ {name: "error", query: "fix the bug in login", expectedIntent: IntentError, minExpansions: 1, hasOriginal: true},
+ {name: "implementation", query: "implement user handler", expectedIntent: IntentImplementation, minExpansions: 1, hasOriginal: true},
+ {name: "architecture", query: "architecture design", expectedIntent: IntentArchitecture, minExpansions: 1, hasOriginal: true},
+ {name: "general", query: "database connection", expectedIntent: IntentGeneral, minExpansions: 1, hasOriginal: true},
+ {name: "empty", query: "", expectedIntent: IntentGeneral, minExpansions: 0, hasOriginal: false},
}
for _, tt := range tests {
@@ -392,13 +392,13 @@ func TestTruncate(t *testing.T) {
tests := []struct {
name string
input string
- maxLen int
expected string
+ maxLen int
}{
- {"short", "hello", 10, "hello"},
- {"exact", "hello", 5, "hello"},
- {"long", "hello world", 5, "hello..."},
- {"empty", "", 10, ""},
+ {name: "short", input: "hello", expected: "hello", maxLen: 10},
+ {name: "exact", input: "hello", expected: "hello", maxLen: 5},
+ {name: "long", input: "hello world", expected: "hello...", maxLen: 5},
+ {name: "empty", input: "", expected: "", maxLen: 10},
}
for _, tt := range tests {
diff --git a/internal/search/integration_test.go b/internal/search/integration_test.go
index a5aa9d7..cf798b3 100644
--- a/internal/search/integration_test.go
+++ b/internal/search/integration_test.go
@@ -516,16 +516,16 @@ func TestTruncate_TableDriven(t *testing.T) {
tests := []struct {
name string
input string
- maxLen int
expected string
+ maxLen int
}{
- {"short_string", "hello", 10, "hello"},
- {"exact_length", "hello", 5, "hello"},
- {"long_string", "hello world", 5, "hello..."},
- {"empty_string", "", 10, ""},
- {"whitespace_only", " ", 10, ""},
- {"with_leading_space", " hello ", 10, "hello"},
- {"very_long", "this is a very long string that should be truncated", 20, "this is a very long ..."},
+ {name: "short_string", input: "hello", expected: "hello", maxLen: 10},
+ {name: "exact_length", input: "hello", expected: "hello", maxLen: 5},
+ {name: "long_string", input: "hello world", expected: "hello...", maxLen: 5},
+ {name: "empty_string", input: "", expected: "", maxLen: 10},
+ {name: "whitespace_only", input: " ", expected: "", maxLen: 10},
+ {name: "with_leading_space", input: " hello ", expected: "hello", maxLen: 10},
+ {name: "very_long", input: "this is a very long string that should be truncated", expected: "this is a very long ...", maxLen: 20},
}
for _, tt := range tests {
diff --git a/internal/search/manager.go b/internal/search/manager.go
index cc21626..6a59da6 100644
--- a/internal/search/manager.go
+++ b/internal/search/manager.go
@@ -35,41 +35,41 @@ func NewManager(
// SearchParams contains parameters for unified search.
type SearchParams struct {
- Query string
- Type string // "observations", "sessions", "prompts", or empty for all
+ Format string
+ Type string
Project string
- ObsType string // Observation type filter
+ ObsType string
Concepts string
Files string
+ Query string
+ Scope string
+ OrderBy string
DateStart int64
- DateEnd int64
- OrderBy string // "relevance", "date_desc", "date_asc"
- Limit int
Offset int
- Format string // "index" or "full"
- Scope string // "project", "global", or empty for project+global
- IncludeGlobal bool // If true, include global observations along with project-scoped
- ExcludeSuperseded bool // If true, exclude observations that have been superseded
+ Limit int
+ DateEnd int64
+ IncludeGlobal bool
+ ExcludeSuperseded bool
}
// SearchResult represents a unified search result.
type SearchResult struct {
- Type string `json:"type"` // "observation", "session", "prompt"
- ID int64 `json:"id"`
+ Metadata map[string]interface{} `json:"metadata,omitempty"`
+ Type string `json:"type"`
Title string `json:"title,omitempty"`
Content string `json:"content,omitempty"`
Project string `json:"project"`
- Scope string `json:"scope,omitempty"` // "project" or "global"
+ Scope string `json:"scope,omitempty"`
+ ID int64 `json:"id"`
CreatedAt int64 `json:"created_at_epoch"`
Score float64 `json:"score,omitempty"`
- Metadata map[string]interface{} `json:"metadata,omitempty"`
}
// UnifiedSearchResult contains the combined search results.
type UnifiedSearchResult struct {
+ Query string `json:"query,omitempty"`
Results []SearchResult `json:"results"`
TotalCount int `json:"total_count"`
- Query string `json:"query,omitempty"`
}
// UnifiedSearch performs a unified search across all document types.
diff --git a/internal/search/manager_test.go b/internal/search/manager_test.go
index bdd077b..f531858 100644
--- a/internal/search/manager_test.go
+++ b/internal/search/manager_test.go
@@ -94,8 +94,8 @@ func TestTruncate(t *testing.T) {
tests := []struct {
name string
input string
- maxLen int
expected string
+ maxLen int
}{
{
name: "short string no truncation",
@@ -148,8 +148,8 @@ func TestObservationToResult(t *testing.T) {
m := NewManager(nil, nil, nil, nil)
tests := []struct {
- name string
obs *models.Observation
+ name string
format string
expected SearchResult
}{
@@ -240,8 +240,8 @@ func TestSummaryToResult(t *testing.T) {
m := NewManager(nil, nil, nil, nil)
tests := []struct {
- name string
summary *models.SessionSummary
+ name string
format string
expected SearchResult
}{
@@ -322,8 +322,8 @@ func TestPromptToResult(t *testing.T) {
m := NewManager(nil, nil, nil, nil)
tests := []struct {
- name string
prompt *models.UserPromptWithSession
+ name string
format string
expected SearchResult
}{
@@ -406,9 +406,9 @@ func TestPromptToResult(t *testing.T) {
func TestSearchParamsValidation(t *testing.T) {
tests := []struct {
name string
+ expectedOrder string
params SearchParams
expectedLimit int
- expectedOrder string
}{
{
name: "default limit applied",
@@ -731,16 +731,16 @@ func TestPromptToResultFormats(t *testing.T) {
func TestSearchParamsDefaults(t *testing.T) {
tests := []struct {
name string
- initialLimit int
initialOrder string
- expectedLimit int
expectedOrder string
+ initialLimit int
+ expectedLimit int
}{
- {"zero_limit", 0, "", 20, "date_desc"},
- {"negative_limit", -5, "", 20, "date_desc"},
- {"over_100_limit", 150, "", 100, "date_desc"},
- {"valid_limit_50", 50, "relevance", 50, "relevance"},
- {"custom_order", 30, "date_asc", 30, "date_asc"},
+ {name: "zero_limit", initialOrder: "", expectedOrder: "date_desc", initialLimit: 0, expectedLimit: 20},
+ {name: "negative_limit", initialOrder: "", expectedOrder: "date_desc", initialLimit: -5, expectedLimit: 20},
+ {name: "over_100_limit", initialOrder: "", expectedOrder: "date_desc", initialLimit: 150, expectedLimit: 100},
+ {name: "valid_limit_50", initialOrder: "relevance", expectedOrder: "relevance", initialLimit: 50, expectedLimit: 50},
+ {name: "custom_order", initialOrder: "date_asc", expectedOrder: "date_asc", initialLimit: 30, expectedLimit: 30},
}
for _, tt := range tests {
@@ -774,18 +774,18 @@ func TestTruncateEdgeCases(t *testing.T) {
tests := []struct {
name string
input string
- maxLen int
expected string
+ maxLen int
}{
// Unicode strings - uses byte length so ensure maxLen accommodates full string
- {"unicode_string_no_truncate", "日本語テスト", 20, "日本語テスト"},
- {"mixed_unicode_no_truncate", "Hello世界", 15, "Hello世界"},
+ {name: "unicode_string_no_truncate", input: "日本語テスト", expected: "日本語テスト", maxLen: 20},
+ {name: "mixed_unicode_no_truncate", input: "Hello世界", expected: "Hello世界", maxLen: 15},
// ASCII truncation
- {"ascii_truncate", "Hello World", 5, "Hello..."},
- {"only_whitespace", " ", 10, ""},
- {"tabs_and_newlines", "\t\n \t", 10, ""},
- {"newlines_with_content", "\n\nhello\n\n", 10, "hello"},
- {"zero_max_len", "hello", 0, "..."},
+ {name: "ascii_truncate", input: "Hello World", expected: "Hello...", maxLen: 5},
+ {name: "only_whitespace", input: " ", expected: "", maxLen: 10},
+ {name: "tabs_and_newlines", input: "\t\n \t", expected: "", maxLen: 10},
+ {name: "newlines_with_content", input: "\n\nhello\n\n", expected: "hello", maxLen: 10},
+ {name: "zero_max_len", input: "hello", expected: "...", maxLen: 0},
}
for _, tt := range tests {
diff --git a/internal/update/update.go b/internal/update/update.go
index c185aa8..096d82a 100644
--- a/internal/update/update.go
+++ b/internal/update/update.go
@@ -33,11 +33,11 @@ const (
// Release represents a GitHub release.
type Release struct {
+ PublishedAt time.Time `json:"published_at"`
TagName string `json:"tag_name"`
Name string `json:"name"`
- PublishedAt time.Time `json:"published_at"`
- Assets []Asset `json:"assets"`
Body string `json:"body"`
+ Assets []Asset `json:"assets"`
}
// Asset represents a release asset.
@@ -49,15 +49,15 @@ type Asset struct {
// UpdateInfo contains information about an available update.
type UpdateInfo struct {
- Available bool `json:"available"`
+ PublishedAt time.Time `json:"published_at,omitempty"`
CurrentVersion string `json:"current_version"`
LatestVersion string `json:"latest_version"`
ReleaseNotes string `json:"release_notes,omitempty"`
- PublishedAt time.Time `json:"published_at,omitempty"`
DownloadURL string `json:"download_url,omitempty"`
ChecksumsURL string `json:"checksums_url,omitempty"`
- BundleURL string `json:"bundle_url,omitempty"` // Sigstore bundle (.sigstore.json)
+ BundleURL string `json:"bundle_url,omitempty"`
ManualUpdateCommand string `json:"manual_update_command,omitempty"`
+ Available bool `json:"available"`
}
// InstallScriptURL is the URL to the remote installation script.
@@ -74,23 +74,22 @@ func GetManualUpdateCommand(version string) string {
// UpdateStatus represents the current update status.
type UpdateStatus struct {
- State string `json:"state"` // "idle", "checking", "downloading", "verifying", "applying", "done", "error"
- Progress float64 `json:"progress"`
+ State string `json:"state"`
Message string `json:"message"`
Error string `json:"error,omitempty"`
- ManualUpdateCommand string `json:"manual_update_command,omitempty"` // Shown when update fails
+ ManualUpdateCommand string `json:"manual_update_command,omitempty"`
+ Progress float64 `json:"progress"`
}
// Updater handles self-updates.
type Updater struct {
+ lastCheck time.Time
+ httpClient *http.Client
+ cachedUpdate *UpdateInfo
currentVersion string
installDir string
- httpClient *http.Client
-
- mu sync.RWMutex
- status UpdateStatus
- lastCheck time.Time
- cachedUpdate *UpdateInfo
+ status UpdateStatus
+ mu sync.RWMutex
}
// New creates a new Updater.
diff --git a/internal/vector/hybrid/autotuner.go b/internal/vector/hybrid/autotuner.go
new file mode 100644
index 0000000..78c3760
--- /dev/null
+++ b/internal/vector/hybrid/autotuner.go
@@ -0,0 +1,309 @@
+package hybrid
+
+import (
+ "context"
+ "sync"
+ "time"
+
+ "github.com/lukaszraczylo/claude-mnemonic/internal/vector/sqlitevec"
+ "github.com/rs/zerolog/log"
+)
+
+// AutoTuner dynamically adjusts hub threshold based on query performance
+type AutoTuner struct {
+ ctx context.Context
+ client *Client
+ cancel context.CancelFunc
+ latencies []time.Duration
+ wg sync.WaitGroup
+ queries int64
+ targetLatency time.Duration
+ adjustPeriod time.Duration
+ minThreshold int
+ maxThreshold int
+ adjustments int
+ latenciesMu sync.Mutex
+}
+
+// AutoTunerConfig configures the auto-tuner
+type AutoTunerConfig struct {
+ TargetLatency time.Duration // Target p95 latency (default: 50ms)
+ MinThreshold int // Min hub threshold (default: 2)
+ MaxThreshold int // Max hub threshold (default: 20)
+ AdjustPeriod time.Duration // Adjustment frequency (default: 5min)
+}
+
+// DefaultAutoTunerConfig returns sensible defaults
+func DefaultAutoTunerConfig() AutoTunerConfig {
+ return AutoTunerConfig{
+ TargetLatency: 50 * time.Millisecond,
+ MinThreshold: 2,
+ MaxThreshold: 20,
+ AdjustPeriod: 5 * time.Minute,
+ }
+}
+
+// NewAutoTuner creates a new auto-tuner for the hybrid client
+func NewAutoTuner(client *Client, cfg AutoTunerConfig) *AutoTuner {
+ ctx, cancel := context.WithCancel(context.Background())
+
+ tuner := &AutoTuner{
+ client: client,
+ targetLatency: cfg.TargetLatency,
+ minThreshold: cfg.MinThreshold,
+ maxThreshold: cfg.MaxThreshold,
+ adjustPeriod: cfg.AdjustPeriod,
+ latencies: make([]time.Duration, 0, 1000),
+ ctx: ctx,
+ cancel: cancel,
+ }
+
+ return tuner
+}
+
+// Start begins auto-tuning in the background
+func (a *AutoTuner) Start() {
+ a.wg.Add(1)
+ go a.tuningLoop()
+
+ log.Info().
+ Dur("target_latency", a.targetLatency).
+ Int("min_threshold", a.minThreshold).
+ Int("max_threshold", a.maxThreshold).
+ Dur("adjust_period", a.adjustPeriod).
+ Msg("Auto-tuner started")
+}
+
+// Stop stops the auto-tuner
+func (a *AutoTuner) Stop() {
+ a.cancel()
+ a.wg.Wait()
+ log.Info().Msg("Auto-tuner stopped")
+}
+
+// RecordQuery records a query latency for analysis
+func (a *AutoTuner) RecordQuery(latency time.Duration) {
+ a.latenciesMu.Lock()
+ defer a.latenciesMu.Unlock()
+
+ a.queries++
+ a.latencies = append(a.latencies, latency)
+
+ // Keep only recent queries (last 1000)
+ if len(a.latencies) > 1000 {
+ a.latencies = a.latencies[len(a.latencies)-1000:]
+ }
+}
+
+// tuningLoop periodically adjusts hub threshold
+func (a *AutoTuner) tuningLoop() {
+ defer a.wg.Done()
+
+ ticker := time.NewTicker(a.adjustPeriod)
+ defer ticker.Stop()
+
+ for {
+ select {
+ case <-a.ctx.Done():
+ return
+
+ case <-ticker.C:
+ a.adjustThreshold()
+ }
+ }
+}
+
+// adjustThreshold analyzes recent queries and adjusts hub threshold
+func (a *AutoTuner) adjustThreshold() {
+ a.latenciesMu.Lock()
+ defer a.latenciesMu.Unlock()
+
+ if len(a.latencies) < 10 {
+ // Not enough data yet
+ return
+ }
+
+ // Calculate p95 latency
+ p95 := calculateP95(a.latencies)
+
+ currentThreshold := a.client.hubThreshold
+
+ log.Debug().
+ Dur("p95_latency", p95).
+ Dur("target_latency", a.targetLatency).
+ Int("current_threshold", currentThreshold).
+ Int("queries", len(a.latencies)).
+ Msg("Auto-tuner evaluating performance")
+
+ // Determine adjustment direction
+ var newThreshold int
+
+ if p95 > a.targetLatency {
+ // Too slow - lower threshold (more hubs = faster queries)
+ adjustment := calculateAdjustment(p95, a.targetLatency)
+ newThreshold = currentThreshold - adjustment
+
+ if newThreshold < a.minThreshold {
+ newThreshold = a.minThreshold
+ }
+
+ log.Info().
+ Dur("p95", p95).
+ Int("old_threshold", currentThreshold).
+ Int("new_threshold", newThreshold).
+ Msg("Auto-tuner: Lowering hub threshold (too slow)")
+
+ } else if p95 < a.targetLatency*8/10 {
+ // Too fast - raise threshold (fewer hubs = more savings)
+ // Only adjust if significantly faster (20% margin)
+ adjustment := calculateAdjustment(a.targetLatency, p95)
+ newThreshold = currentThreshold + adjustment
+
+ if newThreshold > a.maxThreshold {
+ newThreshold = a.maxThreshold
+ }
+
+ log.Info().
+ Dur("p95", p95).
+ Int("old_threshold", currentThreshold).
+ Int("new_threshold", newThreshold).
+ Msg("Auto-tuner: Raising hub threshold (room for savings)")
+
+ } else {
+ // Within acceptable range, no adjustment needed
+ log.Debug().
+ Dur("p95", p95).
+ Int("threshold", currentThreshold).
+ Msg("Auto-tuner: Performance acceptable, no adjustment")
+ return
+ }
+
+ // Apply adjustment
+ if newThreshold != currentThreshold {
+ a.client.hubThreshold = newThreshold
+ a.adjustments++
+
+ // Clear latency history after adjustment
+ a.latencies = make([]time.Duration, 0, 1000)
+
+ log.Info().
+ Int("threshold", newThreshold).
+ Int("total_adjustments", a.adjustments).
+ Msg("Hub threshold adjusted by auto-tuner")
+ }
+}
+
+// calculateP95 computes the 95th percentile latency
+func calculateP95(latencies []time.Duration) time.Duration {
+ if len(latencies) == 0 {
+ return 0
+ }
+
+ // Sort latencies
+ sorted := make([]time.Duration, len(latencies))
+ copy(sorted, latencies)
+
+ // Simple bubble sort (small dataset)
+ n := len(sorted)
+ for i := 0; i < n-1; i++ {
+ for j := 0; j < n-i-1; j++ {
+ if sorted[j] > sorted[j+1] {
+ sorted[j], sorted[j+1] = sorted[j+1], sorted[j]
+ }
+ }
+ }
+
+ // Return 95th percentile
+ idx := int(float64(len(sorted)) * 0.95)
+ if idx >= len(sorted) {
+ idx = len(sorted) - 1
+ }
+
+ return sorted[idx]
+}
+
+// calculateAdjustment determines how much to adjust threshold
+func calculateAdjustment(actual, target time.Duration) int {
+ // Calculate percentage difference
+ diff := float64(actual-target) / float64(target)
+
+ // Adjust more aggressively for larger differences
+ if diff > 0.5 || diff < -0.5 {
+ return 3 // Large adjustment
+ } else if diff > 0.2 || diff < -0.2 {
+ return 2 // Medium adjustment
+ }
+
+ return 1 // Small adjustment
+}
+
+// GetStats returns auto-tuner statistics
+func (a *AutoTuner) GetStats() AutoTunerStats {
+ a.latenciesMu.Lock()
+ defer a.latenciesMu.Unlock()
+
+ stats := AutoTunerStats{
+ CurrentThreshold: a.client.hubThreshold,
+ TargetLatency: a.targetLatency,
+ TotalQueries: a.queries,
+ TotalAdjustments: a.adjustments,
+ RecentQueries: len(a.latencies),
+ }
+
+ if len(a.latencies) > 0 {
+ stats.P95Latency = calculateP95(a.latencies)
+
+ // Calculate average
+ var total time.Duration
+ for _, lat := range a.latencies {
+ total += lat
+ }
+ stats.AvgLatency = total / time.Duration(len(a.latencies))
+ }
+
+ return stats
+}
+
+// AutoTunerStats contains auto-tuner statistics
+type AutoTunerStats struct {
+ CurrentThreshold int
+ TargetLatency time.Duration
+ P95Latency time.Duration
+ AvgLatency time.Duration
+ TotalQueries int64
+ TotalAdjustments int
+ RecentQueries int
+}
+
+// AutoTunedClient wraps Client with automatic performance tuning
+type AutoTunedClient struct {
+ *Client
+ tuner *AutoTuner
+}
+
+// Query wraps the underlying Query call with latency tracking
+func (a *AutoTunedClient) Query(ctx context.Context, query string, limit int, where map[string]any) ([]sqlitevec.QueryResult, error) {
+ start := time.Now()
+ results, err := a.Client.Query(ctx, query, limit, where)
+ latency := time.Since(start)
+
+ a.tuner.RecordQuery(latency)
+
+ return results, err
+}
+
+// WithAutoTuning wraps a hybrid client with auto-tuning enabled
+func WithAutoTuning(client *Client, cfg AutoTunerConfig) *AutoTunedClient {
+ tuner := NewAutoTuner(client, cfg)
+ tuner.Start()
+
+ return &AutoTunedClient{
+ Client: client,
+ tuner: tuner,
+ }
+}
+
+// Stop stops the auto-tuner
+func (a *AutoTunedClient) StopTuning() {
+ a.tuner.Stop()
+}
diff --git a/internal/vector/hybrid/client.go b/internal/vector/hybrid/client.go
new file mode 100644
index 0000000..5a1b99a
--- /dev/null
+++ b/internal/vector/hybrid/client.go
@@ -0,0 +1,515 @@
+// Package hybrid provides LEANN-inspired selective vector storage for claude-mnemonic.
+//
+// This package implements a hybrid storage strategy where frequently-accessed
+// observations ("hubs") have their embeddings stored, while infrequently-accessed
+// observations have their embeddings recomputed on-demand during search.
+//
+// This approach reduces storage by 60-80% with minimal impact on search latency (<50ms).
+package hybrid
+
+import (
+ "context"
+ "database/sql"
+ "fmt"
+ "math"
+ "sync"
+ "time"
+
+ "github.com/lukaszraczylo/claude-mnemonic/internal/embedding"
+ "github.com/lukaszraczylo/claude-mnemonic/internal/vector/sqlitevec"
+ "github.com/rs/zerolog/log"
+)
+
+// VectorStorageStrategy defines how embeddings are stored/computed
+type VectorStorageStrategy int
+
+const (
+ // StorageAlways stores all embeddings (current behavior, backwards compatible)
+ StorageAlways VectorStorageStrategy = iota
+ // StorageHub stores only frequently-accessed "hub" embeddings (recommended)
+ StorageHub
+ // StorageOnDemand recomputes all embeddings during search (maximum savings)
+ StorageOnDemand
+)
+
+// Client wraps sqlitevec.Client with selective storage logic
+type Client struct {
+ base *sqlitevec.Client
+ db *sql.DB
+ embedSvc *embedding.Service
+ accessCount map[string]int
+ lastAccess map[string]time.Time
+ contentCache map[string]string
+ strategy VectorStorageStrategy
+ hubThreshold int
+ mu sync.RWMutex
+ cacheMu sync.RWMutex
+}
+
+// Config for hybrid client
+type Config struct {
+ BaseClient *sqlitevec.Client
+ DB *sql.DB
+ EmbedSvc *embedding.Service
+ Strategy VectorStorageStrategy
+ HubThreshold int // Default: 5 accesses
+}
+
+// NewClient creates a new hybrid vector client
+func NewClient(cfg Config) *Client {
+ if cfg.HubThreshold <= 0 {
+ cfg.HubThreshold = 5
+ }
+
+ log.Info().
+ Str("strategy", strategyToString(cfg.Strategy)).
+ Int("hub_threshold", cfg.HubThreshold).
+ Msg("Initializing LEANN hybrid vector client")
+
+ return &Client{
+ base: cfg.BaseClient,
+ db: cfg.DB,
+ embedSvc: cfg.EmbedSvc,
+ strategy: cfg.Strategy,
+ hubThreshold: cfg.HubThreshold,
+ accessCount: make(map[string]int),
+ lastAccess: make(map[string]time.Time),
+ contentCache: make(map[string]string),
+ }
+}
+
+// AddDocuments implements selective storage based on strategy
+func (c *Client) AddDocuments(ctx context.Context, docs []sqlitevec.Document) error {
+ if len(docs) == 0 {
+ return nil
+ }
+
+ switch c.strategy {
+ case StorageAlways:
+ // Use existing implementation - store all embeddings
+ return c.base.AddDocuments(ctx, docs)
+
+ case StorageHub:
+ // Store only hub candidates
+ return c.addDocumentsSelective(ctx, docs)
+
+ case StorageOnDemand:
+ // Don't store embeddings, only cache content
+ return c.cacheDocuments(ctx, docs)
+
+ default:
+ return c.base.AddDocuments(ctx, docs)
+ }
+}
+
+// addDocumentsSelective stores embeddings only for hub-qualified documents
+func (c *Client) addDocumentsSelective(ctx context.Context, docs []sqlitevec.Document) error {
+ // Always cache content for potential recomputation
+ if err := c.cacheDocuments(ctx, docs); err != nil {
+ return err
+ }
+
+ // Filter to hub documents
+ hubDocs := make([]sqlitevec.Document, 0, len(docs))
+ for _, doc := range docs {
+ if c.isHub(doc.ID) {
+ hubDocs = append(hubDocs, doc)
+ }
+ }
+
+ // Store only hub embeddings
+ if len(hubDocs) > 0 {
+ log.Debug().
+ Int("total", len(docs)).
+ Int("hubs", len(hubDocs)).
+ Msg("Storing selective embeddings")
+ return c.base.AddDocuments(ctx, hubDocs)
+ }
+
+ log.Debug().Int("total", len(docs)).Msg("All documents cached, no hubs to store")
+ return nil
+}
+
+// cacheDocuments stores content for later recomputation
+func (c *Client) cacheDocuments(ctx context.Context, docs []sqlitevec.Document) error {
+ c.cacheMu.Lock()
+ defer c.cacheMu.Unlock()
+
+ for _, doc := range docs {
+ c.contentCache[doc.ID] = doc.Content
+ }
+
+ return nil
+}
+
+// DeleteDocuments removes documents by their IDs
+func (c *Client) DeleteDocuments(ctx context.Context, ids []string) error {
+ // Remove from base storage
+ if err := c.base.DeleteDocuments(ctx, ids); err != nil {
+ return err
+ }
+
+ // Clean up caches
+ c.mu.Lock()
+ for _, id := range ids {
+ delete(c.accessCount, id)
+ delete(c.lastAccess, id)
+ }
+ c.mu.Unlock()
+
+ c.cacheMu.Lock()
+ for _, id := range ids {
+ delete(c.contentCache, id)
+ }
+ c.cacheMu.Unlock()
+
+ return nil
+}
+
+// Query performs search with dynamic recomputation
+func (c *Client) Query(ctx context.Context, query string, limit int, where map[string]any) ([]sqlitevec.QueryResult, error) {
+ switch c.strategy {
+ case StorageAlways:
+ // Use existing implementation
+ return c.queryAndTrack(ctx, query, limit, where)
+
+ case StorageHub:
+ // Search hubs, then expand with recomputation
+ return c.queryHybrid(ctx, query, limit, where)
+
+ case StorageOnDemand:
+ // Fully dynamic search
+ return c.queryDynamic(ctx, query, limit, where)
+
+ default:
+ return c.queryAndTrack(ctx, query, limit, where)
+ }
+}
+
+// queryAndTrack wraps base Query with access tracking
+func (c *Client) queryAndTrack(ctx context.Context, query string, limit int, where map[string]any) ([]sqlitevec.QueryResult, error) {
+ results, err := c.base.Query(ctx, query, limit, where)
+ if err != nil {
+ return nil, err
+ }
+
+ // Track access for hub detection
+ c.trackAccess(results)
+
+ return results, nil
+}
+
+// queryHybrid searches stored hubs and recomputes non-hubs
+func (c *Client) queryHybrid(ctx context.Context, query string, limit int, where map[string]any) ([]sqlitevec.QueryResult, error) {
+ startTime := time.Now()
+
+ // 1. Query stored hub embeddings (limit * 2 for expansion)
+ hubResults, err := c.base.Query(ctx, query, limit*2, where)
+ if err != nil {
+ return nil, err
+ }
+
+ // 2. Track access
+ c.trackAccess(hubResults)
+
+ // 3. Get candidate non-hub IDs (from content cache)
+ candidates := c.getCandidateNonHubs(where, limit*2)
+
+ // 4. Recompute embeddings for candidates if we have any
+ var recomputedResults []sqlitevec.QueryResult
+ if len(candidates) > 0 {
+ recomputedResults, err = c.recomputeAndScore(ctx, query, candidates)
+ if err != nil {
+ // Log but don't fail - use hub results only
+ log.Warn().Err(err).Msg("Failed to recompute embeddings, using hub results only")
+ recomputedResults = nil
+ }
+ }
+
+ // 5. Merge and rank
+ allResults := append(hubResults, recomputedResults...)
+ sortBySimilarity(allResults)
+
+ // 6. Return top K
+ if len(allResults) > limit {
+ allResults = allResults[:limit]
+ }
+
+ duration := time.Since(startTime)
+ log.Debug().
+ Dur("duration_ms", duration).
+ Int("hubs", len(hubResults)).
+ Int("recomputed", len(recomputedResults)).
+ Int("results", len(allResults)).
+ Msg("Hybrid search completed")
+
+ return allResults, nil
+}
+
+// queryDynamic recomputes all embeddings on-the-fly
+func (c *Client) queryDynamic(ctx context.Context, query string, limit int, where map[string]any) ([]sqlitevec.QueryResult, error) {
+ startTime := time.Now()
+
+ // Get all candidate IDs from content cache
+ candidates := c.getCandidateNonHubs(where, limit*5)
+
+ // Recompute and score all
+ results, err := c.recomputeAndScore(ctx, query, candidates)
+ if err != nil {
+ return nil, err
+ }
+
+ // Track access
+ c.trackAccess(results)
+
+ // Return top K
+ if len(results) > limit {
+ results = results[:limit]
+ }
+
+ duration := time.Since(startTime)
+ log.Debug().
+ Dur("duration_ms", duration).
+ Int("recomputed", len(candidates)).
+ Int("results", len(results)).
+ Msg("Dynamic search completed")
+
+ return results, nil
+}
+
+// recomputeAndScore generates embeddings and computes similarities
+func (c *Client) recomputeAndScore(ctx context.Context, query string, candidateIDs []string) ([]sqlitevec.QueryResult, error) {
+ if len(candidateIDs) == 0 {
+ return nil, nil
+ }
+
+ // Generate query embedding
+ queryEmb, err := c.embedSvc.Embed(query)
+ if err != nil {
+ return nil, fmt.Errorf("embed query: %w", err)
+ }
+
+ // Get content for candidates
+ c.cacheMu.RLock()
+ texts := make([]string, 0, len(candidateIDs))
+ validIDs := make([]string, 0, len(candidateIDs))
+ for _, id := range candidateIDs {
+ if content, ok := c.contentCache[id]; ok && content != "" {
+ texts = append(texts, content)
+ validIDs = append(validIDs, id)
+ }
+ }
+ c.cacheMu.RUnlock()
+
+ if len(texts) == 0 {
+ return nil, nil
+ }
+
+ // Batch generate embeddings
+ embeddings, err := c.embedSvc.EmbedBatch(texts)
+ if err != nil {
+ return nil, fmt.Errorf("batch embed: %w", err)
+ }
+
+ // Compute similarities
+ results := make([]sqlitevec.QueryResult, len(embeddings))
+ for i, emb := range embeddings {
+ similarity := cosineSimilarity(queryEmb, emb)
+ distance := 1.0 - similarity // Convert to distance
+
+ results[i] = sqlitevec.QueryResult{
+ ID: validIDs[i],
+ Distance: float64(distance),
+ Similarity: float64(similarity),
+ Metadata: make(map[string]any),
+ }
+ }
+
+ return results, nil
+}
+
+// trackAccess records document access for hub detection
+func (c *Client) trackAccess(results []sqlitevec.QueryResult) {
+ if len(results) == 0 {
+ return
+ }
+
+ c.mu.Lock()
+ defer c.mu.Unlock()
+
+ now := time.Now()
+ for _, r := range results {
+ c.accessCount[r.ID]++
+ c.lastAccess[r.ID] = now
+ }
+}
+
+// isHub checks if a document qualifies as a hub
+func (c *Client) isHub(docID string) bool {
+ c.mu.RLock()
+ defer c.mu.RUnlock()
+
+ count := c.accessCount[docID]
+ return count >= c.hubThreshold
+}
+
+// getCandidateNonHubs returns IDs of non-hub documents matching filter
+func (c *Client) getCandidateNonHubs(where map[string]any, limit int) []string {
+ c.cacheMu.RLock()
+ defer c.cacheMu.RUnlock()
+
+ candidates := make([]string, 0, limit)
+ for id := range c.contentCache {
+ if !c.isHub(id) {
+ candidates = append(candidates, id)
+ if len(candidates) >= limit {
+ break
+ }
+ }
+ }
+
+ return candidates
+}
+
+// IsConnected always returns true (wraps base client)
+func (c *Client) IsConnected() bool {
+ return c.base.IsConnected()
+}
+
+// Close releases resources
+func (c *Client) Close() error {
+ return c.base.Close()
+}
+
+// Count returns the total number of vectors in the store
+func (c *Client) Count(ctx context.Context) (int64, error) {
+ return c.base.Count(ctx)
+}
+
+// ModelVersion returns the current embedding model version
+func (c *Client) ModelVersion() string {
+ return c.base.ModelVersion()
+}
+
+// NeedsRebuild checks if vectors need to be rebuilt due to model version change
+func (c *Client) NeedsRebuild(ctx context.Context) (bool, string) {
+ return c.base.NeedsRebuild(ctx)
+}
+
+// GetStaleVectors returns doc_ids of vectors with mismatched or null model versions
+func (c *Client) GetStaleVectors(ctx context.Context) ([]sqlitevec.StaleVectorInfo, error) {
+ return c.base.GetStaleVectors(ctx)
+}
+
+// DeleteVectorsByDocIDs removes vectors by their doc_ids
+func (c *Client) DeleteVectorsByDocIDs(ctx context.Context, docIDs []string) error {
+ return c.base.DeleteVectorsByDocIDs(ctx, docIDs)
+}
+
+// GetStorageStats returns storage efficiency metrics
+func (c *Client) GetStorageStats(ctx context.Context) (StorageStats, error) {
+ c.mu.RLock()
+ c.cacheMu.RLock()
+ defer c.mu.RUnlock()
+ defer c.cacheMu.RUnlock()
+
+ totalDocs := len(c.contentCache)
+ hubCount := 0
+ for id := range c.contentCache {
+ if c.accessCount[id] >= c.hubThreshold {
+ hubCount++
+ }
+ }
+
+ storedCount := hubCount
+ if c.strategy == StorageAlways {
+ // Get actual count from database
+ if count, err := c.base.Count(ctx); err == nil {
+ storedCount = int(count)
+ }
+ } else if c.strategy == StorageOnDemand {
+ storedCount = 0
+ }
+
+ embeddingSize := 384 * 4 // 384 dims × 4 bytes (float32)
+ storedBytes := storedCount * embeddingSize
+ potentialBytes := totalDocs * embeddingSize
+
+ savingsPercent := 0.0
+ if potentialBytes > 0 {
+ savingsPercent = (1.0 - float64(storedBytes)/float64(potentialBytes)) * 100
+ }
+
+ return StorageStats{
+ TotalDocuments: totalDocs,
+ HubDocuments: hubCount,
+ StoredEmbeddings: storedCount,
+ StorageBytes: storedBytes,
+ SavingsPercent: savingsPercent,
+ Strategy: c.strategy,
+ }, nil
+}
+
+// StorageStats contains storage efficiency metrics
+type StorageStats struct {
+ TotalDocuments int
+ HubDocuments int
+ StoredEmbeddings int
+ StorageBytes int
+ SavingsPercent float64
+ Strategy VectorStorageStrategy
+}
+
+// Helper functions
+
+func cosineSimilarity(a, b []float32) float32 {
+ var dotProduct, normA, normB float32
+ for i := range a {
+ dotProduct += a[i] * b[i]
+ normA += a[i] * a[i]
+ normB += b[i] * b[i]
+ }
+ if normA == 0 || normB == 0 {
+ return 0
+ }
+ return dotProduct / float32(math.Sqrt(float64(normA))*math.Sqrt(float64(normB)))
+}
+
+func sortBySimilarity(results []sqlitevec.QueryResult) {
+ // Use a simple but efficient sorting algorithm
+ n := len(results)
+ for i := 0; i < n-1; i++ {
+ for j := 0; j < n-i-1; j++ {
+ if results[j].Similarity < results[j+1].Similarity {
+ results[j], results[j+1] = results[j+1], results[j]
+ }
+ }
+ }
+}
+
+func strategyToString(s VectorStorageStrategy) string {
+ switch s {
+ case StorageAlways:
+ return "always"
+ case StorageHub:
+ return "hub"
+ case StorageOnDemand:
+ return "on_demand"
+ default:
+ return "unknown"
+ }
+}
+
+// ParseStrategy converts a string to VectorStorageStrategy
+func ParseStrategy(s string) VectorStorageStrategy {
+ switch s {
+ case "hub":
+ return StorageHub
+ case "on_demand":
+ return StorageOnDemand
+ case "always":
+ return StorageAlways
+ default:
+ return StorageHub // Default to hub strategy
+ }
+}
diff --git a/internal/vector/hybrid/client_test.go b/internal/vector/hybrid/client_test.go
new file mode 100644
index 0000000..b567ea5
--- /dev/null
+++ b/internal/vector/hybrid/client_test.go
@@ -0,0 +1,187 @@
+package hybrid
+
+import (
+ "testing"
+
+ "github.com/lukaszraczylo/claude-mnemonic/internal/vector/sqlitevec"
+ _ "github.com/mattn/go-sqlite3" // Import SQLite driver for CGO linking
+ "github.com/stretchr/testify/assert"
+)
+
+func TestParseStrategy(t *testing.T) {
+ tests := []struct {
+ name string
+ input string
+ expected VectorStorageStrategy
+ }{
+ {"hub_strategy", "hub", StorageHub},
+ {"on_demand_strategy", "on_demand", StorageOnDemand},
+ {"always_strategy", "always", StorageAlways},
+ {"invalid_defaults_to_hub", "invalid", StorageHub},
+ {"empty_defaults_to_hub", "", StorageHub},
+ }
+
+ for _, tt := range tests {
+ t.Run(tt.name, func(t *testing.T) {
+ result := ParseStrategy(tt.input)
+ assert.Equal(t, tt.expected, result)
+ })
+ }
+}
+
+func TestStrategyToString(t *testing.T) {
+ tests := []struct {
+ name string
+ expected string
+ input VectorStorageStrategy
+ }{
+ {"hub_to_string", "hub", StorageHub},
+ {"on_demand_to_string", "on_demand", StorageOnDemand},
+ {"always_to_string", "always", StorageAlways},
+ {"invalid_to_unknown", "unknown", VectorStorageStrategy(99)},
+ }
+
+ for _, tt := range tests {
+ t.Run(tt.name, func(t *testing.T) {
+ result := strategyToString(tt.input)
+ assert.Equal(t, tt.expected, result)
+ })
+ }
+}
+
+func TestCosineSimilarity(t *testing.T) {
+ tests := []struct {
+ name string
+ a []float32
+ b []float32
+ expected float32
+ }{
+ {
+ name: "identical_vectors",
+ a: []float32{1, 0, 0},
+ b: []float32{1, 0, 0},
+ expected: 1.0,
+ },
+ {
+ name: "orthogonal_vectors",
+ a: []float32{1, 0, 0},
+ b: []float32{0, 1, 0},
+ expected: 0.0,
+ },
+ {
+ name: "opposite_vectors",
+ a: []float32{1, 0, 0},
+ b: []float32{-1, 0, 0},
+ expected: -1.0,
+ },
+ {
+ name: "zero_vector",
+ a: []float32{0, 0, 0},
+ b: []float32{1, 1, 1},
+ expected: 0.0,
+ },
+ {
+ name: "parallel_vectors",
+ a: []float32{2, 0, 0},
+ b: []float32{4, 0, 0},
+ expected: 1.0,
+ },
+ }
+
+ for _, tt := range tests {
+ t.Run(tt.name, func(t *testing.T) {
+ result := cosineSimilarity(tt.a, tt.b)
+ assert.InDelta(t, tt.expected, result, 0.001)
+ })
+ }
+}
+
+func TestSortBySimilarity(t *testing.T) {
+ tests := []struct {
+ name string
+ input []sqlitevec.QueryResult
+ expected []string // Expected order of IDs
+ }{
+ {
+ name: "already_sorted",
+ input: []sqlitevec.QueryResult{
+ {ID: "doc1", Similarity: 0.9},
+ {ID: "doc2", Similarity: 0.7},
+ {ID: "doc3", Similarity: 0.5},
+ },
+ expected: []string{"doc1", "doc2", "doc3"},
+ },
+ {
+ name: "reverse_sorted",
+ input: []sqlitevec.QueryResult{
+ {ID: "doc1", Similarity: 0.3},
+ {ID: "doc2", Similarity: 0.7},
+ {ID: "doc3", Similarity: 0.9},
+ },
+ expected: []string{"doc3", "doc2", "doc1"},
+ },
+ {
+ name: "random_order",
+ input: []sqlitevec.QueryResult{
+ {ID: "doc1", Similarity: 0.5},
+ {ID: "doc2", Similarity: 0.9},
+ {ID: "doc3", Similarity: 0.3},
+ {ID: "doc4", Similarity: 0.7},
+ },
+ expected: []string{"doc2", "doc4", "doc1", "doc3"},
+ },
+ {
+ name: "identical_similarities",
+ input: []sqlitevec.QueryResult{
+ {ID: "doc1", Similarity: 0.5},
+ {ID: "doc2", Similarity: 0.5},
+ {ID: "doc3", Similarity: 0.5},
+ },
+ expected: []string{"doc1", "doc2", "doc3"},
+ },
+ {
+ name: "empty_list",
+ input: []sqlitevec.QueryResult{},
+ expected: []string{},
+ },
+ {
+ name: "single_element",
+ input: []sqlitevec.QueryResult{
+ {ID: "doc1", Similarity: 0.5},
+ },
+ expected: []string{"doc1"},
+ },
+ }
+
+ for _, tt := range tests {
+ t.Run(tt.name, func(t *testing.T) {
+ sortBySimilarity(tt.input)
+
+ actual := make([]string, len(tt.input))
+ for i, r := range tt.input {
+ actual[i] = r.ID
+ }
+
+ assert.Equal(t, tt.expected, actual)
+ })
+ }
+}
+
+func TestSortBySimilarity_PreserveOtherFields(t *testing.T) {
+ input := []sqlitevec.QueryResult{
+ {ID: "doc1", Similarity: 0.3, Distance: 0.7, Metadata: map[string]any{"key": "val1"}},
+ {ID: "doc2", Similarity: 0.9, Distance: 0.1, Metadata: map[string]any{"key": "val2"}},
+ }
+
+ sortBySimilarity(input)
+
+ assert.Equal(t, "doc2", input[0].ID)
+ assert.InDelta(t, 0.9, input[0].Similarity, 0.001)
+ assert.InDelta(t, 0.1, input[0].Distance, 0.001)
+ assert.Equal(t, "val2", input[0].Metadata["key"])
+
+ assert.Equal(t, "doc1", input[1].ID)
+ assert.InDelta(t, 0.3, input[1].Similarity, 0.001)
+ assert.InDelta(t, 0.7, input[1].Distance, 0.001)
+ assert.Equal(t, "val1", input[1].Metadata["key"])
+}
diff --git a/internal/vector/hybrid/config.go b/internal/vector/hybrid/config.go
new file mode 100644
index 0000000..4cac342
--- /dev/null
+++ b/internal/vector/hybrid/config.go
@@ -0,0 +1,62 @@
+package hybrid
+
+import (
+ "os"
+ "strconv"
+
+ "github.com/rs/zerolog/log"
+)
+
+// GetStrategyFromEnv reads CLAUDE_MNEMONIC_VECTOR_STRATEGY from environment
+func GetStrategyFromEnv() VectorStorageStrategy {
+ strategyStr := os.Getenv("CLAUDE_MNEMONIC_VECTOR_STRATEGY")
+ if strategyStr == "" {
+ // Default to hub strategy for optimal balance
+ return StorageHub
+ }
+
+ strategy := ParseStrategy(strategyStr)
+ log.Info().
+ Str("env_value", strategyStr).
+ Str("strategy", strategyToString(strategy)).
+ Msg("Vector storage strategy from environment")
+
+ return strategy
+}
+
+// GetHubThresholdFromEnv reads CLAUDE_MNEMONIC_HUB_THRESHOLD from environment
+func GetHubThresholdFromEnv() int {
+ thresholdStr := os.Getenv("CLAUDE_MNEMONIC_HUB_THRESHOLD")
+ if thresholdStr == "" {
+ return 5 // Default threshold
+ }
+
+ threshold, err := strconv.Atoi(thresholdStr)
+ if err != nil {
+ log.Warn().
+ Err(err).
+ Str("env_value", thresholdStr).
+ Msg("Invalid hub threshold in environment, using default")
+ return 5
+ }
+
+ if threshold < 1 {
+ log.Warn().
+ Int("env_value", threshold).
+ Msg("Hub threshold too low, using minimum of 1")
+ return 1
+ }
+
+ log.Info().
+ Int("threshold", threshold).
+ Msg("Hub threshold from environment")
+
+ return threshold
+}
+
+// IsHybridEnabled checks if hybrid storage should be used
+// Returns false if CLAUDE_MNEMONIC_VECTOR_STRATEGY=always (backwards compat)
+func IsHybridEnabled() bool {
+ strategy := GetStrategyFromEnv()
+ return strategy != StorageAlways
+}
diff --git a/internal/vector/hybrid/graph_search.go b/internal/vector/hybrid/graph_search.go
new file mode 100644
index 0000000..110cfa3
--- /dev/null
+++ b/internal/vector/hybrid/graph_search.go
@@ -0,0 +1,308 @@
+package hybrid
+
+import (
+ "context"
+ "fmt"
+ "sort"
+ "time"
+
+ "github.com/lukaszraczylo/claude-mnemonic/internal/graph"
+ "github.com/lukaszraczylo/claude-mnemonic/internal/vector/sqlitevec"
+ "github.com/lukaszraczylo/claude-mnemonic/pkg/models"
+ "github.com/rs/zerolog/log"
+)
+
+// GraphConfig configures graph-aware search
+type GraphConfig struct {
+ Enabled bool
+ MaxHops int // Maximum graph traversal depth (default: 2)
+ BranchFactor int // Number of neighbors to expand per node (default: 5)
+ EdgeWeight float64 // Minimum edge weight to follow (default: 0.3)
+}
+
+// DefaultGraphConfig returns sensible defaults for graph search
+func DefaultGraphConfig() GraphConfig {
+ return GraphConfig{
+ Enabled: true,
+ MaxHops: 2,
+ BranchFactor: 5,
+ EdgeWeight: 0.3,
+ }
+}
+
+// GraphSearchClient wraps hybrid.Client with graph-aware search
+type GraphSearchClient struct {
+ *Client
+ graph *graph.ObservationGraph
+ graphConfig GraphConfig
+}
+
+// NewGraphSearchClient creates a graph-enhanced hybrid client
+func NewGraphSearchClient(baseClient *Client, observationGraph *graph.ObservationGraph, cfg GraphConfig) *GraphSearchClient {
+ return &GraphSearchClient{
+ Client: baseClient,
+ graph: observationGraph,
+ graphConfig: cfg,
+ }
+}
+
+// Query performs graph-aware vector search with two-level traversal
+func (g *GraphSearchClient) Query(ctx context.Context, query string, limit int, where map[string]any) ([]sqlitevec.QueryResult, error) {
+ if !g.graphConfig.Enabled || g.graph == nil {
+ // Fall back to standard hybrid search
+ return g.Client.Query(ctx, query, limit, where)
+ }
+
+ startTime := time.Now()
+
+ // 1. Generate query embedding
+ queryEmb, err := g.embedSvc.Embed(query)
+ if err != nil {
+ return nil, fmt.Errorf("embed query: %w", err)
+ }
+
+ // 2. Search hub nodes (stored embeddings)
+ hubResults, err := g.base.Query(ctx, query, limit*2, where)
+ if err != nil {
+ // Fall back to standard search on error
+ log.Warn().Err(err).Msg("Hub search failed, falling back to hybrid search")
+ return g.Client.Query(ctx, query, limit, where)
+ }
+
+ // 3. Track hub access
+ g.trackAccess(hubResults)
+
+ // 4. Expand via graph traversal
+ expandedIDs := g.expandFromHubs(hubResults, limit*4)
+
+ // 5. Filter to non-hubs that need recomputation
+ nonHubIDs := make([]string, 0)
+ for _, id := range expandedIDs {
+ if !g.isHub(id) {
+ nonHubIDs = append(nonHubIDs, id)
+ }
+ }
+
+ // 6. Batch recompute non-hub embeddings
+ recomputedResults, err := g.recomputeAndScore(ctx, query, nonHubIDs)
+ if err != nil {
+ log.Warn().Err(err).Msg("Recomputation failed, using hub results only")
+ recomputedResults = nil
+ }
+
+ // 7. Apply graph-based ranking boost
+ allResults := g.mergeAndRankWithGraph(hubResults, recomputedResults, queryEmb)
+
+ // 8. Return top K
+ if len(allResults) > limit {
+ allResults = allResults[:limit]
+ }
+
+ duration := time.Since(startTime)
+ log.Debug().
+ Dur("duration_ms", duration).
+ Int("hubs", len(hubResults)).
+ Int("expanded", len(expandedIDs)).
+ Int("recomputed", len(recomputedResults)).
+ Int("results", len(allResults)).
+ Msg("Graph search completed")
+
+ return allResults, nil
+}
+
+// expandFromHubs traverses graph from hub nodes to find promising candidates
+func (g *GraphSearchClient) expandFromHubs(hubResults []sqlitevec.QueryResult, maxCandidates int) []string {
+ if g.graph == nil {
+ return nil
+ }
+
+ expanded := make(map[string]float64) // doc_id -> relevance score
+ visited := make(map[int64]bool)
+
+ // Start from top hub results
+ for i, result := range hubResults {
+ if i >= g.graphConfig.BranchFactor*2 {
+ break // Limit starting points
+ }
+
+ // Parse observation ID from doc_id
+ obsID := parseObservationID(result.ID)
+ if obsID == 0 {
+ continue
+ }
+
+ // Mark as visited with high relevance (direct match)
+ visited[obsID] = true
+ expanded[result.ID] = result.Similarity
+
+ // Traverse graph from this hub
+ g.traverseGraph(obsID, result.Similarity, 0, expanded, visited)
+ }
+
+ // Convert to sorted list
+ type candidate struct {
+ ID string
+ Relevance float64
+ }
+
+ candidates := make([]candidate, 0, len(expanded))
+ for id, rel := range expanded {
+ candidates = append(candidates, candidate{ID: id, Relevance: rel})
+ }
+
+ // Sort by relevance descending
+ sort.Slice(candidates, func(i, j int) bool {
+ return candidates[i].Relevance > candidates[j].Relevance
+ })
+
+ // Return top candidates
+ if len(candidates) > maxCandidates {
+ candidates = candidates[:maxCandidates]
+ }
+
+ result := make([]string, len(candidates))
+ for i, c := range candidates {
+ result[i] = c.ID
+ }
+
+ return result
+}
+
+// traverseGraph performs depth-limited graph traversal
+func (g *GraphSearchClient) traverseGraph(nodeID int64, baseRelevance float64, depth int, expanded map[string]float64, visited map[int64]bool) {
+ if depth >= g.graphConfig.MaxHops {
+ return // Max depth reached
+ }
+
+ // Get neighbors from graph
+ neighbors, weights, err := g.graph.GetNeighbors(nodeID)
+ if err != nil {
+ return // No neighbors or error
+ }
+
+ // Traverse top neighbors by weight
+ type neighborWeight struct {
+ ID int64
+ Weight float32
+ }
+
+ neighborList := make([]neighborWeight, len(neighbors))
+ for i := range neighbors {
+ neighborList[i] = neighborWeight{
+ ID: neighbors[i],
+ Weight: weights[i],
+ }
+ }
+
+ // Sort by weight descending
+ sort.Slice(neighborList, func(i, j int) bool {
+ return neighborList[i].Weight > neighborList[j].Weight
+ })
+
+ // Expand top branch_factor neighbors
+ expanded_count := 0
+ for _, nw := range neighborList {
+ if expanded_count >= g.graphConfig.BranchFactor {
+ break
+ }
+
+ // Skip if edge weight too low
+ if float64(nw.Weight) < g.graphConfig.EdgeWeight {
+ continue
+ }
+
+ // Skip if already visited
+ if visited[nw.ID] {
+ continue
+ }
+ visited[nw.ID] = true
+
+ // Calculate propagated relevance (decays with distance)
+ decay := 0.7 // 30% decay per hop
+ propagatedRelevance := baseRelevance * float64(nw.Weight) * decay
+
+ // Add to expanded set
+ docID := formatObservationDocID(nw.ID)
+ if existing, ok := expanded[docID]; !ok || propagatedRelevance > existing {
+ expanded[docID] = propagatedRelevance
+ }
+
+ // Recursively traverse
+ g.traverseGraph(nw.ID, propagatedRelevance, depth+1, expanded, visited)
+ expanded_count++
+ }
+}
+
+// mergeAndRankWithGraph combines hub and recomputed results with graph-based ranking
+func (g *GraphSearchClient) mergeAndRankWithGraph(hubResults, recomputedResults []sqlitevec.QueryResult, queryEmb []float32) []sqlitevec.QueryResult {
+ // Merge results
+ allResults := append(hubResults, recomputedResults...)
+
+ // Apply graph-based re-ranking
+ if g.graph != nil {
+ for i := range allResults {
+ obsID := parseObservationID(allResults[i].ID)
+ if obsID == 0 {
+ continue
+ }
+
+ // Boost score based on node degree (hubs are more important)
+ node, err := g.graph.GetNode(obsID)
+ if err == nil && node.Degree > 0 {
+ // Degree boost: up to 10% increase for high-degree nodes
+ degreeBoost := 1.0 + (0.1 * float64(node.Degree) / 20.0)
+ if degreeBoost > 1.1 {
+ degreeBoost = 1.1
+ }
+ allResults[i].Similarity *= degreeBoost
+ }
+ }
+ }
+
+ // Sort by adjusted similarity
+ sortBySimilarity(allResults)
+
+ return allResults
+}
+
+// parseObservationID extracts observation ID from doc_id
+// Format: "obs-{id}-{field}"
+func parseObservationID(docID string) int64 {
+ var obsID int64
+ // Ignore error - returns 0 on parse failure, which callers handle
+ _, _ = fmt.Sscanf(docID, "obs-%d-", &obsID)
+ return obsID
+}
+
+// formatObservationDocID creates a doc_id for an observation
+func formatObservationDocID(obsID int64) string {
+ return fmt.Sprintf("obs-%d-combined", obsID)
+}
+
+// GetGraphStats returns statistics about the observation graph
+func (g *GraphSearchClient) GetGraphStats() graph.GraphStats {
+ if g.graph == nil {
+ return graph.GraphStats{}
+ }
+ return g.graph.Stats()
+}
+
+// RebuildGraph rebuilds the observation graph from current observations
+// This should be called periodically or when observations change significantly
+func (g *GraphSearchClient) RebuildGraph(ctx context.Context, observations []*models.Observation) error {
+ log.Info().Int("observations", len(observations)).Msg("Rebuilding observation graph")
+
+ newGraph, err := graph.BuildFromObservations(ctx, observations)
+ if err != nil {
+ return fmt.Errorf("build graph: %w", err)
+ }
+
+ g.graph = newGraph
+
+ log.Info().
+ Int("nodes", newGraph.Stats().NodeCount).
+ Int("edges", newGraph.Stats().EdgeCount).
+ Msg("Graph rebuilt successfully")
+
+ return nil
+}
diff --git a/internal/vector/hybrid/interface_test.go b/internal/vector/hybrid/interface_test.go
new file mode 100644
index 0000000..dc890d6
--- /dev/null
+++ b/internal/vector/hybrid/interface_test.go
@@ -0,0 +1,17 @@
+package hybrid
+
+import (
+ "testing"
+
+ "github.com/lukaszraczylo/claude-mnemonic/internal/vector"
+ _ "github.com/mattn/go-sqlite3" // Import SQLite driver for CGO linking
+)
+
+// TestInterfaceImplementation verifies that hybrid clients implement vector.Client interface
+func TestInterfaceImplementation(t *testing.T) {
+ // Compile-time check that Client implements vector.Client
+ var _ vector.Client = (*Client)(nil)
+
+ // Compile-time check that GraphSearchClient implements vector.Client
+ var _ vector.Client = (*GraphSearchClient)(nil)
+}
diff --git a/internal/vector/hybrid/metrics.go b/internal/vector/hybrid/metrics.go
new file mode 100644
index 0000000..2e6ca3c
--- /dev/null
+++ b/internal/vector/hybrid/metrics.go
@@ -0,0 +1,272 @@
+package hybrid
+
+import (
+ "fmt"
+ "sync"
+ "sync/atomic"
+ "time"
+)
+
+// Metrics tracks performance and usage statistics for hybrid vector storage
+type Metrics struct {
+ startTime time.Time
+ recentLatencies []time.Duration
+ latenciesMu sync.Mutex
+ totalQueries atomic.Int64
+ hubOnlyQueries atomic.Int64
+ hybridQueries atomic.Int64
+ onDemandQueries atomic.Int64
+ graphQueries atomic.Int64
+ totalLatency atomic.Int64 // Sum in microseconds
+ hubLatency atomic.Int64
+ recomputeLatency atomic.Int64
+ totalDocuments atomic.Int64
+ hubDocuments atomic.Int64
+ storedEmbeddings atomic.Int64
+ recomputedCount atomic.Int64
+ cacheHits atomic.Int64
+ cacheMisses atomic.Int64
+ graphTraversals atomic.Int64
+ avgTraversalDepth atomic.Int64
+}
+
+// NewMetrics creates a new metrics tracker
+func NewMetrics() *Metrics {
+ return &Metrics{
+ recentLatencies: make([]time.Duration, 0, 1000),
+ startTime: time.Now(),
+ }
+}
+
+// RecordQuery records a query execution
+func (m *Metrics) RecordQuery(queryType string, latency time.Duration, recomputed int) {
+ m.totalQueries.Add(1)
+ m.totalLatency.Add(latency.Microseconds())
+
+ switch queryType {
+ case "hub_only":
+ m.hubOnlyQueries.Add(1)
+ case "hybrid":
+ m.hybridQueries.Add(1)
+ case "on_demand":
+ m.onDemandQueries.Add(1)
+ case "graph":
+ m.graphQueries.Add(1)
+ }
+
+ if recomputed > 0 {
+ m.recomputedCount.Add(int64(recomputed))
+ }
+
+ // Track recent latencies
+ m.latenciesMu.Lock()
+ m.recentLatencies = append(m.recentLatencies, latency)
+ if len(m.recentLatencies) > 1000 {
+ m.recentLatencies = m.recentLatencies[len(m.recentLatencies)-1000:]
+ }
+ m.latenciesMu.Unlock()
+}
+
+// RecordHubLatency records time spent in hub search
+func (m *Metrics) RecordHubLatency(latency time.Duration) {
+ m.hubLatency.Add(latency.Microseconds())
+}
+
+// RecordRecomputeLatency records time spent recomputing embeddings
+func (m *Metrics) RecordRecomputeLatency(latency time.Duration) {
+ m.recomputeLatency.Add(latency.Microseconds())
+}
+
+// RecordCacheHit records a content cache hit
+func (m *Metrics) RecordCacheHit() {
+ m.cacheHits.Add(1)
+}
+
+// RecordCacheMiss records a content cache miss
+func (m *Metrics) RecordCacheMiss() {
+ m.cacheMisses.Add(1)
+}
+
+// RecordGraphTraversal records a graph traversal operation
+func (m *Metrics) RecordGraphTraversal(depth int) {
+ m.graphTraversals.Add(1)
+ m.avgTraversalDepth.Add(int64(depth))
+}
+
+// UpdateStorageStats updates current storage statistics
+func (m *Metrics) UpdateStorageStats(total, hubs, stored int) {
+ m.totalDocuments.Store(int64(total))
+ m.hubDocuments.Store(int64(hubs))
+ m.storedEmbeddings.Store(int64(stored))
+}
+
+// GetSnapshot returns current metrics snapshot
+func (m *Metrics) GetSnapshot() MetricsSnapshot {
+ m.latenciesMu.Lock()
+ defer m.latenciesMu.Unlock()
+
+ totalQueries := m.totalQueries.Load()
+
+ snapshot := MetricsSnapshot{
+ // Query counts
+ TotalQueries: totalQueries,
+ HubOnlyQueries: m.hubOnlyQueries.Load(),
+ HybridQueries: m.hybridQueries.Load(),
+ OnDemandQueries: m.onDemandQueries.Load(),
+ GraphQueries: m.graphQueries.Load(),
+
+ // Storage
+ TotalDocuments: int(m.totalDocuments.Load()),
+ HubDocuments: int(m.hubDocuments.Load()),
+ StoredEmbeddings: int(m.storedEmbeddings.Load()),
+ RecomputedTotal: m.recomputedCount.Load(),
+
+ // Cache
+ CacheHits: m.cacheHits.Load(),
+ CacheMisses: m.cacheMisses.Load(),
+
+ // Graph
+ GraphTraversals: m.graphTraversals.Load(),
+
+ // Runtime
+ Uptime: time.Since(m.startTime),
+ }
+
+ // Calculate latencies
+ if totalQueries > 0 {
+ snapshot.AvgLatency = time.Duration(m.totalLatency.Load()/totalQueries) * time.Microsecond
+ snapshot.AvgHubLatency = time.Duration(m.hubLatency.Load()/totalQueries) * time.Microsecond
+ }
+
+ if m.recomputedCount.Load() > 0 {
+ snapshot.AvgRecomputeLatency = time.Duration(m.recomputeLatency.Load()/m.recomputedCount.Load()) * time.Microsecond
+ }
+
+ // Calculate percentiles
+ if len(m.recentLatencies) > 0 {
+ sorted := make([]time.Duration, len(m.recentLatencies))
+ copy(sorted, m.recentLatencies)
+ sortDurations(sorted)
+
+ snapshot.P50Latency = percentile(sorted, 0.50)
+ snapshot.P95Latency = percentile(sorted, 0.95)
+ snapshot.P99Latency = percentile(sorted, 0.99)
+ }
+
+ // Calculate cache hit rate
+ totalCacheOps := snapshot.CacheHits + snapshot.CacheMisses
+ if totalCacheOps > 0 {
+ snapshot.CacheHitRate = float64(snapshot.CacheHits) / float64(totalCacheOps)
+ }
+
+ // Calculate storage savings
+ if snapshot.TotalDocuments > 0 {
+ embeddingSize := 384 * 4 // 384 dims × 4 bytes
+ fullStorage := snapshot.TotalDocuments * embeddingSize
+ actualStorage := snapshot.StoredEmbeddings * embeddingSize
+
+ if fullStorage > 0 {
+ snapshot.StorageSavingsPercent = (1.0 - float64(actualStorage)/float64(fullStorage)) * 100
+ }
+ }
+
+ // Calculate avg traversal depth
+ if snapshot.GraphTraversals > 0 {
+ snapshot.AvgTraversalDepth = float64(m.avgTraversalDepth.Load()) / float64(snapshot.GraphTraversals)
+ }
+
+ return snapshot
+}
+
+// MetricsSnapshot represents a point-in-time metrics snapshot
+type MetricsSnapshot struct {
+ // Query metrics
+ TotalQueries int64
+ HubOnlyQueries int64
+ HybridQueries int64
+ OnDemandQueries int64
+ GraphQueries int64
+
+ // Latency metrics
+ AvgLatency time.Duration
+ P50Latency time.Duration
+ P95Latency time.Duration
+ P99Latency time.Duration
+ AvgHubLatency time.Duration
+ AvgRecomputeLatency time.Duration
+
+ // Storage metrics
+ TotalDocuments int
+ HubDocuments int
+ StoredEmbeddings int
+ StorageSavingsPercent float64
+ RecomputedTotal int64
+
+ // Cache metrics
+ CacheHits int64
+ CacheMisses int64
+ CacheHitRate float64
+
+ // Graph metrics
+ GraphTraversals int64
+ AvgTraversalDepth float64
+
+ // Runtime
+ Uptime time.Duration
+}
+
+// sortDurations sorts a slice of durations in ascending order
+func sortDurations(durations []time.Duration) {
+ n := len(durations)
+ for i := 0; i < n-1; i++ {
+ for j := 0; j < n-i-1; j++ {
+ if durations[j] > durations[j+1] {
+ durations[j], durations[j+1] = durations[j+1], durations[j]
+ }
+ }
+ }
+}
+
+// percentile calculates the Nth percentile from a sorted slice
+func percentile(sorted []time.Duration, p float64) time.Duration {
+ if len(sorted) == 0 {
+ return 0
+ }
+
+ idx := int(float64(len(sorted)) * p)
+ if idx >= len(sorted) {
+ idx = len(sorted) - 1
+ }
+
+ return sorted[idx]
+}
+
+// String returns a human-readable representation of metrics
+func (s MetricsSnapshot) String() string {
+ return fmt.Sprintf(`Hybrid Vector Storage Metrics:
+ Queries:
+ Total: %d (Hub: %d, Hybrid: %d, OnDemand: %d, Graph: %d)
+ Avg Latency: %v (p50: %v, p95: %v, p99: %v)
+ Hub Latency: %v, Recompute Latency: %v
+ Storage:
+ Documents: %d (Hubs: %d, %.1f%%)
+ Stored Embeddings: %d
+ Savings: %.1f%%
+ Total Recomputed: %d
+ Cache:
+ Hits: %d, Misses: %d (Hit Rate: %.1f%%)
+ Graph:
+ Traversals: %d (Avg Depth: %.2f)
+ Runtime: %v`,
+ s.TotalQueries, s.HubOnlyQueries, s.HybridQueries, s.OnDemandQueries, s.GraphQueries,
+ s.AvgLatency, s.P50Latency, s.P95Latency, s.P99Latency,
+ s.AvgHubLatency, s.AvgRecomputeLatency,
+ s.TotalDocuments, s.HubDocuments, float64(s.HubDocuments)/float64(s.TotalDocuments)*100,
+ s.StoredEmbeddings,
+ s.StorageSavingsPercent,
+ s.RecomputedTotal,
+ s.CacheHits, s.CacheMisses, s.CacheHitRate*100,
+ s.GraphTraversals, s.AvgTraversalDepth,
+ s.Uptime,
+ )
+}
diff --git a/internal/vector/interface.go b/internal/vector/interface.go
new file mode 100644
index 0000000..59d9914
--- /dev/null
+++ b/internal/vector/interface.go
@@ -0,0 +1,42 @@
+// Package vector provides common interfaces for vector storage implementations
+package vector
+
+import (
+ "context"
+
+ "github.com/lukaszraczylo/claude-mnemonic/internal/vector/sqlitevec"
+)
+
+// Client defines the interface for vector storage operations.
+// Both sqlitevec.Client and hybrid.Client implement this interface.
+type Client interface {
+ // AddDocuments adds documents with their embeddings to the vector store
+ AddDocuments(ctx context.Context, docs []sqlitevec.Document) error
+
+ // DeleteDocuments removes documents by their IDs
+ DeleteDocuments(ctx context.Context, ids []string) error
+
+ // Query performs a vector similarity search
+ Query(ctx context.Context, query string, limit int, where map[string]any) ([]sqlitevec.QueryResult, error)
+
+ // IsConnected checks if the vector store is available
+ IsConnected() bool
+
+ // Close releases resources
+ Close() error
+
+ // Count returns the total number of vectors in the store
+ Count(ctx context.Context) (int64, error)
+
+ // ModelVersion returns the current embedding model version
+ ModelVersion() string
+
+ // NeedsRebuild checks if vectors need to be rebuilt due to model version change
+ NeedsRebuild(ctx context.Context) (bool, string)
+
+ // GetStaleVectors returns doc_ids of vectors with mismatched or null model versions
+ GetStaleVectors(ctx context.Context) ([]sqlitevec.StaleVectorInfo, error)
+
+ // DeleteVectorsByDocIDs removes vectors by their doc_ids
+ DeleteVectorsByDocIDs(ctx context.Context, docIDs []string) error
+}
diff --git a/internal/vector/sqlitevec/client.go b/internal/vector/sqlitevec/client.go
index df2e836..5dfb24e 100644
--- a/internal/vector/sqlitevec/client.go
+++ b/internal/vector/sqlitevec/client.go
@@ -319,11 +319,11 @@ func (c *Client) NeedsRebuild(ctx context.Context) (bool, string) {
// StaleVectorInfo contains information about a vector that needs rebuilding.
type StaleVectorInfo struct {
DocID string
- SQLiteID int64
DocType string
FieldType string
Project string
Scope string
+ SQLiteID int64
}
// GetStaleVectors returns doc_ids of vectors with mismatched or null model versions.
diff --git a/internal/vector/sqlitevec/helpers.go b/internal/vector/sqlitevec/helpers.go
index a104b51..363a079 100644
--- a/internal/vector/sqlitevec/helpers.go
+++ b/internal/vector/sqlitevec/helpers.go
@@ -12,17 +12,17 @@ const (
// Document represents a document to store with vector embedding.
type Document struct {
+ Metadata map[string]any
ID string
Content string
- Metadata map[string]any
}
// QueryResult represents a search result from vector search.
type QueryResult struct {
+ Metadata map[string]any
ID string
Distance float64
- Similarity float64 // 1.0 = identical, 0.0 = opposite (derived from distance)
- Metadata map[string]any
+ Similarity float64
}
// DistanceToSimilarity converts sqlite-vec cosine distance to similarity score.
diff --git a/internal/vector/sqlitevec/helpers_test.go b/internal/vector/sqlitevec/helpers_test.go
index 3624f00..e5b287d 100644
--- a/internal/vector/sqlitevec/helpers_test.go
+++ b/internal/vector/sqlitevec/helpers_test.go
@@ -42,10 +42,10 @@ func TestQueryResult_Fields(t *testing.T) {
func TestBuildWhereFilter(t *testing.T) {
tests := []struct {
+ expected map[string]interface{}
name string
docType DocType
project string
- expected map[string]interface{}
}{
{
name: "empty_filters",
@@ -474,9 +474,9 @@ func TestCopyMetadataMulti(t *testing.T) {
func TestJoinStrings(t *testing.T) {
tests := []struct {
name string
- strs []string
sep string
expected string
+ strs []string
}{
{
name: "empty_slice",
@@ -522,8 +522,8 @@ func TestTruncateString(t *testing.T) {
tests := []struct {
name string
input string
- maxLen int
expected string
+ maxLen int
}{
{
name: "shorter_than_max",
@@ -577,10 +577,10 @@ func TestFilterByThreshold(t *testing.T) {
tests := []struct {
name string
results []QueryResult
+ expectedIDs []string
threshold float64
maxResults int
expectedLen int
- expectedIDs []string
}{
{
name: "empty_results",
diff --git a/internal/watcher/watcher.go b/internal/watcher/watcher.go
index c0bb765..e95794a 100644
--- a/internal/watcher/watcher.go
+++ b/internal/watcher/watcher.go
@@ -16,15 +16,15 @@ import (
// Watcher monitors a file or directory for deletion and calls onDelete when removed.
// It watches the parent directory since fsnotify cannot watch non-existent files.
type Watcher struct {
- targetPath string // The file/directory to watch for deletion
- parentPath string // Parent directory (what we actually watch)
- onDelete func() // Callback when target is deleted
- watcher *fsnotify.Watcher
ctx context.Context
+ onDelete func()
+ watcher *fsnotify.Watcher
cancel context.CancelFunc
+ targetPath string
+ parentPath string
+ debounce time.Duration
mu sync.Mutex
running bool
- debounce time.Duration
}
// New creates a new Watcher for the given target path.
diff --git a/internal/worker/handlers.go b/internal/worker/handlers.go
index 230a56d..40593a5 100644
--- a/internal/worker/handlers.go
+++ b/internal/worker/handlers.go
@@ -158,10 +158,10 @@ type SessionInitRequest struct {
// SessionInitResponse is the response for session initialization.
type SessionInitResponse struct {
+ Reason string `json:"reason,omitempty"`
SessionDBID int64 `json:"sessionDbId"`
PromptNumber int `json:"promptNumber"`
Skipped bool `json:"skipped,omitempty"`
- Reason string `json:"reason,omitempty"`
}
// DuplicatePromptWindowSeconds is the time window for detecting duplicate prompt submissions.
@@ -1312,3 +1312,85 @@ func (s *Service) handleRestart(w http.ResponseWriter, r *http.Request) {
}
}()
}
+
+// handleGetGraphStats returns observation graph statistics.
+func (s *Service) handleGetGraphStats(w http.ResponseWriter, r *http.Request) {
+ if s.graphSearchClient == nil {
+ writeJSON(w, map[string]interface{}{
+ "enabled": false,
+ "message": "Graph search not enabled",
+ })
+ return
+ }
+
+ stats := s.graphSearchClient.GetGraphStats()
+
+ response := map[string]interface{}{
+ "enabled": s.config.GraphEnabled,
+ "nodeCount": stats.NodeCount,
+ "edgeCount": stats.EdgeCount,
+ "avgDegree": stats.AvgDegree,
+ "maxDegree": stats.MaxDegree,
+ "minDegree": stats.MinDegree,
+ "medianDegree": stats.MedianDegree,
+ "edgeTypes": stats.EdgeTypes,
+ "config": map[string]interface{}{
+ "maxHops": s.config.GraphMaxHops,
+ "branchFactor": s.config.GraphBranchFactor,
+ "edgeWeight": s.config.GraphEdgeWeight,
+ "rebuildIntervalMin": s.config.GraphRebuildIntervalMin,
+ },
+ }
+
+ writeJSON(w, response)
+}
+
+// handleGetVectorMetrics returns hybrid vector storage metrics.
+func (s *Service) handleGetVectorMetrics(w http.ResponseWriter, r *http.Request) {
+ if s.hybridMetrics == nil {
+ writeJSON(w, map[string]interface{}{
+ "enabled": false,
+ "message": "Vector metrics not available",
+ })
+ return
+ }
+
+ snapshot := s.hybridMetrics.GetSnapshot()
+
+ response := map[string]interface{}{
+ "queries": map[string]interface{}{
+ "total": snapshot.TotalQueries,
+ "hubOnly": snapshot.HubOnlyQueries,
+ "hybrid": snapshot.HybridQueries,
+ "onDemand": snapshot.OnDemandQueries,
+ "graph": snapshot.GraphQueries,
+ },
+ "latency": map[string]interface{}{
+ "avg": snapshot.AvgLatency.String(),
+ "p50": snapshot.P50Latency.String(),
+ "p95": snapshot.P95Latency.String(),
+ "p99": snapshot.P99Latency.String(),
+ "avgHub": snapshot.AvgHubLatency.String(),
+ "avgRecompute": snapshot.AvgRecomputeLatency.String(),
+ },
+ "storage": map[string]interface{}{
+ "totalDocuments": snapshot.TotalDocuments,
+ "hubDocuments": snapshot.HubDocuments,
+ "storedEmbeddings": snapshot.StoredEmbeddings,
+ "savingsPercent": snapshot.StorageSavingsPercent,
+ "recomputedTotal": snapshot.RecomputedTotal,
+ },
+ "cache": map[string]interface{}{
+ "hits": snapshot.CacheHits,
+ "misses": snapshot.CacheMisses,
+ "hitRate": snapshot.CacheHitRate,
+ },
+ "graph": map[string]interface{}{
+ "traversals": snapshot.GraphTraversals,
+ "avgDepth": snapshot.AvgTraversalDepth,
+ },
+ "uptime": snapshot.Uptime.String(),
+ }
+
+ writeJSON(w, response)
+}
diff --git a/internal/worker/sdk/parser_test.go b/internal/worker/sdk/parser_test.go
index e89b981..2fce3fe 100644
--- a/internal/worker/sdk/parser_test.go
+++ b/internal/worker/sdk/parser_test.go
@@ -77,10 +77,10 @@ func TestParseObservations_TableDriven(t *testing.T) {
tests := []struct {
name string
input string
- expectedCount int
expectedType models.ObservationType
expectedTitle string
checkConcepts []string
+ expectedCount int
}{
{
name: "valid_bugfix_observation",
@@ -300,9 +300,9 @@ func TestParseSummary_TableDriven(t *testing.T) {
tests := []struct {
name string
input string
+ expectedRequest string
sessionID int64
expectNil bool
- expectedRequest string
}{
{
name: "empty_input",
diff --git a/internal/worker/sdk/processor.go b/internal/worker/sdk/processor.go
index 17e51d8..94a3d86 100644
--- a/internal/worker/sdk/processor.go
+++ b/internal/worker/sdk/processor.go
@@ -31,15 +31,14 @@ type SyncSummaryFunc func(summary *models.SessionSummary)
// Processor handles SDK agent processing of observations and summaries using Claude Code CLI.
type Processor struct {
- claudePath string
- model string
observationStore *gorm.ObservationStore
summaryStore *gorm.SummaryStore
broadcastFunc BroadcastFunc
syncObservationFunc SyncObservationFunc
syncSummaryFunc SyncSummaryFunc
- // Semaphore to limit concurrent Claude CLI calls (prevents API overload)
- sem chan struct{}
+ sem chan struct{}
+ claudePath string
+ model string
}
// SetBroadcastFunc sets the broadcast callback for SSE events.
diff --git a/internal/worker/sdk/processor_test.go b/internal/worker/sdk/processor_test.go
index f18e6c5..811d5f4 100644
--- a/internal/worker/sdk/processor_test.go
+++ b/internal/worker/sdk/processor_test.go
@@ -11,8 +11,8 @@ import (
func TestIsSelfReferentialSummary(t *testing.T) {
tests := []struct {
- name string
summary *models.ParsedSummary
+ name string
expected bool
}{
{
@@ -281,8 +281,8 @@ func TestTruncateForLog(t *testing.T) {
tests := []struct {
name string
input string
- maxLen int
expected string
+ maxLen int
}{
{
name: "shorter_than_max",
@@ -719,8 +719,8 @@ func TestShouldSkipTrivialOperation_EdgeCases(t *testing.T) {
// TestIsSelfReferentialSummary_MoreCases tests additional self-referential detection cases.
func TestIsSelfReferentialSummary_MoreCases(t *testing.T) {
tests := []struct {
- name string
summary *models.ParsedSummary
+ name string
expected bool
}{
{
diff --git a/internal/worker/sdk/prompts.go b/internal/worker/sdk/prompts.go
index a200c0d..af5e865 100644
--- a/internal/worker/sdk/prompts.go
+++ b/internal/worker/sdk/prompts.go
@@ -24,12 +24,12 @@ var ObservationConcepts = []string{
// ToolExecution represents a tool execution for observation.
type ToolExecution struct {
- ID int64
ToolName string
ToolInput string
ToolOutput string
- CreatedAtEpoch int64
CWD string
+ ID int64
+ CreatedAtEpoch int64
}
// BuildObservationPrompt builds a prompt for processing a tool observation.
@@ -67,12 +67,12 @@ func BuildObservationPrompt(exec ToolExecution) string {
// SummaryRequest contains data for building a summary prompt.
type SummaryRequest struct {
- SessionDBID int64
SDKSessionID string
Project string
UserPrompt string
LastUserMessage string
LastAssistantMessage string
+ SessionDBID int64
}
// BuildSummaryPrompt builds a prompt requesting a session summary.
diff --git a/internal/worker/sdk/prompts_test.go b/internal/worker/sdk/prompts_test.go
index eecf7d5..054767d 100644
--- a/internal/worker/sdk/prompts_test.go
+++ b/internal/worker/sdk/prompts_test.go
@@ -12,8 +12,8 @@ func TestTruncate(t *testing.T) {
tests := []struct {
name string
input string
- maxLen int
expected string
+ maxLen int
}{
{
name: "shorter_than_max",
@@ -60,8 +60,8 @@ func TestBuildObservationPrompt(t *testing.T) {
tests := []struct {
name string
- exec ToolExecution
contains []string
+ exec ToolExecution
}{
{
name: "basic_read_tool",
diff --git a/internal/worker/service.go b/internal/worker/service.go
index a6180a3..8b74bfe 100644
--- a/internal/worker/service.go
+++ b/internal/worker/service.go
@@ -12,6 +12,10 @@ import (
"github.com/go-chi/chi/v5"
"github.com/go-chi/chi/v5/middleware"
+ "github.com/lukaszraczylo/claude-mnemonic/internal/chunking"
+ "github.com/lukaszraczylo/claude-mnemonic/internal/chunking/golang"
+ "github.com/lukaszraczylo/claude-mnemonic/internal/chunking/python"
+ "github.com/lukaszraczylo/claude-mnemonic/internal/chunking/typescript"
"github.com/lukaszraczylo/claude-mnemonic/internal/config"
"github.com/lukaszraczylo/claude-mnemonic/internal/db/gorm"
"github.com/lukaszraczylo/claude-mnemonic/internal/embedding"
@@ -20,6 +24,7 @@ import (
"github.com/lukaszraczylo/claude-mnemonic/internal/scoring"
"github.com/lukaszraczylo/claude-mnemonic/internal/search/expansion"
"github.com/lukaszraczylo/claude-mnemonic/internal/update"
+ "github.com/lukaszraczylo/claude-mnemonic/internal/vector/hybrid"
"github.com/lukaszraczylo/claude-mnemonic/internal/vector/sqlitevec"
"github.com/lukaszraczylo/claude-mnemonic/internal/watcher"
"github.com/lukaszraczylo/claude-mnemonic/internal/worker/sdk"
@@ -56,80 +61,53 @@ type RetrievalStats struct {
// Service is the main worker service orchestrator.
type Service struct {
- // Version of the worker binary
- version string
-
- // Configuration
- config *config.Config
-
- // Database
- store *gorm.Store
- sessionStore *gorm.SessionStore
- observationStore *gorm.ObservationStore
- summaryStore *gorm.SummaryStore
- promptStore *gorm.PromptStore
- conflictStore *gorm.ConflictStore
- patternStore *gorm.PatternStore
- relationStore *gorm.RelationStore
-
- // Pattern detection
- patternDetector *pattern.Detector
-
- // Domain services
- sessionManager *session.Manager
- sseBroadcaster *sse.Broadcaster
- processor *sdk.Processor
-
- // Vector database (sqlite-vec with local embeddings)
- embedSvc *embedding.Service
- vectorClient *sqlitevec.Client
- vectorSync *sqlitevec.Sync
-
- // Cross-encoder reranking (for improved search relevance)
- reranker *reranking.Service
-
- // Query expansion (for improved search recall)
- queryExpander *expansion.Expander
-
- // Importance scoring
- scoreCalculator *scoring.Calculator
- recalculator *scoring.Recalculator
-
- // HTTP server
- router *chi.Mux
- server *http.Server
- startTime time.Time
-
- // Retrieval statistics (per-project)
- retrievalStats map[string]*RetrievalStats
- retrievalStatsMu sync.RWMutex
-
- // Lifecycle
- ctx context.Context
- cancel context.CancelFunc
- wg sync.WaitGroup
-
- // Initialization state (for deferred init)
- ready atomic.Bool
- initError error
- initMu sync.RWMutex
-
- // Background verification queue for stale observations
- staleQueue chan staleVerifyRequest
- staleQueueOnce sync.Once
-
- // File watchers for auto-recreation on deletion
- dbWatcher *watcher.Watcher
- configWatcher *watcher.Watcher
-
- // Self-updater
- updater *update.Updater
+ startTime time.Time
+ initError error
+ ctx context.Context
+ patternDetector *pattern.Detector
+ queryExpander *expansion.Expander
+ summaryStore *gorm.SummaryStore
+ promptStore *gorm.PromptStore
+ conflictStore *gorm.ConflictStore
+ patternStore *gorm.PatternStore
+ relationStore *gorm.RelationStore
+ updater *update.Updater
+ sessionManager *session.Manager
+ scoreCalculator *scoring.Calculator
+ processor *sdk.Processor
+ embedSvc *embedding.Service
+ vectorClient *sqlitevec.Client
+ vectorSync *sqlitevec.Sync
+ graphSearchClient *hybrid.GraphSearchClient
+ hybridMetrics *hybrid.Metrics
+ graphRebuildTicker *time.Ticker
+ chunkingManager *chunking.Manager
+ observationStore *gorm.ObservationStore
+ reranker *reranking.Service
+ sseBroadcaster *sse.Broadcaster
+ recalculator *scoring.Recalculator
+ router *chi.Mux
+ server *http.Server
+ sessionStore *gorm.SessionStore
+ retrievalStats map[string]*RetrievalStats
+ configWatcher *watcher.Watcher
+ store *gorm.Store
+ cancel context.CancelFunc
+ dbWatcher *watcher.Watcher
+ staleQueue chan staleVerifyRequest
+ config *config.Config
+ version string
+ wg sync.WaitGroup
+ initMu sync.RWMutex
+ retrievalStatsMu sync.RWMutex
+ staleQueueOnce sync.Once
+ ready atomic.Bool
}
// staleVerifyRequest represents a request to verify a stale observation in background
type staleVerifyRequest struct {
- observationID int64
cwd string
+ observationID int64
}
// NewService creates a new worker service with deferred initialization.
@@ -210,6 +188,9 @@ func (s *Service) initializeAsync() {
var embedSvc *embedding.Service
var vectorClient *sqlitevec.Client
var vectorSync *sqlitevec.Sync
+ var graphSearchClient *hybrid.GraphSearchClient
+ var hybridMetrics *hybrid.Metrics
+ var chunkingManager *chunking.Manager
var reranker *reranking.Service
@@ -218,18 +199,51 @@ func (s *Service) initializeAsync() {
log.Warn().Err(err).Msg("Embedding service creation failed - vector search disabled")
} else {
embedSvc = emb
- // Create sqlite-vec client using the same DB connection
- client, err := sqlitevec.NewClient(sqlitevec.Config{
+ // Create base sqlite-vec client using the same DB connection
+ baseClient, err := sqlitevec.NewClient(sqlitevec.Config{
DB: store.GetRawDB(),
}, embedSvc)
if err != nil {
log.Warn().Err(err).Msg("sqlite-vec client creation failed - vector search disabled")
} else {
- vectorClient = client
- vectorSync = sqlitevec.NewSync(client)
+ vectorClient = baseClient
+
+ // Wrap with LEANN hybrid storage client
+ strategy := hybrid.ParseStrategy(s.config.VectorStorageStrategy)
+ hybridClient := hybrid.NewClient(hybrid.Config{
+ BaseClient: baseClient,
+ DB: store.GetRawDB(),
+ EmbedSvc: embedSvc,
+ Strategy: strategy,
+ HubThreshold: s.config.HubThreshold,
+ })
+
+ // Wrap with graph-aware search client
+ graphConfig := hybrid.GraphConfig{
+ Enabled: s.config.GraphEnabled,
+ MaxHops: s.config.GraphMaxHops,
+ BranchFactor: s.config.GraphBranchFactor,
+ EdgeWeight: s.config.GraphEdgeWeight,
+ }
+ graphSearchClient = hybrid.NewGraphSearchClient(hybridClient, nil, graphConfig)
+ hybridMetrics = hybrid.NewMetrics()
+
+ vectorSync = sqlitevec.NewSync(baseClient)
+
+ // Initialize AST-aware code chunking
+ chunkOpts := chunking.DefaultChunkOptions()
+ chunkers := []chunking.Chunker{
+ golang.NewChunker(chunkOpts),
+ python.NewChunker(chunkOpts),
+ typescript.NewChunker(chunkOpts),
+ }
+ chunkingManager = chunking.NewManager(chunkers, chunkOpts)
+
log.Info().
Str("model", embedSvc.Version()).
- Msg("sqlite-vec vector search enabled")
+ Str("storage_strategy", s.config.VectorStorageStrategy).
+ Bool("graph_enabled", s.config.GraphEnabled).
+ Msg("LEANN hybrid vector storage and graph search enabled")
}
// Create cross-encoder reranking service if enabled
@@ -284,6 +298,9 @@ func (s *Service) initializeAsync() {
s.embedSvc = embedSvc
s.vectorClient = vectorClient
s.vectorSync = vectorSync
+ s.graphSearchClient = graphSearchClient
+ s.hybridMetrics = hybridMetrics
+ s.chunkingManager = chunkingManager
s.reranker = reranker
s.initMu.Unlock()
@@ -411,6 +428,18 @@ func (s *Service) initializeAsync() {
s.ready.Store(true)
log.Info().Msg("Async initialization complete - service ready")
+ // Build initial observation graph if graph search is enabled
+ if graphSearchClient != nil && s.config.GraphEnabled {
+ s.wg.Add(1)
+ go s.buildInitialGraph(observationStore)
+
+ // Start periodic graph rebuild timer
+ if s.config.GraphRebuildIntervalMin > 0 {
+ s.wg.Add(1)
+ go s.startGraphRebuildTimer(observationStore)
+ }
+ }
+
// Start queue processor if SDK processor is available
if processor != nil {
s.wg.Add(1)
@@ -1136,6 +1165,10 @@ func (s *Service) setupRoutes() {
r.Get("/api/observations/{id}/relations", s.handleGetRelations)
r.Get("/api/observations/{id}/graph", s.handleGetRelationGraph)
r.Get("/api/observations/{id}/related", s.handleGetRelatedObservations)
+
+ // LEANN Phase 2: Graph-based search and hybrid vector storage
+ r.Get("/api/graph/stats", s.handleGetGraphStats)
+ r.Get("/api/vector/metrics", s.handleGetVectorMetrics)
})
}
@@ -1346,6 +1379,87 @@ func (s *Service) processAllSessions() {
s.broadcastProcessingStatus()
}
+// buildInitialGraph builds the observation relationship graph in the background.
+func (s *Service) buildInitialGraph(observationStore *gorm.ObservationStore) {
+ defer s.wg.Done()
+
+ log.Info().Msg("Building initial observation graph...")
+ start := time.Now()
+
+ // Fetch all observations
+ observations, err := observationStore.GetAllObservations(s.ctx)
+ if err != nil {
+ log.Error().Err(err).Msg("Failed to fetch observations for graph building")
+ return
+ }
+
+ if len(observations) == 0 {
+ log.Info().Msg("No observations to build graph from")
+ return
+ }
+
+ // Build graph using RebuildGraph method
+ if err := s.graphSearchClient.RebuildGraph(s.ctx, observations); err != nil {
+ log.Error().Err(err).Msg("Failed to build observation graph")
+ return
+ }
+
+ elapsed := time.Since(start)
+ stats := s.graphSearchClient.GetGraphStats()
+
+ log.Info().
+ Int("observations", len(observations)).
+ Int("nodes", stats.NodeCount).
+ Int("edges", stats.EdgeCount).
+ Float64("avg_degree", stats.AvgDegree).
+ Int("max_degree", stats.MaxDegree).
+ Dur("elapsed", elapsed).
+ Msg("Initial observation graph built successfully")
+}
+
+// startGraphRebuildTimer starts a periodic ticker to rebuild the observation graph.
+func (s *Service) startGraphRebuildTimer(observationStore *gorm.ObservationStore) {
+ defer s.wg.Done()
+
+ interval := time.Duration(s.config.GraphRebuildIntervalMin) * time.Minute
+ s.graphRebuildTicker = time.NewTicker(interval)
+
+ log.Info().
+ Dur("interval", interval).
+ Msg("Started periodic graph rebuild timer")
+
+ for {
+ select {
+ case <-s.ctx.Done():
+ s.graphRebuildTicker.Stop()
+ log.Info().Msg("Stopped graph rebuild timer")
+ return
+
+ case <-s.graphRebuildTicker.C:
+ log.Info().Msg("Periodic graph rebuild triggered")
+ start := time.Now()
+
+ observations, err := observationStore.GetAllObservations(s.ctx)
+ if err != nil {
+ log.Error().Err(err).Msg("Failed to fetch observations for graph rebuild")
+ continue
+ }
+
+ if err := s.graphSearchClient.RebuildGraph(s.ctx, observations); err != nil {
+ log.Error().Err(err).Msg("Failed to rebuild observation graph")
+ continue
+ }
+
+ stats := s.graphSearchClient.GetGraphStats()
+ log.Info().
+ Int("nodes", stats.NodeCount).
+ Int("edges", stats.EdgeCount).
+ Dur("elapsed", time.Since(start)).
+ Msg("Periodic graph rebuild complete")
+ }
+ }
+}
+
// Shutdown gracefully shuts down the service.
func (s *Service) Shutdown(ctx context.Context) error {
s.cancel()
diff --git a/internal/worker/session/manager.go b/internal/worker/session/manager.go
index afe7973..b2910c3 100644
--- a/internal/worker/session/manager.go
+++ b/internal/worker/session/manager.go
@@ -21,11 +21,11 @@ const (
// ObservationData contains data for a tool observation.
type ObservationData struct {
- ToolName string
ToolInput interface{}
ToolResponse interface{}
- PromptNumber int
+ ToolName string
CWD string
+ PromptNumber int
}
// SummarizeData contains data for a summarize request.
@@ -36,30 +36,28 @@ type SummarizeData struct {
// PendingMessage represents a message queued for SDK processing.
type PendingMessage struct {
- Type MessageType
Observation *ObservationData
Summarize *SummarizeData
+ Type MessageType
}
// ActiveSession represents an in-memory active session being processed.
type ActiveSession struct {
- SessionDBID int64
- ClaudeSessionID string
- SDKSessionID string
+ StartTime time.Time
+ ctx context.Context
+ cancel context.CancelFunc
+ notify chan struct{}
Project string
UserPrompt string
+ SDKSessionID string
+ ClaudeSessionID string
+ pendingMessages []PendingMessage
LastPromptNumber int
- StartTime time.Time
CumulativeInputTokens int64
CumulativeOutputTokens int64
-
- // Concurrency control
- pendingMessages []PendingMessage
- messageMu sync.Mutex
- notify chan struct{}
- ctx context.Context
- cancel context.CancelFunc
- generatorActive atomic.Bool
+ SessionDBID int64
+ messageMu sync.Mutex
+ generatorActive atomic.Bool
}
// SessionTimeout is how long an inactive session can exist before cleanup.
@@ -70,15 +68,14 @@ const CleanupInterval = 5 * time.Minute
// Manager manages active session lifecycles.
type Manager struct {
- sessionStore *gorm.SessionStore
- sessions map[int64]*ActiveSession
- mu sync.RWMutex
- onCreated func(int64)
- onDeleted func(int64)
- ctx context.Context
- cancel context.CancelFunc
- // Global notification channel for immediate processing
+ ctx context.Context
+ sessionStore *gorm.SessionStore
+ sessions map[int64]*ActiveSession
+ onCreated func(int64)
+ onDeleted func(int64)
+ cancel context.CancelFunc
ProcessNotify chan struct{}
+ mu sync.RWMutex
}
// NewManager creates a new session manager.
diff --git a/internal/worker/session/manager_test.go b/internal/worker/session/manager_test.go
index 6e7f654..1a181e4 100644
--- a/internal/worker/session/manager_test.go
+++ b/internal/worker/session/manager_test.go
@@ -669,16 +669,16 @@ func TestActiveSessionCWD(t *testing.T) {
// TestToolInputResponse tests various tool input/response types.
func TestToolInputResponse(t *testing.T) {
tests := []struct {
- name string
input interface{}
response interface{}
+ name string
}{
- {"nil_values", nil, nil},
- {"string_values", "input string", "response string"},
- {"map_values", map[string]string{"key": "value"}, map[string]interface{}{"result": true}},
- {"slice_values", []string{"a", "b"}, []int{1, 2, 3}},
- {"int_values", 42, 100},
- {"bool_values", true, false},
+ {name: "nil_values", input: nil, response: nil},
+ {name: "string_values", input: "input string", response: "response string"},
+ {name: "map_values", input: map[string]string{"key": "value"}, response: map[string]interface{}{"result": true}},
+ {name: "slice_values", input: []string{"a", "b"}, response: []int{1, 2, 3}},
+ {name: "int_values", input: 42, response: 100},
+ {name: "bool_values", input: true, response: false},
}
for _, tt := range tests {
diff --git a/internal/worker/sse/broadcaster.go b/internal/worker/sse/broadcaster.go
index b6e8d96..380cfba 100644
--- a/internal/worker/sse/broadcaster.go
+++ b/internal/worker/sse/broadcaster.go
@@ -19,10 +19,10 @@ const (
// Client represents a connected SSE client.
type Client struct {
- ID string
Writer http.ResponseWriter
Flusher http.Flusher
Done chan struct{}
+ ID string
}
// Broadcaster manages SSE client connections and message broadcasting.
diff --git a/internal/worker/sse/broadcaster_test.go b/internal/worker/sse/broadcaster_test.go
index 45776f2..e1504d5 100644
--- a/internal/worker/sse/broadcaster_test.go
+++ b/internal/worker/sse/broadcaster_test.go
@@ -256,8 +256,8 @@ func TestHandleSSE(t *testing.T) {
// TestBroadcastJSON tests broadcasting various JSON types.
func TestBroadcastJSON(t *testing.T) {
tests := []struct {
- name string
data interface{}
+ name string
wantErr bool
}{
{
diff --git a/pkg/hooks/response.go b/pkg/hooks/response.go
index cf9ae38..af98b99 100644
--- a/pkg/hooks/response.go
+++ b/pkg/hooks/response.go
@@ -62,11 +62,11 @@ type BaseInput struct {
// HookContext provides common context for hook handlers.
type HookContext struct {
HookName string
- Port int
Project string
SessionID string
CWD string
RawInput []byte
+ Port int
}
// HookHandler is a function that handles hook-specific logic.
diff --git a/pkg/hooks/worker_test.go b/pkg/hooks/worker_test.go
index fe44e78..b1a48c7 100644
--- a/pkg/hooks/worker_test.go
+++ b/pkg/hooks/worker_test.go
@@ -320,11 +320,11 @@ func TestExtractBaseVersion(t *testing.T) {
// TestPOST tests the POST function with a mock server.
func TestPOST(t *testing.T) {
tests := []struct {
- name string
- serverHandler func(w http.ResponseWriter, r *http.Request)
body interface{}
- expectError bool
+ serverHandler func(w http.ResponseWriter, r *http.Request)
expectedResult map[string]interface{}
+ name string
+ expectError bool
}{
{
name: "successful POST with JSON response",
@@ -393,10 +393,10 @@ func TestPOST(t *testing.T) {
// TestGET tests the GET function with a mock server.
func TestGET(t *testing.T) {
tests := []struct {
- name string
serverHandler func(w http.ResponseWriter, r *http.Request)
- expectError bool
expectedResult map[string]interface{}
+ name string
+ expectError bool
}{
{
name: "successful GET with JSON response",
@@ -532,8 +532,8 @@ func TestExitCodes(t *testing.T) {
func TestHookResponse(t *testing.T) {
tests := []struct {
name string
- response HookResponse
expected string
+ response HookResponse
}{
{
name: "continue true",
@@ -597,8 +597,8 @@ func TestHookContext(t *testing.T) {
// TestIsWorkerRunning_WithServer tests IsWorkerRunning with actual server.
func TestIsWorkerRunning_WithServer(t *testing.T) {
tests := []struct {
- name string
serverHandler func(w http.ResponseWriter, r *http.Request)
+ name string
expectedResult bool
}{
{
@@ -828,8 +828,8 @@ func TestBaseInput_PartialFields(t *testing.T) {
func TestHookResponse_Marshal(t *testing.T) {
tests := []struct {
name string
- response HookResponse
contains []string
+ response HookResponse
}{
{
name: "continue true",
diff --git a/pkg/models/conflict.go b/pkg/models/conflict.go
index c3ed126..c6fef8e 100644
--- a/pkg/models/conflict.go
+++ b/pkg/models/conflict.go
@@ -33,25 +33,25 @@ const (
// ObservationConflict tracks conflicting observations.
type ObservationConflict struct {
- ID int64 `db:"id" json:"id"`
- NewerObsID int64 `db:"newer_obs_id" json:"newer_obs_id"`
- OlderObsID int64 `db:"older_obs_id" json:"older_obs_id"`
+ ResolvedAt *string `db:"resolved_at" json:"resolved_at,omitempty"`
ConflictType ConflictType `db:"conflict_type" json:"conflict_type"`
Resolution ConflictResolution `db:"resolution" json:"resolution"`
Reason string `db:"reason" json:"reason"`
DetectedAt string `db:"detected_at" json:"detected_at"`
+ ID int64 `db:"id" json:"id"`
+ NewerObsID int64 `db:"newer_obs_id" json:"newer_obs_id"`
+ OlderObsID int64 `db:"older_obs_id" json:"older_obs_id"`
DetectedAtEpoch int64 `db:"detected_at_epoch" json:"detected_at_epoch"`
Resolved bool `db:"resolved" json:"resolved"`
- ResolvedAt *string `db:"resolved_at" json:"resolved_at,omitempty"`
}
// ConflictDetectionResult contains the result of conflict detection.
type ConflictDetectionResult struct {
- HasConflict bool
Type ConflictType
Resolution ConflictResolution
Reason string
- OlderObsIDs []int64 // IDs of observations that conflict with the new one
+ OlderObsIDs []int64
+ HasConflict bool
}
// NewObservationConflict creates a new conflict record.
diff --git a/pkg/models/conflict_test.go b/pkg/models/conflict_test.go
index ed14d21..1d0ba78 100644
--- a/pkg/models/conflict_test.go
+++ b/pkg/models/conflict_test.go
@@ -51,8 +51,8 @@ func (s *ConflictSuite) TestDetectExplicitCorrection_TableDriven() {
tests := []struct {
name string
text string
- expectMatch bool
expectPattern string
+ expectMatch bool
}{
{
name: "actually that was wrong",
@@ -128,9 +128,9 @@ func (s *ConflictSuite) TestDetectExplicitCorrection_TableDriven() {
// TestDetectOpposingFileChanges_TableDriven tests opposing file change detection.
func (s *ConflictSuite) TestDetectOpposingFileChanges_TableDriven() {
tests := []struct {
- name string
newerObs *Observation
olderObs *Observation
+ name string
expectConflict bool
}{
{
@@ -202,9 +202,9 @@ func (s *ConflictSuite) TestDetectOpposingFileChanges_TableDriven() {
// TestDetectConceptTagMismatch_TableDriven tests concept tag mismatch detection.
func (s *ConflictSuite) TestDetectConceptTagMismatch_TableDriven() {
tests := []struct {
- name string
newerObs *Observation
olderObs *Observation
+ name string
expectConflict bool
}{
{
diff --git a/pkg/models/observation.go b/pkg/models/observation.go
index abf1ff9..7417fc0 100644
--- a/pkg/models/observation.go
+++ b/pkg/models/observation.go
@@ -121,48 +121,44 @@ func (j JSONInt64Map) Value() (driver.Value, error) {
// Observation represents a learning extracted from a Claude Code session.
type Observation struct {
- ID int64 `db:"id" json:"id"`
+ FileMtimes JSONInt64Map `db:"file_mtimes" json:"file_mtimes,omitempty"`
SDKSessionID string `db:"sdk_session_id" json:"sdk_session_id"`
Project string `db:"project" json:"project"`
Scope ObservationScope `db:"scope" json:"scope"`
Type ObservationType `db:"type" json:"type"`
- Title sql.NullString `db:"title" json:"title,omitempty"`
+ CreatedAt string `db:"created_at" json:"created_at"`
Subtitle sql.NullString `db:"subtitle" json:"subtitle,omitempty"`
- Facts JSONStringArray `db:"facts" json:"facts,omitempty"`
+ Title sql.NullString `db:"title" json:"title,omitempty"`
Narrative sql.NullString `db:"narrative" json:"narrative,omitempty"`
Concepts JSONStringArray `db:"concepts" json:"concepts,omitempty"`
FilesRead JSONStringArray `db:"files_read" json:"files_read,omitempty"`
FilesModified JSONStringArray `db:"files_modified" json:"files_modified,omitempty"`
- FileMtimes JSONInt64Map `db:"file_mtimes" json:"file_mtimes,omitempty"`
+ Facts JSONStringArray `db:"facts" json:"facts,omitempty"`
PromptNumber sql.NullInt64 `db:"prompt_number" json:"prompt_number,omitempty"`
+ LastRetrievedAt sql.NullInt64 `db:"last_retrieved_at_epoch" json:"last_retrieved_at_epoch,omitempty"`
+ ScoreUpdatedAt sql.NullInt64 `db:"score_updated_at_epoch" json:"score_updated_at_epoch,omitempty"`
DiscoveryTokens int64 `db:"discovery_tokens" json:"discovery_tokens"`
- CreatedAt string `db:"created_at" json:"created_at"`
+ ID int64 `db:"id" json:"id"`
CreatedAtEpoch int64 `db:"created_at_epoch" json:"created_at_epoch"`
+ ImportanceScore float64 `db:"importance_score" json:"importance_score"`
+ UserFeedback int `db:"user_feedback" json:"user_feedback"`
+ RetrievalCount int `db:"retrieval_count" json:"retrieval_count"`
IsStale bool `db:"-" json:"is_stale,omitempty"`
-
- // Importance scoring fields
- ImportanceScore float64 `db:"importance_score" json:"importance_score"`
- UserFeedback int `db:"user_feedback" json:"user_feedback"`
- RetrievalCount int `db:"retrieval_count" json:"retrieval_count"`
- LastRetrievedAt sql.NullInt64 `db:"last_retrieved_at_epoch" json:"last_retrieved_at_epoch,omitempty"`
- ScoreUpdatedAt sql.NullInt64 `db:"score_updated_at_epoch" json:"score_updated_at_epoch,omitempty"`
-
- // Conflict detection fields
- IsSuperseded bool `db:"is_superseded" json:"is_superseded,omitempty"`
+ IsSuperseded bool `db:"is_superseded" json:"is_superseded,omitempty"`
}
// ParsedObservation represents an observation parsed from SDK response XML.
type ParsedObservation struct {
+ FileMtimes map[string]int64
Type ObservationType
Title string
Subtitle string
- Facts []string
Narrative string
+ Scope ObservationScope
+ Facts []string
Concepts []string
FilesRead []string
FilesModified []string
- FileMtimes map[string]int64 // File path -> mtime epoch ms
- Scope ObservationScope // Optional: if empty, will be auto-determined
}
// ToStoredObservation converts a ParsedObservation to the stored Observation format.
@@ -197,34 +193,30 @@ func DetermineScope(concepts []string) ObservationScope {
// ObservationJSON is a JSON-friendly representation of Observation.
// It converts sql.NullString to plain strings for clean JSON output.
type ObservationJSON struct {
- ID int64 `json:"id"`
+ FileMtimes map[string]int64 `json:"file_mtimes,omitempty"`
+ Subtitle string `json:"subtitle,omitempty"`
SDKSessionID string `json:"sdk_session_id"`
- Project string `json:"project"`
Scope ObservationScope `json:"scope"`
Type ObservationType `json:"type"`
Title string `json:"title,omitempty"`
- Subtitle string `json:"subtitle,omitempty"`
- Facts []string `json:"facts,omitempty"`
+ CreatedAt string `json:"created_at"`
Narrative string `json:"narrative,omitempty"`
+ Project string `json:"project"`
Concepts []string `json:"concepts,omitempty"`
+ Facts []string `json:"facts,omitempty"`
FilesRead []string `json:"files_read,omitempty"`
FilesModified []string `json:"files_modified,omitempty"`
- FileMtimes map[string]int64 `json:"file_mtimes,omitempty"`
- PromptNumber int64 `json:"prompt_number,omitempty"`
- DiscoveryTokens int64 `json:"discovery_tokens"`
- CreatedAt string `json:"created_at"`
CreatedAtEpoch int64 `json:"created_at_epoch"`
+ DiscoveryTokens int64 `json:"discovery_tokens"`
+ ID int64 `json:"id"`
+ PromptNumber int64 `json:"prompt_number,omitempty"`
+ ImportanceScore float64 `json:"importance_score"`
+ UserFeedback int `json:"user_feedback"`
+ RetrievalCount int `json:"retrieval_count"`
+ LastRetrievedAt int64 `json:"last_retrieved_at_epoch,omitempty"`
+ ScoreUpdatedAt int64 `json:"score_updated_at_epoch,omitempty"`
IsStale bool `json:"is_stale,omitempty"`
-
- // Importance scoring fields
- ImportanceScore float64 `json:"importance_score"`
- UserFeedback int `json:"user_feedback"`
- RetrievalCount int `json:"retrieval_count"`
- LastRetrievedAt int64 `json:"last_retrieved_at_epoch,omitempty"`
- ScoreUpdatedAt int64 `json:"score_updated_at_epoch,omitempty"`
-
- // Conflict detection fields
- IsSuperseded bool `json:"is_superseded,omitempty"`
+ IsSuperseded bool `json:"is_superseded,omitempty"`
}
// MarshalJSON implements json.Marshaler for Observation.
diff --git a/pkg/models/observation_test.go b/pkg/models/observation_test.go
index 50ce357..8c4f11b 100644
--- a/pkg/models/observation_test.go
+++ b/pkg/models/observation_test.go
@@ -50,8 +50,8 @@ func (s *ObservationSuite) TestGlobalizableConcepts() {
func (s *ObservationSuite) TestDetermineScope_TableDriven() {
tests := []struct {
name string
- concepts []string
expected ObservationScope
+ concepts []string
}{
{
name: "empty concepts - project scope",
@@ -121,9 +121,9 @@ func (s *ObservationSuite) TestParsedObservation_FileMtimesJSON() {
// TestObservation_CheckStaleness_TableDriven tests staleness checking.
func (s *ObservationSuite) TestObservation_CheckStaleness_TableDriven() {
tests := []struct {
- name string
storedMtimes map[string]int64
currentMtimes map[string]int64
+ name string
expectedStale bool
}{
{
@@ -300,10 +300,10 @@ func TestParsedObservation_ToStoredObservation(t *testing.T) {
// TestJSONStringArray tests JSONStringArray scanning.
func TestJSONStringArray(t *testing.T) {
tests := []struct {
- name string
input interface{}
- wantErr bool
+ name string
expected JSONStringArray
+ wantErr bool
}{
{
name: "nil input",
@@ -348,10 +348,10 @@ func TestJSONStringArray(t *testing.T) {
// TestJSONInt64Map tests JSONInt64Map scanning.
func TestJSONInt64Map(t *testing.T) {
tests := []struct {
- name string
input interface{}
- wantErr bool
expected JSONInt64Map
+ name string
+ wantErr bool
}{
{
name: "nil input",
diff --git a/pkg/models/pattern.go b/pkg/models/pattern.go
index 03d3d6f..748f299 100644
--- a/pkg/models/pattern.go
+++ b/pkg/models/pattern.go
@@ -39,21 +39,21 @@ const (
// Pattern represents a recurring pattern detected across observations.
// This enables Claude to reference historical insights: "I've encountered this pattern 12 times."
type Pattern struct {
- ID int64 `db:"id" json:"id"`
- Name string `db:"name" json:"name"` // e.g., "State Management Anti-Pattern"
- Type PatternType `db:"type" json:"type"` // bug, refactor, architecture, etc.
- Description sql.NullString `db:"description" json:"description"` // Detailed description
- Signature JSONStringArray `db:"signature" json:"signature"` // Keyword clusters for detection
- Recommendation sql.NullString `db:"recommendation" json:"recommendation"` // What works for this pattern
- Frequency int `db:"frequency" json:"frequency"` // How many times encountered
- Projects JSONStringArray `db:"projects" json:"projects"` // Projects where this pattern was seen
- ObservationIDs JSONInt64Array `db:"observation_ids" json:"observation_ids"` // Source observation IDs
- Status PatternStatus `db:"status" json:"status"` // active, deprecated, merged
- MergedIntoID sql.NullInt64 `db:"merged_into_id" json:"merged_into_id,omitempty"`
- Confidence float64 `db:"confidence" json:"confidence"` // Detection confidence (0.0-1.0)
- LastSeenAt string `db:"last_seen_at" json:"last_seen_at"` // Last time pattern was detected
- LastSeenEpoch int64 `db:"last_seen_at_epoch" json:"last_seen_at_epoch"`
+ Status PatternStatus `db:"status" json:"status"`
+ Name string `db:"name" json:"name"`
+ Type PatternType `db:"type" json:"type"`
CreatedAt string `db:"created_at" json:"created_at"`
+ LastSeenAt string `db:"last_seen_at" json:"last_seen_at"`
+ Signature JSONStringArray `db:"signature" json:"signature"`
+ Projects JSONStringArray `db:"projects" json:"projects"`
+ ObservationIDs JSONInt64Array `db:"observation_ids" json:"observation_ids"`
+ Recommendation sql.NullString `db:"recommendation" json:"recommendation"`
+ Description sql.NullString `db:"description" json:"description"`
+ MergedIntoID sql.NullInt64 `db:"merged_into_id" json:"merged_into_id,omitempty"`
+ Frequency int `db:"frequency" json:"frequency"`
+ Confidence float64 `db:"confidence" json:"confidence"`
+ ID int64 `db:"id" json:"id"`
+ LastSeenEpoch int64 `db:"last_seen_at_epoch" json:"last_seen_at_epoch"`
CreatedAtEpoch int64 `db:"created_at_epoch" json:"created_at_epoch"`
}
@@ -95,21 +95,21 @@ func (j JSONInt64Array) Value() (driver.Value, error) {
// PatternJSON is a JSON-friendly representation of Pattern.
type PatternJSON struct {
- ID int64 `json:"id"`
+ Status PatternStatus `json:"status"`
Name string `json:"name"`
Type PatternType `json:"type"`
Description string `json:"description,omitempty"`
- Signature []string `json:"signature,omitempty"`
+ CreatedAt string `json:"created_at"`
Recommendation string `json:"recommendation,omitempty"`
- Frequency int `json:"frequency"`
- Projects []string `json:"projects,omitempty"`
+ LastSeenAt string `json:"last_seen_at"`
+ Signature []string `json:"signature,omitempty"`
ObservationIDs []int64 `json:"observation_ids,omitempty"`
- Status PatternStatus `json:"status"`
+ Projects []string `json:"projects,omitempty"`
MergedIntoID int64 `json:"merged_into_id,omitempty"`
Confidence float64 `json:"confidence"`
- LastSeenAt string `json:"last_seen_at"`
+ Frequency int `json:"frequency"`
LastSeenEpoch int64 `json:"last_seen_at_epoch"`
- CreatedAt string `json:"created_at"`
+ ID int64 `json:"id"`
CreatedAtEpoch int64 `json:"created_at_epoch"`
}
@@ -214,11 +214,11 @@ func (p *Pattern) updateConfidence() {
// PatternMatch represents a match between an observation and a potential pattern.
type PatternMatch struct {
- PatternID int64 `json:"pattern_id"`
- Score float64 `json:"score"` // Match score (0.0-1.0)
- MatchedOn string `json:"matched_on"` // What triggered the match (concept, keyword, type, etc.)
- IsNew bool `json:"is_new"` // Whether this would create a new pattern
+ MatchedOn string `json:"matched_on"`
SuggestedName string `json:"suggested_name,omitempty"`
+ PatternID int64 `json:"pattern_id"`
+ Score float64 `json:"score"`
+ IsNew bool `json:"is_new"`
}
// PatternSignatureKeywords are common keywords used in pattern detection.
diff --git a/pkg/models/pattern_test.go b/pkg/models/pattern_test.go
index 2614f07..e6d16e0 100644
--- a/pkg/models/pattern_test.go
+++ b/pkg/models/pattern_test.go
@@ -116,18 +116,18 @@ func TestPattern_ConfidenceCalculation(t *testing.T) {
func TestPatternType_Detection(t *testing.T) {
tests := []struct {
- concepts []string
title string
narrative string
expected PatternType
+ concepts []string
}{
- {[]string{"anti-pattern"}, "", "", PatternTypeAntiPattern},
- {[]string{"best-practice"}, "", "", PatternTypeBestPractice},
- {[]string{"architecture"}, "", "", PatternTypeArchitecture},
- {[]string{"refactor"}, "", "", PatternTypeRefactor},
- {[]string{}, "nil pointer bug", "", PatternTypeBug},
- {[]string{}, "Deadlock in concurrent code", "", PatternTypeBug},
- {[]string{}, "Extract interface", "", PatternTypeRefactor},
+ {title: "", narrative: "", expected: PatternTypeAntiPattern, concepts: []string{"anti-pattern"}},
+ {title: "", narrative: "", expected: PatternTypeBestPractice, concepts: []string{"best-practice"}},
+ {title: "", narrative: "", expected: PatternTypeArchitecture, concepts: []string{"architecture"}},
+ {title: "", narrative: "", expected: PatternTypeRefactor, concepts: []string{"refactor"}},
+ {title: "nil pointer bug", narrative: "", expected: PatternTypeBug, concepts: []string{}},
+ {title: "Deadlock in concurrent code", narrative: "", expected: PatternTypeBug, concepts: []string{}},
+ {title: "Extract interface", narrative: "", expected: PatternTypeRefactor, concepts: []string{}},
}
for _, tt := range tests {
diff --git a/pkg/models/prompt.go b/pkg/models/prompt.go
index f0fce3a..af01b8d 100644
--- a/pkg/models/prompt.go
+++ b/pkg/models/prompt.go
@@ -3,18 +3,18 @@ package models
// UserPrompt represents a user prompt captured during a session.
type UserPrompt struct {
- ID int64 `db:"id" json:"id"`
ClaudeSessionID string `db:"claude_session_id" json:"claude_session_id"`
- PromptNumber int `db:"prompt_number" json:"prompt_number"`
PromptText string `db:"prompt_text" json:"prompt_text"`
- MatchedObservations int `db:"matched_observations" json:"matched_observations"`
CreatedAt string `db:"created_at" json:"created_at"`
+ ID int64 `db:"id" json:"id"`
+ PromptNumber int `db:"prompt_number" json:"prompt_number"`
+ MatchedObservations int `db:"matched_observations" json:"matched_observations"`
CreatedAtEpoch int64 `db:"created_at_epoch" json:"created_at_epoch"`
}
// UserPromptWithSession includes session context for search results.
type UserPromptWithSession struct {
- UserPrompt
Project string `db:"project" json:"project"`
SDKSessionID string `db:"sdk_session_id" json:"sdk_session_id"`
+ UserPrompt
}
diff --git a/pkg/models/relation.go b/pkg/models/relation.go
index a2cadc5..c21e982 100644
--- a/pkg/models/relation.go
+++ b/pkg/models/relation.go
@@ -60,14 +60,14 @@ const (
// ObservationRelation represents a directed relationship between two observations.
type ObservationRelation struct {
- ID int64 `db:"id" json:"id"`
- SourceID int64 `db:"source_id" json:"source_id"`
- TargetID int64 `db:"target_id" json:"target_id"`
RelationType RelationType `db:"relation_type" json:"relation_type"`
- Confidence float64 `db:"confidence" json:"confidence"`
DetectionSource RelationDetectionSource `db:"detection_source" json:"detection_source"`
Reason string `db:"reason" json:"reason,omitempty"`
CreatedAt string `db:"created_at" json:"created_at"`
+ ID int64 `db:"id" json:"id"`
+ SourceID int64 `db:"source_id" json:"source_id"`
+ TargetID int64 `db:"target_id" json:"target_id"`
+ Confidence float64 `db:"confidence" json:"confidence"`
CreatedAtEpoch int64 `db:"created_at_epoch" json:"created_at_epoch"`
}
@@ -88,12 +88,12 @@ func NewObservationRelation(sourceID, targetID int64, relType RelationType, conf
// RelationDetectionResult contains the result of relation detection.
type RelationDetectionResult struct {
- SourceID int64
- TargetID int64
RelationType RelationType
- Confidence float64
DetectionSource RelationDetectionSource
Reason string
+ SourceID int64
+ TargetID int64
+ Confidence float64
}
// DetectFileOverlapRelation checks if observations share file references and determines relationship type.
@@ -484,6 +484,6 @@ type RelationWithDetails struct {
// RelationGraph represents a graph of related observations.
type RelationGraph struct {
- CenterID int64 `json:"center_id"`
Relations []*RelationWithDetails `json:"relations"`
+ CenterID int64 `json:"center_id"`
}
diff --git a/pkg/models/relation_test.go b/pkg/models/relation_test.go
index 519973f..aee8e00 100644
--- a/pkg/models/relation_test.go
+++ b/pkg/models/relation_test.go
@@ -8,12 +8,12 @@ import (
func TestDetectFileOverlapRelation(t *testing.T) {
tests := []struct {
- name string
newer *Observation
older *Observation
- wantRelation bool
+ name string
wantRelType RelationType
wantMinConfid float64
+ wantRelation bool
}{
{
name: "no file overlap",
@@ -105,11 +105,11 @@ func TestDetectFileOverlapRelation(t *testing.T) {
func TestDetectConceptOverlapRelation(t *testing.T) {
tests := []struct {
- name string
newer *Observation
older *Observation
- wantRelation bool
+ name string
wantMinConfid float64
+ wantRelation bool
}{
{
name: "no concept overlap",
@@ -179,8 +179,8 @@ func TestDetectTypeProgressionRelation(t *testing.T) {
name string
newerType ObservationType
olderType ObservationType
- wantRelation bool
wantRelType RelationType
+ wantRelation bool
}{
{
name: "bugfix fixes discovery",
@@ -314,8 +314,8 @@ func TestDetectNarrativeMentionRelation(t *testing.T) {
tests := []struct {
name string
narrative string
- wantRelation bool
wantRelType RelationType
+ wantRelation bool
}{
{
name: "fixes language",
diff --git a/pkg/models/scoring.go b/pkg/models/scoring.go
index 28b9af2..c2ab5b3 100644
--- a/pkg/models/scoring.go
+++ b/pkg/models/scoring.go
@@ -4,8 +4,8 @@ package models
// ConceptWeight represents a configurable weight for a concept.
type ConceptWeight struct {
Concept string `db:"concept" json:"concept"`
- Weight float64 `db:"weight" json:"weight"`
UpdatedAt string `db:"updated_at" json:"updated_at"`
+ Weight float64 `db:"weight" json:"weight"`
}
// UserFeedbackType represents the type of user feedback.
@@ -62,28 +62,12 @@ var TypeBaseScores = map[ObservationType]float64{
// ScoringConfig contains all scoring weights and parameters.
type ScoringConfig struct {
- // RecencyHalfLifeDays is the number of days for the importance score to halve.
- // With 7 days, a 7-day old observation has 50% of a new observation's recency score.
- RecencyHalfLifeDays float64 `json:"recency_half_life_days"`
-
- // FeedbackWeight scales the user feedback contribution to final score.
- // With 0.30, a thumbs up adds 0.30 to the score, thumbs down subtracts 0.30.
- FeedbackWeight float64 `json:"feedback_weight"`
-
- // ConceptWeight scales the concept boost contribution.
- // The sum of matching concept weights is multiplied by this.
- ConceptWeight float64 `json:"concept_weight"`
-
- // RetrievalWeight scales the retrieval boost contribution.
- // Popular observations get a logarithmic bonus.
- RetrievalWeight float64 `json:"retrieval_weight"`
-
- // ConceptWeights maps concept names to their importance weights.
- ConceptWeights map[string]float64 `json:"concept_weights"`
-
- // MinScore is the minimum allowed importance score.
- // Prevents observations from completely disappearing.
- MinScore float64 `json:"min_score"`
+ ConceptWeights map[string]float64 `json:"concept_weights"`
+ RecencyHalfLifeDays float64 `json:"recency_half_life_days"`
+ FeedbackWeight float64 `json:"feedback_weight"`
+ ConceptWeight float64 `json:"concept_weight"`
+ RetrievalWeight float64 `json:"retrieval_weight"`
+ MinScore float64 `json:"min_score"`
}
// DefaultScoringConfig returns the default scoring configuration.
diff --git a/pkg/models/session.go b/pkg/models/session.go
index ee58f4b..5321dca 100644
--- a/pkg/models/session.go
+++ b/pkg/models/session.go
@@ -17,29 +17,29 @@ const (
// SDKSession represents a Claude Code session tracked by the memory system.
type SDKSession struct {
- ID int64 `db:"id" json:"id"`
ClaudeSessionID string `db:"claude_session_id" json:"claude_session_id"`
- SDKSessionID sql.NullString `db:"sdk_session_id" json:"sdk_session_id,omitempty"`
Project string `db:"project" json:"project"`
- UserPrompt sql.NullString `db:"user_prompt" json:"user_prompt,omitempty"`
- WorkerPort sql.NullInt64 `db:"worker_port" json:"worker_port,omitempty"`
- PromptCounter int64 `db:"prompt_counter" json:"prompt_counter"`
Status SessionStatus `db:"status" json:"status"`
StartedAt string `db:"started_at" json:"started_at"`
- StartedAtEpoch int64 `db:"started_at_epoch" json:"started_at_epoch"`
+ SDKSessionID sql.NullString `db:"sdk_session_id" json:"sdk_session_id,omitempty"`
+ UserPrompt sql.NullString `db:"user_prompt" json:"user_prompt,omitempty"`
CompletedAt sql.NullString `db:"completed_at" json:"completed_at,omitempty"`
+ WorkerPort sql.NullInt64 `db:"worker_port" json:"worker_port,omitempty"`
CompletedAtEpoch sql.NullInt64 `db:"completed_at_epoch" json:"completed_at_epoch,omitempty"`
+ ID int64 `db:"id" json:"id"`
+ PromptCounter int64 `db:"prompt_counter" json:"prompt_counter"`
+ StartedAtEpoch int64 `db:"started_at_epoch" json:"started_at_epoch"`
}
// ActiveSession represents an in-memory active session being processed.
type ActiveSession struct {
- SessionDBID int64
+ StartTime time.Time
ClaudeSessionID string
SDKSessionID string
Project string
UserPrompt string
+ SessionDBID int64
LastPromptNumber int
- StartTime time.Time
CumulativeInputTokens int64
CumulativeOutputTokens int64
}
diff --git a/pkg/models/summary.go b/pkg/models/summary.go
index 81990c5..6909123 100644
--- a/pkg/models/summary.go
+++ b/pkg/models/summary.go
@@ -9,18 +9,18 @@ import (
// SessionSummary represents a summary of a Claude Code session.
type SessionSummary struct {
- ID int64 `db:"id" json:"id"`
+ CreatedAt string `db:"created_at" json:"created_at"`
SDKSessionID string `db:"sdk_session_id" json:"sdk_session_id"`
Project string `db:"project" json:"project"`
- Request sql.NullString `db:"request" json:"request,omitempty"`
+ Completed sql.NullString `db:"completed" json:"completed,omitempty"`
Investigated sql.NullString `db:"investigated" json:"investigated,omitempty"`
Learned sql.NullString `db:"learned" json:"learned,omitempty"`
- Completed sql.NullString `db:"completed" json:"completed,omitempty"`
NextSteps sql.NullString `db:"next_steps" json:"next_steps,omitempty"`
Notes sql.NullString `db:"notes" json:"notes,omitempty"`
+ Request sql.NullString `db:"request" json:"request,omitempty"`
PromptNumber sql.NullInt64 `db:"prompt_number" json:"prompt_number,omitempty"`
+ ID int64 `db:"id" json:"id"`
DiscoveryTokens int64 `db:"discovery_tokens" json:"discovery_tokens"`
- CreatedAt string `db:"created_at" json:"created_at"`
CreatedAtEpoch int64 `db:"created_at_epoch" json:"created_at_epoch"`
}
@@ -56,18 +56,18 @@ func NewSessionSummary(sdkSessionID, project string, parsed *ParsedSummary, prom
// SessionSummaryJSON is a JSON-friendly representation of SessionSummary.
// It converts sql.NullString to plain strings for clean JSON output.
type SessionSummaryJSON struct {
- ID int64 `json:"id"`
+ Completed string `json:"completed,omitempty"`
SDKSessionID string `json:"sdk_session_id"`
Project string `json:"project"`
Request string `json:"request,omitempty"`
Investigated string `json:"investigated,omitempty"`
Learned string `json:"learned,omitempty"`
- Completed string `json:"completed,omitempty"`
NextSteps string `json:"next_steps,omitempty"`
Notes string `json:"notes,omitempty"`
+ CreatedAt string `json:"created_at"`
+ ID int64 `json:"id"`
PromptNumber int64 `json:"prompt_number,omitempty"`
DiscoveryTokens int64 `json:"discovery_tokens"`
- CreatedAt string `json:"created_at"`
CreatedAtEpoch int64 `json:"created_at_epoch"`
}
diff --git a/pkg/similarity/clustering_test.go b/pkg/similarity/clustering_test.go
index d843dfe..6aa7f06 100644
--- a/pkg/similarity/clustering_test.go
+++ b/pkg/similarity/clustering_test.go
@@ -12,9 +12,9 @@ import (
func TestJaccardSimilarity(t *testing.T) {
tests := []struct {
- name string
set1 map[string]bool
set2 map[string]bool
+ name string
expected float64
}{
{
diff --git a/ui/package-lock.json b/ui/package-lock.json
index 6a20f38..9207c57 100644
--- a/ui/package-lock.json
+++ b/ui/package-lock.json
@@ -1,12 +1,12 @@
{
"name": "claude-mnemonic-dashboard",
- "version": "8fe9ea5-dirty",
+ "version": "v0.10.5-1-g7ab4b07-dirty",
"lockfileVersion": 3,
"requires": true,
"packages": {
"": {
"name": "claude-mnemonic-dashboard",
- "version": "8fe9ea5-dirty",
+ "version": "v0.10.5-1-g7ab4b07-dirty",
"dependencies": {
"vis-data": "^7.1.9",
"vis-network": "^9.1.9",
diff --git a/ui/package.json b/ui/package.json
index 09687eb..21b0206 100644
--- a/ui/package.json
+++ b/ui/package.json
@@ -1,6 +1,6 @@
{
"name": "claude-mnemonic-dashboard",
- "version": "8fe9ea5-dirty",
+ "version": "v0.10.5-1-g7ab4b07-dirty",
"private": true,
"type": "module",
"scripts": {
diff --git a/ui/src/components/Sidebar.vue b/ui/src/components/Sidebar.vue
index 9e8ed96..401c33e 100644
--- a/ui/src/components/Sidebar.vue
+++ b/ui/src/components/Sidebar.vue
@@ -2,6 +2,7 @@
import { ref, computed } from 'vue'
import type { Stats, SelfCheckResponse } from '@/types'
import ProjectFilter from './ProjectFilter.vue'
+import { useGraphMetrics } from '@/composables'
const props = defineProps<{
stats: Stats | null
@@ -18,12 +19,21 @@ defineEmits<{
// Collapse state - persisted in localStorage
const isCollapsed = ref(localStorage.getItem('sidebar-collapsed') === 'true')
+const metricsExpanded = ref(localStorage.getItem('metrics-expanded') === 'true')
+
+// Graph metrics composable
+const { graphStats, vectorMetrics, loading: metricsLoading, refresh: refreshMetrics } = useGraphMetrics()
function toggleCollapse() {
isCollapsed.value = !isCollapsed.value
localStorage.setItem('sidebar-collapsed', String(isCollapsed.value))
}
+function toggleMetrics() {
+ metricsExpanded.value = !metricsExpanded.value
+ localStorage.setItem('metrics-expanded', String(metricsExpanded.value))
+}
+
function formatNumber(n: number): string {
if (n >= 1000000) return (n / 1000000).toFixed(1) + 'M'
if (n >= 1000) return (n / 1000).toFixed(1) + 'K'
@@ -205,6 +215,99 @@ function getStatusColor(status: string): string {
+
+