Make things 'betterer' across the board (#23)

* Make things 'betterer' across the board

* fix: reorganize struct fields and config parameters for consistency

- [x] Reorder Config struct fields alphabetically and by related functionality
- [x] Reorganize Observation model fields with archival fields grouped together
- [x] Reorder ObservationStore fields to group related members
- [x] Reorder Store struct fields with health check caching grouped
- [x] Reorganize HealthInfo and PoolMetrics struct field order
- [x] Reorder maintenance Service struct fields logically
- [x] Reorganize MCP server handler parameter structs alphabetically
- [x] Reorder pattern detector candidate tracking fields
- [x] Reorganize search Manager struct fields by functionality
- [x] Reorder vector Client struct fields with mutex protections grouped
- [x] Reorganize handler request/response struct fields
- [x] Update handlers_test.go to expect wrapped response format
- [x] Reorder middleware TokenAuth and rate limiter fields
- [x] Reorganize Service struct fields with grouped functionality
- [x] Fix RateLimiter field ordering for clarity
- [x] Reorder CircuitBreaker metrics fields

* fix(security): improve JSON output safety and path traversal protection

- [x] Replace unsafe JSON string formatting with proper json.Marshal in export handler
- [x] Remove escapeJSONString helper function in favor of standard JSON marshaling
- [x] Add safeResolvePath function to validate paths and prevent directory traversal
- [x] Apply path traversal validation in captureFileMtimes operations
- [x] Cap result slice capacity in getRecentSearchQueries to prevent DoS via excessive allocation

* fix(sdk): improve path traversal protection and allocation safety

- [x] Enhance safeResolvePath with stricter validation using filepath.Rel
- [x] Reject paths containing ".." after cleaning to prevent traversal
- [x] Validate absolute paths are within cwd when cwd is specified
- [x] Apply safeResolvePath validation to GetFileContent for consistency
- [x] Add comprehensive test coverage for path traversal protection
- [x] Fix allocation safety in getRecentSearchQueries by using constant capacity
This commit is contained in:
2026-01-11 01:51:20 +00:00
committed by GitHub
parent 3107eddeb2
commit d04b60517a
46 changed files with 12710 additions and 2068 deletions
+707 -28
View File
@@ -5,19 +5,105 @@ import (
"context"
"database/sql"
"fmt"
"strconv"
"strings"
"sync"
"sync/atomic"
"time"
sqlite_vec "github.com/asg017/sqlite-vec-go-bindings/cgo"
"github.com/lukaszraczylo/claude-mnemonic/internal/embedding"
"github.com/rs/zerolog/log"
"golang.org/x/sync/singleflight"
)
// embeddingCacheEntry stores a cached embedding with its timestamp.
type embeddingCacheEntry struct {
embedding []float32
timestamp int64 // Unix nano
}
// resultCacheEntry stores cached query results.
type resultCacheEntry struct {
queryHash string
results []QueryResult
timestamp int64
}
// Client provides vector operations via sqlite-vec.
type Client struct {
db *sql.DB
embedSvc *embedding.Service
mu sync.Mutex
embeddingGroup singleflight.Group
resultCache map[string]resultCacheEntry
db *sql.DB
embedSvc *embedding.Service
queryCache map[string]embeddingCacheEntry
stopCleanup chan struct{}
stats CacheStats
cleanupWg sync.WaitGroup
resultCacheTTLNano int64
cacheTTLNano int64
resultCacheMaxSize int
cacheMaxSize int
resultCacheMu sync.RWMutex
queryCacheMu sync.RWMutex
readMu sync.RWMutex
writeMu sync.Mutex
}
// CacheStats tracks cache performance metrics using atomic counters for lock-free updates.
type CacheStats struct {
embeddingHits atomic.Int64
embeddingMisses atomic.Int64
resultHits atomic.Int64
resultMisses atomic.Int64
embeddingEvictions atomic.Int64
resultEvictions atomic.Int64
}
// CacheStatsSnapshot is the exported version of CacheStats for JSON marshaling.
type CacheStatsSnapshot struct {
EmbeddingHits int64 `json:"embedding_hits"`
EmbeddingMisses int64 `json:"embedding_misses"`
ResultHits int64 `json:"result_hits"`
ResultMisses int64 `json:"result_misses"`
EmbeddingEvictions int64 `json:"embedding_evictions"`
ResultEvictions int64 `json:"result_evictions"`
}
// HitRate returns the cache hit rate as a percentage.
func (s CacheStatsSnapshot) HitRate() float64 {
total := s.EmbeddingHits + s.EmbeddingMisses + s.ResultHits + s.ResultMisses
if total == 0 {
return 0
}
hits := s.EmbeddingHits + s.ResultHits
return float64(hits) / float64(total) * 100
}
// HitRate returns the cache hit rate as a percentage.
func (s *CacheStats) HitRate() float64 {
embHits := s.embeddingHits.Load()
embMisses := s.embeddingMisses.Load()
resHits := s.resultHits.Load()
resMisses := s.resultMisses.Load()
total := embHits + embMisses + resHits + resMisses
if total == 0 {
return 0
}
hits := embHits + resHits
return float64(hits) / float64(total) * 100
}
// Snapshot returns a copy of the current stats.
func (s *CacheStats) Snapshot() CacheStatsSnapshot {
return CacheStatsSnapshot{
EmbeddingHits: s.embeddingHits.Load(),
EmbeddingMisses: s.embeddingMisses.Load(),
ResultHits: s.resultHits.Load(),
ResultMisses: s.resultMisses.Load(),
EmbeddingEvictions: s.embeddingEvictions.Load(),
ResultEvictions: s.resultEvictions.Load(),
}
}
// Config holds configuration for the client.
@@ -34,10 +120,23 @@ func NewClient(cfg Config, embedSvc *embedding.Service) (*Client, error) {
return nil, fmt.Errorf("embedding service required")
}
return &Client{
db: cfg.DB,
embedSvc: embedSvc,
}, nil
c := &Client{
db: cfg.DB,
embedSvc: embedSvc,
queryCache: make(map[string]embeddingCacheEntry),
cacheMaxSize: 500, // Cache up to 500 query embeddings
cacheTTLNano: 5 * 60 * 1e9, // 5 minute TTL for embeddings
resultCache: make(map[string]resultCacheEntry),
resultCacheMaxSize: 200, // Cache up to 200 search results
resultCacheTTLNano: 60 * 1e9, // 1 minute TTL for results (shorter since data changes)
stopCleanup: make(chan struct{}),
}
// Start background cache cleanup goroutine
c.cleanupWg.Add(1)
go c.cacheCleanupLoop()
return c, nil
}
// AddDocuments adds documents with their embeddings to the vector store.
@@ -46,8 +145,8 @@ func (c *Client) AddDocuments(ctx context.Context, docs []Document) error {
return nil
}
c.mu.Lock()
defer c.mu.Unlock()
c.writeMu.Lock()
defer c.writeMu.Unlock()
// Generate embeddings for all documents
texts := make([]string, len(docs))
@@ -75,7 +174,10 @@ func (c *Client) AddDocuments(ctx context.Context, docs []Document) error {
}
defer func() {
if err != nil {
_ = tx.Rollback()
if rbErr := tx.Rollback(); rbErr != nil {
// Rollback failure is serious - indicates potential data corruption risk
log.Error().Err(rbErr).Err(err).Msg("Failed to rollback transaction after error - data may be inconsistent")
}
}
}()
@@ -118,6 +220,9 @@ func (c *Client) AddDocuments(ctx context.Context, docs []Document) error {
return fmt.Errorf("commit transaction: %w", err)
}
// Invalidate result cache since data changed
c.InvalidateResultCache()
log.Debug().Int("count", len(docs)).Str("model", modelVersion).Msg("Added documents to sqlite-vec")
return nil
}
@@ -128,12 +233,12 @@ func (c *Client) DeleteDocuments(ctx context.Context, ids []string) error {
return nil
}
c.mu.Lock()
defer c.mu.Unlock()
c.writeMu.Lock()
defer c.writeMu.Unlock()
// Build placeholder string
placeholders := make([]string, len(ids))
args := make([]interface{}, len(ids))
args := make([]any, len(ids))
for i, id := range ids {
placeholders[i] = "?"
args[i] = id
@@ -148,17 +253,25 @@ func (c *Client) DeleteDocuments(ctx context.Context, ids []string) error {
return fmt.Errorf("delete documents: %w", err)
}
// Invalidate result cache since data changed
c.InvalidateResultCache()
log.Debug().Int("count", len(ids)).Msg("Deleted documents from sqlite-vec")
return nil
}
// Query performs a vector similarity search.
func (c *Client) Query(ctx context.Context, query string, limit int, where map[string]any) ([]QueryResult, error) {
c.mu.Lock()
defer c.mu.Unlock()
// Build cache key from query + filters + limit
cacheKey := c.buildResultCacheKey(query, limit, where)
// Generate query embedding
queryEmb, err := c.embedSvc.Embed(query)
// Check result cache first
if results, ok := c.getResultFromCache(cacheKey); ok {
return results, nil
}
// Generate query embedding OUTSIDE the lock for better concurrency
queryEmb, err := c.getOrComputeEmbedding(query)
if err != nil {
return nil, fmt.Errorf("embed query: %w", err)
}
@@ -169,9 +282,13 @@ func (c *Client) Query(ctx context.Context, query string, limit int, where map[s
return nil, fmt.Errorf("serialize query embedding: %w", err)
}
// Now acquire read lock for the actual DB query
c.readMu.RLock()
defer c.readMu.RUnlock()
// Build query with filters
// vec0 supports WHERE clauses on metadata columns
args := []interface{}{queryBlob}
args := []any{queryBlob}
sqlQuery := `
SELECT
@@ -232,6 +349,9 @@ func (c *Client) Query(ctx context.Context, query string, limit int, where map[s
return nil, fmt.Errorf("iterate rows: %w", err)
}
// Cache the results
c.cacheResults(cacheKey, results)
log.Debug().
Str("query", truncateString(query, 50)).
Int("results", len(results)).
@@ -245,11 +365,196 @@ func (c *Client) IsConnected() bool {
return c.db != nil
}
// Close is a no-op (db managed externally).
// Close stops the background cleanup goroutine (db managed externally).
func (c *Client) Close() error {
// Signal cleanup goroutine to stop
close(c.stopCleanup)
// Wait for cleanup to finish
c.cleanupWg.Wait()
return nil
}
// cacheCleanupLoop periodically removes expired cache entries.
func (c *Client) cacheCleanupLoop() {
defer c.cleanupWg.Done()
ticker := time.NewTicker(30 * time.Second) // Cleanup every 30 seconds
defer ticker.Stop()
for {
select {
case <-c.stopCleanup:
return
case <-ticker.C:
c.cleanupExpiredCaches()
}
}
}
// cleanupExpiredCaches removes expired entries from both caches.
func (c *Client) cleanupExpiredCaches() {
now := time.Now().UnixNano()
var embeddingExpired, resultExpired int64
// Cleanup embedding cache
c.queryCacheMu.Lock()
for key, entry := range c.queryCache {
if now-entry.timestamp > c.cacheTTLNano {
delete(c.queryCache, key)
embeddingExpired++
}
}
c.queryCacheMu.Unlock()
// Cleanup result cache
c.resultCacheMu.Lock()
for key, entry := range c.resultCache {
if now-entry.timestamp > c.resultCacheTTLNano {
delete(c.resultCache, key)
resultExpired++
}
}
c.resultCacheMu.Unlock()
// Update stats atomically
if embeddingExpired > 0 || resultExpired > 0 {
c.stats.embeddingEvictions.Add(embeddingExpired)
c.stats.resultEvictions.Add(resultExpired)
log.Debug().
Int64("embedding_expired", embeddingExpired).
Int64("result_expired", resultExpired).
Msg("Cache cleanup completed")
}
}
// BatchQueryResult holds results from a batch query operation.
type BatchQueryResult struct {
Error error
Query string
Results []QueryResult
}
// QueryBatch performs multiple vector searches concurrently.
// Returns results in the same order as input queries.
// Uses a worker pool to limit concurrent queries.
func (c *Client) QueryBatch(ctx context.Context, queries []string, limit int, where map[string]any) []BatchQueryResult {
if len(queries) == 0 {
return nil
}
// Limit concurrency to avoid overwhelming the database
maxConcurrent := min(4, len(queries))
results := make([]BatchQueryResult, len(queries))
sem := make(chan struct{}, maxConcurrent)
var wg sync.WaitGroup
for i, query := range queries {
wg.Add(1)
go func(idx int, q string) {
defer wg.Done()
// Acquire semaphore
select {
case sem <- struct{}{}:
defer func() { <-sem }()
case <-ctx.Done():
results[idx] = BatchQueryResult{
Query: q,
Error: ctx.Err(),
}
return
}
// Execute query
queryResults, err := c.Query(ctx, q, limit, where)
results[idx] = BatchQueryResult{
Query: q,
Results: queryResults,
Error: err,
}
}(i, query)
}
wg.Wait()
return results
}
// QueryMultiField searches across multiple fields for a single query.
// Combines results from different field types and deduplicates by document ID.
func (c *Client) QueryMultiField(ctx context.Context, query string, limit int, docType string, project string) ([]QueryResult, error) {
// Generate embedding once
queryEmb, err := c.getOrComputeEmbedding(query)
if err != nil {
return nil, fmt.Errorf("embed query: %w", err)
}
// Serialize query embedding
queryBlob, err := sqlite_vec.SerializeFloat32(queryEmb)
if err != nil {
return nil, fmt.Errorf("serialize query embedding: %w", err)
}
c.readMu.RLock()
defer c.readMu.RUnlock()
// Query with field type aggregation - get best match per document
sqlQuery := `
WITH ranked_results AS (
SELECT
doc_id,
distance,
sqlite_id,
doc_type,
field_type,
project,
scope,
ROW_NUMBER() OVER (PARTITION BY sqlite_id ORDER BY distance ASC) as rn
FROM vectors
WHERE embedding MATCH ?
AND doc_type = ?
AND (project = ? OR scope = 'global')
)
SELECT doc_id, distance, sqlite_id, doc_type, field_type, project, scope
FROM ranked_results
WHERE rn = 1
ORDER BY distance
LIMIT ?
`
rows, err := c.db.QueryContext(ctx, sqlQuery, queryBlob, docType, project, limit)
if err != nil {
return nil, fmt.Errorf("query vectors: %w", err)
}
defer rows.Close()
// Pre-allocate with limit to avoid repeated slice growth
results := make([]QueryResult, 0, limit)
for rows.Next() {
var r QueryResult
var sqliteID int64
var docTypeVal, fieldType, projectVal, scope sql.NullString
if err := rows.Scan(&r.ID, &r.Distance, &sqliteID, &docTypeVal, &fieldType, &projectVal, &scope); err != nil {
return nil, fmt.Errorf("scan row: %w", err)
}
r.Similarity = DistanceToSimilarity(r.Distance)
r.Metadata = map[string]any{
"sqlite_id": float64(sqliteID),
"doc_type": docTypeVal.String,
"field_type": fieldType.String,
"project": projectVal.String,
"scope": scope.String,
}
results = append(results, r)
}
return results, rows.Err()
}
// truncateString truncates a string to maxLen characters.
func truncateString(s string, maxLen int) string {
if len(s) <= maxLen {
@@ -260,8 +565,8 @@ func truncateString(s string, maxLen int) string {
// Count returns the total number of vectors in the store.
func (c *Client) Count(ctx context.Context) (int64, error) {
c.mu.Lock()
defer c.mu.Unlock()
c.readMu.RLock()
defer c.readMu.RUnlock()
var count int64
err := c.db.QueryRowContext(ctx, "SELECT COUNT(*) FROM vectors").Scan(&count)
@@ -281,8 +586,8 @@ func (c *Client) ModelVersion() string {
// - The vectors table is empty
// - Any vectors have a different model_version than the current model
func (c *Client) NeedsRebuild(ctx context.Context) (bool, string) {
c.mu.Lock()
defer c.mu.Unlock()
c.readMu.RLock()
defer c.readMu.RUnlock()
currentModel := c.embedSvc.Version()
@@ -329,8 +634,8 @@ type StaleVectorInfo struct {
// GetStaleVectors returns doc_ids of vectors with mismatched or null model versions.
// This enables granular rebuild - only re-embedding documents that need updating.
func (c *Client) GetStaleVectors(ctx context.Context) ([]StaleVectorInfo, error) {
c.mu.Lock()
defer c.mu.Unlock()
c.readMu.RLock()
defer c.readMu.RUnlock()
currentModel := c.embedSvc.Version()
@@ -372,6 +677,134 @@ func (c *Client) GetStaleVectors(ctx context.Context) ([]StaleVectorInfo, error)
return results, nil
}
// VectorHealthStats contains comprehensive health information about the vector store.
type VectorHealthStats struct {
CoverageByType map[string]int64 `json:"coverage_by_type"`
ModelVersions map[string]int64 `json:"model_versions"`
ProjectCounts map[string]int64 `json:"project_counts"`
CurrentModel string `json:"current_model"`
RebuildReason string `json:"rebuild_reason,omitempty"`
EmbeddingCache CacheStatsSnapshot `json:"embedding_cache"`
TotalVectors int64 `json:"total_vectors"`
StaleVectors int64 `json:"stale_vectors"`
NeedsRebuild bool `json:"needs_rebuild"`
}
// GetHealthStats returns comprehensive health statistics about the vector store.
func (c *Client) GetHealthStats(ctx context.Context) (*VectorHealthStats, error) {
c.readMu.RLock()
defer c.readMu.RUnlock()
stats := &VectorHealthStats{
CurrentModel: c.embedSvc.Version(),
CoverageByType: make(map[string]int64),
ModelVersions: make(map[string]int64),
ProjectCounts: make(map[string]int64),
EmbeddingCache: c.stats.Snapshot(),
}
// Get total count
err := c.db.QueryRowContext(ctx, "SELECT COUNT(*) FROM vectors").Scan(&stats.TotalVectors)
if err != nil {
return nil, fmt.Errorf("count total vectors: %w", err)
}
// Get stale count
err = c.db.QueryRowContext(ctx,
"SELECT COUNT(*) FROM vectors WHERE model_version != ? OR model_version IS NULL",
stats.CurrentModel,
).Scan(&stats.StaleVectors)
if err != nil {
return nil, fmt.Errorf("count stale vectors: %w", err)
}
// Check if rebuild needed
stats.NeedsRebuild, stats.RebuildReason = c.needsRebuildUnlocked(ctx, stats.CurrentModel)
// Get coverage by doc_type
rows, err := c.db.QueryContext(ctx, "SELECT doc_type, COUNT(*) FROM vectors GROUP BY doc_type")
if err != nil {
return nil, fmt.Errorf("query doc types: %w", err)
}
defer rows.Close()
for rows.Next() {
var docType sql.NullString
var count int64
if err := rows.Scan(&docType, &count); err != nil {
return nil, fmt.Errorf("scan doc type: %w", err)
}
if docType.Valid {
stats.CoverageByType[docType.String] = count
} else {
stats.CoverageByType["unknown"] = count
}
}
// Get model version distribution
rows2, err := c.db.QueryContext(ctx, "SELECT COALESCE(model_version, 'unknown'), COUNT(*) FROM vectors GROUP BY model_version")
if err != nil {
return nil, fmt.Errorf("query model versions: %w", err)
}
defer rows2.Close()
for rows2.Next() {
var version string
var count int64
if err := rows2.Scan(&version, &count); err != nil {
return nil, fmt.Errorf("scan model version: %w", err)
}
stats.ModelVersions[version] = count
}
// Get project counts (top 10)
rows3, err := c.db.QueryContext(ctx,
"SELECT COALESCE(project, 'global'), COUNT(*) FROM vectors GROUP BY project ORDER BY COUNT(*) DESC LIMIT 10")
if err != nil {
return nil, fmt.Errorf("query projects: %w", err)
}
defer rows3.Close()
for rows3.Next() {
var project string
var count int64
if err := rows3.Scan(&project, &count); err != nil {
return nil, fmt.Errorf("scan project: %w", err)
}
stats.ProjectCounts[project] = count
}
return stats, nil
}
// needsRebuildUnlocked checks if rebuild is needed without acquiring lock (caller must hold lock).
func (c *Client) needsRebuildUnlocked(ctx context.Context, currentModel string) (bool, string) {
var totalCount int64
err := c.db.QueryRowContext(ctx, "SELECT COUNT(*) FROM vectors").Scan(&totalCount)
if err != nil {
return false, ""
}
if totalCount == 0 {
return true, "empty"
}
var staleCount int64
err = c.db.QueryRowContext(ctx,
"SELECT COUNT(*) FROM vectors WHERE model_version != ? OR model_version IS NULL",
currentModel,
).Scan(&staleCount)
if err != nil {
return false, ""
}
if staleCount > 0 {
return true, fmt.Sprintf("model_mismatch:%d", staleCount)
}
return false, ""
}
// DeleteVectorsByDocIDs removes vectors by their doc_ids.
// Used for granular rebuild - delete stale vectors before re-adding.
func (c *Client) DeleteVectorsByDocIDs(ctx context.Context, docIDs []string) error {
@@ -379,12 +812,12 @@ func (c *Client) DeleteVectorsByDocIDs(ctx context.Context, docIDs []string) err
return nil
}
c.mu.Lock()
defer c.mu.Unlock()
c.writeMu.Lock()
defer c.writeMu.Unlock()
// Build placeholder string
placeholders := make([]string, len(docIDs))
args := make([]interface{}, len(docIDs))
args := make([]any, len(docIDs))
for i, id := range docIDs {
placeholders[i] = "?"
args[i] = id
@@ -402,3 +835,249 @@ func (c *Client) DeleteVectorsByDocIDs(ctx context.Context, docIDs []string) err
log.Debug().Int("count", len(docIDs)).Msg("Deleted stale vectors by doc_id")
return nil
}
// DeleteByObservationID removes all vectors associated with an observation ID.
// Vectors are stored with doc_ids that include the observation ID, e.g., "obs_123_narrative".
func (c *Client) DeleteByObservationID(ctx context.Context, obsID int64) error {
c.writeMu.Lock()
defer c.writeMu.Unlock()
// Vectors have doc_ids like "obs_123_narrative", "obs_123_facts_0", etc.
pattern := fmt.Sprintf("obs_%d_%%", obsID)
_, err := c.db.ExecContext(ctx, "DELETE FROM vectors WHERE doc_id LIKE ?", pattern)
if err != nil {
return fmt.Errorf("delete vectors for observation %d: %w", obsID, err)
}
log.Debug().Int64("observation_id", obsID).Msg("Deleted vectors for observation")
return nil
}
// getOrComputeEmbedding returns a cached embedding or computes a new one.
// Uses singleflight to prevent duplicate concurrent computations for the same query.
func (c *Client) getOrComputeEmbedding(query string) ([]float32, error) {
now := time.Now().UnixNano()
// Check cache first (read lock)
c.queryCacheMu.RLock()
if entry, ok := c.queryCache[query]; ok {
if now-entry.timestamp < c.cacheTTLNano {
c.queryCacheMu.RUnlock()
c.stats.embeddingHits.Add(1)
return entry.embedding, nil
}
}
c.queryCacheMu.RUnlock()
// Cache miss - use singleflight to deduplicate concurrent embedding requests
result, err, _ := c.embeddingGroup.Do(query, func() (any, error) {
// Double-check cache inside singleflight (another goroutine may have just cached it)
c.queryCacheMu.RLock()
if entry, ok := c.queryCache[query]; ok {
if time.Now().UnixNano()-entry.timestamp < c.cacheTTLNano {
c.queryCacheMu.RUnlock()
return entry.embedding, nil
}
}
c.queryCacheMu.RUnlock()
// Record cache miss
c.stats.embeddingMisses.Add(1)
// Compute embedding
emb, err := c.embedSvc.Embed(query)
if err != nil {
return nil, err
}
// Store in cache (write lock)
c.queryCacheMu.Lock()
nowCache := time.Now().UnixNano()
// Evict old entries if cache is full or near capacity (80% threshold)
evictionThreshold := (c.cacheMaxSize * 8) / 10
if len(c.queryCache) >= evictionThreshold {
// Phase 1: Remove ALL expired entries first (not just 10%)
evicted := int64(0)
for k, v := range c.queryCache {
if nowCache-v.timestamp > c.cacheTTLNano {
delete(c.queryCache, k)
evicted++
}
}
// Phase 2: If still at capacity, evict 10% using random iteration (O(n) instead of O(n log n))
// Go map iteration order is randomized, providing good cache behavior without sorting
if len(c.queryCache) >= c.cacheMaxSize {
evictCount := max(c.cacheMaxSize/10, 1)
for k := range c.queryCache {
delete(c.queryCache, k)
evicted++
evictCount--
if evictCount <= 0 {
break
}
}
}
if evicted > 0 {
c.stats.embeddingEvictions.Add(evicted)
}
}
c.queryCache[query] = embeddingCacheEntry{
embedding: emb,
timestamp: nowCache,
}
c.queryCacheMu.Unlock()
return emb, nil
})
if err != nil {
return nil, err
}
return result.([]float32), nil
}
// ClearCache clears the embedding cache.
func (c *Client) ClearCache() {
c.queryCacheMu.Lock()
c.queryCache = make(map[string]embeddingCacheEntry)
c.queryCacheMu.Unlock()
}
// GetCacheStats returns comprehensive cache statistics.
func (c *Client) GetCacheStats() CacheStatsSnapshot {
return c.stats.Snapshot()
}
// CacheStats returns basic cache size info for backward compatibility.
// Deprecated: Use GetCacheStats() for comprehensive statistics.
func (c *Client) CacheStats() (size int, maxSize int) {
c.queryCacheMu.RLock()
size = len(c.queryCache)
c.queryCacheMu.RUnlock()
return size, c.cacheMaxSize
}
// EmbeddingCacheSize returns the current embedding cache size.
func (c *Client) EmbeddingCacheSize() int {
c.queryCacheMu.RLock()
defer c.queryCacheMu.RUnlock()
return len(c.queryCache)
}
// ResultCacheSize returns the current result cache size.
func (c *Client) ResultCacheSize() int {
c.resultCacheMu.RLock()
defer c.resultCacheMu.RUnlock()
return len(c.resultCache)
}
// buildResultCacheKey creates a unique key for caching query results.
// Uses strings.Builder to avoid intermediate allocations.
func (c *Client) buildResultCacheKey(query string, limit int, where map[string]any) string {
// Pre-allocate with typical key size to avoid reallocation
var b strings.Builder
b.Grow(len(query) + 32) // query + typical prefix/suffix overhead
b.WriteString("q:")
b.WriteString(query)
b.WriteString(":l:")
b.WriteString(strconv.Itoa(limit))
if docType, ok := where["doc_type"].(string); ok {
b.WriteString(":dt:")
b.WriteString(docType)
}
if project, ok := where["project"].(string); ok {
b.WriteString(":p:")
b.WriteString(project)
}
return b.String()
}
// getResultFromCache retrieves cached results if available and not expired.
func (c *Client) getResultFromCache(cacheKey string) ([]QueryResult, bool) {
now := time.Now().UnixNano()
c.resultCacheMu.RLock()
entry, ok := c.resultCache[cacheKey]
c.resultCacheMu.RUnlock()
if !ok {
c.stats.resultMisses.Add(1)
return nil, false
}
// Check if entry is expired
if now-entry.timestamp > c.resultCacheTTLNano {
c.stats.resultMisses.Add(1)
return nil, false
}
c.stats.resultHits.Add(1)
// Return a copy to prevent mutation
results := make([]QueryResult, len(entry.results))
copy(results, entry.results)
return results, true
}
// cacheResults stores query results in the cache.
func (c *Client) cacheResults(cacheKey string, results []QueryResult) {
now := time.Now().UnixNano()
c.resultCacheMu.Lock()
defer c.resultCacheMu.Unlock()
// Evict old entries if cache is full
if len(c.resultCache) >= c.resultCacheMaxSize {
// Two-phase eviction: (1) TTL-expired entries, (2) random if still over capacity
evicted := 0
targetSize := c.resultCacheMaxSize * 8 / 10 // Target 80% capacity
// Phase 1: Remove all TTL-expired entries
for k, v := range c.resultCache {
if now-v.timestamp > c.resultCacheTTLNano {
delete(c.resultCache, k)
evicted++
}
}
// Phase 2: If still over target, remove random entries until at target
if len(c.resultCache) >= targetSize {
evictCount := len(c.resultCache) - targetSize + 1
for k := range c.resultCache {
delete(c.resultCache, k)
evicted++
evictCount--
if evictCount <= 0 {
break
}
}
}
if evicted > 0 {
c.stats.resultEvictions.Add(int64(evicted))
}
}
// Make a copy of results to store
resultsCopy := make([]QueryResult, len(results))
copy(resultsCopy, results)
c.resultCache[cacheKey] = resultCacheEntry{
results: resultsCopy,
timestamp: now,
queryHash: cacheKey,
}
}
// InvalidateResultCache clears the result cache.
// Should be called after write operations that modify vectors.
func (c *Client) InvalidateResultCache() {
c.resultCacheMu.Lock()
c.resultCache = make(map[string]resultCacheEntry)
c.resultCacheMu.Unlock()
}
+184
View File
@@ -338,3 +338,187 @@ func (s *Sync) DeletePatterns(ctx context.Context, patternIDs []int64) error {
return nil
}
// BatchSyncConfig configures batch synchronization behavior.
type BatchSyncConfig struct {
BatchSize int // Number of items per batch (default: 50)
ProgressLogFreq int // Log progress every N items (default: 100)
}
// DefaultBatchSyncConfig returns sensible defaults for batch sync.
func DefaultBatchSyncConfig() BatchSyncConfig {
return BatchSyncConfig{
BatchSize: 50,
ProgressLogFreq: 100,
}
}
// BatchSyncObservations syncs multiple observations efficiently in batches.
// This reduces memory pressure during large rebuilds by processing in chunks.
func (s *Sync) BatchSyncObservations(ctx context.Context, observations []*models.Observation, cfg BatchSyncConfig) (synced int, errors int) {
if len(observations) == 0 {
return 0, 0
}
if cfg.BatchSize <= 0 {
cfg.BatchSize = 50
}
if cfg.ProgressLogFreq <= 0 {
cfg.ProgressLogFreq = 100
}
for i := 0; i < len(observations); i += cfg.BatchSize {
// Check context cancellation
select {
case <-ctx.Done():
log.Warn().Int("synced", synced).Int("remaining", len(observations)-i).Msg("Batch sync cancelled")
return synced, errors
default:
}
end := min(i+cfg.BatchSize, len(observations))
batch := observations[i:end]
var docs []Document
// Collect all documents for this batch
for _, obs := range batch {
docs = append(docs, s.formatObservationDocs(obs)...)
}
// Add all documents in one call
if len(docs) > 0 {
if err := s.client.AddDocuments(ctx, docs); err != nil {
log.Warn().Err(err).Int("batchStart", i).Int("batchSize", len(batch)).Msg("Failed to sync observation batch")
errors += len(batch)
continue
}
}
synced += len(batch)
// Log progress periodically
if synced%cfg.ProgressLogFreq == 0 || synced == len(observations) {
log.Debug().Int("synced", synced).Int("total", len(observations)).Msg("Observation batch sync progress")
}
}
return synced, errors
}
// BatchSyncSummaries syncs multiple summaries efficiently in batches.
func (s *Sync) BatchSyncSummaries(ctx context.Context, summaries []*models.SessionSummary, cfg BatchSyncConfig) (synced int, errors int) {
if len(summaries) == 0 {
return 0, 0
}
if cfg.BatchSize <= 0 {
cfg.BatchSize = 50
}
if cfg.ProgressLogFreq <= 0 {
cfg.ProgressLogFreq = 100
}
for i := 0; i < len(summaries); i += cfg.BatchSize {
// Check context cancellation
select {
case <-ctx.Done():
log.Warn().Int("synced", synced).Int("remaining", len(summaries)-i).Msg("Batch sync cancelled")
return synced, errors
default:
}
end := min(i+cfg.BatchSize, len(summaries))
batch := summaries[i:end]
var docs []Document
// Collect all documents for this batch
for _, summary := range batch {
docs = append(docs, s.formatSummaryDocs(summary)...)
}
// Add all documents in one call
if len(docs) > 0 {
if err := s.client.AddDocuments(ctx, docs); err != nil {
log.Warn().Err(err).Int("batchStart", i).Int("batchSize", len(batch)).Msg("Failed to sync summary batch")
errors += len(batch)
continue
}
}
synced += len(batch)
// Log progress periodically
if synced%cfg.ProgressLogFreq == 0 || synced == len(summaries) {
log.Debug().Int("synced", synced).Int("total", len(summaries)).Msg("Summary batch sync progress")
}
}
return synced, errors
}
// BatchSyncPrompts syncs multiple user prompts efficiently in batches.
func (s *Sync) BatchSyncPrompts(ctx context.Context, prompts []*models.UserPromptWithSession, cfg BatchSyncConfig) (synced int, errors int) {
if len(prompts) == 0 {
return 0, 0
}
if cfg.BatchSize <= 0 {
cfg.BatchSize = 50
}
if cfg.ProgressLogFreq <= 0 {
cfg.ProgressLogFreq = 100
}
for i := 0; i < len(prompts); i += cfg.BatchSize {
// Check context cancellation
select {
case <-ctx.Done():
log.Warn().Int("synced", synced).Int("remaining", len(prompts)-i).Msg("Batch sync cancelled")
return synced, errors
default:
}
end := min(i+cfg.BatchSize, len(prompts))
batch := prompts[i:end]
docs := make([]Document, 0, len(batch))
// Collect all documents for this batch
for _, prompt := range batch {
docs = append(docs, Document{
ID: fmt.Sprintf("prompt_%d", prompt.ID),
Content: prompt.PromptText,
Metadata: map[string]any{
"sqlite_id": prompt.ID,
"doc_type": "user_prompt",
"sdk_session_id": prompt.SDKSessionID,
"project": prompt.Project,
"scope": "",
"created_at_epoch": prompt.CreatedAtEpoch,
"prompt_number": prompt.PromptNumber,
"field_type": "prompt",
},
})
}
// Add all documents in one call
if len(docs) > 0 {
if err := s.client.AddDocuments(ctx, docs); err != nil {
log.Warn().Err(err).Int("batchStart", i).Int("batchSize", len(batch)).Msg("Failed to sync prompt batch")
errors += len(batch)
continue
}
}
synced += len(batch)
// Log progress periodically
if synced%cfg.ProgressLogFreq == 0 || synced == len(prompts) {
log.Debug().Int("synced", synced).Int("total", len(prompts)).Msg("Prompt batch sync progress")
}
}
return synced, errors
}