package sqlitevec import ( "context" "database/sql" "fmt" "os" "path/filepath" "sync" "testing" "time" sqlite_vec "github.com/asg017/sqlite-vec-go-bindings/cgo" "github.com/lukaszraczylo/claude-mnemonic/internal/embedding" _ "github.com/mattn/go-sqlite3" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" ) // testDB creates a test SQLite database with the vectors table. func testDB(t *testing.T) (*sql.DB, func()) { t.Helper() // Create temp directory tmpDir, err := os.MkdirTemp("", "sqlitevec-test-*") require.NoError(t, err) dbPath := filepath.Join(tmpDir, "test.db") db, err := sql.Open("sqlite3", dbPath) require.NoError(t, err) // Enable sqlite-vec sqlite_vec.Auto() // Create vectors table (matches production schema) _, err = db.Exec(` CREATE VIRTUAL TABLE IF NOT EXISTS vectors USING vec0( doc_id TEXT PRIMARY KEY, embedding float[384], sqlite_id INTEGER, doc_type TEXT, field_type TEXT, project TEXT, scope TEXT, model_version TEXT ) `) require.NoError(t, err) cleanup := func() { db.Close() os.RemoveAll(tmpDir) } return db, cleanup } // testEmbeddingService creates a test embedding service. func testEmbeddingService(t *testing.T) (*embedding.Service, func()) { t.Helper() svc, err := embedding.NewService() require.NoError(t, err) cleanup := func() { svc.Close() } return svc, cleanup } func TestNewClient_Success(t *testing.T) { db, dbCleanup := testDB(t) defer dbCleanup() embedSvc, embedCleanup := testEmbeddingService(t) defer embedCleanup() client, err := NewClient(Config{DB: db}, embedSvc) require.NoError(t, err) assert.NotNil(t, client) } func TestNewClient_NilDB(t *testing.T) { embedSvc, embedCleanup := testEmbeddingService(t) defer embedCleanup() client, err := NewClient(Config{DB: nil}, embedSvc) assert.Error(t, err) assert.Nil(t, client) assert.Contains(t, err.Error(), "database connection required") } func TestNewClient_NilEmbedding(t *testing.T) { db, dbCleanup := testDB(t) defer dbCleanup() client, err := NewClient(Config{DB: db}, nil) assert.Error(t, err) assert.Nil(t, client) assert.Contains(t, err.Error(), "embedding service required") } func TestClient_AddDocuments_Empty(t *testing.T) { db, dbCleanup := testDB(t) defer dbCleanup() embedSvc, embedCleanup := testEmbeddingService(t) defer embedCleanup() client, err := NewClient(Config{DB: db}, embedSvc) require.NoError(t, err) err = client.AddDocuments(context.Background(), []Document{}) require.NoError(t, err) } func TestClient_AddDocuments_Single(t *testing.T) { db, dbCleanup := testDB(t) defer dbCleanup() embedSvc, embedCleanup := testEmbeddingService(t) defer embedCleanup() client, err := NewClient(Config{DB: db}, embedSvc) require.NoError(t, err) docs := []Document{ { ID: "obs-1-title", Content: "This is a test observation about authentication.", Metadata: map[string]any{ "sqlite_id": int64(1), "doc_type": "observation", "field_type": "title", "project": "test-project", "scope": "project", }, }, } err = client.AddDocuments(context.Background(), docs) require.NoError(t, err) // Verify document was inserted var count int err = db.QueryRow("SELECT COUNT(*) FROM vectors WHERE doc_id = ?", "obs-1-title").Scan(&count) require.NoError(t, err) assert.Equal(t, 1, count) } func TestClient_AddDocuments_Multiple(t *testing.T) { db, dbCleanup := testDB(t) defer dbCleanup() embedSvc, embedCleanup := testEmbeddingService(t) defer embedCleanup() client, err := NewClient(Config{DB: db}, embedSvc) require.NoError(t, err) docs := []Document{ { ID: "obs-1-title", Content: "Authentication flow implementation.", Metadata: map[string]any{ "sqlite_id": int64(1), "doc_type": "observation", "field_type": "title", "project": "test-project", "scope": "project", }, }, { ID: "obs-1-narrative", Content: "We implemented JWT-based authentication.", Metadata: map[string]any{ "sqlite_id": int64(1), "doc_type": "observation", "field_type": "narrative", "project": "test-project", "scope": "project", }, }, { ID: "obs-2-title", Content: "Database optimization.", Metadata: map[string]any{ "sqlite_id": int64(2), "doc_type": "observation", "field_type": "title", "project": "test-project", "scope": "global", }, }, } err = client.AddDocuments(context.Background(), docs) require.NoError(t, err) // Verify all documents were inserted var count int err = db.QueryRow("SELECT COUNT(*) FROM vectors").Scan(&count) require.NoError(t, err) assert.Equal(t, 3, count) } func TestClient_DeleteDocuments_Empty(t *testing.T) { db, dbCleanup := testDB(t) defer dbCleanup() embedSvc, embedCleanup := testEmbeddingService(t) defer embedCleanup() client, err := NewClient(Config{DB: db}, embedSvc) require.NoError(t, err) err = client.DeleteDocuments(context.Background(), []string{}) require.NoError(t, err) } func TestClient_DeleteDocuments_Existing(t *testing.T) { db, dbCleanup := testDB(t) defer dbCleanup() embedSvc, embedCleanup := testEmbeddingService(t) defer embedCleanup() client, err := NewClient(Config{DB: db}, embedSvc) require.NoError(t, err) // Add documents first docs := []Document{ { ID: "doc-1", Content: "First document.", Metadata: map[string]any{ "sqlite_id": int64(1), "doc_type": "observation", }, }, { ID: "doc-2", Content: "Second document.", Metadata: map[string]any{ "sqlite_id": int64(2), "doc_type": "observation", }, }, } err = client.AddDocuments(context.Background(), docs) require.NoError(t, err) // Delete one document err = client.DeleteDocuments(context.Background(), []string{"doc-1"}) require.NoError(t, err) // Verify only one remains var count int err = db.QueryRow("SELECT COUNT(*) FROM vectors").Scan(&count) require.NoError(t, err) assert.Equal(t, 1, count) } func TestClient_Query_Basic(t *testing.T) { db, dbCleanup := testDB(t) defer dbCleanup() embedSvc, embedCleanup := testEmbeddingService(t) defer embedCleanup() client, err := NewClient(Config{DB: db}, embedSvc) require.NoError(t, err) // Add some test documents docs := []Document{ { ID: "obs-1", Content: "Authentication and login security implementation.", Metadata: map[string]any{ "sqlite_id": int64(1), "doc_type": "observation", "project": "test-project", "scope": "project", }, }, { ID: "obs-2", Content: "Database query optimization techniques.", Metadata: map[string]any{ "sqlite_id": int64(2), "doc_type": "observation", "project": "test-project", "scope": "project", }, }, } err = client.AddDocuments(context.Background(), docs) require.NoError(t, err) // Query for authentication-related content results, err := client.Query(context.Background(), "login authentication", 10, nil) require.NoError(t, err) assert.NotEmpty(t, results) assert.LessOrEqual(t, len(results), 10) // First result should be the authentication document (higher similarity) assert.Equal(t, "obs-1", results[0].ID) } func TestClient_Query_WithDocTypeFilter(t *testing.T) { db, dbCleanup := testDB(t) defer dbCleanup() embedSvc, embedCleanup := testEmbeddingService(t) defer embedCleanup() client, err := NewClient(Config{DB: db}, embedSvc) require.NoError(t, err) // Add documents of different types docs := []Document{ { ID: "obs-1", Content: "Test content for observation.", Metadata: map[string]any{ "sqlite_id": int64(1), "doc_type": "observation", "project": "test-project", }, }, { ID: "summary-1", Content: "Test content for summary.", Metadata: map[string]any{ "sqlite_id": int64(10), "doc_type": "session_summary", "project": "test-project", }, }, } err = client.AddDocuments(context.Background(), docs) require.NoError(t, err) // Query with doc_type filter where := map[string]any{"doc_type": "observation"} results, err := client.Query(context.Background(), "test content", 10, where) require.NoError(t, err) // Should only return observation documents for _, r := range results { docType, _ := r.Metadata["doc_type"].(string) assert.Equal(t, "observation", docType) } } func TestClient_Query_WithProjectFilter(t *testing.T) { db, dbCleanup := testDB(t) defer dbCleanup() embedSvc, embedCleanup := testEmbeddingService(t) defer embedCleanup() client, err := NewClient(Config{DB: db}, embedSvc) require.NoError(t, err) // Add documents from different projects docs := []Document{ { ID: "obs-1", Content: "Project A authentication content.", Metadata: map[string]any{ "sqlite_id": int64(1), "doc_type": "observation", "project": "project-a", "scope": "project", }, }, { ID: "obs-2", Content: "Project B database content.", Metadata: map[string]any{ "sqlite_id": int64(2), "doc_type": "observation", "project": "project-b", "scope": "project", }, }, { ID: "obs-3", Content: "Global security best practices.", Metadata: map[string]any{ "sqlite_id": int64(3), "doc_type": "observation", "project": "project-b", "scope": "global", }, }, } err = client.AddDocuments(context.Background(), docs) require.NoError(t, err) // Query without project filter to verify all docs are there results, err := client.Query(context.Background(), "authentication security", 10, nil) require.NoError(t, err) assert.NotEmpty(t, results, "Should find some results") } func TestClient_IsConnected(t *testing.T) { db, dbCleanup := testDB(t) defer dbCleanup() embedSvc, embedCleanup := testEmbeddingService(t) defer embedCleanup() client, err := NewClient(Config{DB: db}, embedSvc) require.NoError(t, err) assert.True(t, client.IsConnected()) } func TestClient_Close(t *testing.T) { db, dbCleanup := testDB(t) defer dbCleanup() embedSvc, embedCleanup := testEmbeddingService(t) defer embedCleanup() client, err := NewClient(Config{DB: db}, embedSvc) require.NoError(t, err) err = client.Close() require.NoError(t, err) } func TestConfig_Fields(t *testing.T) { db, dbCleanup := testDB(t) defer dbCleanup() cfg := Config{DB: db} assert.Equal(t, db, cfg.DB) } func TestClient_UpdateDocument_DeleteThenAdd(t *testing.T) { db, dbCleanup := testDB(t) defer dbCleanup() embedSvc, embedCleanup := testEmbeddingService(t) defer embedCleanup() client, err := NewClient(Config{DB: db}, embedSvc) require.NoError(t, err) // Add document docs1 := []Document{ { ID: "doc-1", Content: "Original content.", Metadata: map[string]any{ "sqlite_id": int64(1), "doc_type": "observation", }, }, } err = client.AddDocuments(context.Background(), docs1) require.NoError(t, err) // Delete then add with new content (proper update pattern) err = client.DeleteDocuments(context.Background(), []string{"doc-1"}) require.NoError(t, err) docs2 := []Document{ { ID: "doc-1", Content: "Updated content.", Metadata: map[string]any{ "sqlite_id": int64(1), "doc_type": "observation", }, }, } err = client.AddDocuments(context.Background(), docs2) require.NoError(t, err) // Should have exactly 1 document var count int err = db.QueryRow("SELECT COUNT(*) FROM vectors WHERE doc_id = ?", "doc-1").Scan(&count) require.NoError(t, err) assert.Equal(t, 1, count) } func TestClient_DeleteDocuments_NonExistent(t *testing.T) { db, dbCleanup := testDB(t) defer dbCleanup() embedSvc, embedCleanup := testEmbeddingService(t) defer embedCleanup() client, err := NewClient(Config{DB: db}, embedSvc) require.NoError(t, err) // Deleting non-existent document should not error err = client.DeleteDocuments(context.Background(), []string{"non-existent-id"}) require.NoError(t, err) } func TestClient_Count_Empty(t *testing.T) { db, dbCleanup := testDB(t) defer dbCleanup() embedSvc, embedCleanup := testEmbeddingService(t) defer embedCleanup() client, err := NewClient(Config{DB: db}, embedSvc) require.NoError(t, err) count, err := client.Count(context.Background()) require.NoError(t, err) assert.Equal(t, int64(0), count) } func TestClient_Count_WithVectors(t *testing.T) { db, dbCleanup := testDB(t) defer dbCleanup() embedSvc, embedCleanup := testEmbeddingService(t) defer embedCleanup() client, err := NewClient(Config{DB: db}, embedSvc) require.NoError(t, err) // Add some documents docs := []Document{ {ID: "doc-1", Content: "test content 1"}, {ID: "doc-2", Content: "test content 2"}, {ID: "doc-3", Content: "test content 3"}, } err = client.AddDocuments(context.Background(), docs) require.NoError(t, err) count, err := client.Count(context.Background()) require.NoError(t, err) assert.Equal(t, int64(3), count) } func TestClient_ModelVersion(t *testing.T) { db, dbCleanup := testDB(t) defer dbCleanup() embedSvc, embedCleanup := testEmbeddingService(t) defer embedCleanup() client, err := NewClient(Config{DB: db}, embedSvc) require.NoError(t, err) version := client.ModelVersion() assert.NotEmpty(t, version) // Should match the embedding service version assert.Equal(t, embedSvc.Version(), version) } func TestClient_NeedsRebuild_EmptyDatabase(t *testing.T) { db, dbCleanup := testDB(t) defer dbCleanup() embedSvc, embedCleanup := testEmbeddingService(t) defer embedCleanup() client, err := NewClient(Config{DB: db}, embedSvc) require.NoError(t, err) needsRebuild, reason := client.NeedsRebuild(context.Background()) assert.True(t, needsRebuild) assert.Equal(t, "empty", reason) } func TestClient_NeedsRebuild_ModelMismatch(t *testing.T) { db, dbCleanup := testDB(t) defer dbCleanup() embedSvc, embedCleanup := testEmbeddingService(t) defer embedCleanup() client, err := NewClient(Config{DB: db}, embedSvc) require.NoError(t, err) // Insert vectors with wrong model version embedding := make([]float32, 384) for i := range embedding { embedding[i] = 0.1 } embeddingBytes, err := sqlite_vec.SerializeFloat32(embedding) require.NoError(t, err) _, err = db.Exec(` INSERT INTO vectors (doc_id, embedding, model_version, sqlite_id, doc_type, field_type, project, scope) VALUES (?, ?, ?, ?, ?, ?, ?, ?) `, "doc-1", embeddingBytes, "old-model-v1", 1, "observation", "content", "test", "project") require.NoError(t, err) _, err = db.Exec(` INSERT INTO vectors (doc_id, embedding, model_version, sqlite_id, doc_type, field_type, project, scope) VALUES (?, ?, ?, ?, ?, ?, ?, ?) `, "doc-2", embeddingBytes, "old-model-v1", 2, "observation", "content", "test", "project") require.NoError(t, err) needsRebuild, reason := client.NeedsRebuild(context.Background()) assert.True(t, needsRebuild) assert.Contains(t, reason, "model_mismatch:2") } func TestClient_NeedsRebuild_CurrentModel(t *testing.T) { db, dbCleanup := testDB(t) defer dbCleanup() embedSvc, embedCleanup := testEmbeddingService(t) defer embedCleanup() client, err := NewClient(Config{DB: db}, embedSvc) require.NoError(t, err) // Add documents with current model version docs := []Document{ {ID: "doc-1", Content: "test content 1"}, {ID: "doc-2", Content: "test content 2"}, } err = client.AddDocuments(context.Background(), docs) require.NoError(t, err) needsRebuild, reason := client.NeedsRebuild(context.Background()) assert.False(t, needsRebuild) assert.Empty(t, reason) } func TestClient_GetStaleVectors_Empty(t *testing.T) { db, dbCleanup := testDB(t) defer dbCleanup() embedSvc, embedCleanup := testEmbeddingService(t) defer embedCleanup() client, err := NewClient(Config{DB: db}, embedSvc) require.NoError(t, err) stale, err := client.GetStaleVectors(context.Background()) require.NoError(t, err) assert.Empty(t, stale) } func TestClient_GetStaleVectors_WithMismatch(t *testing.T) { db, dbCleanup := testDB(t) defer dbCleanup() embedSvc, embedCleanup := testEmbeddingService(t) defer embedCleanup() client, err := NewClient(Config{DB: db}, embedSvc) require.NoError(t, err) // Insert vectors with wrong model version embedding := make([]float32, 384) embeddingBytes, err := sqlite_vec.SerializeFloat32(embedding) require.NoError(t, err) _, err = db.Exec(` INSERT INTO vectors (doc_id, embedding, model_version, sqlite_id, doc_type, field_type, project, scope) VALUES (?, ?, ?, ?, ?, ?, ?, ?) `, "doc-1", embeddingBytes, "old-model", 1, "observation", "content", "project-1", "project") require.NoError(t, err) _, err = db.Exec(` INSERT INTO vectors (doc_id, embedding, model_version, sqlite_id, doc_type, field_type, project, scope) VALUES (?, ?, ?, ?, ?, ?, ?, ?) `, "doc-2", embeddingBytes, embedSvc.Version(), 2, "observation", "title", "project-1", "project") require.NoError(t, err) stale, err := client.GetStaleVectors(context.Background()) require.NoError(t, err) assert.Len(t, stale, 1) assert.Equal(t, "doc-1", stale[0].DocID) assert.Equal(t, int64(1), stale[0].SQLiteID) assert.Equal(t, "observation", stale[0].DocType) assert.Equal(t, "content", stale[0].FieldType) assert.Equal(t, "project-1", stale[0].Project) assert.Equal(t, "project", stale[0].Scope) } func TestClient_DeleteVectorsByDocIDs_Empty(t *testing.T) { db, dbCleanup := testDB(t) defer dbCleanup() embedSvc, embedCleanup := testEmbeddingService(t) defer embedCleanup() client, err := NewClient(Config{DB: db}, embedSvc) require.NoError(t, err) // Deleting empty slice should not error err = client.DeleteVectorsByDocIDs(context.Background(), []string{}) require.NoError(t, err) } func TestClient_DeleteVectorsByDocIDs_Success(t *testing.T) { db, dbCleanup := testDB(t) defer dbCleanup() embedSvc, embedCleanup := testEmbeddingService(t) defer embedCleanup() client, err := NewClient(Config{DB: db}, embedSvc) require.NoError(t, err) // Add documents docs := []Document{ {ID: "doc-1", Content: "test 1"}, {ID: "doc-2", Content: "test 2"}, {ID: "doc-3", Content: "test 3"}, } err = client.AddDocuments(context.Background(), docs) require.NoError(t, err) // Verify 3 documents exist count, err := client.Count(context.Background()) require.NoError(t, err) assert.Equal(t, int64(3), count) // Delete doc-1 and doc-3 err = client.DeleteVectorsByDocIDs(context.Background(), []string{"doc-1", "doc-3"}) require.NoError(t, err) // Should have 1 document remaining count, err = client.Count(context.Background()) require.NoError(t, err) assert.Equal(t, int64(1), count) // Verify doc-2 still exists var exists int err = db.QueryRow("SELECT COUNT(*) FROM vectors WHERE doc_id = ?", "doc-2").Scan(&exists) require.NoError(t, err) assert.Equal(t, 1, exists) } func TestClient_DeleteVectorsByDocIDs_NonExistent(t *testing.T) { db, dbCleanup := testDB(t) defer dbCleanup() embedSvc, embedCleanup := testEmbeddingService(t) defer embedCleanup() client, err := NewClient(Config{DB: db}, embedSvc) require.NoError(t, err) // Deleting non-existent IDs should not error err = client.DeleteVectorsByDocIDs(context.Background(), []string{"non-existent-1", "non-existent-2"}) require.NoError(t, err) } // ============================================================================= // TESTS FOR CacheStats // ============================================================================= func TestCacheStatsSnapshot_HitRate_NoOperations(t *testing.T) { snapshot := CacheStatsSnapshot{} assert.Equal(t, float64(0), snapshot.HitRate()) } func TestCacheStatsSnapshot_HitRate_WithOperations(t *testing.T) { tests := []struct { name string stats CacheStatsSnapshot expected float64 }{ { name: "all_hits", stats: CacheStatsSnapshot{ EmbeddingHits: 50, ResultHits: 50, }, expected: 100.0, }, { name: "no_hits", stats: CacheStatsSnapshot{ EmbeddingMisses: 50, ResultMisses: 50, }, expected: 0.0, }, { name: "50_percent_hits", stats: CacheStatsSnapshot{ EmbeddingHits: 25, EmbeddingMisses: 25, ResultHits: 25, ResultMisses: 25, }, expected: 50.0, }, { name: "75_percent_hits", stats: CacheStatsSnapshot{ EmbeddingHits: 30, EmbeddingMisses: 10, ResultHits: 30, ResultMisses: 10, }, expected: 75.0, }, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { result := tt.stats.HitRate() assert.InDelta(t, tt.expected, result, 0.01) }) } } func TestCacheStats_HitRate_NoOperations(t *testing.T) { stats := &CacheStats{} assert.Equal(t, float64(0), stats.HitRate()) } func TestCacheStats_HitRate_WithOperations(t *testing.T) { stats := &CacheStats{} stats.embeddingHits.Add(10) stats.embeddingMisses.Add(10) stats.resultHits.Add(10) stats.resultMisses.Add(10) // 20 hits / 40 total = 50% assert.InDelta(t, 50.0, stats.HitRate(), 0.01) } func TestCacheStats_Snapshot(t *testing.T) { stats := &CacheStats{} stats.embeddingHits.Add(10) stats.embeddingMisses.Add(5) stats.resultHits.Add(20) stats.resultMisses.Add(15) stats.embeddingEvictions.Add(2) stats.resultEvictions.Add(3) snapshot := stats.Snapshot() assert.Equal(t, int64(10), snapshot.EmbeddingHits) assert.Equal(t, int64(5), snapshot.EmbeddingMisses) assert.Equal(t, int64(20), snapshot.ResultHits) assert.Equal(t, int64(15), snapshot.ResultMisses) assert.Equal(t, int64(2), snapshot.EmbeddingEvictions) assert.Equal(t, int64(3), snapshot.ResultEvictions) } // ============================================================================= // TESTS FOR Cache Methods // ============================================================================= func TestClient_ClearCache(t *testing.T) { db, dbCleanup := testDB(t) defer dbCleanup() embedSvc, embedCleanup := testEmbeddingService(t) defer embedCleanup() client, err := NewClient(Config{DB: db}, embedSvc) require.NoError(t, err) // Add a document and query to populate cache docs := []Document{ {ID: "doc-1", Content: "test content for caching"}, } err = client.AddDocuments(context.Background(), docs) require.NoError(t, err) // Query to populate cache _, err = client.Query(context.Background(), "test content", 5, nil) require.NoError(t, err) // Verify cache has entries initialSize := client.EmbeddingCacheSize() assert.Greater(t, initialSize, 0) // Clear cache client.ClearCache() // Verify cache is empty assert.Equal(t, 0, client.EmbeddingCacheSize()) } func TestClient_GetCacheStats(t *testing.T) { db, dbCleanup := testDB(t) defer dbCleanup() embedSvc, embedCleanup := testEmbeddingService(t) defer embedCleanup() client, err := NewClient(Config{DB: db}, embedSvc) require.NoError(t, err) // Get stats before any operations stats := client.GetCacheStats() assert.Equal(t, int64(0), stats.EmbeddingHits) assert.Equal(t, int64(0), stats.EmbeddingMisses) // Add a document and query to generate cache activity docs := []Document{ {ID: "doc-1", Content: "test content for caching"}, } err = client.AddDocuments(context.Background(), docs) require.NoError(t, err) // Query - should be a miss first time _, err = client.Query(context.Background(), "test content", 5, nil) require.NoError(t, err) // Query again - should be a hit _, err = client.Query(context.Background(), "test content", 5, nil) require.NoError(t, err) // Get stats after operations stats = client.GetCacheStats() assert.Greater(t, stats.EmbeddingMisses+stats.EmbeddingHits, int64(0)) } func TestClient_CacheStats(t *testing.T) { db, dbCleanup := testDB(t) defer dbCleanup() embedSvc, embedCleanup := testEmbeddingService(t) defer embedCleanup() client, err := NewClient(Config{DB: db}, embedSvc) require.NoError(t, err) // Get initial stats size, maxSize := client.CacheStats() assert.Equal(t, 0, size) assert.Greater(t, maxSize, 0) // Add a document and query to populate cache docs := []Document{ {ID: "doc-1", Content: "test content"}, } err = client.AddDocuments(context.Background(), docs) require.NoError(t, err) _, err = client.Query(context.Background(), "test content", 5, nil) require.NoError(t, err) // Check stats after operations size, _ = client.CacheStats() assert.Greater(t, size, 0) } func TestClient_EmbeddingCacheSize(t *testing.T) { db, dbCleanup := testDB(t) defer dbCleanup() embedSvc, embedCleanup := testEmbeddingService(t) defer embedCleanup() client, err := NewClient(Config{DB: db}, embedSvc) require.NoError(t, err) // Initially empty assert.Equal(t, 0, client.EmbeddingCacheSize()) // Add a document and query docs := []Document{ {ID: "doc-1", Content: "test content"}, } err = client.AddDocuments(context.Background(), docs) require.NoError(t, err) _, err = client.Query(context.Background(), "unique query", 5, nil) require.NoError(t, err) // Should have at least one entry assert.Greater(t, client.EmbeddingCacheSize(), 0) } func TestClient_ResultCacheSize(t *testing.T) { db, dbCleanup := testDB(t) defer dbCleanup() embedSvc, embedCleanup := testEmbeddingService(t) defer embedCleanup() client, err := NewClient(Config{DB: db}, embedSvc) require.NoError(t, err) // Initially empty assert.Equal(t, 0, client.ResultCacheSize()) } // ============================================================================= // TESTS FOR QueryBatch // ============================================================================= func TestClient_QueryBatch_Empty(t *testing.T) { db, dbCleanup := testDB(t) defer dbCleanup() embedSvc, embedCleanup := testEmbeddingService(t) defer embedCleanup() client, err := NewClient(Config{DB: db}, embedSvc) require.NoError(t, err) results := client.QueryBatch(context.Background(), []string{}, 10, nil) assert.Nil(t, results) } func TestClient_QueryBatch_Single(t *testing.T) { db, dbCleanup := testDB(t) defer dbCleanup() embedSvc, embedCleanup := testEmbeddingService(t) defer embedCleanup() client, err := NewClient(Config{DB: db}, embedSvc) require.NoError(t, err) // Add some documents docs := []Document{ { ID: "obs-1", Content: "Authentication and security implementation.", Metadata: map[string]any{"doc_type": "observation"}, }, } err = client.AddDocuments(context.Background(), docs) require.NoError(t, err) // Query batch with single query results := client.QueryBatch(context.Background(), []string{"authentication"}, 10, nil) assert.Len(t, results, 1) assert.NoError(t, results[0].Error) assert.Equal(t, "authentication", results[0].Query) } func TestClient_QueryBatch_Multiple(t *testing.T) { db, dbCleanup := testDB(t) defer dbCleanup() embedSvc, embedCleanup := testEmbeddingService(t) defer embedCleanup() client, err := NewClient(Config{DB: db}, embedSvc) require.NoError(t, err) // Add some documents docs := []Document{ {ID: "obs-1", Content: "Authentication and security implementation."}, {ID: "obs-2", Content: "Database optimization and indexing."}, {ID: "obs-3", Content: "API rate limiting and throttling."}, } err = client.AddDocuments(context.Background(), docs) require.NoError(t, err) // Query batch with multiple queries queries := []string{"authentication", "database", "API"} results := client.QueryBatch(context.Background(), queries, 10, nil) assert.Len(t, results, 3) for i, r := range results { assert.NoError(t, r.Error) assert.Equal(t, queries[i], r.Query) } } func TestClient_QueryBatch_WithContextCancellation(t *testing.T) { db, dbCleanup := testDB(t) defer dbCleanup() embedSvc, embedCleanup := testEmbeddingService(t) defer embedCleanup() client, err := NewClient(Config{DB: db}, embedSvc) require.NoError(t, err) // Cancel context immediately ctx, cancel := context.WithCancel(context.Background()) cancel() // Queries should fail due to cancelled context queries := []string{"query1", "query2", "query3"} results := client.QueryBatch(ctx, queries, 10, nil) assert.Len(t, results, 3) // At least some should have context cancellation error hasError := false for _, r := range results { if r.Error != nil { hasError = true } } assert.True(t, hasError, "Should have at least one error due to cancelled context") } // ============================================================================= // TESTS FOR QueryMultiField // ============================================================================= func TestClient_QueryMultiField_Basic(t *testing.T) { t.Skip("QueryMultiField SQL query needs 'k' parameter fix for sqlite-vec") db, dbCleanup := testDB(t) defer dbCleanup() embedSvc, embedCleanup := testEmbeddingService(t) defer embedCleanup() client, err := NewClient(Config{DB: db}, embedSvc) require.NoError(t, err) // Add documents with different field types for same sqlite_id docs := []Document{ { ID: "obs-1-title", Content: "Authentication implementation", Metadata: map[string]any{ "sqlite_id": int64(1), "doc_type": "observation", "field_type": "title", "project": "test-project", "scope": "project", }, }, { ID: "obs-1-narrative", Content: "We implemented JWT-based authentication for the API.", Metadata: map[string]any{ "sqlite_id": int64(1), "doc_type": "observation", "field_type": "narrative", "project": "test-project", "scope": "project", }, }, { ID: "obs-2-title", Content: "Database optimization", Metadata: map[string]any{ "sqlite_id": int64(2), "doc_type": "observation", "field_type": "title", "project": "test-project", "scope": "project", }, }, } err = client.AddDocuments(context.Background(), docs) require.NoError(t, err) // Query multi-field results, err := client.QueryMultiField(context.Background(), "authentication JWT", 10, "observation", "test-project") require.NoError(t, err) // Should return deduplicated results (one per sqlite_id) assert.NotEmpty(t, results) // Each result should have unique sqlite_id seenIDs := make(map[float64]bool) for _, r := range results { sqliteID, ok := r.Metadata["sqlite_id"].(float64) if ok { assert.False(t, seenIDs[sqliteID], "Should not have duplicate sqlite_ids") seenIDs[sqliteID] = true } } } func TestClient_QueryMultiField_WithGlobalScope(t *testing.T) { t.Skip("QueryMultiField SQL query needs 'k' parameter fix for sqlite-vec") db, dbCleanup := testDB(t) defer dbCleanup() embedSvc, embedCleanup := testEmbeddingService(t) defer embedCleanup() client, err := NewClient(Config{DB: db}, embedSvc) require.NoError(t, err) // Add documents: one project-scoped, one global docs := []Document{ { ID: "obs-1-title", Content: "Security best practices", Metadata: map[string]any{ "sqlite_id": int64(1), "doc_type": "observation", "field_type": "title", "project": "project-a", "scope": "project", }, }, { ID: "obs-2-title", Content: "Security patterns for all projects", Metadata: map[string]any{ "sqlite_id": int64(2), "doc_type": "observation", "field_type": "title", "project": "project-b", "scope": "global", }, }, } err = client.AddDocuments(context.Background(), docs) require.NoError(t, err) // Query from project-a - should get project-a doc and global doc results, err := client.QueryMultiField(context.Background(), "security", 10, "observation", "project-a") require.NoError(t, err) // Should include both project-scoped (matching project) and global assert.NotEmpty(t, results) } // ============================================================================= // TESTS FOR GetHealthStats // ============================================================================= func TestClient_GetHealthStats_Empty(t *testing.T) { db, dbCleanup := testDB(t) defer dbCleanup() embedSvc, embedCleanup := testEmbeddingService(t) defer embedCleanup() client, err := NewClient(Config{DB: db}, embedSvc) require.NoError(t, err) stats, err := client.GetHealthStats(context.Background()) require.NoError(t, err) assert.NotNil(t, stats) assert.Equal(t, int64(0), stats.TotalVectors) assert.Equal(t, int64(0), stats.StaleVectors) assert.Equal(t, embedSvc.Version(), stats.CurrentModel) assert.True(t, stats.NeedsRebuild) assert.Equal(t, "empty", stats.RebuildReason) } func TestClient_GetHealthStats_WithData(t *testing.T) { db, dbCleanup := testDB(t) defer dbCleanup() embedSvc, embedCleanup := testEmbeddingService(t) defer embedCleanup() client, err := NewClient(Config{DB: db}, embedSvc) require.NoError(t, err) // Add some documents docs := []Document{ { ID: "obs-1", Content: "Test content 1", Metadata: map[string]any{ "sqlite_id": int64(1), "doc_type": "observation", "project": "project-a", }, }, { ID: "obs-2", Content: "Test content 2", Metadata: map[string]any{ "sqlite_id": int64(2), "doc_type": "observation", "project": "project-a", }, }, { ID: "sum-1", Content: "Summary content", Metadata: map[string]any{ "sqlite_id": int64(10), "doc_type": "session_summary", "project": "project-b", }, }, } err = client.AddDocuments(context.Background(), docs) require.NoError(t, err) stats, err := client.GetHealthStats(context.Background()) require.NoError(t, err) assert.NotNil(t, stats) assert.Equal(t, int64(3), stats.TotalVectors) assert.Equal(t, int64(0), stats.StaleVectors) // All fresh assert.False(t, stats.NeedsRebuild) // Coverage by type assert.Equal(t, int64(2), stats.CoverageByType["observation"]) assert.Equal(t, int64(1), stats.CoverageByType["session_summary"]) // Model versions assert.Equal(t, int64(3), stats.ModelVersions[embedSvc.Version()]) // Project counts assert.Equal(t, int64(2), stats.ProjectCounts["project-a"]) assert.Equal(t, int64(1), stats.ProjectCounts["project-b"]) } func TestClient_GetHealthStats_WithStaleVectors(t *testing.T) { db, dbCleanup := testDB(t) defer dbCleanup() embedSvc, embedCleanup := testEmbeddingService(t) defer embedCleanup() client, err := NewClient(Config{DB: db}, embedSvc) require.NoError(t, err) // Add a document with current model docs := []Document{ {ID: "doc-1", Content: "Fresh content"}, } err = client.AddDocuments(context.Background(), docs) require.NoError(t, err) // Insert a stale vector directly embedding := make([]float32, 384) embeddingBytes, err := sqlite_vec.SerializeFloat32(embedding) require.NoError(t, err) _, err = db.Exec(` INSERT INTO vectors (doc_id, embedding, model_version, sqlite_id, doc_type, field_type, project, scope) VALUES (?, ?, ?, ?, ?, ?, ?, ?) `, "stale-doc", embeddingBytes, "old-model", 999, "observation", "content", "test-project", "project") require.NoError(t, err) stats, err := client.GetHealthStats(context.Background()) require.NoError(t, err) assert.Equal(t, int64(2), stats.TotalVectors) assert.Equal(t, int64(1), stats.StaleVectors) assert.True(t, stats.NeedsRebuild) assert.Contains(t, stats.RebuildReason, "model_mismatch") } // ============================================================================= // TESTS FOR DeleteByObservationID // ============================================================================= func TestClient_DeleteByObservationID_NoMatches(t *testing.T) { db, dbCleanup := testDB(t) defer dbCleanup() embedSvc, embedCleanup := testEmbeddingService(t) defer embedCleanup() client, err := NewClient(Config{DB: db}, embedSvc) require.NoError(t, err) // Delete non-existent observation - should not error err = client.DeleteByObservationID(context.Background(), 999) require.NoError(t, err) } func TestClient_DeleteByObservationID_WithMatches(t *testing.T) { db, dbCleanup := testDB(t) defer dbCleanup() embedSvc, embedCleanup := testEmbeddingService(t) defer embedCleanup() client, err := NewClient(Config{DB: db}, embedSvc) require.NoError(t, err) // Add documents with observation IDs in doc_id docs := []Document{ {ID: "obs_123_narrative", Content: "Narrative for observation 123"}, {ID: "obs_123_facts_0", Content: "Fact 0 for observation 123"}, {ID: "obs_123_facts_1", Content: "Fact 1 for observation 123"}, {ID: "obs_456_narrative", Content: "Narrative for observation 456"}, } err = client.AddDocuments(context.Background(), docs) require.NoError(t, err) // Verify 4 documents exist count, err := client.Count(context.Background()) require.NoError(t, err) assert.Equal(t, int64(4), count) // Delete observation 123 err = client.DeleteByObservationID(context.Background(), 123) require.NoError(t, err) // Should have 1 document remaining (obs_456) count, err = client.Count(context.Background()) require.NoError(t, err) assert.Equal(t, int64(1), count) // Verify obs_456 still exists var exists int err = db.QueryRow("SELECT COUNT(*) FROM vectors WHERE doc_id LIKE 'obs_456_%'").Scan(&exists) require.NoError(t, err) assert.Equal(t, 1, exists) } // ============================================================================= // TESTS FOR cacheCleanupLoop and cleanupExpiredCaches // ============================================================================= func TestClient_CleanupExpiredCaches(t *testing.T) { db, dbCleanup := testDB(t) defer dbCleanup() embedSvc, embedCleanup := testEmbeddingService(t) defer embedCleanup() client, err := NewClient(Config{DB: db}, embedSvc) require.NoError(t, err) // Add a document and query to populate cache docs := []Document{ {ID: "doc-1", Content: "test content"}, } err = client.AddDocuments(context.Background(), docs) require.NoError(t, err) _, err = client.Query(context.Background(), "test", 5, nil) require.NoError(t, err) // Verify cache has entries assert.Greater(t, client.EmbeddingCacheSize(), 0) // Call cleanup (will only clean expired entries) client.cleanupExpiredCaches() // Fresh cache entries should still exist assert.Greater(t, client.EmbeddingCacheSize(), 0) } func TestClient_CacheCleanupLoop_StopsOnClose(t *testing.T) { db, dbCleanup := testDB(t) defer dbCleanup() embedSvc, embedCleanup := testEmbeddingService(t) defer embedCleanup() client, err := NewClient(Config{DB: db}, embedSvc) require.NoError(t, err) // Close should stop the cleanup loop err = client.Close() require.NoError(t, err) } // ============================================================================= // TESTS FOR EMBEDDING CACHE BEHAVIOR // ============================================================================= func TestClient_EmbeddingCache_HitAfterMiss(t *testing.T) { db, dbCleanup := testDB(t) defer dbCleanup() embedSvc, embedCleanup := testEmbeddingService(t) defer embedCleanup() client, err := NewClient(Config{DB: db}, embedSvc) require.NoError(t, err) // Add a document so we can query docs := []Document{ {ID: "test-1", Content: "Hello world test content"}, } err = client.AddDocuments(context.Background(), docs) require.NoError(t, err) // First query - cache miss _, err = client.Query(context.Background(), "hello world", 10, nil) require.NoError(t, err) stats1 := client.GetCacheStats() assert.Equal(t, int64(1), stats1.EmbeddingMisses) // Invalidate result cache to force embedding cache usage on second query client.InvalidateResultCache() // Second query with same text - should be embedding cache hit (result cache miss) _, err = client.Query(context.Background(), "hello world", 10, nil) require.NoError(t, err) stats2 := client.GetCacheStats() assert.Equal(t, int64(1), stats2.EmbeddingMisses) // Same miss count assert.Equal(t, int64(1), stats2.EmbeddingHits) // One hit } func TestClient_ResultCache_HitAfterMiss(t *testing.T) { db, dbCleanup := testDB(t) defer dbCleanup() embedSvc, embedCleanup := testEmbeddingService(t) defer embedCleanup() client, err := NewClient(Config{DB: db}, embedSvc) require.NoError(t, err) // Add a document docs := []Document{ { ID: "test-1", Content: "Testing result cache behavior", }, } err = client.AddDocuments(context.Background(), docs) require.NoError(t, err) // First query - result cache miss _, err = client.Query(context.Background(), "testing cache", 10, nil) require.NoError(t, err) stats1 := client.GetCacheStats() assert.Equal(t, int64(1), stats1.ResultMisses) // Second identical query - should be result cache hit _, err = client.Query(context.Background(), "testing cache", 10, nil) require.NoError(t, err) stats2 := client.GetCacheStats() assert.Equal(t, int64(1), stats2.ResultMisses) // Same miss count assert.Equal(t, int64(1), stats2.ResultHits) // One hit } func TestClient_Query_WithContextCancel(t *testing.T) { db, dbCleanup := testDB(t) defer dbCleanup() embedSvc, embedCleanup := testEmbeddingService(t) defer embedCleanup() client, err := NewClient(Config{DB: db}, embedSvc) require.NoError(t, err) // Create cancelled context ctx, cancel := context.WithCancel(context.Background()) cancel() // Query with cancelled context _, err = client.Query(ctx, "test query", 10, nil) require.Error(t, err) assert.Contains(t, err.Error(), "context canceled") } func TestClient_AddDocuments_WithContextCancel(t *testing.T) { db, dbCleanup := testDB(t) defer dbCleanup() embedSvc, embedCleanup := testEmbeddingService(t) defer embedCleanup() client, err := NewClient(Config{DB: db}, embedSvc) require.NoError(t, err) // Create cancelled context ctx, cancel := context.WithCancel(context.Background()) cancel() docs := []Document{{ID: "test", Content: "test content"}} err = client.AddDocuments(ctx, docs) require.Error(t, err) assert.Contains(t, err.Error(), "context canceled") } func TestClient_InvalidateResultCache(t *testing.T) { db, dbCleanup := testDB(t) defer dbCleanup() embedSvc, embedCleanup := testEmbeddingService(t) defer embedCleanup() client, err := NewClient(Config{DB: db}, embedSvc) require.NoError(t, err) // Add a document docs := []Document{ {ID: "test-1", Content: "Test invalidation"}, } err = client.AddDocuments(context.Background(), docs) require.NoError(t, err) // Query to populate result cache _, err = client.Query(context.Background(), "invalidation", 10, nil) require.NoError(t, err) assert.Greater(t, client.ResultCacheSize(), 0) // Invalidate the result cache client.InvalidateResultCache() assert.Equal(t, 0, client.ResultCacheSize()) } func TestClient_Count_WithError(t *testing.T) { db, dbCleanup := testDB(t) defer dbCleanup() embedSvc, embedCleanup := testEmbeddingService(t) defer embedCleanup() client, err := NewClient(Config{DB: db}, embedSvc) require.NoError(t, err) // Close DB to cause error db.Close() _, err = client.Count(context.Background()) require.Error(t, err) } func TestClient_NeedsRebuild_ReturnsReason(t *testing.T) { db, dbCleanup := testDB(t) defer dbCleanup() embedSvc, embedCleanup := testEmbeddingService(t) defer embedCleanup() client, err := NewClient(Config{DB: db}, embedSvc) require.NoError(t, err) // Empty database should need rebuild needsRebuild, reason := client.NeedsRebuild(context.Background()) assert.True(t, needsRebuild) assert.NotEmpty(t, reason) } func TestClient_GetStaleVectors_DBError(t *testing.T) { db, dbCleanup := testDB(t) defer dbCleanup() embedSvc, embedCleanup := testEmbeddingService(t) defer embedCleanup() client, err := NewClient(Config{DB: db}, embedSvc) require.NoError(t, err) // Close DB to cause error db.Close() _, err = client.GetStaleVectors(context.Background()) require.Error(t, err) } func TestClient_DeleteVectorsByDocIDs_DBError(t *testing.T) { db, dbCleanup := testDB(t) defer dbCleanup() embedSvc, embedCleanup := testEmbeddingService(t) defer embedCleanup() client, err := NewClient(Config{DB: db}, embedSvc) require.NoError(t, err) // Close DB to cause error db.Close() err = client.DeleteVectorsByDocIDs(context.Background(), []string{"doc-1"}) require.Error(t, err) } func TestClient_DeleteByObservationID_DBError(t *testing.T) { db, dbCleanup := testDB(t) defer dbCleanup() embedSvc, embedCleanup := testEmbeddingService(t) defer embedCleanup() client, err := NewClient(Config{DB: db}, embedSvc) require.NoError(t, err) // Close DB to cause error db.Close() err = client.DeleteByObservationID(context.Background(), 123) require.Error(t, err) } func TestClient_Query_DBError(t *testing.T) { db, dbCleanup := testDB(t) defer dbCleanup() embedSvc, embedCleanup := testEmbeddingService(t) defer embedCleanup() client, err := NewClient(Config{DB: db}, embedSvc) require.NoError(t, err) // Add document first docs := []Document{{ID: "test", Content: "test content"}} err = client.AddDocuments(context.Background(), docs) require.NoError(t, err) // Close DB to cause error on query db.Close() // Clear the cache so it has to hit the DB client.InvalidateResultCache() client.ClearCache() _, err = client.Query(context.Background(), "test", 10, nil) require.Error(t, err) } func TestClient_AddDocuments_DBError(t *testing.T) { db, dbCleanup := testDB(t) defer dbCleanup() embedSvc, embedCleanup := testEmbeddingService(t) defer embedCleanup() client, err := NewClient(Config{DB: db}, embedSvc) require.NoError(t, err) // Close DB to cause error db.Close() docs := []Document{{ID: "test", Content: "test content"}} err = client.AddDocuments(context.Background(), docs) require.Error(t, err) } func TestClient_GetHealthStats_DBError(t *testing.T) { db, dbCleanup := testDB(t) defer dbCleanup() embedSvc, embedCleanup := testEmbeddingService(t) defer embedCleanup() client, err := NewClient(Config{DB: db}, embedSvc) require.NoError(t, err) // Close DB to cause error db.Close() _, err = client.GetHealthStats(context.Background()) require.Error(t, err) } func TestClient_QueryBatch_DBError(t *testing.T) { db, dbCleanup := testDB(t) defer dbCleanup() embedSvc, embedCleanup := testEmbeddingService(t) defer embedCleanup() client, err := NewClient(Config{DB: db}, embedSvc) require.NoError(t, err) // Close DB to cause error db.Close() results := client.QueryBatch(context.Background(), []string{"test1", "test2"}, 10, nil) require.Len(t, results, 2) assert.Error(t, results[0].Error) assert.Error(t, results[1].Error) } func TestClient_DeleteDocuments_DBError(t *testing.T) { db, dbCleanup := testDB(t) defer dbCleanup() embedSvc, embedCleanup := testEmbeddingService(t) defer embedCleanup() client, err := NewClient(Config{DB: db}, embedSvc) require.NoError(t, err) // Close DB to cause error db.Close() err = client.DeleteDocuments(context.Background(), []string{"doc-1"}) require.Error(t, err) } func TestClient_Query_WithEmptyResults(t *testing.T) { db, dbCleanup := testDB(t) defer dbCleanup() embedSvc, embedCleanup := testEmbeddingService(t) defer embedCleanup() client, err := NewClient(Config{DB: db}, embedSvc) require.NoError(t, err) // Query with no documents - should return empty results results, err := client.Query(context.Background(), "nonexistent query", 10, nil) require.NoError(t, err) assert.Empty(t, results) } func TestClient_QueryBatch_AllSucceed(t *testing.T) { db, dbCleanup := testDB(t) defer dbCleanup() embedSvc, embedCleanup := testEmbeddingService(t) defer embedCleanup() client, err := NewClient(Config{DB: db}, embedSvc) require.NoError(t, err) // Add some documents docs := []Document{ {ID: "doc-1", Content: "Test content for batch query one"}, {ID: "doc-2", Content: "Test content for batch query two"}, } err = client.AddDocuments(context.Background(), docs) require.NoError(t, err) // Run batch query with multiple queries results := client.QueryBatch(context.Background(), []string{"batch one", "batch two", "batch three"}, 10, nil) // All queries should succeed require.Len(t, results, 3) for i, r := range results { assert.NoError(t, r.Error, "Query %d should not fail", i) } } // ============================================================================= // TESTS FOR HELPER FUNCTIONS EDGE CASES // ============================================================================= func TestExtractObservationIDs_Int64Metadata(t *testing.T) { // Test the int64 fallback path for sqlite_id metadata results := []QueryResult{ { ID: "obs-1", Similarity: 0.9, Metadata: map[string]any{ "sqlite_id": int64(123), // int64 instead of float64 "doc_type": "observation", "project": "test-project", }, }, } ids := ExtractObservationIDs(results, "test-project") assert.Len(t, ids, 1) assert.Equal(t, int64(123), ids[0]) } func TestExtractSummaryIDs_Int64Metadata(t *testing.T) { // Test the int64 fallback path for sqlite_id metadata results := []QueryResult{ { ID: "sum-1", Similarity: 0.9, Metadata: map[string]any{ "sqlite_id": int64(456), // int64 instead of float64 "doc_type": "session_summary", "project": "test-project", }, }, } ids := ExtractSummaryIDs(results, "test-project") assert.Len(t, ids, 1) assert.Equal(t, int64(456), ids[0]) } func TestExtractPromptIDs_Int64Metadata(t *testing.T) { // Test the int64 fallback path for sqlite_id metadata results := []QueryResult{ { ID: "prompt-1", Similarity: 0.9, Metadata: map[string]any{ "sqlite_id": int64(789), // int64 instead of float64 "doc_type": "user_prompt", "project": "test-project", }, }, } ids := ExtractPromptIDs(results, "test-project") assert.Len(t, ids, 1) assert.Equal(t, int64(789), ids[0]) } func TestExtractObservationIDs_GlobalScope(t *testing.T) { // Test that global scope observations are included for any project results := []QueryResult{ { ID: "obs-1", Similarity: 0.9, Metadata: map[string]any{ "sqlite_id": float64(123), "doc_type": "observation", "project": "other-project", "scope": "global", // Global scope should be included }, }, } ids := ExtractObservationIDs(results, "test-project") assert.Len(t, ids, 1) assert.Equal(t, int64(123), ids[0]) } // ============================================================================= // REGRESSION TESTS FOR acquireRLockWithContext (Fix #45) // ============================================================================= // TestAcquireRLockWithContext_Cancel verifies that when a write lock is held // and the context times out, acquireRLockWithContext returns context.DeadlineExceeded // promptly and the cleanup goroutine eventually releases the lock. func TestAcquireRLockWithContext_Cancel(t *testing.T) { var mu sync.RWMutex // Hold write lock so any RLock() call blocks. locked := make(chan struct{}) release := make(chan struct{}) go func() { mu.Lock() close(locked) <-release mu.Unlock() }() <-locked // write lock is held // Context with a tight deadline — must expire before we release the write lock. ctx, cancel := context.WithTimeout(context.Background(), 50*time.Millisecond) defer cancel() start := time.Now() err := acquireRLockWithContext(ctx, &mu) elapsed := time.Since(start) assert.ErrorIs(t, err, context.DeadlineExceeded, "should return DeadlineExceeded") assert.Less(t, elapsed, 200*time.Millisecond, "should return within ~100ms of deadline") // Release the write lock so the cleanup goroutine can finish. close(release) // After the write lock is released the cleanup goroutine acquires+releases // the RLock. Wait long enough for it to drain. time.Sleep(100 * time.Millisecond) // Now an uncontended RLock should succeed immediately. ctx2 := context.Background() err2 := acquireRLockWithContext(ctx2, &mu) assert.NoError(t, err2, "should succeed when uncontended after cleanup") if err2 == nil { mu.RUnlock() } } // TestAcquireRLockWithContext_Success verifies that an uncontended mutex is // acquired without error and can be properly unlocked. func TestAcquireRLockWithContext_Success(t *testing.T) { var mu sync.RWMutex err := acquireRLockWithContext(context.Background(), &mu) assert.NoError(t, err, "should succeed on uncontended mutex") if err == nil { // Panics if not held — validates that the lock was actually taken. mu.RUnlock() } } // TestAcquireRLockWithContext_CleanupOnCancel verifies that when // acquireRLockWithContext returns an error due to context cancellation, the // cleanup goroutine eventually releases the RLock so the mutex can be write- // locked again without deadlock. func TestAcquireRLockWithContext_CleanupOnCancel(t *testing.T) { var mu sync.RWMutex // Hold write lock to force RLock to block. release := make(chan struct{}) locked := make(chan struct{}) go func() { mu.Lock() close(locked) <-release mu.Unlock() }() <-locked // Context cancels after 10ms — way before we release the write lock. ctx, cancel := context.WithTimeout(context.Background(), 10*time.Millisecond) defer cancel() err := acquireRLockWithContext(ctx, &mu) assert.Error(t, err, "should fail due to cancelled context") // Release the write lock; the cleanup goroutine inside acquireRLockWithContext // will now acquire the RLock and immediately release it. close(release) // Give the cleanup goroutine time to run. time.Sleep(50 * time.Millisecond) // Acquire a write lock — would deadlock if cleanup goroutine left an RLock // dangling. Use a done channel and select to avoid hanging the test. done := make(chan struct{}) go func() { mu.Lock() defer mu.Unlock() close(done) //nolint:SA2001 // intentional: proves no deadlock from leaked RLock }() select { case <-done: // Success — write lock acquired without deadlock. case <-time.After(2 * time.Second): t.Fatal("write lock acquisition timed out: cleanup goroutine may have leaked an RLock") } } // ============================================================================= // REGRESSION TESTS FOR Fix #1: Embedding outside writeMu in AddDocuments // ============================================================================= // TestAddDocuments_EmbeddingOutsideWriteLock verifies that AddDocuments does NOT // hold the write lock during embedding computation. A concurrent Query call // should complete quickly while AddDocuments is computing embeddings. func TestAddDocuments_EmbeddingOutsideWriteLock(t *testing.T) { db, dbCleanup := testDB(t) defer dbCleanup() embedSvc, embedCleanup := testEmbeddingService(t) defer embedCleanup() client, err := NewClient(Config{DB: db}, embedSvc) require.NoError(t, err) // Seed the DB with one document so Query has something to search seedDocs := []Document{ {ID: "seed-1", Content: "Seed document for concurrency test"}, } err = client.AddDocuments(context.Background(), seedDocs) require.NoError(t, err) // Pre-warm the embedding cache for the query text so the Query call // itself doesn't need the embedding mutex — it only needs the DB read lock. _, err = client.Query(context.Background(), "concurrency test", 5, nil) require.NoError(t, err) // Prepare a batch of documents to trigger a slow AddDocuments call batchDocs := make([]Document, 10) for i := range batchDocs { batchDocs[i] = Document{ ID: fmt.Sprintf("batch-%d", i), Content: fmt.Sprintf("Batch document number %d for write lock test with unique content", i), } } // Launch AddDocuments in background — embedding will take time addDone := make(chan error, 1) go func() { addDone <- client.AddDocuments(context.Background(), batchDocs) }() // Give AddDocuments a moment to start embedding computation time.Sleep(10 * time.Millisecond) // Invalidate result cache so Query must go through to DB (tests read lock) client.InvalidateResultCache() // A concurrent Query should NOT be blocked by AddDocuments. // If the old code held writeMu during embedding, this would block until // embedding finishes. With the fix, it should complete quickly. queryCtx, queryCancel := context.WithTimeout(context.Background(), 2*time.Second) defer queryCancel() start := time.Now() _, err = client.Query(queryCtx, "concurrency test", 5, nil) queryDuration := time.Since(start) require.NoError(t, err, "Query should succeed while AddDocuments is embedding") // The query should complete well within the timeout if writeMu is not held assert.Less(t, queryDuration, 1*time.Second, "Query should complete quickly when writeMu is not held during embedding") // Wait for AddDocuments to finish err = <-addDone require.NoError(t, err, "AddDocuments should succeed") } // ============================================================================= // REGRESSION TESTS FOR Fix #2a: DoChan + context select in getOrComputeEmbedding // ============================================================================= // TestGetOrComputeEmbedding_ContextCancelDuringSingleflight verifies that when // a singleflight embedding computation is in progress, a second caller with a // short-lived context returns context.DeadlineExceeded promptly rather than // waiting for the slow first call to finish. func TestGetOrComputeEmbedding_ContextCancelDuringSingleflight(t *testing.T) { db, dbCleanup := testDB(t) defer dbCleanup() embedSvc, embedCleanup := testEmbeddingService(t) defer embedCleanup() client, err := NewClient(Config{DB: db}, embedSvc) require.NoError(t, err) // Add a document so we can query docs := []Document{ {ID: "sf-test-1", Content: "Singleflight context cancellation test document"}, } err = client.AddDocuments(context.Background(), docs) require.NoError(t, err) // Clear cache to force embedding computation client.ClearCache() client.InvalidateResultCache() queryText := "unique singleflight context test query" // First call: start a normal query in background (will trigger singleflight) firstDone := make(chan struct{}) go func() { defer close(firstDone) _, _ = client.Query(context.Background(), queryText, 5, nil) }() // Give the first call a moment to start the singleflight computation time.Sleep(5 * time.Millisecond) // Second call: use a very short context that should expire quickly shortCtx, shortCancel := context.WithTimeout(context.Background(), 20*time.Millisecond) defer shortCancel() // Clear result cache again so the second call hits getOrComputeEmbedding client.InvalidateResultCache() start := time.Now() _, err = client.Query(shortCtx, queryText, 5, nil) elapsed := time.Since(start) // With DoChan + select, the second caller should return quickly with context error. // Note: If the embedding completes fast enough, the second call may succeed // via singleflight sharing. That's also valid — the test primarily checks // that it doesn't BLOCK for the full embedding duration on context cancel. if err != nil { assert.ErrorIs(t, err, context.DeadlineExceeded, "Should return DeadlineExceeded when context expires during singleflight") assert.Less(t, elapsed, 500*time.Millisecond, "Should return promptly on context cancellation, not wait for slow computation") } // If err == nil, the singleflight completed before the context expired — also valid. // Wait for first call to finish <-firstDone }