mirror of
https://github.com/lukaszraczylo/claude-mnemonic.git
synced 2026-06-05 23:03:55 +00:00
Make things 'betterer' across the board (#23)
* Make things 'betterer' across the board * fix: reorganize struct fields and config parameters for consistency - [x] Reorder Config struct fields alphabetically and by related functionality - [x] Reorganize Observation model fields with archival fields grouped together - [x] Reorder ObservationStore fields to group related members - [x] Reorder Store struct fields with health check caching grouped - [x] Reorganize HealthInfo and PoolMetrics struct field order - [x] Reorder maintenance Service struct fields logically - [x] Reorganize MCP server handler parameter structs alphabetically - [x] Reorder pattern detector candidate tracking fields - [x] Reorganize search Manager struct fields by functionality - [x] Reorder vector Client struct fields with mutex protections grouped - [x] Reorganize handler request/response struct fields - [x] Update handlers_test.go to expect wrapped response format - [x] Reorder middleware TokenAuth and rate limiter fields - [x] Reorganize Service struct fields with grouped functionality - [x] Fix RateLimiter field ordering for clarity - [x] Reorder CircuitBreaker metrics fields * fix(security): improve JSON output safety and path traversal protection - [x] Replace unsafe JSON string formatting with proper json.Marshal in export handler - [x] Remove escapeJSONString helper function in favor of standard JSON marshaling - [x] Add safeResolvePath function to validate paths and prevent directory traversal - [x] Apply path traversal validation in captureFileMtimes operations - [x] Cap result slice capacity in getRecentSearchQueries to prevent DoS via excessive allocation * fix(sdk): improve path traversal protection and allocation safety - [x] Enhance safeResolvePath with stricter validation using filepath.Rel - [x] Reject paths containing ".." after cleaning to prevent traversal - [x] Validate absolute paths are within cwd when cwd is specified - [x] Apply safeResolvePath validation to GetFileContent for consistency - [x] Add comprehensive test coverage for path traversal protection - [x] Fix allocation safety in getRecentSearchQueries by using constant capacity
This commit is contained in:
@@ -5,19 +5,105 @@ import (
|
||||
"context"
|
||||
"database/sql"
|
||||
"fmt"
|
||||
"strconv"
|
||||
"strings"
|
||||
"sync"
|
||||
"sync/atomic"
|
||||
"time"
|
||||
|
||||
sqlite_vec "github.com/asg017/sqlite-vec-go-bindings/cgo"
|
||||
"github.com/lukaszraczylo/claude-mnemonic/internal/embedding"
|
||||
"github.com/rs/zerolog/log"
|
||||
"golang.org/x/sync/singleflight"
|
||||
)
|
||||
|
||||
// embeddingCacheEntry stores a cached embedding with its timestamp.
|
||||
type embeddingCacheEntry struct {
|
||||
embedding []float32
|
||||
timestamp int64 // Unix nano
|
||||
}
|
||||
|
||||
// resultCacheEntry stores cached query results.
|
||||
type resultCacheEntry struct {
|
||||
queryHash string
|
||||
results []QueryResult
|
||||
timestamp int64
|
||||
}
|
||||
|
||||
// Client provides vector operations via sqlite-vec.
|
||||
type Client struct {
|
||||
db *sql.DB
|
||||
embedSvc *embedding.Service
|
||||
mu sync.Mutex
|
||||
embeddingGroup singleflight.Group
|
||||
resultCache map[string]resultCacheEntry
|
||||
db *sql.DB
|
||||
embedSvc *embedding.Service
|
||||
queryCache map[string]embeddingCacheEntry
|
||||
stopCleanup chan struct{}
|
||||
stats CacheStats
|
||||
cleanupWg sync.WaitGroup
|
||||
resultCacheTTLNano int64
|
||||
cacheTTLNano int64
|
||||
resultCacheMaxSize int
|
||||
cacheMaxSize int
|
||||
resultCacheMu sync.RWMutex
|
||||
queryCacheMu sync.RWMutex
|
||||
readMu sync.RWMutex
|
||||
writeMu sync.Mutex
|
||||
}
|
||||
|
||||
// CacheStats tracks cache performance metrics using atomic counters for lock-free updates.
|
||||
type CacheStats struct {
|
||||
embeddingHits atomic.Int64
|
||||
embeddingMisses atomic.Int64
|
||||
resultHits atomic.Int64
|
||||
resultMisses atomic.Int64
|
||||
embeddingEvictions atomic.Int64
|
||||
resultEvictions atomic.Int64
|
||||
}
|
||||
|
||||
// CacheStatsSnapshot is the exported version of CacheStats for JSON marshaling.
|
||||
type CacheStatsSnapshot struct {
|
||||
EmbeddingHits int64 `json:"embedding_hits"`
|
||||
EmbeddingMisses int64 `json:"embedding_misses"`
|
||||
ResultHits int64 `json:"result_hits"`
|
||||
ResultMisses int64 `json:"result_misses"`
|
||||
EmbeddingEvictions int64 `json:"embedding_evictions"`
|
||||
ResultEvictions int64 `json:"result_evictions"`
|
||||
}
|
||||
|
||||
// HitRate returns the cache hit rate as a percentage.
|
||||
func (s CacheStatsSnapshot) HitRate() float64 {
|
||||
total := s.EmbeddingHits + s.EmbeddingMisses + s.ResultHits + s.ResultMisses
|
||||
if total == 0 {
|
||||
return 0
|
||||
}
|
||||
hits := s.EmbeddingHits + s.ResultHits
|
||||
return float64(hits) / float64(total) * 100
|
||||
}
|
||||
|
||||
// HitRate returns the cache hit rate as a percentage.
|
||||
func (s *CacheStats) HitRate() float64 {
|
||||
embHits := s.embeddingHits.Load()
|
||||
embMisses := s.embeddingMisses.Load()
|
||||
resHits := s.resultHits.Load()
|
||||
resMisses := s.resultMisses.Load()
|
||||
total := embHits + embMisses + resHits + resMisses
|
||||
if total == 0 {
|
||||
return 0
|
||||
}
|
||||
hits := embHits + resHits
|
||||
return float64(hits) / float64(total) * 100
|
||||
}
|
||||
|
||||
// Snapshot returns a copy of the current stats.
|
||||
func (s *CacheStats) Snapshot() CacheStatsSnapshot {
|
||||
return CacheStatsSnapshot{
|
||||
EmbeddingHits: s.embeddingHits.Load(),
|
||||
EmbeddingMisses: s.embeddingMisses.Load(),
|
||||
ResultHits: s.resultHits.Load(),
|
||||
ResultMisses: s.resultMisses.Load(),
|
||||
EmbeddingEvictions: s.embeddingEvictions.Load(),
|
||||
ResultEvictions: s.resultEvictions.Load(),
|
||||
}
|
||||
}
|
||||
|
||||
// Config holds configuration for the client.
|
||||
@@ -34,10 +120,23 @@ func NewClient(cfg Config, embedSvc *embedding.Service) (*Client, error) {
|
||||
return nil, fmt.Errorf("embedding service required")
|
||||
}
|
||||
|
||||
return &Client{
|
||||
db: cfg.DB,
|
||||
embedSvc: embedSvc,
|
||||
}, nil
|
||||
c := &Client{
|
||||
db: cfg.DB,
|
||||
embedSvc: embedSvc,
|
||||
queryCache: make(map[string]embeddingCacheEntry),
|
||||
cacheMaxSize: 500, // Cache up to 500 query embeddings
|
||||
cacheTTLNano: 5 * 60 * 1e9, // 5 minute TTL for embeddings
|
||||
resultCache: make(map[string]resultCacheEntry),
|
||||
resultCacheMaxSize: 200, // Cache up to 200 search results
|
||||
resultCacheTTLNano: 60 * 1e9, // 1 minute TTL for results (shorter since data changes)
|
||||
stopCleanup: make(chan struct{}),
|
||||
}
|
||||
|
||||
// Start background cache cleanup goroutine
|
||||
c.cleanupWg.Add(1)
|
||||
go c.cacheCleanupLoop()
|
||||
|
||||
return c, nil
|
||||
}
|
||||
|
||||
// AddDocuments adds documents with their embeddings to the vector store.
|
||||
@@ -46,8 +145,8 @@ func (c *Client) AddDocuments(ctx context.Context, docs []Document) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
c.mu.Lock()
|
||||
defer c.mu.Unlock()
|
||||
c.writeMu.Lock()
|
||||
defer c.writeMu.Unlock()
|
||||
|
||||
// Generate embeddings for all documents
|
||||
texts := make([]string, len(docs))
|
||||
@@ -75,7 +174,10 @@ func (c *Client) AddDocuments(ctx context.Context, docs []Document) error {
|
||||
}
|
||||
defer func() {
|
||||
if err != nil {
|
||||
_ = tx.Rollback()
|
||||
if rbErr := tx.Rollback(); rbErr != nil {
|
||||
// Rollback failure is serious - indicates potential data corruption risk
|
||||
log.Error().Err(rbErr).Err(err).Msg("Failed to rollback transaction after error - data may be inconsistent")
|
||||
}
|
||||
}
|
||||
}()
|
||||
|
||||
@@ -118,6 +220,9 @@ func (c *Client) AddDocuments(ctx context.Context, docs []Document) error {
|
||||
return fmt.Errorf("commit transaction: %w", err)
|
||||
}
|
||||
|
||||
// Invalidate result cache since data changed
|
||||
c.InvalidateResultCache()
|
||||
|
||||
log.Debug().Int("count", len(docs)).Str("model", modelVersion).Msg("Added documents to sqlite-vec")
|
||||
return nil
|
||||
}
|
||||
@@ -128,12 +233,12 @@ func (c *Client) DeleteDocuments(ctx context.Context, ids []string) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
c.mu.Lock()
|
||||
defer c.mu.Unlock()
|
||||
c.writeMu.Lock()
|
||||
defer c.writeMu.Unlock()
|
||||
|
||||
// Build placeholder string
|
||||
placeholders := make([]string, len(ids))
|
||||
args := make([]interface{}, len(ids))
|
||||
args := make([]any, len(ids))
|
||||
for i, id := range ids {
|
||||
placeholders[i] = "?"
|
||||
args[i] = id
|
||||
@@ -148,17 +253,25 @@ func (c *Client) DeleteDocuments(ctx context.Context, ids []string) error {
|
||||
return fmt.Errorf("delete documents: %w", err)
|
||||
}
|
||||
|
||||
// Invalidate result cache since data changed
|
||||
c.InvalidateResultCache()
|
||||
|
||||
log.Debug().Int("count", len(ids)).Msg("Deleted documents from sqlite-vec")
|
||||
return nil
|
||||
}
|
||||
|
||||
// Query performs a vector similarity search.
|
||||
func (c *Client) Query(ctx context.Context, query string, limit int, where map[string]any) ([]QueryResult, error) {
|
||||
c.mu.Lock()
|
||||
defer c.mu.Unlock()
|
||||
// Build cache key from query + filters + limit
|
||||
cacheKey := c.buildResultCacheKey(query, limit, where)
|
||||
|
||||
// Generate query embedding
|
||||
queryEmb, err := c.embedSvc.Embed(query)
|
||||
// Check result cache first
|
||||
if results, ok := c.getResultFromCache(cacheKey); ok {
|
||||
return results, nil
|
||||
}
|
||||
|
||||
// Generate query embedding OUTSIDE the lock for better concurrency
|
||||
queryEmb, err := c.getOrComputeEmbedding(query)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("embed query: %w", err)
|
||||
}
|
||||
@@ -169,9 +282,13 @@ func (c *Client) Query(ctx context.Context, query string, limit int, where map[s
|
||||
return nil, fmt.Errorf("serialize query embedding: %w", err)
|
||||
}
|
||||
|
||||
// Now acquire read lock for the actual DB query
|
||||
c.readMu.RLock()
|
||||
defer c.readMu.RUnlock()
|
||||
|
||||
// Build query with filters
|
||||
// vec0 supports WHERE clauses on metadata columns
|
||||
args := []interface{}{queryBlob}
|
||||
args := []any{queryBlob}
|
||||
|
||||
sqlQuery := `
|
||||
SELECT
|
||||
@@ -232,6 +349,9 @@ func (c *Client) Query(ctx context.Context, query string, limit int, where map[s
|
||||
return nil, fmt.Errorf("iterate rows: %w", err)
|
||||
}
|
||||
|
||||
// Cache the results
|
||||
c.cacheResults(cacheKey, results)
|
||||
|
||||
log.Debug().
|
||||
Str("query", truncateString(query, 50)).
|
||||
Int("results", len(results)).
|
||||
@@ -245,11 +365,196 @@ func (c *Client) IsConnected() bool {
|
||||
return c.db != nil
|
||||
}
|
||||
|
||||
// Close is a no-op (db managed externally).
|
||||
// Close stops the background cleanup goroutine (db managed externally).
|
||||
func (c *Client) Close() error {
|
||||
// Signal cleanup goroutine to stop
|
||||
close(c.stopCleanup)
|
||||
// Wait for cleanup to finish
|
||||
c.cleanupWg.Wait()
|
||||
return nil
|
||||
}
|
||||
|
||||
// cacheCleanupLoop periodically removes expired cache entries.
|
||||
func (c *Client) cacheCleanupLoop() {
|
||||
defer c.cleanupWg.Done()
|
||||
|
||||
ticker := time.NewTicker(30 * time.Second) // Cleanup every 30 seconds
|
||||
defer ticker.Stop()
|
||||
|
||||
for {
|
||||
select {
|
||||
case <-c.stopCleanup:
|
||||
return
|
||||
case <-ticker.C:
|
||||
c.cleanupExpiredCaches()
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// cleanupExpiredCaches removes expired entries from both caches.
|
||||
func (c *Client) cleanupExpiredCaches() {
|
||||
now := time.Now().UnixNano()
|
||||
var embeddingExpired, resultExpired int64
|
||||
|
||||
// Cleanup embedding cache
|
||||
c.queryCacheMu.Lock()
|
||||
for key, entry := range c.queryCache {
|
||||
if now-entry.timestamp > c.cacheTTLNano {
|
||||
delete(c.queryCache, key)
|
||||
embeddingExpired++
|
||||
}
|
||||
}
|
||||
c.queryCacheMu.Unlock()
|
||||
|
||||
// Cleanup result cache
|
||||
c.resultCacheMu.Lock()
|
||||
for key, entry := range c.resultCache {
|
||||
if now-entry.timestamp > c.resultCacheTTLNano {
|
||||
delete(c.resultCache, key)
|
||||
resultExpired++
|
||||
}
|
||||
}
|
||||
c.resultCacheMu.Unlock()
|
||||
|
||||
// Update stats atomically
|
||||
if embeddingExpired > 0 || resultExpired > 0 {
|
||||
c.stats.embeddingEvictions.Add(embeddingExpired)
|
||||
c.stats.resultEvictions.Add(resultExpired)
|
||||
|
||||
log.Debug().
|
||||
Int64("embedding_expired", embeddingExpired).
|
||||
Int64("result_expired", resultExpired).
|
||||
Msg("Cache cleanup completed")
|
||||
}
|
||||
}
|
||||
|
||||
// BatchQueryResult holds results from a batch query operation.
|
||||
type BatchQueryResult struct {
|
||||
Error error
|
||||
Query string
|
||||
Results []QueryResult
|
||||
}
|
||||
|
||||
// QueryBatch performs multiple vector searches concurrently.
|
||||
// Returns results in the same order as input queries.
|
||||
// Uses a worker pool to limit concurrent queries.
|
||||
func (c *Client) QueryBatch(ctx context.Context, queries []string, limit int, where map[string]any) []BatchQueryResult {
|
||||
if len(queries) == 0 {
|
||||
return nil
|
||||
}
|
||||
|
||||
// Limit concurrency to avoid overwhelming the database
|
||||
maxConcurrent := min(4, len(queries))
|
||||
|
||||
results := make([]BatchQueryResult, len(queries))
|
||||
sem := make(chan struct{}, maxConcurrent)
|
||||
var wg sync.WaitGroup
|
||||
|
||||
for i, query := range queries {
|
||||
wg.Add(1)
|
||||
go func(idx int, q string) {
|
||||
defer wg.Done()
|
||||
|
||||
// Acquire semaphore
|
||||
select {
|
||||
case sem <- struct{}{}:
|
||||
defer func() { <-sem }()
|
||||
case <-ctx.Done():
|
||||
results[idx] = BatchQueryResult{
|
||||
Query: q,
|
||||
Error: ctx.Err(),
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
// Execute query
|
||||
queryResults, err := c.Query(ctx, q, limit, where)
|
||||
results[idx] = BatchQueryResult{
|
||||
Query: q,
|
||||
Results: queryResults,
|
||||
Error: err,
|
||||
}
|
||||
}(i, query)
|
||||
}
|
||||
|
||||
wg.Wait()
|
||||
return results
|
||||
}
|
||||
|
||||
// QueryMultiField searches across multiple fields for a single query.
|
||||
// Combines results from different field types and deduplicates by document ID.
|
||||
func (c *Client) QueryMultiField(ctx context.Context, query string, limit int, docType string, project string) ([]QueryResult, error) {
|
||||
// Generate embedding once
|
||||
queryEmb, err := c.getOrComputeEmbedding(query)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("embed query: %w", err)
|
||||
}
|
||||
|
||||
// Serialize query embedding
|
||||
queryBlob, err := sqlite_vec.SerializeFloat32(queryEmb)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("serialize query embedding: %w", err)
|
||||
}
|
||||
|
||||
c.readMu.RLock()
|
||||
defer c.readMu.RUnlock()
|
||||
|
||||
// Query with field type aggregation - get best match per document
|
||||
sqlQuery := `
|
||||
WITH ranked_results AS (
|
||||
SELECT
|
||||
doc_id,
|
||||
distance,
|
||||
sqlite_id,
|
||||
doc_type,
|
||||
field_type,
|
||||
project,
|
||||
scope,
|
||||
ROW_NUMBER() OVER (PARTITION BY sqlite_id ORDER BY distance ASC) as rn
|
||||
FROM vectors
|
||||
WHERE embedding MATCH ?
|
||||
AND doc_type = ?
|
||||
AND (project = ? OR scope = 'global')
|
||||
)
|
||||
SELECT doc_id, distance, sqlite_id, doc_type, field_type, project, scope
|
||||
FROM ranked_results
|
||||
WHERE rn = 1
|
||||
ORDER BY distance
|
||||
LIMIT ?
|
||||
`
|
||||
|
||||
rows, err := c.db.QueryContext(ctx, sqlQuery, queryBlob, docType, project, limit)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("query vectors: %w", err)
|
||||
}
|
||||
defer rows.Close()
|
||||
|
||||
// Pre-allocate with limit to avoid repeated slice growth
|
||||
results := make([]QueryResult, 0, limit)
|
||||
for rows.Next() {
|
||||
var r QueryResult
|
||||
var sqliteID int64
|
||||
var docTypeVal, fieldType, projectVal, scope sql.NullString
|
||||
|
||||
if err := rows.Scan(&r.ID, &r.Distance, &sqliteID, &docTypeVal, &fieldType, &projectVal, &scope); err != nil {
|
||||
return nil, fmt.Errorf("scan row: %w", err)
|
||||
}
|
||||
|
||||
r.Similarity = DistanceToSimilarity(r.Distance)
|
||||
r.Metadata = map[string]any{
|
||||
"sqlite_id": float64(sqliteID),
|
||||
"doc_type": docTypeVal.String,
|
||||
"field_type": fieldType.String,
|
||||
"project": projectVal.String,
|
||||
"scope": scope.String,
|
||||
}
|
||||
|
||||
results = append(results, r)
|
||||
}
|
||||
|
||||
return results, rows.Err()
|
||||
}
|
||||
|
||||
// truncateString truncates a string to maxLen characters.
|
||||
func truncateString(s string, maxLen int) string {
|
||||
if len(s) <= maxLen {
|
||||
@@ -260,8 +565,8 @@ func truncateString(s string, maxLen int) string {
|
||||
|
||||
// Count returns the total number of vectors in the store.
|
||||
func (c *Client) Count(ctx context.Context) (int64, error) {
|
||||
c.mu.Lock()
|
||||
defer c.mu.Unlock()
|
||||
c.readMu.RLock()
|
||||
defer c.readMu.RUnlock()
|
||||
|
||||
var count int64
|
||||
err := c.db.QueryRowContext(ctx, "SELECT COUNT(*) FROM vectors").Scan(&count)
|
||||
@@ -281,8 +586,8 @@ func (c *Client) ModelVersion() string {
|
||||
// - The vectors table is empty
|
||||
// - Any vectors have a different model_version than the current model
|
||||
func (c *Client) NeedsRebuild(ctx context.Context) (bool, string) {
|
||||
c.mu.Lock()
|
||||
defer c.mu.Unlock()
|
||||
c.readMu.RLock()
|
||||
defer c.readMu.RUnlock()
|
||||
|
||||
currentModel := c.embedSvc.Version()
|
||||
|
||||
@@ -329,8 +634,8 @@ type StaleVectorInfo struct {
|
||||
// GetStaleVectors returns doc_ids of vectors with mismatched or null model versions.
|
||||
// This enables granular rebuild - only re-embedding documents that need updating.
|
||||
func (c *Client) GetStaleVectors(ctx context.Context) ([]StaleVectorInfo, error) {
|
||||
c.mu.Lock()
|
||||
defer c.mu.Unlock()
|
||||
c.readMu.RLock()
|
||||
defer c.readMu.RUnlock()
|
||||
|
||||
currentModel := c.embedSvc.Version()
|
||||
|
||||
@@ -372,6 +677,134 @@ func (c *Client) GetStaleVectors(ctx context.Context) ([]StaleVectorInfo, error)
|
||||
return results, nil
|
||||
}
|
||||
|
||||
// VectorHealthStats contains comprehensive health information about the vector store.
|
||||
type VectorHealthStats struct {
|
||||
CoverageByType map[string]int64 `json:"coverage_by_type"`
|
||||
ModelVersions map[string]int64 `json:"model_versions"`
|
||||
ProjectCounts map[string]int64 `json:"project_counts"`
|
||||
CurrentModel string `json:"current_model"`
|
||||
RebuildReason string `json:"rebuild_reason,omitempty"`
|
||||
EmbeddingCache CacheStatsSnapshot `json:"embedding_cache"`
|
||||
TotalVectors int64 `json:"total_vectors"`
|
||||
StaleVectors int64 `json:"stale_vectors"`
|
||||
NeedsRebuild bool `json:"needs_rebuild"`
|
||||
}
|
||||
|
||||
// GetHealthStats returns comprehensive health statistics about the vector store.
|
||||
func (c *Client) GetHealthStats(ctx context.Context) (*VectorHealthStats, error) {
|
||||
c.readMu.RLock()
|
||||
defer c.readMu.RUnlock()
|
||||
|
||||
stats := &VectorHealthStats{
|
||||
CurrentModel: c.embedSvc.Version(),
|
||||
CoverageByType: make(map[string]int64),
|
||||
ModelVersions: make(map[string]int64),
|
||||
ProjectCounts: make(map[string]int64),
|
||||
EmbeddingCache: c.stats.Snapshot(),
|
||||
}
|
||||
|
||||
// Get total count
|
||||
err := c.db.QueryRowContext(ctx, "SELECT COUNT(*) FROM vectors").Scan(&stats.TotalVectors)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("count total vectors: %w", err)
|
||||
}
|
||||
|
||||
// Get stale count
|
||||
err = c.db.QueryRowContext(ctx,
|
||||
"SELECT COUNT(*) FROM vectors WHERE model_version != ? OR model_version IS NULL",
|
||||
stats.CurrentModel,
|
||||
).Scan(&stats.StaleVectors)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("count stale vectors: %w", err)
|
||||
}
|
||||
|
||||
// Check if rebuild needed
|
||||
stats.NeedsRebuild, stats.RebuildReason = c.needsRebuildUnlocked(ctx, stats.CurrentModel)
|
||||
|
||||
// Get coverage by doc_type
|
||||
rows, err := c.db.QueryContext(ctx, "SELECT doc_type, COUNT(*) FROM vectors GROUP BY doc_type")
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("query doc types: %w", err)
|
||||
}
|
||||
defer rows.Close()
|
||||
|
||||
for rows.Next() {
|
||||
var docType sql.NullString
|
||||
var count int64
|
||||
if err := rows.Scan(&docType, &count); err != nil {
|
||||
return nil, fmt.Errorf("scan doc type: %w", err)
|
||||
}
|
||||
if docType.Valid {
|
||||
stats.CoverageByType[docType.String] = count
|
||||
} else {
|
||||
stats.CoverageByType["unknown"] = count
|
||||
}
|
||||
}
|
||||
|
||||
// Get model version distribution
|
||||
rows2, err := c.db.QueryContext(ctx, "SELECT COALESCE(model_version, 'unknown'), COUNT(*) FROM vectors GROUP BY model_version")
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("query model versions: %w", err)
|
||||
}
|
||||
defer rows2.Close()
|
||||
|
||||
for rows2.Next() {
|
||||
var version string
|
||||
var count int64
|
||||
if err := rows2.Scan(&version, &count); err != nil {
|
||||
return nil, fmt.Errorf("scan model version: %w", err)
|
||||
}
|
||||
stats.ModelVersions[version] = count
|
||||
}
|
||||
|
||||
// Get project counts (top 10)
|
||||
rows3, err := c.db.QueryContext(ctx,
|
||||
"SELECT COALESCE(project, 'global'), COUNT(*) FROM vectors GROUP BY project ORDER BY COUNT(*) DESC LIMIT 10")
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("query projects: %w", err)
|
||||
}
|
||||
defer rows3.Close()
|
||||
|
||||
for rows3.Next() {
|
||||
var project string
|
||||
var count int64
|
||||
if err := rows3.Scan(&project, &count); err != nil {
|
||||
return nil, fmt.Errorf("scan project: %w", err)
|
||||
}
|
||||
stats.ProjectCounts[project] = count
|
||||
}
|
||||
|
||||
return stats, nil
|
||||
}
|
||||
|
||||
// needsRebuildUnlocked checks if rebuild is needed without acquiring lock (caller must hold lock).
|
||||
func (c *Client) needsRebuildUnlocked(ctx context.Context, currentModel string) (bool, string) {
|
||||
var totalCount int64
|
||||
err := c.db.QueryRowContext(ctx, "SELECT COUNT(*) FROM vectors").Scan(&totalCount)
|
||||
if err != nil {
|
||||
return false, ""
|
||||
}
|
||||
|
||||
if totalCount == 0 {
|
||||
return true, "empty"
|
||||
}
|
||||
|
||||
var staleCount int64
|
||||
err = c.db.QueryRowContext(ctx,
|
||||
"SELECT COUNT(*) FROM vectors WHERE model_version != ? OR model_version IS NULL",
|
||||
currentModel,
|
||||
).Scan(&staleCount)
|
||||
if err != nil {
|
||||
return false, ""
|
||||
}
|
||||
|
||||
if staleCount > 0 {
|
||||
return true, fmt.Sprintf("model_mismatch:%d", staleCount)
|
||||
}
|
||||
|
||||
return false, ""
|
||||
}
|
||||
|
||||
// DeleteVectorsByDocIDs removes vectors by their doc_ids.
|
||||
// Used for granular rebuild - delete stale vectors before re-adding.
|
||||
func (c *Client) DeleteVectorsByDocIDs(ctx context.Context, docIDs []string) error {
|
||||
@@ -379,12 +812,12 @@ func (c *Client) DeleteVectorsByDocIDs(ctx context.Context, docIDs []string) err
|
||||
return nil
|
||||
}
|
||||
|
||||
c.mu.Lock()
|
||||
defer c.mu.Unlock()
|
||||
c.writeMu.Lock()
|
||||
defer c.writeMu.Unlock()
|
||||
|
||||
// Build placeholder string
|
||||
placeholders := make([]string, len(docIDs))
|
||||
args := make([]interface{}, len(docIDs))
|
||||
args := make([]any, len(docIDs))
|
||||
for i, id := range docIDs {
|
||||
placeholders[i] = "?"
|
||||
args[i] = id
|
||||
@@ -402,3 +835,249 @@ func (c *Client) DeleteVectorsByDocIDs(ctx context.Context, docIDs []string) err
|
||||
log.Debug().Int("count", len(docIDs)).Msg("Deleted stale vectors by doc_id")
|
||||
return nil
|
||||
}
|
||||
|
||||
// DeleteByObservationID removes all vectors associated with an observation ID.
|
||||
// Vectors are stored with doc_ids that include the observation ID, e.g., "obs_123_narrative".
|
||||
func (c *Client) DeleteByObservationID(ctx context.Context, obsID int64) error {
|
||||
c.writeMu.Lock()
|
||||
defer c.writeMu.Unlock()
|
||||
|
||||
// Vectors have doc_ids like "obs_123_narrative", "obs_123_facts_0", etc.
|
||||
pattern := fmt.Sprintf("obs_%d_%%", obsID)
|
||||
|
||||
_, err := c.db.ExecContext(ctx, "DELETE FROM vectors WHERE doc_id LIKE ?", pattern)
|
||||
if err != nil {
|
||||
return fmt.Errorf("delete vectors for observation %d: %w", obsID, err)
|
||||
}
|
||||
|
||||
log.Debug().Int64("observation_id", obsID).Msg("Deleted vectors for observation")
|
||||
return nil
|
||||
}
|
||||
|
||||
// getOrComputeEmbedding returns a cached embedding or computes a new one.
|
||||
// Uses singleflight to prevent duplicate concurrent computations for the same query.
|
||||
func (c *Client) getOrComputeEmbedding(query string) ([]float32, error) {
|
||||
now := time.Now().UnixNano()
|
||||
|
||||
// Check cache first (read lock)
|
||||
c.queryCacheMu.RLock()
|
||||
if entry, ok := c.queryCache[query]; ok {
|
||||
if now-entry.timestamp < c.cacheTTLNano {
|
||||
c.queryCacheMu.RUnlock()
|
||||
c.stats.embeddingHits.Add(1)
|
||||
return entry.embedding, nil
|
||||
}
|
||||
}
|
||||
c.queryCacheMu.RUnlock()
|
||||
|
||||
// Cache miss - use singleflight to deduplicate concurrent embedding requests
|
||||
result, err, _ := c.embeddingGroup.Do(query, func() (any, error) {
|
||||
// Double-check cache inside singleflight (another goroutine may have just cached it)
|
||||
c.queryCacheMu.RLock()
|
||||
if entry, ok := c.queryCache[query]; ok {
|
||||
if time.Now().UnixNano()-entry.timestamp < c.cacheTTLNano {
|
||||
c.queryCacheMu.RUnlock()
|
||||
return entry.embedding, nil
|
||||
}
|
||||
}
|
||||
c.queryCacheMu.RUnlock()
|
||||
|
||||
// Record cache miss
|
||||
c.stats.embeddingMisses.Add(1)
|
||||
|
||||
// Compute embedding
|
||||
emb, err := c.embedSvc.Embed(query)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// Store in cache (write lock)
|
||||
c.queryCacheMu.Lock()
|
||||
nowCache := time.Now().UnixNano()
|
||||
// Evict old entries if cache is full or near capacity (80% threshold)
|
||||
evictionThreshold := (c.cacheMaxSize * 8) / 10
|
||||
if len(c.queryCache) >= evictionThreshold {
|
||||
// Phase 1: Remove ALL expired entries first (not just 10%)
|
||||
evicted := int64(0)
|
||||
for k, v := range c.queryCache {
|
||||
if nowCache-v.timestamp > c.cacheTTLNano {
|
||||
delete(c.queryCache, k)
|
||||
evicted++
|
||||
}
|
||||
}
|
||||
|
||||
// Phase 2: If still at capacity, evict 10% using random iteration (O(n) instead of O(n log n))
|
||||
// Go map iteration order is randomized, providing good cache behavior without sorting
|
||||
if len(c.queryCache) >= c.cacheMaxSize {
|
||||
evictCount := max(c.cacheMaxSize/10, 1)
|
||||
for k := range c.queryCache {
|
||||
delete(c.queryCache, k)
|
||||
evicted++
|
||||
evictCount--
|
||||
if evictCount <= 0 {
|
||||
break
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if evicted > 0 {
|
||||
c.stats.embeddingEvictions.Add(evicted)
|
||||
}
|
||||
}
|
||||
c.queryCache[query] = embeddingCacheEntry{
|
||||
embedding: emb,
|
||||
timestamp: nowCache,
|
||||
}
|
||||
c.queryCacheMu.Unlock()
|
||||
|
||||
return emb, nil
|
||||
})
|
||||
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return result.([]float32), nil
|
||||
}
|
||||
|
||||
// ClearCache clears the embedding cache.
|
||||
func (c *Client) ClearCache() {
|
||||
c.queryCacheMu.Lock()
|
||||
c.queryCache = make(map[string]embeddingCacheEntry)
|
||||
c.queryCacheMu.Unlock()
|
||||
}
|
||||
|
||||
// GetCacheStats returns comprehensive cache statistics.
|
||||
func (c *Client) GetCacheStats() CacheStatsSnapshot {
|
||||
return c.stats.Snapshot()
|
||||
}
|
||||
|
||||
// CacheStats returns basic cache size info for backward compatibility.
|
||||
// Deprecated: Use GetCacheStats() for comprehensive statistics.
|
||||
func (c *Client) CacheStats() (size int, maxSize int) {
|
||||
c.queryCacheMu.RLock()
|
||||
size = len(c.queryCache)
|
||||
c.queryCacheMu.RUnlock()
|
||||
return size, c.cacheMaxSize
|
||||
}
|
||||
|
||||
// EmbeddingCacheSize returns the current embedding cache size.
|
||||
func (c *Client) EmbeddingCacheSize() int {
|
||||
c.queryCacheMu.RLock()
|
||||
defer c.queryCacheMu.RUnlock()
|
||||
return len(c.queryCache)
|
||||
}
|
||||
|
||||
// ResultCacheSize returns the current result cache size.
|
||||
func (c *Client) ResultCacheSize() int {
|
||||
c.resultCacheMu.RLock()
|
||||
defer c.resultCacheMu.RUnlock()
|
||||
return len(c.resultCache)
|
||||
}
|
||||
|
||||
// buildResultCacheKey creates a unique key for caching query results.
|
||||
// Uses strings.Builder to avoid intermediate allocations.
|
||||
func (c *Client) buildResultCacheKey(query string, limit int, where map[string]any) string {
|
||||
// Pre-allocate with typical key size to avoid reallocation
|
||||
var b strings.Builder
|
||||
b.Grow(len(query) + 32) // query + typical prefix/suffix overhead
|
||||
|
||||
b.WriteString("q:")
|
||||
b.WriteString(query)
|
||||
b.WriteString(":l:")
|
||||
b.WriteString(strconv.Itoa(limit))
|
||||
|
||||
if docType, ok := where["doc_type"].(string); ok {
|
||||
b.WriteString(":dt:")
|
||||
b.WriteString(docType)
|
||||
}
|
||||
if project, ok := where["project"].(string); ok {
|
||||
b.WriteString(":p:")
|
||||
b.WriteString(project)
|
||||
}
|
||||
return b.String()
|
||||
}
|
||||
|
||||
// getResultFromCache retrieves cached results if available and not expired.
|
||||
func (c *Client) getResultFromCache(cacheKey string) ([]QueryResult, bool) {
|
||||
now := time.Now().UnixNano()
|
||||
|
||||
c.resultCacheMu.RLock()
|
||||
entry, ok := c.resultCache[cacheKey]
|
||||
c.resultCacheMu.RUnlock()
|
||||
|
||||
if !ok {
|
||||
c.stats.resultMisses.Add(1)
|
||||
return nil, false
|
||||
}
|
||||
|
||||
// Check if entry is expired
|
||||
if now-entry.timestamp > c.resultCacheTTLNano {
|
||||
c.stats.resultMisses.Add(1)
|
||||
return nil, false
|
||||
}
|
||||
|
||||
c.stats.resultHits.Add(1)
|
||||
|
||||
// Return a copy to prevent mutation
|
||||
results := make([]QueryResult, len(entry.results))
|
||||
copy(results, entry.results)
|
||||
return results, true
|
||||
}
|
||||
|
||||
// cacheResults stores query results in the cache.
|
||||
func (c *Client) cacheResults(cacheKey string, results []QueryResult) {
|
||||
now := time.Now().UnixNano()
|
||||
|
||||
c.resultCacheMu.Lock()
|
||||
defer c.resultCacheMu.Unlock()
|
||||
|
||||
// Evict old entries if cache is full
|
||||
if len(c.resultCache) >= c.resultCacheMaxSize {
|
||||
// Two-phase eviction: (1) TTL-expired entries, (2) random if still over capacity
|
||||
evicted := 0
|
||||
targetSize := c.resultCacheMaxSize * 8 / 10 // Target 80% capacity
|
||||
|
||||
// Phase 1: Remove all TTL-expired entries
|
||||
for k, v := range c.resultCache {
|
||||
if now-v.timestamp > c.resultCacheTTLNano {
|
||||
delete(c.resultCache, k)
|
||||
evicted++
|
||||
}
|
||||
}
|
||||
|
||||
// Phase 2: If still over target, remove random entries until at target
|
||||
if len(c.resultCache) >= targetSize {
|
||||
evictCount := len(c.resultCache) - targetSize + 1
|
||||
for k := range c.resultCache {
|
||||
delete(c.resultCache, k)
|
||||
evicted++
|
||||
evictCount--
|
||||
if evictCount <= 0 {
|
||||
break
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if evicted > 0 {
|
||||
c.stats.resultEvictions.Add(int64(evicted))
|
||||
}
|
||||
}
|
||||
|
||||
// Make a copy of results to store
|
||||
resultsCopy := make([]QueryResult, len(results))
|
||||
copy(resultsCopy, results)
|
||||
|
||||
c.resultCache[cacheKey] = resultCacheEntry{
|
||||
results: resultsCopy,
|
||||
timestamp: now,
|
||||
queryHash: cacheKey,
|
||||
}
|
||||
}
|
||||
|
||||
// InvalidateResultCache clears the result cache.
|
||||
// Should be called after write operations that modify vectors.
|
||||
func (c *Client) InvalidateResultCache() {
|
||||
c.resultCacheMu.Lock()
|
||||
c.resultCache = make(map[string]resultCacheEntry)
|
||||
c.resultCacheMu.Unlock()
|
||||
}
|
||||
|
||||
@@ -338,3 +338,187 @@ func (s *Sync) DeletePatterns(ctx context.Context, patternIDs []int64) error {
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// BatchSyncConfig configures batch synchronization behavior.
|
||||
type BatchSyncConfig struct {
|
||||
BatchSize int // Number of items per batch (default: 50)
|
||||
ProgressLogFreq int // Log progress every N items (default: 100)
|
||||
}
|
||||
|
||||
// DefaultBatchSyncConfig returns sensible defaults for batch sync.
|
||||
func DefaultBatchSyncConfig() BatchSyncConfig {
|
||||
return BatchSyncConfig{
|
||||
BatchSize: 50,
|
||||
ProgressLogFreq: 100,
|
||||
}
|
||||
}
|
||||
|
||||
// BatchSyncObservations syncs multiple observations efficiently in batches.
|
||||
// This reduces memory pressure during large rebuilds by processing in chunks.
|
||||
func (s *Sync) BatchSyncObservations(ctx context.Context, observations []*models.Observation, cfg BatchSyncConfig) (synced int, errors int) {
|
||||
if len(observations) == 0 {
|
||||
return 0, 0
|
||||
}
|
||||
|
||||
if cfg.BatchSize <= 0 {
|
||||
cfg.BatchSize = 50
|
||||
}
|
||||
if cfg.ProgressLogFreq <= 0 {
|
||||
cfg.ProgressLogFreq = 100
|
||||
}
|
||||
|
||||
for i := 0; i < len(observations); i += cfg.BatchSize {
|
||||
// Check context cancellation
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
log.Warn().Int("synced", synced).Int("remaining", len(observations)-i).Msg("Batch sync cancelled")
|
||||
return synced, errors
|
||||
default:
|
||||
}
|
||||
|
||||
end := min(i+cfg.BatchSize, len(observations))
|
||||
|
||||
batch := observations[i:end]
|
||||
var docs []Document
|
||||
|
||||
// Collect all documents for this batch
|
||||
for _, obs := range batch {
|
||||
docs = append(docs, s.formatObservationDocs(obs)...)
|
||||
}
|
||||
|
||||
// Add all documents in one call
|
||||
if len(docs) > 0 {
|
||||
if err := s.client.AddDocuments(ctx, docs); err != nil {
|
||||
log.Warn().Err(err).Int("batchStart", i).Int("batchSize", len(batch)).Msg("Failed to sync observation batch")
|
||||
errors += len(batch)
|
||||
continue
|
||||
}
|
||||
}
|
||||
|
||||
synced += len(batch)
|
||||
|
||||
// Log progress periodically
|
||||
if synced%cfg.ProgressLogFreq == 0 || synced == len(observations) {
|
||||
log.Debug().Int("synced", synced).Int("total", len(observations)).Msg("Observation batch sync progress")
|
||||
}
|
||||
}
|
||||
|
||||
return synced, errors
|
||||
}
|
||||
|
||||
// BatchSyncSummaries syncs multiple summaries efficiently in batches.
|
||||
func (s *Sync) BatchSyncSummaries(ctx context.Context, summaries []*models.SessionSummary, cfg BatchSyncConfig) (synced int, errors int) {
|
||||
if len(summaries) == 0 {
|
||||
return 0, 0
|
||||
}
|
||||
|
||||
if cfg.BatchSize <= 0 {
|
||||
cfg.BatchSize = 50
|
||||
}
|
||||
if cfg.ProgressLogFreq <= 0 {
|
||||
cfg.ProgressLogFreq = 100
|
||||
}
|
||||
|
||||
for i := 0; i < len(summaries); i += cfg.BatchSize {
|
||||
// Check context cancellation
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
log.Warn().Int("synced", synced).Int("remaining", len(summaries)-i).Msg("Batch sync cancelled")
|
||||
return synced, errors
|
||||
default:
|
||||
}
|
||||
|
||||
end := min(i+cfg.BatchSize, len(summaries))
|
||||
|
||||
batch := summaries[i:end]
|
||||
var docs []Document
|
||||
|
||||
// Collect all documents for this batch
|
||||
for _, summary := range batch {
|
||||
docs = append(docs, s.formatSummaryDocs(summary)...)
|
||||
}
|
||||
|
||||
// Add all documents in one call
|
||||
if len(docs) > 0 {
|
||||
if err := s.client.AddDocuments(ctx, docs); err != nil {
|
||||
log.Warn().Err(err).Int("batchStart", i).Int("batchSize", len(batch)).Msg("Failed to sync summary batch")
|
||||
errors += len(batch)
|
||||
continue
|
||||
}
|
||||
}
|
||||
|
||||
synced += len(batch)
|
||||
|
||||
// Log progress periodically
|
||||
if synced%cfg.ProgressLogFreq == 0 || synced == len(summaries) {
|
||||
log.Debug().Int("synced", synced).Int("total", len(summaries)).Msg("Summary batch sync progress")
|
||||
}
|
||||
}
|
||||
|
||||
return synced, errors
|
||||
}
|
||||
|
||||
// BatchSyncPrompts syncs multiple user prompts efficiently in batches.
|
||||
func (s *Sync) BatchSyncPrompts(ctx context.Context, prompts []*models.UserPromptWithSession, cfg BatchSyncConfig) (synced int, errors int) {
|
||||
if len(prompts) == 0 {
|
||||
return 0, 0
|
||||
}
|
||||
|
||||
if cfg.BatchSize <= 0 {
|
||||
cfg.BatchSize = 50
|
||||
}
|
||||
if cfg.ProgressLogFreq <= 0 {
|
||||
cfg.ProgressLogFreq = 100
|
||||
}
|
||||
|
||||
for i := 0; i < len(prompts); i += cfg.BatchSize {
|
||||
// Check context cancellation
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
log.Warn().Int("synced", synced).Int("remaining", len(prompts)-i).Msg("Batch sync cancelled")
|
||||
return synced, errors
|
||||
default:
|
||||
}
|
||||
|
||||
end := min(i+cfg.BatchSize, len(prompts))
|
||||
|
||||
batch := prompts[i:end]
|
||||
docs := make([]Document, 0, len(batch))
|
||||
|
||||
// Collect all documents for this batch
|
||||
for _, prompt := range batch {
|
||||
docs = append(docs, Document{
|
||||
ID: fmt.Sprintf("prompt_%d", prompt.ID),
|
||||
Content: prompt.PromptText,
|
||||
Metadata: map[string]any{
|
||||
"sqlite_id": prompt.ID,
|
||||
"doc_type": "user_prompt",
|
||||
"sdk_session_id": prompt.SDKSessionID,
|
||||
"project": prompt.Project,
|
||||
"scope": "",
|
||||
"created_at_epoch": prompt.CreatedAtEpoch,
|
||||
"prompt_number": prompt.PromptNumber,
|
||||
"field_type": "prompt",
|
||||
},
|
||||
})
|
||||
}
|
||||
|
||||
// Add all documents in one call
|
||||
if len(docs) > 0 {
|
||||
if err := s.client.AddDocuments(ctx, docs); err != nil {
|
||||
log.Warn().Err(err).Int("batchStart", i).Int("batchSize", len(batch)).Msg("Failed to sync prompt batch")
|
||||
errors += len(batch)
|
||||
continue
|
||||
}
|
||||
}
|
||||
|
||||
synced += len(batch)
|
||||
|
||||
// Log progress periodically
|
||||
if synced%cfg.ProgressLogFreq == 0 || synced == len(prompts) {
|
||||
log.Debug().Int("synced", synced).Int("total", len(prompts)).Msg("Prompt batch sync progress")
|
||||
}
|
||||
}
|
||||
|
||||
return synced, errors
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user