diff --git a/.golangci.yml b/.golangci.yml new file mode 100644 index 0000000..d74bf73 --- /dev/null +++ b/.golangci.yml @@ -0,0 +1,35 @@ +linters-settings: + govet: + enable: + - fieldalignment + errcheck: + # Ignore error checks in test files for common test helpers + exclude-functions: + - (io.Closer).Close + - (*encoding/json.Encoder).Encode + - (io.Writer).Write + +linters: + enable: + - errcheck + - gosec + - govet + - gofmt + - staticcheck + - unused + - ineffassign + - typecheck + +issues: + exclude-dirs: + - vendor + # Exclude some linters from running on test files + exclude-rules: + - path: _test\.go + linters: + - errcheck + - gosec + +run: + timeout: 5m + tests: true diff --git a/go.mod b/go.mod index b8f4cf1..489ba7f 100644 --- a/go.mod +++ b/go.mod @@ -27,6 +27,7 @@ require ( github.com/pmezard/go-difflib v1.0.0 // indirect github.com/rivo/uniseg v0.4.7 // indirect github.com/schollz/progressbar/v2 v2.15.0 // indirect + github.com/stretchr/objx v0.5.2 // indirect github.com/sugarme/regexpset v0.0.0-20200920021344-4d4ec8eaf93c // indirect golang.org/x/sys v0.39.0 // indirect golang.org/x/text v0.32.0 // indirect diff --git a/go.sum b/go.sum index a1cd6c1..ddb22b9 100644 --- a/go.sum +++ b/go.sum @@ -41,6 +41,8 @@ github.com/schollz/progressbar/v2 v2.15.0/go.mod h1:UdPq3prGkfQ7MOzZKlDRpYKcFqEM github.com/smacker/go-tree-sitter v0.0.0-20240827094217-dd81d9e9be82 h1:6C8qej6f1bStuePVkLSFxoU22XBS165D3klxlzRg8F4= github.com/smacker/go-tree-sitter v0.0.0-20240827094217-dd81d9e9be82/go.mod h1:xe4pgH49k4SsmkQq5OT8abwhWmnzkhpgnXeekbx2efw= github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME= +github.com/stretchr/objx v0.5.2 h1:xuMeJ0Sdp5ZMRXx/aWO6RZxdr3beISkG5/G/aIRr3pY= +github.com/stretchr/objx v0.5.2/go.mod h1:FRsXN1f5AsAjCGJKqEizvkpNtU+EGNCLh3NxZ/8L+MA= github.com/stretchr/testify v1.3.0/go.mod h1:M5WIy9Dh21IEIfnGCwXGc5bZfKNJtfHm1UVUgZn+9EI= github.com/stretchr/testify v1.11.1 h1:7s2iGBzp5EwR7/aIZr8ao5+dra3wiQyKjjFuvgVKu7U= github.com/stretchr/testify v1.11.1/go.mod h1:wZwfW3scLgRK+23gO65QZefKpKQRnfz6sD981Nm4B6U= diff --git a/internal/config/config.go b/internal/config/config.go index c07686b..4524f11 100644 --- a/internal/config/config.go +++ b/internal/config/config.go @@ -47,6 +47,7 @@ type Config struct { RerankingMinImprovement float64 `json:"reranking_min_improvement"` RerankingCandidates int `json:"reranking_candidates"` RerankingAlpha float64 `json:"reranking_alpha"` + GraphEdgeWeight float64 `json:"graph_edge_weight"` WorkerPort int `json:"worker_port"` ContextMaxPromptResults int `json:"context_max_prompt_results"` ContextObservations int `json:"context_observations"` @@ -55,11 +56,15 @@ type Config struct { ContextRelevanceThreshold float64 `json:"context_relevance_threshold"` MaxConns int `json:"max_conns"` RerankingResults int `json:"reranking_results"` + GraphMaxHops int `json:"graph_max_hops"` + GraphBranchFactor int `json:"graph_branch_factor"` + GraphRebuildIntervalMin int `json:"graph_rebuild_interval_min"` ContextShowLastSummary bool `json:"context_show_last_summary"` RerankingEnabled bool `json:"reranking_enabled"` ContextShowWorkTokens bool `json:"context_show_work_tokens"` ContextShowReadTokens bool `json:"context_show_read_tokens"` RerankingPureMode bool `json:"reranking_pure_mode"` + GraphEnabled bool `json:"graph_enabled"` } var ( @@ -137,6 +142,11 @@ func Default() *Config { RerankingResults: 10, // Return top 10 after reranking RerankingAlpha: 0.7, // Favor cross-encoder score RerankingMinImprovement: 0, // Always apply reranking + GraphEnabled: true, // Enable graph-aware search by default + GraphMaxHops: 2, // Two-hop traversal + GraphBranchFactor: 5, // Expand top 5 neighbors per node + GraphEdgeWeight: 0.3, // Minimum edge weight to follow + GraphRebuildIntervalMin: 60, // Rebuild graph every 60 minutes ContextObservations: 100, ContextFullCount: 25, ContextSessionCount: 10, @@ -222,6 +232,22 @@ func Load() (*Config, error) { if v, ok := settings["CLAUDE_MNEMONIC_CONTEXT_MAX_PROMPT_RESULTS"].(float64); ok && v >= 0 { cfg.ContextMaxPromptResults = int(v) } + // Graph settings + if v, ok := settings["CLAUDE_MNEMONIC_GRAPH_ENABLED"].(bool); ok { + cfg.GraphEnabled = v + } + if v, ok := settings["CLAUDE_MNEMONIC_GRAPH_MAX_HOPS"].(float64); ok && v > 0 { + cfg.GraphMaxHops = int(v) + } + if v, ok := settings["CLAUDE_MNEMONIC_GRAPH_BRANCH_FACTOR"].(float64); ok && v > 0 { + cfg.GraphBranchFactor = int(v) + } + if v, ok := settings["CLAUDE_MNEMONIC_GRAPH_EDGE_WEIGHT"].(float64); ok && v >= 0 && v <= 1 { + cfg.GraphEdgeWeight = v + } + if v, ok := settings["CLAUDE_MNEMONIC_GRAPH_REBUILD_INTERVAL_MIN"].(float64); ok && v > 0 { + cfg.GraphRebuildIntervalMin = int(v) + } return cfg, nil } diff --git a/internal/graph/edge_detector.go b/internal/graph/edge_detector.go new file mode 100644 index 0000000..0770010 --- /dev/null +++ b/internal/graph/edge_detector.go @@ -0,0 +1,417 @@ +package graph + +import ( + "context" + "fmt" + "math" + + "github.com/lukaszraczylo/claude-mnemonic/pkg/models" + "github.com/rs/zerolog/log" +) + +const ( + // SemanticSimilarityThreshold for creating semantic edges + SemanticSimilarityThreshold = 0.85 + + // MinFileOverlapForEdge minimum file overlap ratio to create edge + MinFileOverlapForEdge = 0.3 + + // MaxEdgesPerNode prevents creating too many edges + MaxEdgesPerNode = 20 +) + +// DetectEdges identifies relationships between observations +func DetectEdges(ctx context.Context, observations []*models.Observation) ([]Edge, error) { + if len(observations) < 2 { + return nil, nil + } + + edges := make([]Edge, 0) + + // Build lookup maps for efficient detection + sessionMap := buildSessionMap(observations) + conceptMap := buildConceptMap(observations) + fileMap := buildFileMap(observations) + + log.Info(). + Int("observations", len(observations)). + Int("sessions", len(sessionMap)). + Int("concepts", len(conceptMap)). + Msg("Starting edge detection") + + // Detect temporal edges (same session) + temporalEdges := detectTemporalEdges(sessionMap) + edges = append(edges, temporalEdges...) + + // Detect concept edges (shared tags) + conceptEdges := detectConceptEdges(conceptMap) + edges = append(edges, conceptEdges...) + + // Detect file overlap edges + fileEdges := detectFileOverlapEdges(fileMap, observations) + edges = append(edges, fileEdges...) + + // Prune excessive edges per node + edges = pruneEdges(edges, MaxEdgesPerNode) + + log.Info(). + Int("temporal_edges", len(temporalEdges)). + Int("concept_edges", len(conceptEdges)). + Int("file_edges", len(fileEdges)). + Int("total_edges", len(edges)). + Msg("Edge detection complete") + + return edges, nil +} + +// buildSessionMap groups observations by SDK session +func buildSessionMap(observations []*models.Observation) map[string][]int64 { + sessionMap := make(map[string][]int64) + + for _, obs := range observations { + if obs.SDKSessionID != "" { + sessionMap[obs.SDKSessionID] = append(sessionMap[obs.SDKSessionID], obs.ID) + } + } + + return sessionMap +} + +// buildConceptMap groups observations by concept tags +func buildConceptMap(observations []*models.Observation) map[string][]int64 { + conceptMap := make(map[string][]int64) + + for _, obs := range observations { + for _, concept := range obs.Concepts { + conceptMap[concept] = append(conceptMap[concept], obs.ID) + } + } + + return conceptMap +} + +// buildFileMap maps files to observations (from both FilesRead and FilesModified) +func buildFileMap(observations []*models.Observation) map[string][]int64 { + fileMap := make(map[string][]int64) + + for _, obs := range observations { + // Add files from FilesRead + for _, file := range obs.FilesRead { + fileMap[file] = append(fileMap[file], obs.ID) + } + // Add files from FilesModified + for _, file := range obs.FilesModified { + fileMap[file] = append(fileMap[file], obs.ID) + } + } + + return fileMap +} + +// detectTemporalEdges creates edges between observations in the same session +func detectTemporalEdges(sessionMap map[string][]int64) []Edge { + edges := make([]Edge, 0) + + for _, obsIDs := range sessionMap { + if len(obsIDs) < 2 { + continue + } + + // Create edges between consecutive observations in session + for i := 0; i < len(obsIDs)-1; i++ { + edges = append(edges, Edge{ + FromID: obsIDs[i], + ToID: obsIDs[i+1], + Relation: RelationTemporal, + Weight: 0.8, // High weight for temporal proximity + }) + } + } + + return edges +} + +// detectConceptEdges creates edges between observations sharing concepts +func detectConceptEdges(conceptMap map[string][]int64) []Edge { + edges := make([]Edge, 0) + seen := make(map[string]bool) + + for concept, obsIDs := range conceptMap { + if len(obsIDs) < 2 { + continue + } + + // Create edges between all observations sharing this concept + for i := 0; i < len(obsIDs); i++ { + for j := i + 1; j < len(obsIDs); j++ { + // Use sorted pair as key to avoid duplicates + pairKey := edgeKey(obsIDs[i], obsIDs[j]) + if seen[pairKey] { + continue + } + seen[pairKey] = true + + // Weight based on concept specificity (longer = more specific) + weight := float32(0.5 + 0.3*math.Min(1.0, float64(len(concept))/20.0)) + + edges = append(edges, Edge{ + FromID: obsIDs[i], + ToID: obsIDs[j], + Relation: RelationConcept, + Weight: weight, + }) + } + } + } + + return edges +} + +// detectFileOverlapEdges creates edges based on file references +func detectFileOverlapEdges(fileMap map[string][]int64, observations []*models.Observation) []Edge { + edges := make([]Edge, 0) + seen := make(map[string]bool) + + // Build observation ID to observation map for quick lookup + obsMap := make(map[int64]*models.Observation) + for _, obs := range observations { + obsMap[obs.ID] = obs + } + + for _, obsIDs := range fileMap { + if len(obsIDs) < 2 { + continue + } + + // Create edges between observations referencing same files + for i := 0; i < len(obsIDs); i++ { + for j := i + 1; j < len(obsIDs); j++ { + pairKey := edgeKey(obsIDs[i], obsIDs[j]) + if seen[pairKey] { + continue + } + seen[pairKey] = true + + // Calculate file overlap ratio + obs1, ok1 := obsMap[obsIDs[i]] + obs2, ok2 := obsMap[obsIDs[j]] + + if !ok1 || !ok2 { + continue + } + + // Merge FilesRead and FilesModified for both observations + files1 := append([]string{}, obs1.FilesRead...) + files1 = append(files1, obs1.FilesModified...) + files2 := append([]string{}, obs2.FilesRead...) + files2 = append(files2, obs2.FilesModified...) + + overlap := calculateFileOverlap(files1, files2) + if overlap < MinFileOverlapForEdge { + continue + } + + edges = append(edges, Edge{ + FromID: obsIDs[i], + ToID: obsIDs[j], + Relation: RelationFileOverlap, + Weight: overlap, + }) + } + } + } + + return edges +} + +// calculateFileOverlap computes Jaccard similarity of file sets +func calculateFileOverlap(files1, files2 []string) float32 { + if len(files1) == 0 || len(files2) == 0 { + return 0.0 + } + + // Convert to sets + set1 := make(map[string]bool) + for _, f := range files1 { + set1[f] = true + } + + set2 := make(map[string]bool) + for _, f := range files2 { + set2[f] = true + } + + // Count intersection + intersection := 0 + for f := range set1 { + if set2[f] { + intersection++ + } + } + + // Jaccard similarity = intersection / union + union := len(set1) + len(set2) - intersection + if union == 0 { + return 0.0 + } + + return float32(intersection) / float32(union) +} + +// pruneEdges limits edges per node to prevent graph explosion +func pruneEdges(edges []Edge, maxPerNode int) []Edge { + if maxPerNode <= 0 { + return edges + } + + // Count edges per node + outEdges := make(map[int64][]Edge) + inEdges := make(map[int64][]Edge) + + for _, edge := range edges { + outEdges[edge.FromID] = append(outEdges[edge.FromID], edge) + inEdges[edge.ToID] = append(inEdges[edge.ToID], edge) + } + + // Prune low-weight edges if node has too many + pruned := make([]Edge, 0, len(edges)) + processed := make(map[string]bool) + + for _, edge := range edges { + pairKey := edgeKey(edge.FromID, edge.ToID) + if processed[pairKey] { + continue + } + processed[pairKey] = true + + // Check if either node has too many edges + fromCount := len(outEdges[edge.FromID]) + toCount := len(inEdges[edge.ToID]) + + if fromCount <= maxPerNode && toCount <= maxPerNode { + pruned = append(pruned, edge) + continue + } + + // Keep edge if it's high-weight (top edges for this node) + if shouldKeepEdge(edge, outEdges[edge.FromID], maxPerNode) { + pruned = append(pruned, edge) + } + } + + if len(pruned) < len(edges) { + log.Debug(). + Int("original", len(edges)). + Int("pruned", len(pruned)). + Int("removed", len(edges)-len(pruned)). + Msg("Pruned excessive edges") + } + + return pruned +} + +// shouldKeepEdge determines if edge should be kept during pruning +func shouldKeepEdge(edge Edge, nodeEdges []Edge, maxPerNode int) bool { + // Sort node's edges by weight descending + sortedEdges := make([]Edge, len(nodeEdges)) + copy(sortedEdges, nodeEdges) + + sortEdgesByWeight(sortedEdges) + + // Keep edge if it's in top maxPerNode + for i := 0; i < maxPerNode && i < len(sortedEdges); i++ { + if sortedEdges[i].FromID == edge.FromID && sortedEdges[i].ToID == edge.ToID { + return true + } + } + + return false +} + +// sortEdgesByWeight sorts edges by weight descending +func sortEdgesByWeight(edges []Edge) { + // Simple bubble sort (edges are typically small per node) + n := len(edges) + for i := 0; i < n-1; i++ { + for j := 0; j < n-i-1; j++ { + if edges[j].Weight < edges[j+1].Weight { + edges[j], edges[j+1] = edges[j+1], edges[j] + } + } + } +} + +// edgeKey creates a unique key for an edge pair (sorted) +func edgeKey(id1, id2 int64) string { + if id1 < id2 { + return fmt.Sprintf("%d-%d", id1, id2) + } + return fmt.Sprintf("%d-%d", id2, id1) +} + +// DetectSemanticEdges creates edges based on semantic similarity +// This requires embeddings and is called separately when available +func DetectSemanticEdges(ctx context.Context, observations []*models.Observation, embeddings map[int64][]float32) []Edge { + edges := make([]Edge, 0) + seen := make(map[string]bool) + + // Compare all pairs (expensive, but necessary for semantic similarity) + for i := 0; i < len(observations); i++ { + emb1, ok1 := embeddings[observations[i].ID] + if !ok1 { + continue + } + + for j := i + 1; j < len(observations); j++ { + emb2, ok2 := embeddings[observations[j].ID] + if !ok2 { + continue + } + + similarity := cosineSimilarity(emb1, emb2) + if similarity < SemanticSimilarityThreshold { + continue + } + + pairKey := edgeKey(observations[i].ID, observations[j].ID) + if seen[pairKey] { + continue + } + seen[pairKey] = true + + edges = append(edges, Edge{ + FromID: observations[i].ID, + ToID: observations[j].ID, + Relation: RelationSemantic, + Weight: similarity, + }) + } + } + + log.Info(). + Int("semantic_edges", len(edges)). + Float32("threshold", SemanticSimilarityThreshold). + Msg("Detected semantic edges") + + return edges +} + +// cosineSimilarity computes cosine similarity between two vectors +func cosineSimilarity(a, b []float32) float32 { + if len(a) != len(b) { + return 0.0 + } + + var dotProduct, normA, normB float32 + for i := range a { + dotProduct += a[i] * b[i] + normA += a[i] * a[i] + normB += b[i] * b[i] + } + + if normA == 0 || normB == 0 { + return 0.0 + } + + return dotProduct / float32(math.Sqrt(float64(normA))*math.Sqrt(float64(normB))) +} diff --git a/internal/graph/observation_graph.go b/internal/graph/observation_graph.go new file mode 100644 index 0000000..c86b6fa --- /dev/null +++ b/internal/graph/observation_graph.go @@ -0,0 +1,423 @@ +// Package graph provides observation relationship graphs for LEANN Phase 2. +// +// This package implements graph-based selective recomputation where observation +// relationships (file overlap, semantic similarity, temporal proximity) form a +// graph structure. Hub nodes (high-degree observations) store embeddings, while +// leaf nodes recompute on-demand. +package graph + +import ( + "context" + "fmt" + "math" + "sort" + "sync" + "time" + + "github.com/lukaszraczylo/claude-mnemonic/pkg/models" + "github.com/rs/zerolog/log" +) + +// RelationType defines the type of relationship between observations +type RelationType int + +const ( + // RelationFileOverlap indicates observations reference overlapping files + RelationFileOverlap RelationType = iota + // RelationSemantic indicates high semantic similarity (cosine > 0.85) + RelationSemantic + // RelationTemporal indicates observations from same session + RelationTemporal + // RelationConcept indicates shared concept tags + RelationConcept +) + +// Edge represents a relationship between two observations +type Edge struct { + FromID int64 + ToID int64 + Relation RelationType + Weight float32 // 0.0-1.0, higher = stronger relationship +} + +// Node represents an observation in the graph +type Node struct { + Metadata NodeMetadata + LastAccess time.Time + StoredEmb []float32 // Nil if recomputed on-demand + ID int64 + Degree int // Number of edges (hub detection) + AccessCount int +} + +// NodeMetadata contains observation metadata +type NodeMetadata struct { + CreatedAt time.Time + Project string + Type string + Title string + IsSuperseded bool +} + +// CSRGraph represents a graph in Compressed Sparse Row format for memory efficiency +type CSRGraph struct { + RowPtr []int32 // Node adjacency list pointers + ColIdx []int32 // Edge destination IDs + Weights []float32 // Edge weights + mu sync.RWMutex +} + +// ObservationGraph manages the observation relationship graph +type ObservationGraph struct { + nodes map[int64]*Node + csr *CSRGraph + edges []Edge + nodesMu sync.RWMutex + edgesMu sync.RWMutex +} + +// NewObservationGraph creates a new empty observation graph +func NewObservationGraph() *ObservationGraph { + return &ObservationGraph{ + nodes: make(map[int64]*Node), + edges: make([]Edge, 0), + csr: &CSRGraph{}, + } +} + +// AddNode adds or updates a node in the graph +func (g *ObservationGraph) AddNode(node *Node) { + g.nodesMu.Lock() + defer g.nodesMu.Unlock() + + g.nodes[node.ID] = node +} + +// AddEdge adds an edge to the graph +func (g *ObservationGraph) AddEdge(edge Edge) { + g.edgesMu.Lock() + defer g.edgesMu.Unlock() + + g.edges = append(g.edges, edge) + + // Update degree counts + g.nodesMu.Lock() + if fromNode, ok := g.nodes[edge.FromID]; ok { + fromNode.Degree++ + } + if toNode, ok := g.nodes[edge.ToID]; ok { + toNode.Degree++ + } + g.nodesMu.Unlock() +} + +// BuildCSR converts edge list to CSR format for efficient traversal +func (g *ObservationGraph) BuildCSR() error { + g.edgesMu.RLock() + g.nodesMu.RLock() + defer g.edgesMu.RUnlock() + defer g.nodesMu.RUnlock() + + if len(g.nodes) == 0 { + return fmt.Errorf("no nodes in graph") + } + + // Create node ID to index mapping + nodeIDs := make([]int64, 0, len(g.nodes)) + for id := range g.nodes { + nodeIDs = append(nodeIDs, id) + } + sort.Slice(nodeIDs, func(i, j int) bool { + return nodeIDs[i] < nodeIDs[j] + }) + + idToIdx := make(map[int64]int32) + for idx, id := range nodeIDs { + // #nosec G115 - observation count will never exceed int32 max (2.1B) in practice + idToIdx[id] = int32(idx) + } + + // Count edges per node + edgeCounts := make([]int, len(nodeIDs)) + for _, edge := range g.edges { + if fromIdx, ok := idToIdx[edge.FromID]; ok { + edgeCounts[fromIdx]++ + } + } + + // Build row pointers + rowPtr := make([]int32, len(nodeIDs)+1) + rowPtr[0] = 0 + for i := 0; i < len(nodeIDs); i++ { + // #nosec G115 - edge counts per node will not exceed int32 max + rowPtr[i+1] = rowPtr[i] + int32(edgeCounts[i]) + } + + // Build column indices and weights + totalEdges := rowPtr[len(nodeIDs)] + colIdx := make([]int32, totalEdges) + weights := make([]float32, totalEdges) + + // Temporary counter for filling CSR + currentPos := make([]int32, len(nodeIDs)) + copy(currentPos, rowPtr[:len(nodeIDs)]) + + for _, edge := range g.edges { + fromIdx, fromOk := idToIdx[edge.FromID] + toIdx, toOk := idToIdx[edge.ToID] + + if fromOk && toOk { + pos := currentPos[fromIdx] + colIdx[pos] = toIdx + weights[pos] = edge.Weight + currentPos[fromIdx]++ + } + } + + g.csr.mu.Lock() + g.csr.RowPtr = rowPtr + g.csr.ColIdx = colIdx + g.csr.Weights = weights + g.csr.mu.Unlock() + + log.Info(). + Int("nodes", len(nodeIDs)). + Int("edges", int(totalEdges)). + Msg("Built CSR graph representation") + + return nil +} + +// GetNeighbors returns neighboring nodes and their edge weights +func (g *ObservationGraph) GetNeighbors(nodeID int64) ([]int64, []float32, error) { + g.csr.mu.RLock() + defer g.csr.mu.RUnlock() + + // Find node index in CSR + g.nodesMu.RLock() + nodeIDs := make([]int64, 0, len(g.nodes)) + for id := range g.nodes { + nodeIDs = append(nodeIDs, id) + } + g.nodesMu.RUnlock() + + sort.Slice(nodeIDs, func(i, j int) bool { + return nodeIDs[i] < nodeIDs[j] + }) + + nodeIdx := sort.Search(len(nodeIDs), func(i int) bool { + return nodeIDs[i] >= nodeID + }) + + if nodeIdx >= len(nodeIDs) || nodeIDs[nodeIdx] != nodeID { + return nil, nil, fmt.Errorf("node %d not found", nodeID) + } + + // Extract neighbors from CSR + startIdx := g.csr.RowPtr[nodeIdx] + endIdx := g.csr.RowPtr[nodeIdx+1] + + neighborCount := endIdx - startIdx + neighbors := make([]int64, neighborCount) + weights := make([]float32, neighborCount) + + for i := int32(0); i < neighborCount; i++ { + neighborIdx := g.csr.ColIdx[startIdx+i] + neighbors[i] = nodeIDs[neighborIdx] + weights[i] = g.csr.Weights[startIdx+i] + } + + return neighbors, weights, nil +} + +// GetNode retrieves a node by ID +func (g *ObservationGraph) GetNode(nodeID int64) (*Node, error) { + g.nodesMu.RLock() + defer g.nodesMu.RUnlock() + + node, ok := g.nodes[nodeID] + if !ok { + return nil, fmt.Errorf("node %d not found", nodeID) + } + + return node, nil +} + +// FindHubs identifies hub nodes (high degree) in the graph +func (g *ObservationGraph) FindHubs(percentile float64) []int64 { + g.nodesMu.RLock() + defer g.nodesMu.RUnlock() + + if len(g.nodes) == 0 { + return nil + } + + // Collect all degrees + degrees := make([]int, 0, len(g.nodes)) + nodeIDs := make([]int64, 0, len(g.nodes)) + + for id, node := range g.nodes { + degrees = append(degrees, node.Degree) + nodeIDs = append(nodeIDs, id) + } + + // Sort by degree + type nodeDegree struct { + ID int64 + Degree int + } + + nodeDegrees := make([]nodeDegree, len(nodeIDs)) + for i := range nodeIDs { + nodeDegrees[i] = nodeDegree{ + ID: nodeIDs[i], + Degree: degrees[i], + } + } + + sort.Slice(nodeDegrees, func(i, j int) bool { + return nodeDegrees[i].Degree > nodeDegrees[j].Degree + }) + + // Return top percentile + cutoff := int(math.Ceil(float64(len(nodeDegrees)) * (1.0 - percentile))) + if cutoff > len(nodeDegrees) { + cutoff = len(nodeDegrees) + } + + hubs := make([]int64, cutoff) + for i := 0; i < cutoff; i++ { + hubs[i] = nodeDegrees[i].ID + } + + log.Info(). + Int("total_nodes", len(g.nodes)). + Int("hubs", len(hubs)). + Float64("percentile", percentile). + Msg("Identified hub nodes") + + return hubs +} + +// Stats returns graph statistics +func (g *ObservationGraph) Stats() GraphStats { + g.nodesMu.RLock() + g.edgesMu.RLock() + defer g.nodesMu.RUnlock() + defer g.edgesMu.RUnlock() + + stats := GraphStats{ + NodeCount: len(g.nodes), + EdgeCount: len(g.edges), + } + + if len(g.nodes) > 0 { + degrees := make([]int, 0, len(g.nodes)) + for _, node := range g.nodes { + degrees = append(degrees, node.Degree) + } + + sort.Ints(degrees) + stats.AvgDegree = float64(sum(degrees)) / float64(len(degrees)) + stats.MaxDegree = degrees[len(degrees)-1] + stats.MinDegree = degrees[0] + + // Median + mid := len(degrees) / 2 + if len(degrees)%2 == 0 { + stats.MedianDegree = float64(degrees[mid-1]+degrees[mid]) / 2.0 + } else { + stats.MedianDegree = float64(degrees[mid]) + } + } + + // Count edge types + stats.EdgeTypes = make(map[RelationType]int) + for _, edge := range g.edges { + stats.EdgeTypes[edge.Relation]++ + } + + return stats +} + +// GraphStats contains graph statistics +type GraphStats struct { + EdgeTypes map[RelationType]int + AvgDegree float64 + MedianDegree float64 + NodeCount int + EdgeCount int + MaxDegree int + MinDegree int +} + +// BuildFromObservations constructs a graph from a list of observations +func BuildFromObservations(ctx context.Context, observations []*models.Observation) (*ObservationGraph, error) { + graph := NewObservationGraph() + + // Add nodes + for _, obs := range observations { + // Extract title from sql.NullString + title := "" + if obs.Title.Valid { + title = obs.Title.String + } + + node := &Node{ + ID: obs.ID, + Degree: 0, + Metadata: NodeMetadata{ + Project: obs.Project, + Type: string(obs.Type), + Title: title, + CreatedAt: time.UnixMilli(obs.CreatedAtEpoch), + IsSuperseded: obs.IsSuperseded, + }, + LastAccess: time.Now(), + AccessCount: 0, + } + graph.AddNode(node) + } + + // Detect edges (will be implemented in edge_detector.go) + edges, err := DetectEdges(ctx, observations) + if err != nil { + return nil, fmt.Errorf("detect edges: %w", err) + } + + for _, edge := range edges { + graph.AddEdge(edge) + } + + // Build CSR representation + if err := graph.BuildCSR(); err != nil { + return nil, fmt.Errorf("build CSR: %w", err) + } + + return graph, nil +} + +// Helper function to sum integers +func sum(values []int) int { + total := 0 + for _, v := range values { + total += v + } + return total +} + +// String returns a human-readable representation of RelationType +func (r RelationType) String() string { + switch r { + case RelationFileOverlap: + return "file_overlap" + case RelationSemantic: + return "semantic" + case RelationTemporal: + return "temporal" + case RelationConcept: + return "concept" + default: + return "unknown" + } +} diff --git a/internal/vector/hybrid/autotuner.go b/internal/vector/hybrid/autotuner.go new file mode 100644 index 0000000..78c3760 --- /dev/null +++ b/internal/vector/hybrid/autotuner.go @@ -0,0 +1,309 @@ +package hybrid + +import ( + "context" + "sync" + "time" + + "github.com/lukaszraczylo/claude-mnemonic/internal/vector/sqlitevec" + "github.com/rs/zerolog/log" +) + +// AutoTuner dynamically adjusts hub threshold based on query performance +type AutoTuner struct { + ctx context.Context + client *Client + cancel context.CancelFunc + latencies []time.Duration + wg sync.WaitGroup + queries int64 + targetLatency time.Duration + adjustPeriod time.Duration + minThreshold int + maxThreshold int + adjustments int + latenciesMu sync.Mutex +} + +// AutoTunerConfig configures the auto-tuner +type AutoTunerConfig struct { + TargetLatency time.Duration // Target p95 latency (default: 50ms) + MinThreshold int // Min hub threshold (default: 2) + MaxThreshold int // Max hub threshold (default: 20) + AdjustPeriod time.Duration // Adjustment frequency (default: 5min) +} + +// DefaultAutoTunerConfig returns sensible defaults +func DefaultAutoTunerConfig() AutoTunerConfig { + return AutoTunerConfig{ + TargetLatency: 50 * time.Millisecond, + MinThreshold: 2, + MaxThreshold: 20, + AdjustPeriod: 5 * time.Minute, + } +} + +// NewAutoTuner creates a new auto-tuner for the hybrid client +func NewAutoTuner(client *Client, cfg AutoTunerConfig) *AutoTuner { + ctx, cancel := context.WithCancel(context.Background()) + + tuner := &AutoTuner{ + client: client, + targetLatency: cfg.TargetLatency, + minThreshold: cfg.MinThreshold, + maxThreshold: cfg.MaxThreshold, + adjustPeriod: cfg.AdjustPeriod, + latencies: make([]time.Duration, 0, 1000), + ctx: ctx, + cancel: cancel, + } + + return tuner +} + +// Start begins auto-tuning in the background +func (a *AutoTuner) Start() { + a.wg.Add(1) + go a.tuningLoop() + + log.Info(). + Dur("target_latency", a.targetLatency). + Int("min_threshold", a.minThreshold). + Int("max_threshold", a.maxThreshold). + Dur("adjust_period", a.adjustPeriod). + Msg("Auto-tuner started") +} + +// Stop stops the auto-tuner +func (a *AutoTuner) Stop() { + a.cancel() + a.wg.Wait() + log.Info().Msg("Auto-tuner stopped") +} + +// RecordQuery records a query latency for analysis +func (a *AutoTuner) RecordQuery(latency time.Duration) { + a.latenciesMu.Lock() + defer a.latenciesMu.Unlock() + + a.queries++ + a.latencies = append(a.latencies, latency) + + // Keep only recent queries (last 1000) + if len(a.latencies) > 1000 { + a.latencies = a.latencies[len(a.latencies)-1000:] + } +} + +// tuningLoop periodically adjusts hub threshold +func (a *AutoTuner) tuningLoop() { + defer a.wg.Done() + + ticker := time.NewTicker(a.adjustPeriod) + defer ticker.Stop() + + for { + select { + case <-a.ctx.Done(): + return + + case <-ticker.C: + a.adjustThreshold() + } + } +} + +// adjustThreshold analyzes recent queries and adjusts hub threshold +func (a *AutoTuner) adjustThreshold() { + a.latenciesMu.Lock() + defer a.latenciesMu.Unlock() + + if len(a.latencies) < 10 { + // Not enough data yet + return + } + + // Calculate p95 latency + p95 := calculateP95(a.latencies) + + currentThreshold := a.client.hubThreshold + + log.Debug(). + Dur("p95_latency", p95). + Dur("target_latency", a.targetLatency). + Int("current_threshold", currentThreshold). + Int("queries", len(a.latencies)). + Msg("Auto-tuner evaluating performance") + + // Determine adjustment direction + var newThreshold int + + if p95 > a.targetLatency { + // Too slow - lower threshold (more hubs = faster queries) + adjustment := calculateAdjustment(p95, a.targetLatency) + newThreshold = currentThreshold - adjustment + + if newThreshold < a.minThreshold { + newThreshold = a.minThreshold + } + + log.Info(). + Dur("p95", p95). + Int("old_threshold", currentThreshold). + Int("new_threshold", newThreshold). + Msg("Auto-tuner: Lowering hub threshold (too slow)") + + } else if p95 < a.targetLatency*8/10 { + // Too fast - raise threshold (fewer hubs = more savings) + // Only adjust if significantly faster (20% margin) + adjustment := calculateAdjustment(a.targetLatency, p95) + newThreshold = currentThreshold + adjustment + + if newThreshold > a.maxThreshold { + newThreshold = a.maxThreshold + } + + log.Info(). + Dur("p95", p95). + Int("old_threshold", currentThreshold). + Int("new_threshold", newThreshold). + Msg("Auto-tuner: Raising hub threshold (room for savings)") + + } else { + // Within acceptable range, no adjustment needed + log.Debug(). + Dur("p95", p95). + Int("threshold", currentThreshold). + Msg("Auto-tuner: Performance acceptable, no adjustment") + return + } + + // Apply adjustment + if newThreshold != currentThreshold { + a.client.hubThreshold = newThreshold + a.adjustments++ + + // Clear latency history after adjustment + a.latencies = make([]time.Duration, 0, 1000) + + log.Info(). + Int("threshold", newThreshold). + Int("total_adjustments", a.adjustments). + Msg("Hub threshold adjusted by auto-tuner") + } +} + +// calculateP95 computes the 95th percentile latency +func calculateP95(latencies []time.Duration) time.Duration { + if len(latencies) == 0 { + return 0 + } + + // Sort latencies + sorted := make([]time.Duration, len(latencies)) + copy(sorted, latencies) + + // Simple bubble sort (small dataset) + n := len(sorted) + for i := 0; i < n-1; i++ { + for j := 0; j < n-i-1; j++ { + if sorted[j] > sorted[j+1] { + sorted[j], sorted[j+1] = sorted[j+1], sorted[j] + } + } + } + + // Return 95th percentile + idx := int(float64(len(sorted)) * 0.95) + if idx >= len(sorted) { + idx = len(sorted) - 1 + } + + return sorted[idx] +} + +// calculateAdjustment determines how much to adjust threshold +func calculateAdjustment(actual, target time.Duration) int { + // Calculate percentage difference + diff := float64(actual-target) / float64(target) + + // Adjust more aggressively for larger differences + if diff > 0.5 || diff < -0.5 { + return 3 // Large adjustment + } else if diff > 0.2 || diff < -0.2 { + return 2 // Medium adjustment + } + + return 1 // Small adjustment +} + +// GetStats returns auto-tuner statistics +func (a *AutoTuner) GetStats() AutoTunerStats { + a.latenciesMu.Lock() + defer a.latenciesMu.Unlock() + + stats := AutoTunerStats{ + CurrentThreshold: a.client.hubThreshold, + TargetLatency: a.targetLatency, + TotalQueries: a.queries, + TotalAdjustments: a.adjustments, + RecentQueries: len(a.latencies), + } + + if len(a.latencies) > 0 { + stats.P95Latency = calculateP95(a.latencies) + + // Calculate average + var total time.Duration + for _, lat := range a.latencies { + total += lat + } + stats.AvgLatency = total / time.Duration(len(a.latencies)) + } + + return stats +} + +// AutoTunerStats contains auto-tuner statistics +type AutoTunerStats struct { + CurrentThreshold int + TargetLatency time.Duration + P95Latency time.Duration + AvgLatency time.Duration + TotalQueries int64 + TotalAdjustments int + RecentQueries int +} + +// AutoTunedClient wraps Client with automatic performance tuning +type AutoTunedClient struct { + *Client + tuner *AutoTuner +} + +// Query wraps the underlying Query call with latency tracking +func (a *AutoTunedClient) Query(ctx context.Context, query string, limit int, where map[string]any) ([]sqlitevec.QueryResult, error) { + start := time.Now() + results, err := a.Client.Query(ctx, query, limit, where) + latency := time.Since(start) + + a.tuner.RecordQuery(latency) + + return results, err +} + +// WithAutoTuning wraps a hybrid client with auto-tuning enabled +func WithAutoTuning(client *Client, cfg AutoTunerConfig) *AutoTunedClient { + tuner := NewAutoTuner(client, cfg) + tuner.Start() + + return &AutoTunedClient{ + Client: client, + tuner: tuner, + } +} + +// Stop stops the auto-tuner +func (a *AutoTunedClient) StopTuning() { + a.tuner.Stop() +} diff --git a/internal/vector/hybrid/client.go b/internal/vector/hybrid/client.go new file mode 100644 index 0000000..5a1b99a --- /dev/null +++ b/internal/vector/hybrid/client.go @@ -0,0 +1,515 @@ +// Package hybrid provides LEANN-inspired selective vector storage for claude-mnemonic. +// +// This package implements a hybrid storage strategy where frequently-accessed +// observations ("hubs") have their embeddings stored, while infrequently-accessed +// observations have their embeddings recomputed on-demand during search. +// +// This approach reduces storage by 60-80% with minimal impact on search latency (<50ms). +package hybrid + +import ( + "context" + "database/sql" + "fmt" + "math" + "sync" + "time" + + "github.com/lukaszraczylo/claude-mnemonic/internal/embedding" + "github.com/lukaszraczylo/claude-mnemonic/internal/vector/sqlitevec" + "github.com/rs/zerolog/log" +) + +// VectorStorageStrategy defines how embeddings are stored/computed +type VectorStorageStrategy int + +const ( + // StorageAlways stores all embeddings (current behavior, backwards compatible) + StorageAlways VectorStorageStrategy = iota + // StorageHub stores only frequently-accessed "hub" embeddings (recommended) + StorageHub + // StorageOnDemand recomputes all embeddings during search (maximum savings) + StorageOnDemand +) + +// Client wraps sqlitevec.Client with selective storage logic +type Client struct { + base *sqlitevec.Client + db *sql.DB + embedSvc *embedding.Service + accessCount map[string]int + lastAccess map[string]time.Time + contentCache map[string]string + strategy VectorStorageStrategy + hubThreshold int + mu sync.RWMutex + cacheMu sync.RWMutex +} + +// Config for hybrid client +type Config struct { + BaseClient *sqlitevec.Client + DB *sql.DB + EmbedSvc *embedding.Service + Strategy VectorStorageStrategy + HubThreshold int // Default: 5 accesses +} + +// NewClient creates a new hybrid vector client +func NewClient(cfg Config) *Client { + if cfg.HubThreshold <= 0 { + cfg.HubThreshold = 5 + } + + log.Info(). + Str("strategy", strategyToString(cfg.Strategy)). + Int("hub_threshold", cfg.HubThreshold). + Msg("Initializing LEANN hybrid vector client") + + return &Client{ + base: cfg.BaseClient, + db: cfg.DB, + embedSvc: cfg.EmbedSvc, + strategy: cfg.Strategy, + hubThreshold: cfg.HubThreshold, + accessCount: make(map[string]int), + lastAccess: make(map[string]time.Time), + contentCache: make(map[string]string), + } +} + +// AddDocuments implements selective storage based on strategy +func (c *Client) AddDocuments(ctx context.Context, docs []sqlitevec.Document) error { + if len(docs) == 0 { + return nil + } + + switch c.strategy { + case StorageAlways: + // Use existing implementation - store all embeddings + return c.base.AddDocuments(ctx, docs) + + case StorageHub: + // Store only hub candidates + return c.addDocumentsSelective(ctx, docs) + + case StorageOnDemand: + // Don't store embeddings, only cache content + return c.cacheDocuments(ctx, docs) + + default: + return c.base.AddDocuments(ctx, docs) + } +} + +// addDocumentsSelective stores embeddings only for hub-qualified documents +func (c *Client) addDocumentsSelective(ctx context.Context, docs []sqlitevec.Document) error { + // Always cache content for potential recomputation + if err := c.cacheDocuments(ctx, docs); err != nil { + return err + } + + // Filter to hub documents + hubDocs := make([]sqlitevec.Document, 0, len(docs)) + for _, doc := range docs { + if c.isHub(doc.ID) { + hubDocs = append(hubDocs, doc) + } + } + + // Store only hub embeddings + if len(hubDocs) > 0 { + log.Debug(). + Int("total", len(docs)). + Int("hubs", len(hubDocs)). + Msg("Storing selective embeddings") + return c.base.AddDocuments(ctx, hubDocs) + } + + log.Debug().Int("total", len(docs)).Msg("All documents cached, no hubs to store") + return nil +} + +// cacheDocuments stores content for later recomputation +func (c *Client) cacheDocuments(ctx context.Context, docs []sqlitevec.Document) error { + c.cacheMu.Lock() + defer c.cacheMu.Unlock() + + for _, doc := range docs { + c.contentCache[doc.ID] = doc.Content + } + + return nil +} + +// DeleteDocuments removes documents by their IDs +func (c *Client) DeleteDocuments(ctx context.Context, ids []string) error { + // Remove from base storage + if err := c.base.DeleteDocuments(ctx, ids); err != nil { + return err + } + + // Clean up caches + c.mu.Lock() + for _, id := range ids { + delete(c.accessCount, id) + delete(c.lastAccess, id) + } + c.mu.Unlock() + + c.cacheMu.Lock() + for _, id := range ids { + delete(c.contentCache, id) + } + c.cacheMu.Unlock() + + return nil +} + +// Query performs search with dynamic recomputation +func (c *Client) Query(ctx context.Context, query string, limit int, where map[string]any) ([]sqlitevec.QueryResult, error) { + switch c.strategy { + case StorageAlways: + // Use existing implementation + return c.queryAndTrack(ctx, query, limit, where) + + case StorageHub: + // Search hubs, then expand with recomputation + return c.queryHybrid(ctx, query, limit, where) + + case StorageOnDemand: + // Fully dynamic search + return c.queryDynamic(ctx, query, limit, where) + + default: + return c.queryAndTrack(ctx, query, limit, where) + } +} + +// queryAndTrack wraps base Query with access tracking +func (c *Client) queryAndTrack(ctx context.Context, query string, limit int, where map[string]any) ([]sqlitevec.QueryResult, error) { + results, err := c.base.Query(ctx, query, limit, where) + if err != nil { + return nil, err + } + + // Track access for hub detection + c.trackAccess(results) + + return results, nil +} + +// queryHybrid searches stored hubs and recomputes non-hubs +func (c *Client) queryHybrid(ctx context.Context, query string, limit int, where map[string]any) ([]sqlitevec.QueryResult, error) { + startTime := time.Now() + + // 1. Query stored hub embeddings (limit * 2 for expansion) + hubResults, err := c.base.Query(ctx, query, limit*2, where) + if err != nil { + return nil, err + } + + // 2. Track access + c.trackAccess(hubResults) + + // 3. Get candidate non-hub IDs (from content cache) + candidates := c.getCandidateNonHubs(where, limit*2) + + // 4. Recompute embeddings for candidates if we have any + var recomputedResults []sqlitevec.QueryResult + if len(candidates) > 0 { + recomputedResults, err = c.recomputeAndScore(ctx, query, candidates) + if err != nil { + // Log but don't fail - use hub results only + log.Warn().Err(err).Msg("Failed to recompute embeddings, using hub results only") + recomputedResults = nil + } + } + + // 5. Merge and rank + allResults := append(hubResults, recomputedResults...) + sortBySimilarity(allResults) + + // 6. Return top K + if len(allResults) > limit { + allResults = allResults[:limit] + } + + duration := time.Since(startTime) + log.Debug(). + Dur("duration_ms", duration). + Int("hubs", len(hubResults)). + Int("recomputed", len(recomputedResults)). + Int("results", len(allResults)). + Msg("Hybrid search completed") + + return allResults, nil +} + +// queryDynamic recomputes all embeddings on-the-fly +func (c *Client) queryDynamic(ctx context.Context, query string, limit int, where map[string]any) ([]sqlitevec.QueryResult, error) { + startTime := time.Now() + + // Get all candidate IDs from content cache + candidates := c.getCandidateNonHubs(where, limit*5) + + // Recompute and score all + results, err := c.recomputeAndScore(ctx, query, candidates) + if err != nil { + return nil, err + } + + // Track access + c.trackAccess(results) + + // Return top K + if len(results) > limit { + results = results[:limit] + } + + duration := time.Since(startTime) + log.Debug(). + Dur("duration_ms", duration). + Int("recomputed", len(candidates)). + Int("results", len(results)). + Msg("Dynamic search completed") + + return results, nil +} + +// recomputeAndScore generates embeddings and computes similarities +func (c *Client) recomputeAndScore(ctx context.Context, query string, candidateIDs []string) ([]sqlitevec.QueryResult, error) { + if len(candidateIDs) == 0 { + return nil, nil + } + + // Generate query embedding + queryEmb, err := c.embedSvc.Embed(query) + if err != nil { + return nil, fmt.Errorf("embed query: %w", err) + } + + // Get content for candidates + c.cacheMu.RLock() + texts := make([]string, 0, len(candidateIDs)) + validIDs := make([]string, 0, len(candidateIDs)) + for _, id := range candidateIDs { + if content, ok := c.contentCache[id]; ok && content != "" { + texts = append(texts, content) + validIDs = append(validIDs, id) + } + } + c.cacheMu.RUnlock() + + if len(texts) == 0 { + return nil, nil + } + + // Batch generate embeddings + embeddings, err := c.embedSvc.EmbedBatch(texts) + if err != nil { + return nil, fmt.Errorf("batch embed: %w", err) + } + + // Compute similarities + results := make([]sqlitevec.QueryResult, len(embeddings)) + for i, emb := range embeddings { + similarity := cosineSimilarity(queryEmb, emb) + distance := 1.0 - similarity // Convert to distance + + results[i] = sqlitevec.QueryResult{ + ID: validIDs[i], + Distance: float64(distance), + Similarity: float64(similarity), + Metadata: make(map[string]any), + } + } + + return results, nil +} + +// trackAccess records document access for hub detection +func (c *Client) trackAccess(results []sqlitevec.QueryResult) { + if len(results) == 0 { + return + } + + c.mu.Lock() + defer c.mu.Unlock() + + now := time.Now() + for _, r := range results { + c.accessCount[r.ID]++ + c.lastAccess[r.ID] = now + } +} + +// isHub checks if a document qualifies as a hub +func (c *Client) isHub(docID string) bool { + c.mu.RLock() + defer c.mu.RUnlock() + + count := c.accessCount[docID] + return count >= c.hubThreshold +} + +// getCandidateNonHubs returns IDs of non-hub documents matching filter +func (c *Client) getCandidateNonHubs(where map[string]any, limit int) []string { + c.cacheMu.RLock() + defer c.cacheMu.RUnlock() + + candidates := make([]string, 0, limit) + for id := range c.contentCache { + if !c.isHub(id) { + candidates = append(candidates, id) + if len(candidates) >= limit { + break + } + } + } + + return candidates +} + +// IsConnected always returns true (wraps base client) +func (c *Client) IsConnected() bool { + return c.base.IsConnected() +} + +// Close releases resources +func (c *Client) Close() error { + return c.base.Close() +} + +// Count returns the total number of vectors in the store +func (c *Client) Count(ctx context.Context) (int64, error) { + return c.base.Count(ctx) +} + +// ModelVersion returns the current embedding model version +func (c *Client) ModelVersion() string { + return c.base.ModelVersion() +} + +// NeedsRebuild checks if vectors need to be rebuilt due to model version change +func (c *Client) NeedsRebuild(ctx context.Context) (bool, string) { + return c.base.NeedsRebuild(ctx) +} + +// GetStaleVectors returns doc_ids of vectors with mismatched or null model versions +func (c *Client) GetStaleVectors(ctx context.Context) ([]sqlitevec.StaleVectorInfo, error) { + return c.base.GetStaleVectors(ctx) +} + +// DeleteVectorsByDocIDs removes vectors by their doc_ids +func (c *Client) DeleteVectorsByDocIDs(ctx context.Context, docIDs []string) error { + return c.base.DeleteVectorsByDocIDs(ctx, docIDs) +} + +// GetStorageStats returns storage efficiency metrics +func (c *Client) GetStorageStats(ctx context.Context) (StorageStats, error) { + c.mu.RLock() + c.cacheMu.RLock() + defer c.mu.RUnlock() + defer c.cacheMu.RUnlock() + + totalDocs := len(c.contentCache) + hubCount := 0 + for id := range c.contentCache { + if c.accessCount[id] >= c.hubThreshold { + hubCount++ + } + } + + storedCount := hubCount + if c.strategy == StorageAlways { + // Get actual count from database + if count, err := c.base.Count(ctx); err == nil { + storedCount = int(count) + } + } else if c.strategy == StorageOnDemand { + storedCount = 0 + } + + embeddingSize := 384 * 4 // 384 dims × 4 bytes (float32) + storedBytes := storedCount * embeddingSize + potentialBytes := totalDocs * embeddingSize + + savingsPercent := 0.0 + if potentialBytes > 0 { + savingsPercent = (1.0 - float64(storedBytes)/float64(potentialBytes)) * 100 + } + + return StorageStats{ + TotalDocuments: totalDocs, + HubDocuments: hubCount, + StoredEmbeddings: storedCount, + StorageBytes: storedBytes, + SavingsPercent: savingsPercent, + Strategy: c.strategy, + }, nil +} + +// StorageStats contains storage efficiency metrics +type StorageStats struct { + TotalDocuments int + HubDocuments int + StoredEmbeddings int + StorageBytes int + SavingsPercent float64 + Strategy VectorStorageStrategy +} + +// Helper functions + +func cosineSimilarity(a, b []float32) float32 { + var dotProduct, normA, normB float32 + for i := range a { + dotProduct += a[i] * b[i] + normA += a[i] * a[i] + normB += b[i] * b[i] + } + if normA == 0 || normB == 0 { + return 0 + } + return dotProduct / float32(math.Sqrt(float64(normA))*math.Sqrt(float64(normB))) +} + +func sortBySimilarity(results []sqlitevec.QueryResult) { + // Use a simple but efficient sorting algorithm + n := len(results) + for i := 0; i < n-1; i++ { + for j := 0; j < n-i-1; j++ { + if results[j].Similarity < results[j+1].Similarity { + results[j], results[j+1] = results[j+1], results[j] + } + } + } +} + +func strategyToString(s VectorStorageStrategy) string { + switch s { + case StorageAlways: + return "always" + case StorageHub: + return "hub" + case StorageOnDemand: + return "on_demand" + default: + return "unknown" + } +} + +// ParseStrategy converts a string to VectorStorageStrategy +func ParseStrategy(s string) VectorStorageStrategy { + switch s { + case "hub": + return StorageHub + case "on_demand": + return StorageOnDemand + case "always": + return StorageAlways + default: + return StorageHub // Default to hub strategy + } +} diff --git a/internal/vector/hybrid/client_test.go b/internal/vector/hybrid/client_test.go new file mode 100644 index 0000000..f784177 --- /dev/null +++ b/internal/vector/hybrid/client_test.go @@ -0,0 +1,186 @@ +package hybrid + +import ( + "testing" + + "github.com/lukaszraczylo/claude-mnemonic/internal/vector/sqlitevec" + "github.com/stretchr/testify/assert" +) + +func TestParseStrategy(t *testing.T) { + tests := []struct { + name string + input string + expected VectorStorageStrategy + }{ + {"hub_strategy", "hub", StorageHub}, + {"on_demand_strategy", "on_demand", StorageOnDemand}, + {"always_strategy", "always", StorageAlways}, + {"invalid_defaults_to_hub", "invalid", StorageHub}, + {"empty_defaults_to_hub", "", StorageHub}, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + result := ParseStrategy(tt.input) + assert.Equal(t, tt.expected, result) + }) + } +} + +func TestStrategyToString(t *testing.T) { + tests := []struct { + name string + expected string + input VectorStorageStrategy + }{ + {"hub_to_string", "hub", StorageHub}, + {"on_demand_to_string", "on_demand", StorageOnDemand}, + {"always_to_string", "always", StorageAlways}, + {"invalid_to_unknown", "unknown", VectorStorageStrategy(99)}, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + result := strategyToString(tt.input) + assert.Equal(t, tt.expected, result) + }) + } +} + +func TestCosineSimilarity(t *testing.T) { + tests := []struct { + name string + a []float32 + b []float32 + expected float32 + }{ + { + name: "identical_vectors", + a: []float32{1, 0, 0}, + b: []float32{1, 0, 0}, + expected: 1.0, + }, + { + name: "orthogonal_vectors", + a: []float32{1, 0, 0}, + b: []float32{0, 1, 0}, + expected: 0.0, + }, + { + name: "opposite_vectors", + a: []float32{1, 0, 0}, + b: []float32{-1, 0, 0}, + expected: -1.0, + }, + { + name: "zero_vector", + a: []float32{0, 0, 0}, + b: []float32{1, 1, 1}, + expected: 0.0, + }, + { + name: "parallel_vectors", + a: []float32{2, 0, 0}, + b: []float32{4, 0, 0}, + expected: 1.0, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + result := cosineSimilarity(tt.a, tt.b) + assert.InDelta(t, tt.expected, result, 0.001) + }) + } +} + +func TestSortBySimilarity(t *testing.T) { + tests := []struct { + name string + input []sqlitevec.QueryResult + expected []string // Expected order of IDs + }{ + { + name: "already_sorted", + input: []sqlitevec.QueryResult{ + {ID: "doc1", Similarity: 0.9}, + {ID: "doc2", Similarity: 0.7}, + {ID: "doc3", Similarity: 0.5}, + }, + expected: []string{"doc1", "doc2", "doc3"}, + }, + { + name: "reverse_sorted", + input: []sqlitevec.QueryResult{ + {ID: "doc1", Similarity: 0.3}, + {ID: "doc2", Similarity: 0.7}, + {ID: "doc3", Similarity: 0.9}, + }, + expected: []string{"doc3", "doc2", "doc1"}, + }, + { + name: "random_order", + input: []sqlitevec.QueryResult{ + {ID: "doc1", Similarity: 0.5}, + {ID: "doc2", Similarity: 0.9}, + {ID: "doc3", Similarity: 0.3}, + {ID: "doc4", Similarity: 0.7}, + }, + expected: []string{"doc2", "doc4", "doc1", "doc3"}, + }, + { + name: "identical_similarities", + input: []sqlitevec.QueryResult{ + {ID: "doc1", Similarity: 0.5}, + {ID: "doc2", Similarity: 0.5}, + {ID: "doc3", Similarity: 0.5}, + }, + expected: []string{"doc1", "doc2", "doc3"}, + }, + { + name: "empty_list", + input: []sqlitevec.QueryResult{}, + expected: []string{}, + }, + { + name: "single_element", + input: []sqlitevec.QueryResult{ + {ID: "doc1", Similarity: 0.5}, + }, + expected: []string{"doc1"}, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + sortBySimilarity(tt.input) + + actual := make([]string, len(tt.input)) + for i, r := range tt.input { + actual[i] = r.ID + } + + assert.Equal(t, tt.expected, actual) + }) + } +} + +func TestSortBySimilarity_PreserveOtherFields(t *testing.T) { + input := []sqlitevec.QueryResult{ + {ID: "doc1", Similarity: 0.3, Distance: 0.7, Metadata: map[string]any{"key": "val1"}}, + {ID: "doc2", Similarity: 0.9, Distance: 0.1, Metadata: map[string]any{"key": "val2"}}, + } + + sortBySimilarity(input) + + assert.Equal(t, "doc2", input[0].ID) + assert.InDelta(t, 0.9, input[0].Similarity, 0.001) + assert.InDelta(t, 0.1, input[0].Distance, 0.001) + assert.Equal(t, "val2", input[0].Metadata["key"]) + + assert.Equal(t, "doc1", input[1].ID) + assert.InDelta(t, 0.3, input[1].Similarity, 0.001) + assert.InDelta(t, 0.7, input[1].Distance, 0.001) + assert.Equal(t, "val1", input[1].Metadata["key"]) +} diff --git a/internal/vector/hybrid/config.go b/internal/vector/hybrid/config.go new file mode 100644 index 0000000..4cac342 --- /dev/null +++ b/internal/vector/hybrid/config.go @@ -0,0 +1,62 @@ +package hybrid + +import ( + "os" + "strconv" + + "github.com/rs/zerolog/log" +) + +// GetStrategyFromEnv reads CLAUDE_MNEMONIC_VECTOR_STRATEGY from environment +func GetStrategyFromEnv() VectorStorageStrategy { + strategyStr := os.Getenv("CLAUDE_MNEMONIC_VECTOR_STRATEGY") + if strategyStr == "" { + // Default to hub strategy for optimal balance + return StorageHub + } + + strategy := ParseStrategy(strategyStr) + log.Info(). + Str("env_value", strategyStr). + Str("strategy", strategyToString(strategy)). + Msg("Vector storage strategy from environment") + + return strategy +} + +// GetHubThresholdFromEnv reads CLAUDE_MNEMONIC_HUB_THRESHOLD from environment +func GetHubThresholdFromEnv() int { + thresholdStr := os.Getenv("CLAUDE_MNEMONIC_HUB_THRESHOLD") + if thresholdStr == "" { + return 5 // Default threshold + } + + threshold, err := strconv.Atoi(thresholdStr) + if err != nil { + log.Warn(). + Err(err). + Str("env_value", thresholdStr). + Msg("Invalid hub threshold in environment, using default") + return 5 + } + + if threshold < 1 { + log.Warn(). + Int("env_value", threshold). + Msg("Hub threshold too low, using minimum of 1") + return 1 + } + + log.Info(). + Int("threshold", threshold). + Msg("Hub threshold from environment") + + return threshold +} + +// IsHybridEnabled checks if hybrid storage should be used +// Returns false if CLAUDE_MNEMONIC_VECTOR_STRATEGY=always (backwards compat) +func IsHybridEnabled() bool { + strategy := GetStrategyFromEnv() + return strategy != StorageAlways +} diff --git a/internal/vector/hybrid/graph_search.go b/internal/vector/hybrid/graph_search.go new file mode 100644 index 0000000..110cfa3 --- /dev/null +++ b/internal/vector/hybrid/graph_search.go @@ -0,0 +1,308 @@ +package hybrid + +import ( + "context" + "fmt" + "sort" + "time" + + "github.com/lukaszraczylo/claude-mnemonic/internal/graph" + "github.com/lukaszraczylo/claude-mnemonic/internal/vector/sqlitevec" + "github.com/lukaszraczylo/claude-mnemonic/pkg/models" + "github.com/rs/zerolog/log" +) + +// GraphConfig configures graph-aware search +type GraphConfig struct { + Enabled bool + MaxHops int // Maximum graph traversal depth (default: 2) + BranchFactor int // Number of neighbors to expand per node (default: 5) + EdgeWeight float64 // Minimum edge weight to follow (default: 0.3) +} + +// DefaultGraphConfig returns sensible defaults for graph search +func DefaultGraphConfig() GraphConfig { + return GraphConfig{ + Enabled: true, + MaxHops: 2, + BranchFactor: 5, + EdgeWeight: 0.3, + } +} + +// GraphSearchClient wraps hybrid.Client with graph-aware search +type GraphSearchClient struct { + *Client + graph *graph.ObservationGraph + graphConfig GraphConfig +} + +// NewGraphSearchClient creates a graph-enhanced hybrid client +func NewGraphSearchClient(baseClient *Client, observationGraph *graph.ObservationGraph, cfg GraphConfig) *GraphSearchClient { + return &GraphSearchClient{ + Client: baseClient, + graph: observationGraph, + graphConfig: cfg, + } +} + +// Query performs graph-aware vector search with two-level traversal +func (g *GraphSearchClient) Query(ctx context.Context, query string, limit int, where map[string]any) ([]sqlitevec.QueryResult, error) { + if !g.graphConfig.Enabled || g.graph == nil { + // Fall back to standard hybrid search + return g.Client.Query(ctx, query, limit, where) + } + + startTime := time.Now() + + // 1. Generate query embedding + queryEmb, err := g.embedSvc.Embed(query) + if err != nil { + return nil, fmt.Errorf("embed query: %w", err) + } + + // 2. Search hub nodes (stored embeddings) + hubResults, err := g.base.Query(ctx, query, limit*2, where) + if err != nil { + // Fall back to standard search on error + log.Warn().Err(err).Msg("Hub search failed, falling back to hybrid search") + return g.Client.Query(ctx, query, limit, where) + } + + // 3. Track hub access + g.trackAccess(hubResults) + + // 4. Expand via graph traversal + expandedIDs := g.expandFromHubs(hubResults, limit*4) + + // 5. Filter to non-hubs that need recomputation + nonHubIDs := make([]string, 0) + for _, id := range expandedIDs { + if !g.isHub(id) { + nonHubIDs = append(nonHubIDs, id) + } + } + + // 6. Batch recompute non-hub embeddings + recomputedResults, err := g.recomputeAndScore(ctx, query, nonHubIDs) + if err != nil { + log.Warn().Err(err).Msg("Recomputation failed, using hub results only") + recomputedResults = nil + } + + // 7. Apply graph-based ranking boost + allResults := g.mergeAndRankWithGraph(hubResults, recomputedResults, queryEmb) + + // 8. Return top K + if len(allResults) > limit { + allResults = allResults[:limit] + } + + duration := time.Since(startTime) + log.Debug(). + Dur("duration_ms", duration). + Int("hubs", len(hubResults)). + Int("expanded", len(expandedIDs)). + Int("recomputed", len(recomputedResults)). + Int("results", len(allResults)). + Msg("Graph search completed") + + return allResults, nil +} + +// expandFromHubs traverses graph from hub nodes to find promising candidates +func (g *GraphSearchClient) expandFromHubs(hubResults []sqlitevec.QueryResult, maxCandidates int) []string { + if g.graph == nil { + return nil + } + + expanded := make(map[string]float64) // doc_id -> relevance score + visited := make(map[int64]bool) + + // Start from top hub results + for i, result := range hubResults { + if i >= g.graphConfig.BranchFactor*2 { + break // Limit starting points + } + + // Parse observation ID from doc_id + obsID := parseObservationID(result.ID) + if obsID == 0 { + continue + } + + // Mark as visited with high relevance (direct match) + visited[obsID] = true + expanded[result.ID] = result.Similarity + + // Traverse graph from this hub + g.traverseGraph(obsID, result.Similarity, 0, expanded, visited) + } + + // Convert to sorted list + type candidate struct { + ID string + Relevance float64 + } + + candidates := make([]candidate, 0, len(expanded)) + for id, rel := range expanded { + candidates = append(candidates, candidate{ID: id, Relevance: rel}) + } + + // Sort by relevance descending + sort.Slice(candidates, func(i, j int) bool { + return candidates[i].Relevance > candidates[j].Relevance + }) + + // Return top candidates + if len(candidates) > maxCandidates { + candidates = candidates[:maxCandidates] + } + + result := make([]string, len(candidates)) + for i, c := range candidates { + result[i] = c.ID + } + + return result +} + +// traverseGraph performs depth-limited graph traversal +func (g *GraphSearchClient) traverseGraph(nodeID int64, baseRelevance float64, depth int, expanded map[string]float64, visited map[int64]bool) { + if depth >= g.graphConfig.MaxHops { + return // Max depth reached + } + + // Get neighbors from graph + neighbors, weights, err := g.graph.GetNeighbors(nodeID) + if err != nil { + return // No neighbors or error + } + + // Traverse top neighbors by weight + type neighborWeight struct { + ID int64 + Weight float32 + } + + neighborList := make([]neighborWeight, len(neighbors)) + for i := range neighbors { + neighborList[i] = neighborWeight{ + ID: neighbors[i], + Weight: weights[i], + } + } + + // Sort by weight descending + sort.Slice(neighborList, func(i, j int) bool { + return neighborList[i].Weight > neighborList[j].Weight + }) + + // Expand top branch_factor neighbors + expanded_count := 0 + for _, nw := range neighborList { + if expanded_count >= g.graphConfig.BranchFactor { + break + } + + // Skip if edge weight too low + if float64(nw.Weight) < g.graphConfig.EdgeWeight { + continue + } + + // Skip if already visited + if visited[nw.ID] { + continue + } + visited[nw.ID] = true + + // Calculate propagated relevance (decays with distance) + decay := 0.7 // 30% decay per hop + propagatedRelevance := baseRelevance * float64(nw.Weight) * decay + + // Add to expanded set + docID := formatObservationDocID(nw.ID) + if existing, ok := expanded[docID]; !ok || propagatedRelevance > existing { + expanded[docID] = propagatedRelevance + } + + // Recursively traverse + g.traverseGraph(nw.ID, propagatedRelevance, depth+1, expanded, visited) + expanded_count++ + } +} + +// mergeAndRankWithGraph combines hub and recomputed results with graph-based ranking +func (g *GraphSearchClient) mergeAndRankWithGraph(hubResults, recomputedResults []sqlitevec.QueryResult, queryEmb []float32) []sqlitevec.QueryResult { + // Merge results + allResults := append(hubResults, recomputedResults...) + + // Apply graph-based re-ranking + if g.graph != nil { + for i := range allResults { + obsID := parseObservationID(allResults[i].ID) + if obsID == 0 { + continue + } + + // Boost score based on node degree (hubs are more important) + node, err := g.graph.GetNode(obsID) + if err == nil && node.Degree > 0 { + // Degree boost: up to 10% increase for high-degree nodes + degreeBoost := 1.0 + (0.1 * float64(node.Degree) / 20.0) + if degreeBoost > 1.1 { + degreeBoost = 1.1 + } + allResults[i].Similarity *= degreeBoost + } + } + } + + // Sort by adjusted similarity + sortBySimilarity(allResults) + + return allResults +} + +// parseObservationID extracts observation ID from doc_id +// Format: "obs-{id}-{field}" +func parseObservationID(docID string) int64 { + var obsID int64 + // Ignore error - returns 0 on parse failure, which callers handle + _, _ = fmt.Sscanf(docID, "obs-%d-", &obsID) + return obsID +} + +// formatObservationDocID creates a doc_id for an observation +func formatObservationDocID(obsID int64) string { + return fmt.Sprintf("obs-%d-combined", obsID) +} + +// GetGraphStats returns statistics about the observation graph +func (g *GraphSearchClient) GetGraphStats() graph.GraphStats { + if g.graph == nil { + return graph.GraphStats{} + } + return g.graph.Stats() +} + +// RebuildGraph rebuilds the observation graph from current observations +// This should be called periodically or when observations change significantly +func (g *GraphSearchClient) RebuildGraph(ctx context.Context, observations []*models.Observation) error { + log.Info().Int("observations", len(observations)).Msg("Rebuilding observation graph") + + newGraph, err := graph.BuildFromObservations(ctx, observations) + if err != nil { + return fmt.Errorf("build graph: %w", err) + } + + g.graph = newGraph + + log.Info(). + Int("nodes", newGraph.Stats().NodeCount). + Int("edges", newGraph.Stats().EdgeCount). + Msg("Graph rebuilt successfully") + + return nil +} diff --git a/internal/vector/hybrid/interface_test.go b/internal/vector/hybrid/interface_test.go new file mode 100644 index 0000000..0f23398 --- /dev/null +++ b/internal/vector/hybrid/interface_test.go @@ -0,0 +1,16 @@ +package hybrid + +import ( + "testing" + + "github.com/lukaszraczylo/claude-mnemonic/internal/vector" +) + +// TestInterfaceImplementation verifies that hybrid clients implement vector.Client interface +func TestInterfaceImplementation(t *testing.T) { + // Compile-time check that Client implements vector.Client + var _ vector.Client = (*Client)(nil) + + // Compile-time check that GraphSearchClient implements vector.Client + var _ vector.Client = (*GraphSearchClient)(nil) +} diff --git a/internal/vector/hybrid/metrics.go b/internal/vector/hybrid/metrics.go new file mode 100644 index 0000000..2e6ca3c --- /dev/null +++ b/internal/vector/hybrid/metrics.go @@ -0,0 +1,272 @@ +package hybrid + +import ( + "fmt" + "sync" + "sync/atomic" + "time" +) + +// Metrics tracks performance and usage statistics for hybrid vector storage +type Metrics struct { + startTime time.Time + recentLatencies []time.Duration + latenciesMu sync.Mutex + totalQueries atomic.Int64 + hubOnlyQueries atomic.Int64 + hybridQueries atomic.Int64 + onDemandQueries atomic.Int64 + graphQueries atomic.Int64 + totalLatency atomic.Int64 // Sum in microseconds + hubLatency atomic.Int64 + recomputeLatency atomic.Int64 + totalDocuments atomic.Int64 + hubDocuments atomic.Int64 + storedEmbeddings atomic.Int64 + recomputedCount atomic.Int64 + cacheHits atomic.Int64 + cacheMisses atomic.Int64 + graphTraversals atomic.Int64 + avgTraversalDepth atomic.Int64 +} + +// NewMetrics creates a new metrics tracker +func NewMetrics() *Metrics { + return &Metrics{ + recentLatencies: make([]time.Duration, 0, 1000), + startTime: time.Now(), + } +} + +// RecordQuery records a query execution +func (m *Metrics) RecordQuery(queryType string, latency time.Duration, recomputed int) { + m.totalQueries.Add(1) + m.totalLatency.Add(latency.Microseconds()) + + switch queryType { + case "hub_only": + m.hubOnlyQueries.Add(1) + case "hybrid": + m.hybridQueries.Add(1) + case "on_demand": + m.onDemandQueries.Add(1) + case "graph": + m.graphQueries.Add(1) + } + + if recomputed > 0 { + m.recomputedCount.Add(int64(recomputed)) + } + + // Track recent latencies + m.latenciesMu.Lock() + m.recentLatencies = append(m.recentLatencies, latency) + if len(m.recentLatencies) > 1000 { + m.recentLatencies = m.recentLatencies[len(m.recentLatencies)-1000:] + } + m.latenciesMu.Unlock() +} + +// RecordHubLatency records time spent in hub search +func (m *Metrics) RecordHubLatency(latency time.Duration) { + m.hubLatency.Add(latency.Microseconds()) +} + +// RecordRecomputeLatency records time spent recomputing embeddings +func (m *Metrics) RecordRecomputeLatency(latency time.Duration) { + m.recomputeLatency.Add(latency.Microseconds()) +} + +// RecordCacheHit records a content cache hit +func (m *Metrics) RecordCacheHit() { + m.cacheHits.Add(1) +} + +// RecordCacheMiss records a content cache miss +func (m *Metrics) RecordCacheMiss() { + m.cacheMisses.Add(1) +} + +// RecordGraphTraversal records a graph traversal operation +func (m *Metrics) RecordGraphTraversal(depth int) { + m.graphTraversals.Add(1) + m.avgTraversalDepth.Add(int64(depth)) +} + +// UpdateStorageStats updates current storage statistics +func (m *Metrics) UpdateStorageStats(total, hubs, stored int) { + m.totalDocuments.Store(int64(total)) + m.hubDocuments.Store(int64(hubs)) + m.storedEmbeddings.Store(int64(stored)) +} + +// GetSnapshot returns current metrics snapshot +func (m *Metrics) GetSnapshot() MetricsSnapshot { + m.latenciesMu.Lock() + defer m.latenciesMu.Unlock() + + totalQueries := m.totalQueries.Load() + + snapshot := MetricsSnapshot{ + // Query counts + TotalQueries: totalQueries, + HubOnlyQueries: m.hubOnlyQueries.Load(), + HybridQueries: m.hybridQueries.Load(), + OnDemandQueries: m.onDemandQueries.Load(), + GraphQueries: m.graphQueries.Load(), + + // Storage + TotalDocuments: int(m.totalDocuments.Load()), + HubDocuments: int(m.hubDocuments.Load()), + StoredEmbeddings: int(m.storedEmbeddings.Load()), + RecomputedTotal: m.recomputedCount.Load(), + + // Cache + CacheHits: m.cacheHits.Load(), + CacheMisses: m.cacheMisses.Load(), + + // Graph + GraphTraversals: m.graphTraversals.Load(), + + // Runtime + Uptime: time.Since(m.startTime), + } + + // Calculate latencies + if totalQueries > 0 { + snapshot.AvgLatency = time.Duration(m.totalLatency.Load()/totalQueries) * time.Microsecond + snapshot.AvgHubLatency = time.Duration(m.hubLatency.Load()/totalQueries) * time.Microsecond + } + + if m.recomputedCount.Load() > 0 { + snapshot.AvgRecomputeLatency = time.Duration(m.recomputeLatency.Load()/m.recomputedCount.Load()) * time.Microsecond + } + + // Calculate percentiles + if len(m.recentLatencies) > 0 { + sorted := make([]time.Duration, len(m.recentLatencies)) + copy(sorted, m.recentLatencies) + sortDurations(sorted) + + snapshot.P50Latency = percentile(sorted, 0.50) + snapshot.P95Latency = percentile(sorted, 0.95) + snapshot.P99Latency = percentile(sorted, 0.99) + } + + // Calculate cache hit rate + totalCacheOps := snapshot.CacheHits + snapshot.CacheMisses + if totalCacheOps > 0 { + snapshot.CacheHitRate = float64(snapshot.CacheHits) / float64(totalCacheOps) + } + + // Calculate storage savings + if snapshot.TotalDocuments > 0 { + embeddingSize := 384 * 4 // 384 dims × 4 bytes + fullStorage := snapshot.TotalDocuments * embeddingSize + actualStorage := snapshot.StoredEmbeddings * embeddingSize + + if fullStorage > 0 { + snapshot.StorageSavingsPercent = (1.0 - float64(actualStorage)/float64(fullStorage)) * 100 + } + } + + // Calculate avg traversal depth + if snapshot.GraphTraversals > 0 { + snapshot.AvgTraversalDepth = float64(m.avgTraversalDepth.Load()) / float64(snapshot.GraphTraversals) + } + + return snapshot +} + +// MetricsSnapshot represents a point-in-time metrics snapshot +type MetricsSnapshot struct { + // Query metrics + TotalQueries int64 + HubOnlyQueries int64 + HybridQueries int64 + OnDemandQueries int64 + GraphQueries int64 + + // Latency metrics + AvgLatency time.Duration + P50Latency time.Duration + P95Latency time.Duration + P99Latency time.Duration + AvgHubLatency time.Duration + AvgRecomputeLatency time.Duration + + // Storage metrics + TotalDocuments int + HubDocuments int + StoredEmbeddings int + StorageSavingsPercent float64 + RecomputedTotal int64 + + // Cache metrics + CacheHits int64 + CacheMisses int64 + CacheHitRate float64 + + // Graph metrics + GraphTraversals int64 + AvgTraversalDepth float64 + + // Runtime + Uptime time.Duration +} + +// sortDurations sorts a slice of durations in ascending order +func sortDurations(durations []time.Duration) { + n := len(durations) + for i := 0; i < n-1; i++ { + for j := 0; j < n-i-1; j++ { + if durations[j] > durations[j+1] { + durations[j], durations[j+1] = durations[j+1], durations[j] + } + } + } +} + +// percentile calculates the Nth percentile from a sorted slice +func percentile(sorted []time.Duration, p float64) time.Duration { + if len(sorted) == 0 { + return 0 + } + + idx := int(float64(len(sorted)) * p) + if idx >= len(sorted) { + idx = len(sorted) - 1 + } + + return sorted[idx] +} + +// String returns a human-readable representation of metrics +func (s MetricsSnapshot) String() string { + return fmt.Sprintf(`Hybrid Vector Storage Metrics: + Queries: + Total: %d (Hub: %d, Hybrid: %d, OnDemand: %d, Graph: %d) + Avg Latency: %v (p50: %v, p95: %v, p99: %v) + Hub Latency: %v, Recompute Latency: %v + Storage: + Documents: %d (Hubs: %d, %.1f%%) + Stored Embeddings: %d + Savings: %.1f%% + Total Recomputed: %d + Cache: + Hits: %d, Misses: %d (Hit Rate: %.1f%%) + Graph: + Traversals: %d (Avg Depth: %.2f) + Runtime: %v`, + s.TotalQueries, s.HubOnlyQueries, s.HybridQueries, s.OnDemandQueries, s.GraphQueries, + s.AvgLatency, s.P50Latency, s.P95Latency, s.P99Latency, + s.AvgHubLatency, s.AvgRecomputeLatency, + s.TotalDocuments, s.HubDocuments, float64(s.HubDocuments)/float64(s.TotalDocuments)*100, + s.StoredEmbeddings, + s.StorageSavingsPercent, + s.RecomputedTotal, + s.CacheHits, s.CacheMisses, s.CacheHitRate*100, + s.GraphTraversals, s.AvgTraversalDepth, + s.Uptime, + ) +} diff --git a/internal/vector/interface.go b/internal/vector/interface.go new file mode 100644 index 0000000..59d9914 --- /dev/null +++ b/internal/vector/interface.go @@ -0,0 +1,42 @@ +// Package vector provides common interfaces for vector storage implementations +package vector + +import ( + "context" + + "github.com/lukaszraczylo/claude-mnemonic/internal/vector/sqlitevec" +) + +// Client defines the interface for vector storage operations. +// Both sqlitevec.Client and hybrid.Client implement this interface. +type Client interface { + // AddDocuments adds documents with their embeddings to the vector store + AddDocuments(ctx context.Context, docs []sqlitevec.Document) error + + // DeleteDocuments removes documents by their IDs + DeleteDocuments(ctx context.Context, ids []string) error + + // Query performs a vector similarity search + Query(ctx context.Context, query string, limit int, where map[string]any) ([]sqlitevec.QueryResult, error) + + // IsConnected checks if the vector store is available + IsConnected() bool + + // Close releases resources + Close() error + + // Count returns the total number of vectors in the store + Count(ctx context.Context) (int64, error) + + // ModelVersion returns the current embedding model version + ModelVersion() string + + // NeedsRebuild checks if vectors need to be rebuilt due to model version change + NeedsRebuild(ctx context.Context) (bool, string) + + // GetStaleVectors returns doc_ids of vectors with mismatched or null model versions + GetStaleVectors(ctx context.Context) ([]sqlitevec.StaleVectorInfo, error) + + // DeleteVectorsByDocIDs removes vectors by their doc_ids + DeleteVectorsByDocIDs(ctx context.Context, docIDs []string) error +} diff --git a/internal/worker/handlers.go b/internal/worker/handlers.go index b50053a..b56f65f 100644 --- a/internal/worker/handlers.go +++ b/internal/worker/handlers.go @@ -1312,3 +1312,85 @@ func (s *Service) handleRestart(w http.ResponseWriter, r *http.Request) { } }() } + +// handleGetGraphStats returns observation graph statistics. +func (s *Service) handleGetGraphStats(w http.ResponseWriter, r *http.Request) { + if s.graphSearchClient == nil { + writeJSON(w, map[string]interface{}{ + "enabled": false, + "message": "Graph search not enabled", + }) + return + } + + stats := s.graphSearchClient.GetGraphStats() + + response := map[string]interface{}{ + "enabled": s.config.GraphEnabled, + "nodeCount": stats.NodeCount, + "edgeCount": stats.EdgeCount, + "avgDegree": stats.AvgDegree, + "maxDegree": stats.MaxDegree, + "minDegree": stats.MinDegree, + "medianDegree": stats.MedianDegree, + "edgeTypes": stats.EdgeTypes, + "config": map[string]interface{}{ + "maxHops": s.config.GraphMaxHops, + "branchFactor": s.config.GraphBranchFactor, + "edgeWeight": s.config.GraphEdgeWeight, + "rebuildIntervalMin": s.config.GraphRebuildIntervalMin, + }, + } + + writeJSON(w, response) +} + +// handleGetVectorMetrics returns hybrid vector storage metrics. +func (s *Service) handleGetVectorMetrics(w http.ResponseWriter, r *http.Request) { + if s.hybridMetrics == nil { + writeJSON(w, map[string]interface{}{ + "enabled": false, + "message": "Vector metrics not available", + }) + return + } + + snapshot := s.hybridMetrics.GetSnapshot() + + response := map[string]interface{}{ + "queries": map[string]interface{}{ + "total": snapshot.TotalQueries, + "hubOnly": snapshot.HubOnlyQueries, + "hybrid": snapshot.HybridQueries, + "onDemand": snapshot.OnDemandQueries, + "graph": snapshot.GraphQueries, + }, + "latency": map[string]interface{}{ + "avg": snapshot.AvgLatency.String(), + "p50": snapshot.P50Latency.String(), + "p95": snapshot.P95Latency.String(), + "p99": snapshot.P99Latency.String(), + "avgHub": snapshot.AvgHubLatency.String(), + "avgRecompute": snapshot.AvgRecomputeLatency.String(), + }, + "storage": map[string]interface{}{ + "totalDocuments": snapshot.TotalDocuments, + "hubDocuments": snapshot.HubDocuments, + "storedEmbeddings": snapshot.StoredEmbeddings, + "savingsPercent": snapshot.StorageSavingsPercent, + "recomputedTotal": snapshot.RecomputedTotal, + }, + "cache": map[string]interface{}{ + "hits": snapshot.CacheHits, + "misses": snapshot.CacheMisses, + "hitRate": snapshot.CacheHitRate, + }, + "graph": map[string]interface{}{ + "traversals": snapshot.GraphTraversals, + "avgDepth": snapshot.AvgTraversalDepth, + }, + "uptime": snapshot.Uptime.String(), + } + + writeJSON(w, response) +} diff --git a/internal/worker/service.go b/internal/worker/service.go index 9853e1e..668fe11 100644 --- a/internal/worker/service.go +++ b/internal/worker/service.go @@ -24,6 +24,8 @@ import ( "github.com/lukaszraczylo/claude-mnemonic/internal/scoring" "github.com/lukaszraczylo/claude-mnemonic/internal/search/expansion" "github.com/lukaszraczylo/claude-mnemonic/internal/update" + "github.com/lukaszraczylo/claude-mnemonic/internal/vector" + "github.com/lukaszraczylo/claude-mnemonic/internal/vector/hybrid" "github.com/lukaszraczylo/claude-mnemonic/internal/vector/sqlitevec" "github.com/lukaszraczylo/claude-mnemonic/internal/watcher" "github.com/lukaszraczylo/claude-mnemonic/internal/worker/sdk" @@ -60,43 +62,46 @@ type RetrievalStats struct { // Service is the main worker service orchestrator. type Service struct { - startTime time.Time - initError error - ctx context.Context - queryExpander *expansion.Expander - recalculator *scoring.Recalculator - summaryStore *sqlite.SummaryStore - promptStore *sqlite.PromptStore - conflictStore *sqlite.ConflictStore - patternStore *sqlite.PatternStore - relationStore *sqlite.RelationStore - patternDetector *pattern.Detector - sessionManager *session.Manager - sseBroadcaster *sse.Broadcaster - router *chi.Mux - embedSvc *embedding.Service - vectorClient *sqlitevec.Client - vectorSync *sqlitevec.Sync - reranker *reranking.Service - updater *update.Updater - observationStore *sqlite.ObservationStore - scoreCalculator *scoring.Calculator - processor *sdk.Processor - server *http.Server - sessionStore *sqlite.SessionStore - retrievalStats map[string]*RetrievalStats - configWatcher *watcher.Watcher - store *sqlite.Store - cancel context.CancelFunc - dbWatcher *watcher.Watcher - staleQueue chan staleVerifyRequest - config *config.Config - version string - wg sync.WaitGroup - initMu sync.RWMutex - retrievalStatsMu sync.RWMutex - staleQueueOnce sync.Once - ready atomic.Bool + startTime time.Time + initError error + vectorClient vector.Client + ctx context.Context + sseBroadcaster *sse.Broadcaster + server *http.Server + graphRebuildTicker *time.Ticker + hybridMetrics *hybrid.Metrics + graphSearchClient *hybrid.GraphSearchClient + retrievalStats map[string]*RetrievalStats + staleQueue chan staleVerifyRequest + queryExpander *expansion.Expander + recalculator *scoring.Recalculator + summaryStore *sqlite.SummaryStore + promptStore *sqlite.PromptStore + conflictStore *sqlite.ConflictStore + patternStore *sqlite.PatternStore + relationStore *sqlite.RelationStore + patternDetector *pattern.Detector + sessionManager *session.Manager + router *chi.Mux + config *config.Config + store *sqlite.Store + vectorSync *sqlitevec.Sync + reranker *reranking.Service + updater *update.Updater + observationStore *sqlite.ObservationStore + scoreCalculator *scoring.Calculator + processor *sdk.Processor + dbWatcher *watcher.Watcher + sessionStore *sqlite.SessionStore + configWatcher *watcher.Watcher + embedSvc *embedding.Service + cancel context.CancelFunc + version string + wg sync.WaitGroup + initMu sync.RWMutex + retrievalStatsMu sync.RWMutex + staleQueueOnce sync.Once + ready atomic.Bool } // staleVerifyRequest represents a request to verify a stale observation in background @@ -185,7 +190,7 @@ func (s *Service) initializeAsync() { // Create embedding service and sqlite-vec client for vector search (optional) var embedSvc *embedding.Service - var vectorClient *sqlitevec.Client + var vectorClient vector.Client var vectorSync *sqlitevec.Sync var reranker *reranking.Service @@ -196,14 +201,35 @@ func (s *Service) initializeAsync() { } else { embedSvc = emb // Create sqlite-vec client using the same DB connection - client, clientErr := sqlitevec.NewClient(sqlitevec.Config{ + baseClient, clientErr := sqlitevec.NewClient(sqlitevec.Config{ DB: store.DB(), }, embedSvc) if clientErr != nil { log.Warn().Err(clientErr).Msg("sqlite-vec client creation failed - vector search disabled") } else { - vectorClient = client - vectorSync = sqlitevec.NewSync(client) + // Wrap with LEANN hybrid storage client + strategy := hybrid.GetStrategyFromEnv() + hybridClient := hybrid.NewClient(hybrid.Config{ + BaseClient: baseClient, + DB: store.DB(), + EmbedSvc: embedSvc, + Strategy: strategy, + HubThreshold: hybrid.GetHubThresholdFromEnv(), + }) + + // Wrap with graph-aware search client + graphConfig := hybrid.GraphConfig{ + Enabled: s.config.GraphEnabled, + MaxHops: s.config.GraphMaxHops, + BranchFactor: s.config.GraphBranchFactor, + EdgeWeight: s.config.GraphEdgeWeight, + } + graphClient := hybrid.NewGraphSearchClient(hybridClient, nil, graphConfig) + vectorClient = graphClient + s.graphSearchClient = graphClient + s.hybridMetrics = hybrid.NewMetrics() + + vectorSync = sqlitevec.NewSync(baseClient) // Initialize AST-aware code chunking chunkOpts := chunking.DefaultChunkOptions() @@ -215,10 +241,28 @@ func (s *Service) initializeAsync() { chunkingManager := chunking.NewManager(chunkers, chunkOpts) vectorSync.SetChunkingManager(chunkingManager) + strategyName := "hub" // default + switch strategy { + case hybrid.StorageAlways: + strategyName = "always" + case hybrid.StorageOnDemand: + strategyName = "on_demand" + } + log.Info(). Str("model", embedSvc.Version()). + Str("vector_strategy", strategyName). + Bool("graph_enabled", s.config.GraphEnabled). Strs("chunkers", []string{"go", "python", "typescript"}). Msg("sqlite-vec vector search with AST-aware code chunking enabled") + + if s.config.GraphEnabled { + log.Info(). + Int("max_hops", s.config.GraphMaxHops). + Int("branch_factor", s.config.GraphBranchFactor). + Float64("edge_weight", s.config.GraphEdgeWeight). + Msg("Graph-aware search configured (graph will be built after initialization)") + } } // Create cross-encoder reranking service if enabled @@ -409,6 +453,12 @@ func (s *Service) initializeAsync() { // Start file watchers for auto-recreation on deletion s.startWatchers() + // Build initial observation graph in background if graph search is enabled + if s.config.GraphEnabled && s.graphSearchClient != nil { + s.wg.Add(1) + go s.buildInitialGraph(observationStore) + } + // Check if vectors need rebuilding (empty or model version mismatch) and trigger background rebuild if vectorClient != nil && vectorSync != nil { needsRebuild, reason := vectorClient.NeedsRebuild(s.ctx) @@ -876,7 +926,7 @@ func (s *Service) rebuildStaleVectors( observationStore *sqlite.ObservationStore, summaryStore *sqlite.SummaryStore, promptStore *sqlite.PromptStore, - vectorClient *sqlitevec.Client, + vectorClient vector.Client, vectorSync *sqlitevec.Sync, ) { defer s.wg.Done() @@ -1041,6 +1091,113 @@ func (s *Service) verifyStaleObservation(req staleVerifyRequest) { } } +// buildInitialGraph builds the observation graph from all observations in background. +func (s *Service) buildInitialGraph(observationStore *sqlite.ObservationStore) { + defer s.wg.Done() + + log.Info().Msg("Building initial observation graph...") + start := time.Now() + + // Fetch all observations + observations, err := observationStore.GetAllObservations(s.ctx) + if err != nil { + log.Error().Err(err).Msg("Failed to fetch observations for graph building") + return + } + + if len(observations) == 0 { + log.Info().Msg("No observations to build graph from") + return + } + + // Build graph using RebuildGraph method + if err := s.graphSearchClient.RebuildGraph(s.ctx, observations); err != nil { + log.Error().Err(err).Msg("Failed to build observation graph") + return + } + + elapsed := time.Since(start) + stats := s.graphSearchClient.GetGraphStats() + + log.Info(). + Int("observations", len(observations)). + Int("nodes", stats.NodeCount). + Int("edges", stats.EdgeCount). + Float64("avg_degree", stats.AvgDegree). + Int("max_degree", stats.MaxDegree). + Dur("elapsed", elapsed). + Msg("Initial observation graph built successfully") + + // Start periodic graph rebuild if configured + if s.config.GraphRebuildIntervalMin > 0 { + s.startGraphRebuildTimer(observationStore) + } +} + +// startGraphRebuildTimer starts a periodic ticker to rebuild the observation graph. +func (s *Service) startGraphRebuildTimer(observationStore *sqlite.ObservationStore) { + interval := time.Duration(s.config.GraphRebuildIntervalMin) * time.Minute + s.graphRebuildTicker = time.NewTicker(interval) + + log.Info(). + Dur("interval", interval). + Msg("Started periodic graph rebuild timer") + + s.wg.Add(1) + go func() { + defer s.wg.Done() + defer s.graphRebuildTicker.Stop() + + for { + select { + case <-s.ctx.Done(): + return + case <-s.graphRebuildTicker.C: + log.Info().Msg("Periodic graph rebuild triggered") + s.rebuildGraph(observationStore) + } + } + }() +} + +// rebuildGraph rebuilds the observation graph from current observations. +func (s *Service) rebuildGraph(observationStore *sqlite.ObservationStore) { + if s.graphSearchClient == nil { + return + } + + start := time.Now() + + // Fetch all observations + observations, err := observationStore.GetAllObservations(s.ctx) + if err != nil { + log.Error().Err(err).Msg("Failed to fetch observations for graph rebuild") + return + } + + if len(observations) == 0 { + log.Debug().Msg("No observations to rebuild graph from") + return + } + + // Rebuild graph + if err := s.graphSearchClient.RebuildGraph(s.ctx, observations); err != nil { + log.Error().Err(err).Msg("Failed to rebuild observation graph") + return + } + + elapsed := time.Since(start) + stats := s.graphSearchClient.GetGraphStats() + + log.Info(). + Int("observations", len(observations)). + Int("nodes", stats.NodeCount). + Int("edges", stats.EdgeCount). + Float64("avg_degree", stats.AvgDegree). + Dur("elapsed", elapsed). + Msg("Observation graph rebuilt successfully") +} + // setupMiddleware configures HTTP middleware. func (s *Service) setupMiddleware() { s.router.Use(middleware.Logger) @@ -1106,6 +1263,10 @@ func (s *Service) setupRoutes() { r.Get("/api/types", s.handleGetTypes) r.Get("/api/models", s.handleGetModels) + // Graph and vector metrics routes + r.Get("/api/graph/stats", s.handleGetGraphStats) + r.Get("/api/vector/metrics", s.handleGetVectorMetrics) + // Observation scoring and feedback routes r.Post("/api/observations/{id}/feedback", s.handleObservationFeedback) r.Get("/api/observations/{id}/score", s.handleExplainScore) @@ -1372,6 +1533,11 @@ func (s *Service) Shutdown(ctx context.Context) error { s.patternDetector.Stop() } + // Stop graph rebuild ticker + if s.graphRebuildTicker != nil { + s.graphRebuildTicker.Stop() + } + // Shutdown all sessions s.sessionManager.ShutdownAll(ctx) diff --git a/ui/package-lock.json b/ui/package-lock.json index b60fbb8..4745a9d 100644 --- a/ui/package-lock.json +++ b/ui/package-lock.json @@ -1,12 +1,12 @@ { "name": "claude-mnemonic-dashboard", - "version": "40a44a7-dirty", + "version": "4f4b4ac-dirty", "lockfileVersion": 3, "requires": true, "packages": { "": { "name": "claude-mnemonic-dashboard", - "version": "40a44a7-dirty", + "version": "4f4b4ac-dirty", "dependencies": { "vis-data": "^7.1.9", "vis-network": "^9.1.9", diff --git a/ui/package.json b/ui/package.json index cc4ac6f..edd73f6 100644 --- a/ui/package.json +++ b/ui/package.json @@ -1,6 +1,6 @@ { "name": "claude-mnemonic-dashboard", - "version": "40a44a7-dirty", + "version": "4f4b4ac-dirty", "private": true, "type": "module", "scripts": { diff --git a/ui/src/components/Sidebar.vue b/ui/src/components/Sidebar.vue index 9e8ed96..401c33e 100644 --- a/ui/src/components/Sidebar.vue +++ b/ui/src/components/Sidebar.vue @@ -2,6 +2,7 @@ import { ref, computed } from 'vue' import type { Stats, SelfCheckResponse } from '@/types' import ProjectFilter from './ProjectFilter.vue' +import { useGraphMetrics } from '@/composables' const props = defineProps<{ stats: Stats | null @@ -18,12 +19,21 @@ defineEmits<{ // Collapse state - persisted in localStorage const isCollapsed = ref(localStorage.getItem('sidebar-collapsed') === 'true') +const metricsExpanded = ref(localStorage.getItem('metrics-expanded') === 'true') + +// Graph metrics composable +const { graphStats, vectorMetrics, loading: metricsLoading, refresh: refreshMetrics } = useGraphMetrics() function toggleCollapse() { isCollapsed.value = !isCollapsed.value localStorage.setItem('sidebar-collapsed', String(isCollapsed.value)) } +function toggleMetrics() { + metricsExpanded.value = !metricsExpanded.value + localStorage.setItem('metrics-expanded', String(metricsExpanded.value)) +} + function formatNumber(n: number): string { if (n >= 1000000) return (n / 1000000).toFixed(1) + 'M' if (n >= 1000) return (n / 1000).toFixed(1) + 'K' @@ -205,6 +215,99 @@ function getStatusColor(status: string): string { + +
Loading metrics...
+