diff --git a/internal/db/sqlite/helpers_test.go b/internal/db/sqlite/helpers_test.go
new file mode 100644
index 0000000..da7e99f
--- /dev/null
+++ b/internal/db/sqlite/helpers_test.go
@@ -0,0 +1,254 @@
+package sqlite
+
+import (
+ "net/http"
+ "net/http/httptest"
+ "testing"
+
+ "github.com/stretchr/testify/assert"
+ "github.com/stretchr/testify/require"
+)
+
+func TestNullString(t *testing.T) {
+ tests := []struct {
+ name string
+ input string
+ expected string
+ valid bool
+ }{
+ {"empty_string", "", "", false},
+ {"non_empty_string", "hello", "hello", true},
+ {"whitespace", " ", " ", true},
+ }
+
+ for _, tt := range tests {
+ t.Run(tt.name, func(t *testing.T) {
+ result := nullString(tt.input)
+ assert.Equal(t, tt.expected, result.String)
+ assert.Equal(t, tt.valid, result.Valid)
+ })
+ }
+}
+
+func TestNullInt(t *testing.T) {
+ tests := []struct {
+ name string
+ input int
+ expected int64
+ valid bool
+ }{
+ {"zero", 0, 0, false},
+ {"positive", 42, 42, true},
+ {"negative", -1, -1, false},
+ }
+
+ for _, tt := range tests {
+ t.Run(tt.name, func(t *testing.T) {
+ result := nullInt(tt.input)
+ assert.Equal(t, tt.expected, result.Int64)
+ assert.Equal(t, tt.valid, result.Valid)
+ })
+ }
+}
+
+func TestRepeatPlaceholders(t *testing.T) {
+ tests := []struct {
+ name string
+ n int
+ expected string
+ }{
+ {"zero", 0, ""},
+ {"negative", -1, ""},
+ {"one", 1, ", ?"},
+ {"two", 2, ", ?, ?"},
+ {"three", 3, ", ?, ?, ?"},
+ }
+
+ for _, tt := range tests {
+ t.Run(tt.name, func(t *testing.T) {
+ result := repeatPlaceholders(tt.n)
+ assert.Equal(t, tt.expected, result)
+ })
+ }
+}
+
+func TestInt64SliceToInterface(t *testing.T) {
+ tests := []struct {
+ name string
+ input []int64
+ expected []interface{}
+ }{
+ {"empty", []int64{}, []interface{}{}},
+ {"single", []int64{42}, []interface{}{int64(42)}},
+ {"multiple", []int64{1, 2, 3}, []interface{}{int64(1), int64(2), int64(3)}},
+ }
+
+ for _, tt := range tests {
+ t.Run(tt.name, func(t *testing.T) {
+ result := int64SliceToInterface(tt.input)
+ assert.Equal(t, tt.expected, result)
+ })
+ }
+}
+
+func TestParseLimitParam(t *testing.T) {
+ tests := []struct {
+ name string
+ query string
+ defaultLimit int
+ expected int
+ }{
+ {"no_param_uses_default", "", 10, 10},
+ {"valid_limit", "limit=20", 10, 20},
+ {"invalid_limit_uses_default", "limit=abc", 10, 10},
+ {"zero_limit_uses_default", "limit=0", 10, 10},
+ {"negative_limit_uses_default", "limit=-5", 10, 10},
+ {"large_limit", "limit=1000", 10, 1000},
+ }
+
+ for _, tt := range tests {
+ t.Run(tt.name, func(t *testing.T) {
+ req := httptest.NewRequest("GET", "/?"+tt.query, nil)
+ result := ParseLimitParam(req, tt.defaultLimit)
+ assert.Equal(t, tt.expected, result)
+ })
+ }
+}
+
+func TestScanSummary(t *testing.T) {
+ db, _, cleanup := testDB(t)
+ defer cleanup()
+ createBaseTables(t, db)
+ seedSession(t, db, "claude-123", "sdk-123", "test-project")
+
+ // Insert a test summary
+ _, err := db.Exec(`
+ INSERT INTO session_summaries (sdk_session_id, project, request, investigated, learned, completed, next_steps, notes, prompt_number, discovery_tokens, created_at, created_at_epoch)
+ VALUES ('sdk-123', 'test-project', 'test request', 'test investigated', 'test learned', 'test completed', 'test next steps', 'test notes', 1, 100, '2025-01-01T00:00:00Z', 1704067200000)
+ `)
+ require.NoError(t, err)
+
+ // Query and scan
+ row := db.QueryRow(`
+ SELECT id, sdk_session_id, project, request, investigated, learned, completed, next_steps, notes, prompt_number, discovery_tokens, created_at, created_at_epoch
+ FROM session_summaries WHERE sdk_session_id = ?
+ `, "sdk-123")
+
+ summary, err := scanSummary(row)
+ require.NoError(t, err)
+ assert.NotNil(t, summary)
+ assert.Equal(t, "sdk-123", summary.SDKSessionID)
+ assert.Equal(t, "test-project", summary.Project)
+ assert.Equal(t, "test request", summary.Request.String)
+ assert.Equal(t, "test investigated", summary.Investigated.String)
+ assert.Equal(t, "test learned", summary.Learned.String)
+ assert.Equal(t, "test completed", summary.Completed.String)
+ assert.Equal(t, "test next steps", summary.NextSteps.String)
+ assert.Equal(t, "test notes", summary.Notes.String)
+}
+
+func TestScanSummaryRows(t *testing.T) {
+ db, _, cleanup := testDB(t)
+ defer cleanup()
+ createBaseTables(t, db)
+ seedSession(t, db, "claude-123", "sdk-123", "test-project")
+
+ // Insert multiple summaries
+ _, err := db.Exec(`
+ INSERT INTO session_summaries (sdk_session_id, project, request, investigated, learned, completed, next_steps, notes, prompt_number, discovery_tokens, created_at, created_at_epoch)
+ VALUES
+ ('sdk-123', 'test-project', 'request 1', '', '', '', '', '', 1, 0, '2025-01-01T00:00:00Z', 1704067200000),
+ ('sdk-123', 'test-project', 'request 2', '', '', '', '', '', 2, 0, '2025-01-02T00:00:00Z', 1704153600000)
+ `)
+ require.NoError(t, err)
+
+ rows, err := db.Query(`
+ SELECT id, sdk_session_id, project, request, investigated, learned, completed, next_steps, notes, prompt_number, discovery_tokens, created_at, created_at_epoch
+ FROM session_summaries WHERE sdk_session_id = ? ORDER BY id
+ `, "sdk-123")
+ require.NoError(t, err)
+ defer rows.Close()
+
+ summaries, err := scanSummaryRows(rows)
+ require.NoError(t, err)
+ assert.Len(t, summaries, 2)
+ assert.Equal(t, "request 1", summaries[0].Request.String)
+ assert.Equal(t, "request 2", summaries[1].Request.String)
+}
+
+func TestScanPromptWithSession(t *testing.T) {
+ db, _, cleanup := testDB(t)
+ defer cleanup()
+ createBaseTables(t, db)
+ seedSession(t, db, "claude-123", "sdk-123", "test-project")
+
+ // Insert a test prompt
+ _, err := db.Exec(`
+ INSERT INTO user_prompts (claude_session_id, prompt_number, prompt_text, matched_observations, created_at, created_at_epoch)
+ VALUES ('claude-123', 1, 'test prompt', 5, '2025-01-01T00:00:00Z', 1704067200000)
+ `)
+ require.NoError(t, err)
+
+ // Query with session join
+ row := db.QueryRow(`
+ SELECT p.id, p.claude_session_id, p.prompt_number, p.prompt_text, p.matched_observations, p.created_at, p.created_at_epoch, s.project, s.sdk_session_id
+ FROM user_prompts p
+ JOIN sdk_sessions s ON p.claude_session_id = s.claude_session_id
+ WHERE p.claude_session_id = ?
+ `, "claude-123")
+
+ prompt, err := scanPromptWithSession(row)
+ require.NoError(t, err)
+ assert.NotNil(t, prompt)
+ assert.Equal(t, "claude-123", prompt.ClaudeSessionID)
+ assert.Equal(t, 1, prompt.PromptNumber)
+ assert.Equal(t, "test prompt", prompt.PromptText)
+ assert.Equal(t, 5, prompt.MatchedObservations)
+ assert.Equal(t, "test-project", prompt.Project)
+ assert.Equal(t, "sdk-123", prompt.SDKSessionID)
+}
+
+func TestScanPromptWithSessionRows(t *testing.T) {
+ db, _, cleanup := testDB(t)
+ defer cleanup()
+ createBaseTables(t, db)
+ seedSession(t, db, "claude-123", "sdk-123", "test-project")
+
+ // Insert multiple prompts
+ _, err := db.Exec(`
+ INSERT INTO user_prompts (claude_session_id, prompt_number, prompt_text, matched_observations, created_at, created_at_epoch)
+ VALUES
+ ('claude-123', 1, 'prompt one', 3, '2025-01-01T00:00:00Z', 1704067200000),
+ ('claude-123', 2, 'prompt two', 5, '2025-01-02T00:00:00Z', 1704153600000)
+ `)
+ require.NoError(t, err)
+
+ rows, err := db.Query(`
+ SELECT p.id, p.claude_session_id, p.prompt_number, p.prompt_text, p.matched_observations, p.created_at, p.created_at_epoch, s.project, s.sdk_session_id
+ FROM user_prompts p
+ JOIN sdk_sessions s ON p.claude_session_id = s.claude_session_id
+ WHERE p.claude_session_id = ? ORDER BY p.id
+ `, "claude-123")
+ require.NoError(t, err)
+ defer rows.Close()
+
+ prompts, err := scanPromptWithSessionRows(rows)
+ require.NoError(t, err)
+ assert.Len(t, prompts, 2)
+ assert.Equal(t, "prompt one", prompts[0].PromptText)
+ assert.Equal(t, "prompt two", prompts[1].PromptText)
+}
+
+func TestParseLimitParam_HTTPRequest(t *testing.T) {
+ // Test with an actual HTTP request
+ handler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
+ limit := ParseLimitParam(r, 25)
+ if limit != 50 {
+ t.Errorf("Expected limit 50, got %d", limit)
+ }
+ })
+
+ req := httptest.NewRequest("GET", "http://example.com/api?limit=50", nil)
+ w := httptest.NewRecorder()
+ handler.ServeHTTP(w, req)
+}
diff --git a/internal/db/sqlite/migrations_test.go b/internal/db/sqlite/migrations_test.go
new file mode 100644
index 0000000..129be43
--- /dev/null
+++ b/internal/db/sqlite/migrations_test.go
@@ -0,0 +1,196 @@
+package sqlite
+
+import (
+ "testing"
+
+ "github.com/stretchr/testify/assert"
+ "github.com/stretchr/testify/require"
+)
+
+func TestNewMigrationManager(t *testing.T) {
+ db, _, cleanup := testDB(t)
+ defer cleanup()
+
+ manager := NewMigrationManager(db)
+ require.NotNil(t, manager)
+ assert.Equal(t, db, manager.db)
+}
+
+func TestMigrationManager_EnsureSchemaVersionsTable(t *testing.T) {
+ db, _, cleanup := testDB(t)
+ defer cleanup()
+
+ manager := NewMigrationManager(db)
+
+ // Should create table without error
+ err := manager.EnsureSchemaVersionsTable()
+ require.NoError(t, err)
+
+ // Table should exist
+ var count int
+ err = db.QueryRow("SELECT COUNT(*) FROM schema_versions").Scan(&count)
+ require.NoError(t, err)
+ assert.Equal(t, 0, count) // Empty table
+
+ // Calling again should not error (IF NOT EXISTS)
+ err = manager.EnsureSchemaVersionsTable()
+ require.NoError(t, err)
+}
+
+func TestMigrationManager_GetAppliedVersions_Empty(t *testing.T) {
+ db, _, cleanup := testDB(t)
+ defer cleanup()
+
+ manager := NewMigrationManager(db)
+ err := manager.EnsureSchemaVersionsTable()
+ require.NoError(t, err)
+
+ versions, err := manager.GetAppliedVersions()
+ require.NoError(t, err)
+ assert.Empty(t, versions)
+}
+
+func TestMigrationManager_GetAppliedVersions_WithVersions(t *testing.T) {
+ db, _, cleanup := testDB(t)
+ defer cleanup()
+
+ manager := NewMigrationManager(db)
+ err := manager.EnsureSchemaVersionsTable()
+ require.NoError(t, err)
+
+ // Insert some versions
+ _, err = db.Exec("INSERT INTO schema_versions (version, applied_at) VALUES (1, '2025-01-01T00:00:00Z')")
+ require.NoError(t, err)
+ _, err = db.Exec("INSERT INTO schema_versions (version, applied_at) VALUES (2, '2025-01-02T00:00:00Z')")
+ require.NoError(t, err)
+
+ versions, err := manager.GetAppliedVersions()
+ require.NoError(t, err)
+ assert.Len(t, versions, 2)
+ assert.True(t, versions[1])
+ assert.True(t, versions[2])
+ assert.False(t, versions[3]) // Not applied
+}
+
+func TestMigrationManager_ApplyMigration(t *testing.T) {
+ db, _, cleanup := testDB(t)
+ defer cleanup()
+
+ manager := NewMigrationManager(db)
+ err := manager.EnsureSchemaVersionsTable()
+ require.NoError(t, err)
+
+ // Apply a simple migration
+ migration := Migration{
+ Version: 100,
+ Name: "test_migration",
+ SQL: "CREATE TABLE test_table (id INTEGER PRIMARY KEY, name TEXT)",
+ }
+
+ err = manager.ApplyMigration(migration)
+ require.NoError(t, err)
+
+ // Verify table was created
+ var count int
+ err = db.QueryRow("SELECT COUNT(*) FROM sqlite_master WHERE type='table' AND name='test_table'").Scan(&count)
+ require.NoError(t, err)
+ assert.Equal(t, 1, count)
+
+ // Verify migration was recorded
+ var version int
+ err = db.QueryRow("SELECT version FROM schema_versions WHERE version = 100").Scan(&version)
+ require.NoError(t, err)
+ assert.Equal(t, 100, version)
+}
+
+func TestMigrationManager_ApplyMigration_InvalidSQL(t *testing.T) {
+ db, _, cleanup := testDB(t)
+ defer cleanup()
+
+ manager := NewMigrationManager(db)
+ err := manager.EnsureSchemaVersionsTable()
+ require.NoError(t, err)
+
+ // Try to apply invalid migration
+ migration := Migration{
+ Version: 100,
+ Name: "invalid_migration",
+ SQL: "INVALID SQL SYNTAX",
+ }
+
+ err = manager.ApplyMigration(migration)
+ assert.Error(t, err)
+ assert.Contains(t, err.Error(), "execute migration 100")
+}
+
+func TestMigrationManager_RunMigrations_SingleMigration(t *testing.T) {
+ db, _, cleanup := testDB(t)
+ defer cleanup()
+
+ // Create a test migration manager with a subset of migrations
+ manager := NewMigrationManager(db)
+
+ // First ensure schema versions table exists
+ err := manager.EnsureSchemaVersionsTable()
+ require.NoError(t, err)
+
+ // Apply first migration manually
+ err = manager.ApplyMigration(Migrations[0])
+ require.NoError(t, err)
+
+ // Verify the first migration version was recorded
+ versions, err := manager.GetAppliedVersions()
+ require.NoError(t, err)
+ assert.True(t, versions[Migrations[0].Version])
+}
+
+func TestMigrationManager_RunMigrations_SkipsApplied(t *testing.T) {
+ db, _, cleanup := testDB(t)
+ defer cleanup()
+
+ manager := NewMigrationManager(db)
+ err := manager.EnsureSchemaVersionsTable()
+ require.NoError(t, err)
+
+ // Mark some migrations as already applied
+ _, err = db.Exec("INSERT INTO schema_versions (version, applied_at) VALUES (4, '2025-01-01T00:00:00Z')")
+ require.NoError(t, err)
+
+ // Get applied versions
+ versions, err := manager.GetAppliedVersions()
+ require.NoError(t, err)
+ assert.True(t, versions[4])
+}
+
+func TestMigration_Struct(t *testing.T) {
+ m := Migration{
+ Version: 1,
+ Name: "test",
+ SQL: "SELECT 1",
+ }
+
+ assert.Equal(t, 1, m.Version)
+ assert.Equal(t, "test", m.Name)
+ assert.Equal(t, "SELECT 1", m.SQL)
+}
+
+func TestMigrations_List(t *testing.T) {
+ // Verify migrations are ordered correctly
+ assert.NotEmpty(t, Migrations)
+
+ // Verify all migrations have required fields
+ for i, m := range Migrations {
+ assert.Greater(t, m.Version, 0, "Migration %d has invalid version", i)
+ assert.NotEmpty(t, m.Name, "Migration %d has empty name", i)
+ assert.NotEmpty(t, m.SQL, "Migration %d has empty SQL", i)
+ }
+
+ // Verify key migrations exist
+ versionSet := make(map[int]bool)
+ for _, m := range Migrations {
+ versionSet[m.Version] = true
+ }
+
+ assert.True(t, versionSet[4], "Should have sdk_agent_architecture migration")
+ assert.True(t, versionSet[17], "Should have sqlite_vec_vectors migration")
+}
diff --git a/internal/db/sqlite/observation_test.go b/internal/db/sqlite/observation_test.go
index 20ddf48..6b4972c 100644
--- a/internal/db/sqlite/observation_test.go
+++ b/internal/db/sqlite/observation_test.go
@@ -800,3 +800,148 @@ func TestExtractKeywords(t *testing.T) {
})
}
}
+
+func TestObservationStore_GetObservationsByProjectStrict(t *testing.T) {
+ obsStore, _, cleanup := testObservationStore(t)
+ defer cleanup()
+
+ ctx := context.Background()
+
+ // Create project-scoped observation for project-a
+ projectObs := &models.ParsedObservation{
+ Type: models.ObsTypeDiscovery,
+ Title: "Project A specific",
+ Narrative: "Only for project-a",
+ Concepts: []string{"local-concept"},
+ }
+ _, _, err := obsStore.StoreObservation(ctx, "session-1", "project-a", projectObs, 1, 100)
+ require.NoError(t, err)
+
+ // Create global observation from project-a
+ globalObs := &models.ParsedObservation{
+ Type: models.ObsTypeDiscovery,
+ Title: "Global security practice",
+ Narrative: "Best practice for all",
+ Concepts: []string{"security", "best-practice"},
+ }
+ _, _, err = obsStore.StoreObservation(ctx, "session-1", "project-a", globalObs, 2, 100)
+ require.NoError(t, err)
+
+ // Create observation for project-b
+ projectBObs := &models.ParsedObservation{
+ Type: models.ObsTypeDiscovery,
+ Title: "Project B specific",
+ Narrative: "Only for project-b",
+ }
+ _, _, err = obsStore.StoreObservation(ctx, "session-1", "project-b", projectBObs, 1, 100)
+ require.NoError(t, err)
+
+ // GetObservationsByProjectStrict for project-a should only return project-a observations
+ // This is different from GetRecentObservations which includes globals from other projects
+ results, err := obsStore.GetObservationsByProjectStrict(ctx, "project-a", 10)
+ require.NoError(t, err)
+ assert.Len(t, results, 2) // Only observations created in project-a
+
+ // Verify both are from project-a
+ for _, obs := range results {
+ assert.Equal(t, "project-a", obs.Project)
+ }
+
+ // GetObservationsByProjectStrict for project-b should only return project-b observations
+ results, err = obsStore.GetObservationsByProjectStrict(ctx, "project-b", 10)
+ require.NoError(t, err)
+ assert.Len(t, results, 1)
+ assert.Equal(t, "Project B specific", results[0].Title.String)
+}
+
+func TestObservationStore_SearchObservationsFTS_EmptyQuery(t *testing.T) {
+ obsStore, _, cleanup := testObservationStore(t)
+ defer cleanup()
+
+ ctx := context.Background()
+
+ // Create an observation
+ obs := &models.ParsedObservation{
+ Type: models.ObsTypeDiscovery,
+ Title: "Test observation",
+ Narrative: "Some content here",
+ }
+ _, _, err := obsStore.StoreObservation(ctx, "session-1", "project-a", obs, 1, 100)
+ require.NoError(t, err)
+
+ // Search with only stop words (should return nil)
+ results, err := obsStore.SearchObservationsFTS(ctx, "the a an is are", "project-a", 10)
+ require.NoError(t, err)
+ assert.Nil(t, results)
+
+ // Search with empty query
+ results, err = obsStore.SearchObservationsFTS(ctx, "", "project-a", 10)
+ require.NoError(t, err)
+ assert.Nil(t, results)
+}
+
+func TestObservationStore_SearchObservationsFTS_DefaultLimit(t *testing.T) {
+ obsStore, _, cleanup := testObservationStore(t)
+ defer cleanup()
+
+ ctx := context.Background()
+
+ // Create observations
+ for i := 0; i < 15; i++ {
+ obs := &models.ParsedObservation{
+ Type: models.ObsTypeDiscovery,
+ Title: "Authentication test " + string(rune('A'+i)),
+ Narrative: "Auth related content",
+ }
+ _, _, err := obsStore.StoreObservation(ctx, "session-1", "project-a", obs, i+1, 100)
+ require.NoError(t, err)
+ time.Sleep(time.Millisecond)
+ }
+
+ // Search with limit 0 (should default to 10)
+ results, err := obsStore.SearchObservationsFTS(ctx, "authentication", "project-a", 0)
+ require.NoError(t, err)
+ assert.LessOrEqual(t, len(results), 10)
+
+ // Search with negative limit (should default to 10)
+ results, err = obsStore.SearchObservationsFTS(ctx, "authentication", "project-a", -5)
+ require.NoError(t, err)
+ assert.LessOrEqual(t, len(results), 10)
+}
+
+func TestObservationStore_GetAllRecentObservations(t *testing.T) {
+ obsStore, _, cleanup := testObservationStore(t)
+ defer cleanup()
+
+ ctx := context.Background()
+
+ // Create observations across different projects
+ projects := []string{"project-a", "project-b", "project-c"}
+ for _, proj := range projects {
+ for i := 0; i < 3; i++ {
+ obs := &models.ParsedObservation{
+ Type: models.ObsTypeDiscovery,
+ Title: proj + " observation " + string(rune('A'+i)),
+ Narrative: "Content for " + proj,
+ }
+ _, _, err := obsStore.StoreObservation(ctx, "session-1", proj, obs, i+1, 100)
+ require.NoError(t, err)
+ time.Sleep(time.Millisecond)
+ }
+ }
+
+ // Get all recent observations
+ results, err := obsStore.GetAllRecentObservations(ctx, 100)
+ require.NoError(t, err)
+ assert.Len(t, results, 9) // 3 projects * 3 observations
+
+ // Verify they are in descending order by epoch
+ for i := 1; i < len(results); i++ {
+ assert.GreaterOrEqual(t, results[i-1].CreatedAtEpoch, results[i].CreatedAtEpoch)
+ }
+
+ // Test with limit
+ results, err = obsStore.GetAllRecentObservations(ctx, 5)
+ require.NoError(t, err)
+ assert.Len(t, results, 5)
+}
diff --git a/internal/embedding/service_test.go b/internal/embedding/service_test.go
new file mode 100644
index 0000000..cd04351
--- /dev/null
+++ b/internal/embedding/service_test.go
@@ -0,0 +1,332 @@
+package embedding
+
+import (
+ "math"
+ "sync"
+ "testing"
+
+ "github.com/stretchr/testify/assert"
+ "github.com/stretchr/testify/require"
+)
+
+// TestEmbeddingDim verifies the embedding dimension constant.
+func TestEmbeddingDim(t *testing.T) {
+ assert.Equal(t, 384, EmbeddingDim)
+}
+
+// TestNewService tests creating a new embedding service.
+func TestNewService(t *testing.T) {
+ svc, err := NewService()
+ require.NoError(t, err)
+ require.NotNil(t, svc)
+
+ defer svc.Close()
+
+ assert.NotNil(t, svc.tk)
+ assert.NotNil(t, svc.session)
+}
+
+// TestEmbed_SingleText tests embedding a single text.
+func TestEmbed_SingleText(t *testing.T) {
+ svc, err := NewService()
+ require.NoError(t, err)
+ defer svc.Close()
+
+ embedding, err := svc.Embed("Hello, world!")
+ require.NoError(t, err)
+
+ assert.Len(t, embedding, EmbeddingDim)
+
+ // Verify non-zero embedding
+ var sum float32
+ for _, v := range embedding {
+ sum += v * v
+ }
+ assert.Greater(t, sum, float32(0), "Embedding should not be all zeros")
+}
+
+// TestEmbed_EmptyText tests embedding an empty string.
+func TestEmbed_EmptyText(t *testing.T) {
+ svc, err := NewService()
+ require.NoError(t, err)
+ defer svc.Close()
+
+ embedding, err := svc.Embed("")
+ require.NoError(t, err)
+
+ assert.Len(t, embedding, EmbeddingDim)
+
+ // Empty text should return zero vector
+ for _, v := range embedding {
+ assert.Equal(t, float32(0), v)
+ }
+}
+
+// TestEmbed_SimilarTexts tests that similar texts produce similar embeddings.
+func TestEmbed_SimilarTexts(t *testing.T) {
+ svc, err := NewService()
+ require.NoError(t, err)
+ defer svc.Close()
+
+ emb1, err := svc.Embed("The quick brown fox jumps over the lazy dog.")
+ require.NoError(t, err)
+
+ emb2, err := svc.Embed("A fast brown fox leaps over a sleepy dog.")
+ require.NoError(t, err)
+
+ emb3, err := svc.Embed("Go programming language concurrency patterns.")
+ require.NoError(t, err)
+
+ // Calculate cosine similarity
+ sim12 := cosineSimilarity(emb1, emb2)
+ sim13 := cosineSimilarity(emb1, emb3)
+
+ // Similar texts should have higher similarity
+ assert.Greater(t, sim12, sim13, "Similar sentences should have higher similarity than dissimilar ones")
+ assert.Greater(t, sim12, float64(0.7), "Similar sentences should have high similarity")
+}
+
+// TestEmbedBatch_MultipleTexts tests batch embedding.
+func TestEmbedBatch_MultipleTexts(t *testing.T) {
+ svc, err := NewService()
+ require.NoError(t, err)
+ defer svc.Close()
+
+ texts := []string{
+ "First text about programming.",
+ "Second text about databases.",
+ "Third text about machine learning.",
+ }
+
+ embeddings, err := svc.EmbedBatch(texts)
+ require.NoError(t, err)
+
+ assert.Len(t, embeddings, len(texts))
+ for i, emb := range embeddings {
+ assert.Len(t, emb, EmbeddingDim, "Embedding %d should have correct dimension", i)
+ }
+}
+
+// TestEmbedBatch_EmptySlice tests batch embedding with empty slice.
+func TestEmbedBatch_EmptySlice(t *testing.T) {
+ svc, err := NewService()
+ require.NoError(t, err)
+ defer svc.Close()
+
+ embeddings, err := svc.EmbedBatch([]string{})
+ require.NoError(t, err)
+ assert.Nil(t, embeddings)
+}
+
+// TestEmbedBatch_WithEmptyTexts tests batch embedding with some empty texts.
+func TestEmbedBatch_WithEmptyTexts(t *testing.T) {
+ svc, err := NewService()
+ require.NoError(t, err)
+ defer svc.Close()
+
+ texts := []string{
+ "Valid text one.",
+ "",
+ "Valid text two.",
+ "",
+ }
+
+ embeddings, err := svc.EmbedBatch(texts)
+ require.NoError(t, err)
+
+ assert.Len(t, embeddings, 4)
+
+ // Non-empty texts should have non-zero embeddings
+ var sum0 float32
+ for _, v := range embeddings[0] {
+ sum0 += v * v
+ }
+ assert.Greater(t, sum0, float32(0))
+
+ // Empty texts should have zero embeddings
+ for _, v := range embeddings[1] {
+ assert.Equal(t, float32(0), v)
+ }
+
+ var sum2 float32
+ for _, v := range embeddings[2] {
+ sum2 += v * v
+ }
+ assert.Greater(t, sum2, float32(0))
+
+ for _, v := range embeddings[3] {
+ assert.Equal(t, float32(0), v)
+ }
+}
+
+// TestEmbedBatch_AllEmpty tests batch embedding with all empty texts.
+func TestEmbedBatch_AllEmpty(t *testing.T) {
+ svc, err := NewService()
+ require.NoError(t, err)
+ defer svc.Close()
+
+ texts := []string{"", "", ""}
+
+ embeddings, err := svc.EmbedBatch(texts)
+ require.NoError(t, err)
+
+ assert.Len(t, embeddings, 3)
+ for i, emb := range embeddings {
+ assert.Len(t, emb, EmbeddingDim, "Embedding %d should have correct dimension", i)
+ for j, v := range emb {
+ assert.Equal(t, float32(0), v, "Embedding %d[%d] should be zero", i, j)
+ }
+ }
+}
+
+// TestEmbed_Concurrent tests concurrent embedding calls.
+func TestEmbed_Concurrent(t *testing.T) {
+ svc, err := NewService()
+ require.NoError(t, err)
+ defer svc.Close()
+
+ var wg sync.WaitGroup
+ numGoroutines := 10
+
+ errors := make(chan error, numGoroutines)
+ embeddings := make([][]float32, numGoroutines)
+
+ for i := 0; i < numGoroutines; i++ {
+ wg.Add(1)
+ go func(idx int) {
+ defer wg.Done()
+
+ text := "Test text for concurrent embedding test"
+ emb, err := svc.Embed(text)
+ if err != nil {
+ errors <- err
+ return
+ }
+ embeddings[idx] = emb
+ }(i)
+ }
+
+ wg.Wait()
+ close(errors)
+
+ for err := range errors {
+ t.Errorf("Concurrent embedding error: %v", err)
+ }
+
+ // All embeddings should be valid
+ for i, emb := range embeddings {
+ if emb != nil {
+ assert.Len(t, emb, EmbeddingDim, "Embedding %d should have correct dimension", i)
+ }
+ }
+}
+
+// TestEmbed_SpecialCharacters tests embedding text with special characters.
+func TestEmbed_SpecialCharacters(t *testing.T) {
+ svc, err := NewService()
+ require.NoError(t, err)
+ defer svc.Close()
+
+ texts := []string{
+ "Text with unicode: 你好世界 🎉",
+ "Text with newlines:\nLine 1\nLine 2",
+ "Text with tabs:\tColumn1\tColumn2",
+ "Text with quotes: \"quoted\" and 'single'",
+ "Text with code: func main() { fmt.Println(\"hello\") }",
+ }
+
+ for _, text := range texts {
+ t.Run(text[:20], func(t *testing.T) {
+ emb, err := svc.Embed(text)
+ require.NoError(t, err)
+ assert.Len(t, emb, EmbeddingDim)
+ })
+ }
+}
+
+// TestEmbed_LongText tests embedding long text.
+func TestEmbed_LongText(t *testing.T) {
+ svc, err := NewService()
+ require.NoError(t, err)
+ defer svc.Close()
+
+ // Create a long text (tokenizer should truncate appropriately)
+ longText := ""
+ for i := 0; i < 100; i++ {
+ longText += "This is a sentence to make the text very long. "
+ }
+
+ emb, err := svc.Embed(longText)
+ require.NoError(t, err)
+ assert.Len(t, emb, EmbeddingDim)
+}
+
+// TestClose tests closing the service.
+func TestClose(t *testing.T) {
+ svc, err := NewService()
+ require.NoError(t, err)
+
+ err = svc.Close()
+ require.NoError(t, err)
+
+ // Session should be nil after close
+ assert.Nil(t, svc.session)
+}
+
+// TestEmbedBatch_SingleItem tests batch embedding with single item.
+func TestEmbedBatch_SingleItem(t *testing.T) {
+ svc, err := NewService()
+ require.NoError(t, err)
+ defer svc.Close()
+
+ texts := []string{"Single text for batch embedding."}
+
+ embeddings, err := svc.EmbedBatch(texts)
+ require.NoError(t, err)
+
+ assert.Len(t, embeddings, 1)
+ assert.Len(t, embeddings[0], EmbeddingDim)
+}
+
+// TestEmbed_Deterministic tests that embedding is deterministic.
+func TestEmbed_Deterministic(t *testing.T) {
+ svc, err := NewService()
+ require.NoError(t, err)
+ defer svc.Close()
+
+ text := "Test text for deterministic embedding."
+
+ emb1, err := svc.Embed(text)
+ require.NoError(t, err)
+
+ emb2, err := svc.Embed(text)
+ require.NoError(t, err)
+
+ // Same text should produce same embedding
+ for i := 0; i < EmbeddingDim; i++ {
+ assert.Equal(t, emb1[i], emb2[i], "Embedding should be deterministic at index %d", i)
+ }
+}
+
+// Helper function to calculate cosine similarity
+func cosineSimilarity(a, b []float32) float64 {
+ if len(a) != len(b) {
+ return 0
+ }
+
+ var dotProduct float64
+ var normA float64
+ var normB float64
+
+ for i := range a {
+ dotProduct += float64(a[i]) * float64(b[i])
+ normA += float64(a[i]) * float64(a[i])
+ normB += float64(b[i]) * float64(b[i])
+ }
+
+ if normA == 0 || normB == 0 {
+ return 0
+ }
+
+ return dotProduct / (math.Sqrt(normA) * math.Sqrt(normB))
+}
diff --git a/internal/mcp/server_test.go b/internal/mcp/server_test.go
index 1b4294d..6afce82 100644
--- a/internal/mcp/server_test.go
+++ b/internal/mcp/server_test.go
@@ -597,3 +597,363 @@ func TestToolListContainsExpectedSchemas(t *testing.T) {
assert.True(t, hasType, "tool %s schema should have type", tool.Name)
}
}
+
+// TestHandleToolsCall_UnknownTool tests tools/call with unknown tool name.
+func TestHandleToolsCall_UnknownTool(t *testing.T) {
+ server := NewServer(nil, "1.0.0")
+ ctx := context.Background()
+
+ req := &Request{
+ JSONRPC: "2.0",
+ ID: 1,
+ Method: "tools/call",
+ Params: json.RawMessage(`{"name":"unknown_tool","arguments":{}}`),
+ }
+
+ resp := server.handleToolsCall(ctx, req)
+ require.NotNil(t, resp.Error)
+ assert.Equal(t, -32000, resp.Error.Code)
+ assert.Contains(t, resp.Error.Data, "unknown tool")
+}
+
+// TestCallTool_ToolNameRecognition tests that valid tool names are recognized (not "unknown tool").
+func TestCallTool_ToolNameRecognition(t *testing.T) {
+ // Note: This test verifies tool routing logic, not execution (which requires searchMgr)
+ // All valid tool names should be in the handleToolsList response
+ server := NewServer(nil, "1.0.0")
+
+ req := &Request{
+ JSONRPC: "2.0",
+ ID: 1,
+ Method: "tools/list",
+ }
+
+ resp := server.handleToolsList(req)
+ result := resp.Result.(map[string]any)
+ tools := result["tools"].([]Tool)
+
+ // Verify all expected tools are registered
+ expectedTools := map[string]bool{
+ "search": true,
+ "timeline": true,
+ "decisions": true,
+ "changes": true,
+ "how_it_works": true,
+ "find_by_concept": true,
+ "find_by_file": true,
+ "find_by_type": true,
+ "get_recent_context": true,
+ "get_context_timeline": true,
+ "get_timeline_by_query": true,
+ }
+
+ foundTools := make(map[string]bool)
+ for _, tool := range tools {
+ foundTools[tool.Name] = true
+ }
+
+ for name := range expectedTools {
+ assert.True(t, foundTools[name], "tool %s should be registered", name)
+ }
+}
+
+// TestRun_MultipleRequests tests Run with multiple sequential requests.
+func TestRun_MultipleRequests(t *testing.T) {
+ var stdout bytes.Buffer
+ req1 := `{"jsonrpc":"2.0","id":1,"method":"initialize"}`
+ req2 := `{"jsonrpc":"2.0","id":2,"method":"tools/list"}`
+ stdin := strings.NewReader(req1 + "\n" + req2 + "\n")
+
+ server := &Server{
+ stdin: stdin,
+ stdout: &stdout,
+ version: "1.0.0",
+ }
+
+ err := server.Run(context.Background())
+ require.NoError(t, err)
+
+ output := stdout.String()
+ // Should contain responses for both requests
+ assert.Contains(t, output, `"id":1`)
+ assert.Contains(t, output, `"id":2`)
+}
+
+// TestHandleTimeline_Defaults tests timeline default values.
+func TestHandleTimeline_Defaults(t *testing.T) {
+ // Test that handleTimeline sets default before/after values
+ params := TimelineParams{
+ AnchorID: 0,
+ Query: "",
+ Before: 0,
+ After: 0,
+ }
+
+ // Simulate the default value assignment from handleTimeline
+ if params.Before <= 0 {
+ params.Before = 10
+ }
+ if params.After <= 0 {
+ params.After = 10
+ }
+
+ assert.Equal(t, 10, params.Before)
+ assert.Equal(t, 10, params.After)
+}
+
+// TestTimelineParams_Complete tests complete TimelineParams parsing.
+func TestTimelineParams_Complete(t *testing.T) {
+ input := `{
+ "anchor_id": 100,
+ "query": "test query",
+ "before": 5,
+ "after": 15,
+ "project": "my-project",
+ "obs_type": "bugfix",
+ "concepts": "security,auth",
+ "files": "main.go,handler.go",
+ "dateStart": 1700000000000,
+ "dateEnd": 1700100000000,
+ "format": "full"
+ }`
+
+ var params TimelineParams
+ err := json.Unmarshal([]byte(input), ¶ms)
+ require.NoError(t, err)
+
+ assert.Equal(t, int64(100), params.AnchorID)
+ assert.Equal(t, "test query", params.Query)
+ assert.Equal(t, 5, params.Before)
+ assert.Equal(t, 15, params.After)
+ assert.Equal(t, "my-project", params.Project)
+ assert.Equal(t, "bugfix", params.ObsType)
+ assert.Equal(t, "security,auth", params.Concepts)
+ assert.Equal(t, "main.go,handler.go", params.Files)
+ assert.Equal(t, int64(1700000000000), params.DateStart)
+ assert.Equal(t, int64(1700100000000), params.DateEnd)
+ assert.Equal(t, "full", params.Format)
+}
+
+// TestServerStdinStdoutConfig tests that server stdin/stdout can be configured.
+func TestServerStdinStdoutConfig(t *testing.T) {
+ var stdout bytes.Buffer
+ var stdin bytes.Buffer
+
+ server := &Server{
+ stdin: &stdin,
+ stdout: &stdout,
+ version: "test-version",
+ }
+
+ assert.Equal(t, &stdin, server.stdin)
+ assert.Equal(t, &stdout, server.stdout)
+ assert.Equal(t, "test-version", server.version)
+}
+
+// TestResponseIDTypes tests that response IDs can be various types.
+func TestResponseIDTypes(t *testing.T) {
+ tests := []struct {
+ name string
+ id any
+ }{
+ {"integer id", 1},
+ {"string id", "abc-123"},
+ {"float id", 1.5},
+ {"null id", nil},
+ }
+
+ for _, tt := range tests {
+ t.Run(tt.name, func(t *testing.T) {
+ var buf bytes.Buffer
+ server := &Server{stdout: &buf}
+
+ resp := &Response{
+ JSONRPC: "2.0",
+ ID: tt.id,
+ Result: "ok",
+ }
+
+ server.sendResponse(resp)
+ output := buf.String()
+ assert.Contains(t, output, `"jsonrpc":"2.0"`)
+ })
+ }
+}
+
+// TestHandleTimelineByQuery_EmptyQuery tests timeline by query with empty query.
+func TestHandleTimelineByQuery_EmptyQuery(t *testing.T) {
+ server := NewServer(nil, "1.0.0")
+ ctx := context.Background()
+
+ // Empty query should error
+ _, err := server.handleTimelineByQuery(ctx, json.RawMessage(`{}`))
+ require.Error(t, err)
+ assert.Contains(t, err.Error(), "query is required")
+}
+
+// TestHandleTimeline_InvalidJSON tests timeline with invalid JSON.
+func TestHandleTimeline_InvalidJSON(t *testing.T) {
+ server := NewServer(nil, "1.0.0")
+ ctx := context.Background()
+
+ _, err := server.handleTimeline(ctx, json.RawMessage(`{invalid`))
+ require.Error(t, err)
+ assert.Contains(t, err.Error(), "invalid timeline params")
+}
+
+// TestHandleTimelineByQuery_InvalidJSON tests timeline by query with invalid JSON.
+func TestHandleTimelineByQuery_InvalidJSON(t *testing.T) {
+ server := NewServer(nil, "1.0.0")
+ ctx := context.Background()
+
+ _, err := server.handleTimelineByQuery(ctx, json.RawMessage(`{invalid`))
+ require.Error(t, err)
+ assert.Contains(t, err.Error(), "invalid timeline params")
+}
+
+// TestHandleTimeline_NoAnchorNoQuery tests timeline with no anchor and no query.
+func TestHandleTimeline_NoAnchorNoQuery(t *testing.T) {
+ server := NewServer(nil, "1.0.0")
+ ctx := context.Background()
+
+ // No anchor_id and no query should return empty result
+ result, err := server.handleTimeline(ctx, json.RawMessage(`{}`))
+ require.NoError(t, err)
+ assert.NotNil(t, result)
+ assert.Empty(t, result.Results)
+}
+
+// TestHandleTimeline_WithDefaults tests timeline default values are applied.
+func TestHandleTimeline_WithDefaults(t *testing.T) {
+ server := NewServer(nil, "1.0.0")
+ ctx := context.Background()
+
+ // With anchor_id but no before/after, defaults should be applied
+ // However, since searchMgr is nil, this will fail after defaults are applied
+ result, err := server.handleTimeline(ctx, json.RawMessage(`{"anchor_id": 0}`))
+ // Should return empty result since anchor_id is 0
+ require.NoError(t, err)
+ assert.NotNil(t, result)
+ assert.Empty(t, result.Results)
+}
+
+// TestServerFields tests Server struct fields.
+func TestServerFields(t *testing.T) {
+ server := NewServer(nil, "2.0.0")
+
+ assert.Equal(t, "2.0.0", server.version)
+ assert.Nil(t, server.searchMgr)
+ assert.NotNil(t, server.stdin)
+ assert.NotNil(t, server.stdout)
+}
+
+// TestRequestUnmarshalWithNullID tests Request unmarshaling with null ID.
+func TestRequestUnmarshalWithNullID(t *testing.T) {
+ input := `{"jsonrpc":"2.0","id":null,"method":"initialize"}`
+
+ var req Request
+ err := json.Unmarshal([]byte(input), &req)
+ require.NoError(t, err)
+ assert.Equal(t, "2.0", req.JSONRPC)
+ assert.Nil(t, req.ID)
+ assert.Equal(t, "initialize", req.Method)
+}
+
+// TestResponseWithNullError tests Response without error.
+func TestResponseWithNullError(t *testing.T) {
+ resp := Response{
+ JSONRPC: "2.0",
+ ID: 1,
+ Result: "success",
+ Error: nil,
+ }
+
+ data, err := json.Marshal(resp)
+ require.NoError(t, err)
+ assert.Contains(t, string(data), `"result":"success"`)
+ assert.NotContains(t, string(data), `"error"`)
+}
+
+// TestErrorWithNilData tests Error without data.
+func TestErrorWithNilData(t *testing.T) {
+ err := Error{
+ Code: -32600,
+ Message: "Invalid Request",
+ Data: nil,
+ }
+
+ data, errMarshal := json.Marshal(err)
+ require.NoError(t, errMarshal)
+ assert.Contains(t, string(data), `"code":-32600`)
+ assert.Contains(t, string(data), `"message":"Invalid Request"`)
+ assert.NotContains(t, string(data), `"data"`)
+}
+
+// TestToolInputSchema tests that tool input schemas have required fields.
+func TestToolInputSchema(t *testing.T) {
+ server := NewServer(nil, "1.0.0")
+
+ req := &Request{
+ JSONRPC: "2.0",
+ ID: 1,
+ Method: "tools/list",
+ }
+
+ resp := server.handleToolsList(req)
+ result := resp.Result.(map[string]any)
+ tools := result["tools"].([]Tool)
+
+ for _, tool := range tools {
+ schema := tool.InputSchema
+ schemaType, ok := schema["type"]
+ assert.True(t, ok, "tool %s schema should have type", tool.Name)
+ assert.Equal(t, "object", schemaType, "tool %s schema type should be object", tool.Name)
+
+ // All tools should have properties
+ _, hasProperties := schema["properties"]
+ assert.True(t, hasProperties, "tool %s should have properties", tool.Name)
+ }
+}
+
+// TestRunMixedRequests tests Run with mixed valid and invalid requests.
+func TestRunMixedRequests(t *testing.T) {
+ var stdout bytes.Buffer
+ req1 := `{"jsonrpc":"2.0","id":1,"method":"initialize"}`
+ req2 := `invalid json`
+ req3 := `{"jsonrpc":"2.0","id":3,"method":"tools/list"}`
+ stdin := strings.NewReader(req1 + "\n" + req2 + "\n" + req3 + "\n")
+
+ server := &Server{
+ stdin: stdin,
+ stdout: &stdout,
+ version: "1.0.0",
+ }
+
+ err := server.Run(context.Background())
+ require.NoError(t, err)
+
+ output := stdout.String()
+ // Should have responses for all three requests
+ assert.Contains(t, output, `"id":1`)
+ assert.Contains(t, output, `"error"`) // Parse error for invalid json
+ assert.Contains(t, output, `"id":3`)
+}
+
+// TestToolCallParamsWithComplexArgs tests ToolCallParams with complex arguments.
+func TestToolCallParamsWithComplexArgs(t *testing.T) {
+ input := `{
+ "name": "search",
+ "arguments": {
+ "query": "authentication bug",
+ "project": "my-project",
+ "limit": 10,
+ "type": "observations"
+ }
+ }`
+
+ var params ToolCallParams
+ err := json.Unmarshal([]byte(input), ¶ms)
+ require.NoError(t, err)
+ assert.Equal(t, "search", params.Name)
+ assert.NotEmpty(t, params.Arguments)
+}
diff --git a/internal/search/integration_test.go b/internal/search/integration_test.go
new file mode 100644
index 0000000..a1e0358
--- /dev/null
+++ b/internal/search/integration_test.go
@@ -0,0 +1,646 @@
+// Package search provides unified search capabilities for claude-mnemonic.
+package search
+
+import (
+ "context"
+ "database/sql"
+ "os"
+ "testing"
+
+ "github.com/lukaszraczylo/claude-mnemonic/internal/db/sqlite"
+ "github.com/lukaszraczylo/claude-mnemonic/pkg/models"
+ "github.com/stretchr/testify/assert"
+ "github.com/stretchr/testify/require"
+ "github.com/stretchr/testify/suite"
+
+ // Import sqlite driver
+ _ "github.com/mattn/go-sqlite3"
+)
+
+// Ensure context is used (for later tests)
+var _ = context.Background
+
+// hasFTS5 checks if FTS5 is available in the SQLite build.
+func hasFTS5(t *testing.T) bool {
+ t.Helper()
+
+ tmpDir, err := os.MkdirTemp("", "fts5-check-*")
+ if err != nil {
+ return false
+ }
+ defer func() { _ = os.RemoveAll(tmpDir) }()
+
+ dbPath := tmpDir + "/check.db"
+ db, err := sql.Open("sqlite3", dbPath)
+ if err != nil {
+ return false
+ }
+ defer func() { _ = db.Close() }()
+
+ _, err = db.Exec("CREATE VIRTUAL TABLE IF NOT EXISTS fts5_test USING fts5(content)")
+ if err != nil {
+ return false
+ }
+ _, _ = db.Exec("DROP TABLE IF EXISTS fts5_test")
+ return true
+}
+
+// testStore creates a sqlite.Store with a temporary database for testing.
+func testStore(t *testing.T) (*sqlite.Store, func()) {
+ t.Helper()
+
+ if !hasFTS5(t) {
+ t.Skip("FTS5 not available in this SQLite build")
+ }
+
+ tmpDir, err := os.MkdirTemp("", "search-integration-test-*")
+ require.NoError(t, err)
+
+ dbPath := tmpDir + "/test.db"
+
+ store, err := sqlite.NewStore(sqlite.StoreConfig{
+ Path: dbPath,
+ MaxConns: 1,
+ WALMode: true,
+ })
+ require.NoError(t, err)
+
+ cleanup := func() {
+ _ = store.Close()
+ _ = os.RemoveAll(tmpDir)
+ }
+
+ return store, cleanup
+}
+
+// SearchIntegrationSuite tests search with real SQLite stores.
+type SearchIntegrationSuite struct {
+ suite.Suite
+ store *sqlite.Store
+ cleanup func()
+ manager *Manager
+ obsStore *sqlite.ObservationStore
+ sumStore *sqlite.SummaryStore
+ prmStore *sqlite.PromptStore
+}
+
+func (s *SearchIntegrationSuite) SetupTest() {
+ if !hasFTS5(s.T()) {
+ s.T().Skip("FTS5 not available in this SQLite build")
+ }
+
+ s.store, s.cleanup = testStore(s.T())
+
+ // Create real stores backed by SQLite
+ s.obsStore = sqlite.NewObservationStore(s.store)
+ s.sumStore = sqlite.NewSummaryStore(s.store)
+ s.prmStore = sqlite.NewPromptStore(s.store)
+
+ // Create search manager with real stores (no vector client for now)
+ s.manager = NewManager(s.obsStore, s.sumStore, s.prmStore, nil)
+}
+
+func (s *SearchIntegrationSuite) TearDownTest() {
+ if s.cleanup != nil {
+ s.cleanup()
+ }
+}
+
+func TestSearchIntegrationSuite(t *testing.T) {
+ suite.Run(t, new(SearchIntegrationSuite))
+}
+
+// seedObservations inserts test observations into the database.
+func (s *SearchIntegrationSuite) seedObservations(ctx context.Context) []int64 {
+ var ids []int64
+
+ // Observation 1: Authentication bug fix
+ obs1 := &models.ParsedObservation{
+ Type: models.ObsTypeBugfix,
+ Scope: models.ScopeProject,
+ Title: "Fixed Authentication Bug",
+ Narrative: "Resolved JWT token validation issue that caused intermittent login failures",
+ Concepts: []string{"authentication", "jwt", "security"},
+ FilesRead: []string{"auth/handler.go", "auth/jwt.go"},
+ }
+ id1, _, err := s.obsStore.StoreObservation(ctx, "sdk-sess-1", "test-project", obs1, 1, 100)
+ s.Require().NoError(err)
+ ids = append(ids, id1)
+
+ // Observation 2: Database optimization decision
+ obs2 := &models.ParsedObservation{
+ Type: models.ObsTypeDecision,
+ Scope: models.ScopeProject,
+ Title: "Database Query Optimization Decision",
+ Narrative: "Decided to add indexes on user_id and created_at columns for better performance",
+ Concepts: []string{"database", "performance", "decision"},
+ FilesRead: []string{"db/migrations/001.sql"},
+ }
+ id2, _, err := s.obsStore.StoreObservation(ctx, "sdk-sess-1", "test-project", obs2, 2, 150)
+ s.Require().NoError(err)
+ ids = append(ids, id2)
+
+ // Observation 3: Global best practice
+ obs3 := &models.ParsedObservation{
+ Type: models.ObsTypeDiscovery,
+ Scope: models.ScopeGlobal,
+ Title: "Error Handling Best Practice",
+ Narrative: "Use wrapped errors with context for better debugging: errors.Wrap(err, context)",
+ Concepts: []string{"best-practice", "errors", "patterns"},
+ FilesRead: []string{"pkg/errors/errors.go"},
+ }
+ id3, _, err := s.obsStore.StoreObservation(ctx, "sdk-sess-2", "other-project", obs3, 1, 80)
+ s.Require().NoError(err)
+ ids = append(ids, id3)
+
+ // Observation 4: Code change/refactoring
+ obs4 := &models.ParsedObservation{
+ Type: models.ObsTypeChange,
+ Scope: models.ScopeProject,
+ Title: "Refactored User Service",
+ Narrative: "Changed user service to use repository pattern, modified interfaces for better testability",
+ Concepts: []string{"refactoring", "architecture"},
+ FilesModified: []string{"services/user.go", "services/user_test.go"},
+ }
+ id4, _, err := s.obsStore.StoreObservation(ctx, "sdk-sess-1", "test-project", obs4, 3, 200)
+ s.Require().NoError(err)
+ ids = append(ids, id4)
+
+ return ids
+}
+
+// seedSummaries inserts test session summaries into the database.
+func (s *SearchIntegrationSuite) seedSummaries(ctx context.Context) []int64 {
+ var ids []int64
+
+ // Summary 1
+ sum1 := &models.ParsedSummary{
+ Request: "Fix authentication bug",
+ Investigated: "JWT token validation and session handling",
+ Learned: "JWT validation requires algorithm check to prevent alg:none attacks",
+ Completed: "Fixed JWT validation, added tests",
+ NextSteps: "Review other security endpoints",
+ }
+ id1, _, err := s.sumStore.StoreSummary(ctx, "sdk-sess-1", "test-project", sum1, 1, 100)
+ s.Require().NoError(err)
+ ids = append(ids, id1)
+
+ // Summary 2
+ sum2 := &models.ParsedSummary{
+ Request: "Optimize database queries",
+ Investigated: "Query execution plans and index usage",
+ Learned: "Composite indexes work better for range queries",
+ Completed: "Added indexes, verified performance improvement",
+ NextSteps: "Monitor query times in production",
+ }
+ id2, _, err := s.sumStore.StoreSummary(ctx, "sdk-sess-1", "test-project", sum2, 2, 150)
+ s.Require().NoError(err)
+ ids = append(ids, id2)
+
+ return ids
+}
+
+// TestFilterSearch_WithRealStores tests filterSearch with seeded data.
+func (s *SearchIntegrationSuite) TestFilterSearch_WithRealStores() {
+ ctx := context.Background()
+
+ // Seed test data
+ obsIDs := s.seedObservations(ctx)
+ sumIDs := s.seedSummaries(ctx)
+ s.Require().Len(obsIDs, 4)
+ s.Require().Len(sumIDs, 2)
+
+ // Test filter search for observations only
+ result, err := s.manager.filterSearch(ctx, SearchParams{
+ Project: "test-project",
+ Type: "observations",
+ Limit: 10,
+ Format: "full",
+ })
+ s.Require().NoError(err)
+ s.NotNil(result)
+
+ // Should return project observations + global observation (4 total: 3 project + 1 global)
+ s.GreaterOrEqual(len(result.Results), 3)
+
+ // Verify result types
+ for _, r := range result.Results {
+ s.Equal("observation", r.Type)
+ }
+}
+
+// TestFilterSearch_SessionsOnly tests filterSearch for sessions.
+func (s *SearchIntegrationSuite) TestFilterSearch_SessionsOnly() {
+ ctx := context.Background()
+
+ // Seed test data
+ _ = s.seedObservations(ctx)
+ sumIDs := s.seedSummaries(ctx)
+ s.Require().Len(sumIDs, 2)
+
+ // Test filter search for sessions only
+ result, err := s.manager.filterSearch(ctx, SearchParams{
+ Project: "test-project",
+ Type: "sessions",
+ Limit: 10,
+ Format: "full",
+ })
+ s.Require().NoError(err)
+ s.NotNil(result)
+
+ // Should return 2 summaries
+ s.Len(result.Results, 2)
+
+ // Verify result types
+ for _, r := range result.Results {
+ s.Equal("session", r.Type)
+ s.NotEmpty(r.Title) // Title should be populated from Request
+ }
+}
+
+// TestFilterSearch_AllTypes tests filterSearch for all types.
+func (s *SearchIntegrationSuite) TestFilterSearch_AllTypes() {
+ ctx := context.Background()
+
+ // Seed test data
+ obsIDs := s.seedObservations(ctx)
+ sumIDs := s.seedSummaries(ctx)
+ s.Require().Len(obsIDs, 4)
+ s.Require().Len(sumIDs, 2)
+
+ // Test filter search for all types (Type = "")
+ result, err := s.manager.filterSearch(ctx, SearchParams{
+ Project: "test-project",
+ Type: "", // All types
+ Limit: 20,
+ Format: "full",
+ })
+ s.Require().NoError(err)
+ s.NotNil(result)
+
+ // Should return both observations and sessions
+ hasObservations := false
+ hasSessions := false
+ for _, r := range result.Results {
+ if r.Type == "observation" {
+ hasObservations = true
+ }
+ if r.Type == "session" {
+ hasSessions = true
+ }
+ }
+ s.True(hasObservations, "Should have observation results")
+ s.True(hasSessions, "Should have session results")
+}
+
+// TestUnifiedSearch_DefaultLimit tests UnifiedSearch with default limit.
+func (s *SearchIntegrationSuite) TestUnifiedSearch_DefaultLimit() {
+ ctx := context.Background()
+
+ // Seed test data
+ s.seedObservations(ctx)
+ s.seedSummaries(ctx)
+
+ // Test with no limit specified (should default to 20)
+ result, err := s.manager.UnifiedSearch(ctx, SearchParams{
+ Project: "test-project",
+ })
+ s.Require().NoError(err)
+ s.NotNil(result)
+ s.LessOrEqual(len(result.Results), 20)
+}
+
+// TestUnifiedSearch_LimitCapping tests UnifiedSearch limit capping.
+func (s *SearchIntegrationSuite) TestUnifiedSearch_LimitCapping() {
+ ctx := context.Background()
+
+ // Seed test data
+ s.seedObservations(ctx)
+ s.seedSummaries(ctx)
+
+ // Test with limit > 100 (should be capped to 100)
+ result, err := s.manager.UnifiedSearch(ctx, SearchParams{
+ Project: "test-project",
+ Limit: 500,
+ })
+ s.Require().NoError(err)
+ s.NotNil(result)
+ s.LessOrEqual(len(result.Results), 100)
+}
+
+// TestDecisions_WithRealStores tests the Decisions method falls back to filterSearch.
+func (s *SearchIntegrationSuite) TestDecisions_WithRealStores() {
+ ctx := context.Background()
+
+ // Seed test data
+ s.seedObservations(ctx)
+
+ // Test Decisions search (without vector client, falls back to filterSearch)
+ result, err := s.manager.Decisions(ctx, SearchParams{
+ Project: "test-project",
+ Query: "database",
+ Limit: 10,
+ })
+ s.Require().NoError(err)
+ s.NotNil(result)
+
+ // Without vector client, falls back to filterSearch which returns observations
+ // All results should be observations (type is forced to "observations" in Decisions)
+ for _, r := range result.Results {
+ s.Equal("observation", r.Type)
+ }
+}
+
+// TestChanges_WithRealStores tests the Changes method falls back to filterSearch.
+func (s *SearchIntegrationSuite) TestChanges_WithRealStores() {
+ ctx := context.Background()
+
+ // Seed test data
+ s.seedObservations(ctx)
+
+ // Test Changes search (without vector client, falls back to filterSearch)
+ result, err := s.manager.Changes(ctx, SearchParams{
+ Project: "test-project",
+ Query: "user service",
+ Limit: 10,
+ })
+ s.Require().NoError(err)
+ s.NotNil(result)
+
+ // Without vector client, falls back to filterSearch which returns observations
+ // All results should be observations (type is forced to "observations" in Changes)
+ for _, r := range result.Results {
+ s.Equal("observation", r.Type)
+ }
+}
+
+// TestHowItWorks_WithRealStores tests the HowItWorks method falls back to filterSearch.
+func (s *SearchIntegrationSuite) TestHowItWorks_WithRealStores() {
+ ctx := context.Background()
+
+ // Seed test data
+ s.seedObservations(ctx)
+
+ // Test HowItWorks search (without vector client, falls back to filterSearch)
+ result, err := s.manager.HowItWorks(ctx, SearchParams{
+ Project: "test-project",
+ Query: "authentication",
+ Limit: 10,
+ })
+ s.Require().NoError(err)
+ s.NotNil(result)
+
+ // Without vector client, falls back to filterSearch which returns observations
+ // All results should be observations (type is forced to "observations" in HowItWorks)
+ for _, r := range result.Results {
+ s.Equal("observation", r.Type)
+ }
+}
+
+// TestObservationToResult tests observation to result conversion with full format.
+func (s *SearchIntegrationSuite) TestObservationToResult_FullFormat() {
+ ctx := context.Background()
+
+ // Insert single observation
+ obs := &models.ParsedObservation{
+ Type: models.ObsTypeDiscovery,
+ Scope: models.ScopeProject,
+ Title: "Test Title",
+ Narrative: "Detailed narrative content for testing",
+ Concepts: []string{"testing", "content"},
+ }
+ id, _, err := s.obsStore.StoreObservation(ctx, "sdk-test", "test-project", obs, 1, 50)
+ s.Require().NoError(err)
+
+ // Retrieve and convert
+ retrieved, err := s.obsStore.GetObservationByID(ctx, id)
+ s.Require().NoError(err)
+ s.Require().NotNil(retrieved)
+
+ result := s.manager.observationToResult(retrieved, "full")
+
+ s.Equal("observation", result.Type)
+ s.Equal(id, result.ID)
+ s.Equal("Test Title", result.Title)
+ s.Equal("Detailed narrative content for testing", result.Content)
+ s.Equal("test-project", result.Project)
+ s.Equal("project", result.Scope)
+ s.NotNil(result.Metadata)
+ s.Equal("discovery", result.Metadata["obs_type"])
+}
+
+// TestObservationToResult_IndexFormat tests index format (no content).
+func (s *SearchIntegrationSuite) TestObservationToResult_IndexFormat() {
+ ctx := context.Background()
+
+ obs := &models.ParsedObservation{
+ Type: models.ObsTypeBugfix,
+ Scope: models.ScopeGlobal,
+ Title: "Bug Fix Title",
+ Narrative: "This should not appear in index format",
+ }
+ id, _, err := s.obsStore.StoreObservation(ctx, "sdk-test", "test-project", obs, 1, 50)
+ s.Require().NoError(err)
+
+ retrieved, err := s.obsStore.GetObservationByID(ctx, id)
+ s.Require().NoError(err)
+
+ result := s.manager.observationToResult(retrieved, "index")
+
+ s.Equal("observation", result.Type)
+ s.Equal("Bug Fix Title", result.Title)
+ s.Empty(result.Content, "Index format should not include content")
+ s.Equal("global", result.Scope)
+}
+
+// TestSummaryToResult_FullFormat tests summary to result conversion.
+func (s *SearchIntegrationSuite) TestSummaryToResult_FullFormat() {
+ ctx := context.Background()
+
+ sum := &models.ParsedSummary{
+ Request: "Implement new feature",
+ Learned: "Learned important lessons about testing",
+ }
+ id, _, err := s.sumStore.StoreSummary(ctx, "sdk-test", "test-project", sum, 1, 50)
+ s.Require().NoError(err)
+
+ // Retrieve via GetRecentSummaries since there's no GetByID
+ summaries, err := s.sumStore.GetRecentSummaries(ctx, "test-project", 10)
+ s.Require().NoError(err)
+ s.Require().NotEmpty(summaries)
+
+ var retrieved *models.SessionSummary
+ for _, s := range summaries {
+ if s.ID == id {
+ retrieved = s
+ break
+ }
+ }
+ s.Require().NotNil(retrieved)
+
+ result := s.manager.summaryToResult(retrieved, "full")
+
+ s.Equal("session", result.Type)
+ s.Equal(id, result.ID)
+ s.Contains(result.Title, "Implement new feature")
+ s.Equal("Learned important lessons about testing", result.Content)
+ s.Equal("test-project", result.Project)
+}
+
+// TestPromptToResult_FullFormat tests prompt to result conversion.
+func (s *SearchIntegrationSuite) TestPromptToResult_FullFormat() {
+ // First create a session
+ ctx := context.Background()
+ sessionStore := sqlite.NewSessionStore(s.store)
+ _, err := sessionStore.CreateSDKSession(ctx, "sdk-prompt-test", "test-project", "initial prompt")
+ s.Require().NoError(err)
+
+ // Save a user prompt
+ promptID, err := s.prmStore.SaveUserPromptWithMatches(ctx, "sdk-prompt-test", 1, "Help me fix this authentication bug", 3)
+ s.Require().NoError(err)
+
+ // Retrieve prompts
+ prompts, err := s.prmStore.GetPromptsByIDs(ctx, []int64{promptID}, "date_desc", 10)
+ s.Require().NoError(err)
+ s.Require().NotEmpty(prompts)
+
+ result := s.manager.promptToResult(prompts[0], "full")
+
+ s.Equal("prompt", result.Type)
+ s.Equal(promptID, result.ID)
+ s.Contains(result.Title, "Help me fix")
+ s.Equal("Help me fix this authentication bug", result.Content)
+}
+
+// TestTruncate_TableDriven tests truncation with various inputs.
+func TestTruncate_TableDriven(t *testing.T) {
+ tests := []struct {
+ name string
+ input string
+ maxLen int
+ expected string
+ }{
+ {"short_string", "hello", 10, "hello"},
+ {"exact_length", "hello", 5, "hello"},
+ {"long_string", "hello world", 5, "hello..."},
+ {"empty_string", "", 10, ""},
+ {"whitespace_only", " ", 10, ""},
+ {"with_leading_space", " hello ", 10, "hello"},
+ {"very_long", "this is a very long string that should be truncated", 20, "this is a very long ..."},
+ }
+
+ for _, tt := range tests {
+ t.Run(tt.name, func(t *testing.T) {
+ result := truncate(tt.input, tt.maxLen)
+ assert.Equal(t, tt.expected, result)
+ })
+ }
+}
+
+// TestManagerWithNilStores tests that Manager handles nil stores gracefully.
+func TestManagerWithNilStores(t *testing.T) {
+ m := NewManager(nil, nil, nil, nil)
+ assert.NotNil(t, m)
+ assert.Nil(t, m.observationStore)
+ assert.Nil(t, m.summaryStore)
+ assert.Nil(t, m.promptStore)
+ assert.Nil(t, m.vectorClient)
+}
+
+// TestSearchResultMetadataFields tests all metadata fields with real data.
+func (s *SearchIntegrationSuite) TestSearchResultMetadataFields() {
+ ctx := context.Background()
+
+ obs := &models.ParsedObservation{
+ Type: models.ObsTypeDecision,
+ Scope: models.ScopeGlobal,
+ Title: "Architecture Decision",
+ Concepts: []string{"auth", "security"},
+ FilesRead: []string{"handler.go", "auth.go"},
+ }
+ id, _, err := s.obsStore.StoreObservation(ctx, "sdk-meta-test", "test-project", obs, 1, 50)
+ s.Require().NoError(err)
+
+ retrieved, err := s.obsStore.GetObservationByID(ctx, id)
+ s.Require().NoError(err)
+
+ result := s.manager.observationToResult(retrieved, "full")
+
+ // Check metadata fields
+ s.NotNil(result.Metadata)
+ s.Equal("decision", result.Metadata["obs_type"])
+ s.Equal("global", result.Metadata["scope"])
+ s.Equal("global", result.Scope)
+}
+
+// TestObservationToResult_AllFormats tests different format options.
+func TestObservationToResult_AllFormats(t *testing.T) {
+ m := NewManager(nil, nil, nil, nil)
+
+ obs := &models.Observation{
+ ID: 1,
+ Project: "test",
+ Type: models.ObsTypeBugfix,
+ Scope: models.ScopeProject,
+ Title: sql.NullString{String: "Bug Fix Title", Valid: true},
+ Narrative: sql.NullString{String: "Detailed bug fix narrative", Valid: true},
+ CreatedAtEpoch: 1704067200000,
+ }
+
+ tests := []struct {
+ name string
+ format string
+ expectContent bool
+ }{
+ {"full_format", "full", true},
+ {"index_format", "index", false},
+ {"empty_format", "", false},
+ }
+
+ for _, tt := range tests {
+ t.Run(tt.name, func(t *testing.T) {
+ result := m.observationToResult(obs, tt.format)
+ assert.Equal(t, "observation", result.Type)
+ assert.Equal(t, int64(1), result.ID)
+ if tt.expectContent {
+ assert.NotEmpty(t, result.Content)
+ }
+ })
+ }
+}
+
+// TestSummaryToResult_AllFormats tests different format options for summaries.
+func TestSummaryToResult_AllFormats(t *testing.T) {
+ m := NewManager(nil, nil, nil, nil)
+
+ summary := &models.SessionSummary{
+ ID: 1,
+ Project: "test",
+ Request: sql.NullString{String: "Test request", Valid: true},
+ Learned: sql.NullString{String: "Test learned", Valid: true},
+ Completed: sql.NullString{String: "Test completed", Valid: true},
+ NextSteps: sql.NullString{String: "Test next steps", Valid: true},
+ CreatedAtEpoch: 1704067200000,
+ }
+
+ tests := []struct {
+ name string
+ format string
+ expectContent bool
+ }{
+ {"full_format", "full", true},
+ {"index_format", "index", false},
+ {"empty_format", "", false},
+ }
+
+ for _, tt := range tests {
+ t.Run(tt.name, func(t *testing.T) {
+ result := m.summaryToResult(summary, tt.format)
+ assert.Equal(t, "session", result.Type)
+ assert.Equal(t, int64(1), result.ID)
+ if tt.expectContent {
+ assert.NotEmpty(t, result.Content)
+ }
+ })
+ }
+}
diff --git a/internal/search/manager_test.go b/internal/search/manager_test.go
index 427469e..bdd077b 100644
--- a/internal/search/manager_test.go
+++ b/internal/search/manager_test.go
@@ -4,6 +4,7 @@ package search
import (
"database/sql"
"testing"
+ "time"
"github.com/lukaszraczylo/claude-mnemonic/pkg/models"
"github.com/stretchr/testify/assert"
@@ -594,3 +595,482 @@ func TestSearchTypeMapping(t *testing.T) {
})
}
}
+
+// TestFilterSearchWithObservations tests filter search when observations exist.
+func TestFilterSearchWithObservations(t *testing.T) {
+ // Create mock observation
+ obs := &models.Observation{
+ ID: 1,
+ Project: "test-project",
+ Type: models.ObsTypeDiscovery,
+ Scope: models.ScopeProject,
+ Title: sql.NullString{String: "Test Title", Valid: true},
+ Narrative: sql.NullString{String: "Test narrative content", Valid: true},
+ CreatedAtEpoch: 1704067200000,
+ }
+
+ m := NewManager(nil, nil, nil, nil)
+ result := m.observationToResult(obs, "full")
+
+ assert.Equal(t, "observation", result.Type)
+ assert.Equal(t, int64(1), result.ID)
+ assert.Equal(t, "Test Title", result.Title)
+ assert.Equal(t, "Test narrative content", result.Content)
+ assert.Equal(t, "test-project", result.Project)
+ assert.Equal(t, "project", result.Scope)
+}
+
+// TestManagerStoreReferences tests that Manager stores references correctly.
+func TestManagerStoreReferences(t *testing.T) {
+ m := NewManager(nil, nil, nil, nil)
+
+ assert.Nil(t, m.observationStore)
+ assert.Nil(t, m.summaryStore)
+ assert.Nil(t, m.promptStore)
+ assert.Nil(t, m.vectorClient)
+}
+
+// TestObservationToResultWithMetadata tests metadata inclusion in results.
+func TestObservationToResultWithMetadata(t *testing.T) {
+ m := NewManager(nil, nil, nil, nil)
+
+ tests := []struct {
+ name string
+ obsType models.ObservationType
+ scope models.ObservationScope
+ }{
+ {"bugfix_project", models.ObsTypeBugfix, models.ScopeProject},
+ {"feature_global", models.ObsTypeFeature, models.ScopeGlobal},
+ {"discovery_project", models.ObsTypeDiscovery, models.ScopeProject},
+ {"decision_global", models.ObsTypeDecision, models.ScopeGlobal},
+ {"refactor_project", models.ObsTypeRefactor, models.ScopeProject},
+ {"change_global", models.ObsTypeChange, models.ScopeGlobal},
+ }
+
+ for _, tt := range tests {
+ t.Run(tt.name, func(t *testing.T) {
+ obs := &models.Observation{
+ ID: 1,
+ Project: "test-project",
+ Type: tt.obsType,
+ Scope: tt.scope,
+ Title: sql.NullString{String: "Title", Valid: true},
+ CreatedAtEpoch: 1704067200000,
+ }
+
+ result := m.observationToResult(obs, "full")
+
+ assert.Equal(t, string(tt.obsType), result.Metadata["obs_type"])
+ assert.Equal(t, string(tt.scope), result.Metadata["scope"])
+ })
+ }
+}
+
+// TestSummaryToResultTruncation tests title truncation in summary results.
+func TestSummaryToResultTruncation(t *testing.T) {
+ m := NewManager(nil, nil, nil, nil)
+
+ tests := []struct {
+ name string
+ request string
+ expectedLen int
+ shouldTrunc bool
+ }{
+ {"short_title", "Short request", 13, false},
+ {"exact_100", string(make([]byte, 100)), 103, true}, // 100 + "..."
+ {"over_100", string(make([]byte, 150)), 103, true},
+ }
+
+ for _, tt := range tests {
+ t.Run(tt.name, func(t *testing.T) {
+ summary := &models.SessionSummary{
+ ID: 1,
+ Project: "test-project",
+ Request: sql.NullString{String: tt.request, Valid: true},
+ CreatedAtEpoch: 1704067200000,
+ }
+
+ result := m.summaryToResult(summary, "full")
+
+ if tt.shouldTrunc {
+ assert.LessOrEqual(t, len(result.Title), tt.expectedLen)
+ assert.True(t, len(result.Title) <= 103) // max 100 + "..."
+ } else {
+ assert.Equal(t, tt.request, result.Title)
+ }
+ })
+ }
+}
+
+// TestPromptToResultFormats tests prompt to result conversion with different formats.
+func TestPromptToResultFormats(t *testing.T) {
+ m := NewManager(nil, nil, nil, nil)
+
+ prompt := &models.UserPromptWithSession{
+ UserPrompt: models.UserPrompt{
+ ID: 123,
+ PromptText: "What is the meaning of life?",
+ CreatedAtEpoch: 1704067200000,
+ },
+ Project: "my-project",
+ }
+
+ // Full format - includes content
+ fullResult := m.promptToResult(prompt, "full")
+ assert.Equal(t, "What is the meaning of life?", fullResult.Content)
+
+ // Index format - no content
+ indexResult := m.promptToResult(prompt, "index")
+ assert.Equal(t, "", indexResult.Content)
+
+ // Both should have same title
+ assert.Equal(t, fullResult.Title, indexResult.Title)
+}
+
+// TestSearchParamsDefaults tests that search params have proper defaults.
+func TestSearchParamsDefaults(t *testing.T) {
+ tests := []struct {
+ name string
+ initialLimit int
+ initialOrder string
+ expectedLimit int
+ expectedOrder string
+ }{
+ {"zero_limit", 0, "", 20, "date_desc"},
+ {"negative_limit", -5, "", 20, "date_desc"},
+ {"over_100_limit", 150, "", 100, "date_desc"},
+ {"valid_limit_50", 50, "relevance", 50, "relevance"},
+ {"custom_order", 30, "date_asc", 30, "date_asc"},
+ }
+
+ for _, tt := range tests {
+ t.Run(tt.name, func(t *testing.T) {
+ params := SearchParams{
+ Query: "test",
+ Project: "project",
+ Limit: tt.initialLimit,
+ OrderBy: tt.initialOrder,
+ }
+
+ // Simulate the normalization that happens in UnifiedSearch
+ if params.Limit <= 0 {
+ params.Limit = 20
+ }
+ if params.Limit > 100 {
+ params.Limit = 100
+ }
+ if params.OrderBy == "" {
+ params.OrderBy = "date_desc"
+ }
+
+ assert.Equal(t, tt.expectedLimit, params.Limit)
+ assert.Equal(t, tt.expectedOrder, params.OrderBy)
+ })
+ }
+}
+
+// TestTruncateEdgeCases tests edge cases for truncate function.
+func TestTruncateEdgeCases(t *testing.T) {
+ tests := []struct {
+ name string
+ input string
+ maxLen int
+ expected string
+ }{
+ // Unicode strings - uses byte length so ensure maxLen accommodates full string
+ {"unicode_string_no_truncate", "日本語テスト", 20, "日本語テスト"},
+ {"mixed_unicode_no_truncate", "Hello世界", 15, "Hello世界"},
+ // ASCII truncation
+ {"ascii_truncate", "Hello World", 5, "Hello..."},
+ {"only_whitespace", " ", 10, ""},
+ {"tabs_and_newlines", "\t\n \t", 10, ""},
+ {"newlines_with_content", "\n\nhello\n\n", 10, "hello"},
+ {"zero_max_len", "hello", 0, "..."},
+ }
+
+ for _, tt := range tests {
+ t.Run(tt.name, func(t *testing.T) {
+ result := truncate(tt.input, tt.maxLen)
+ assert.Equal(t, tt.expected, result)
+ })
+ }
+}
+
+// TestUnifiedSearchResultEmpty tests empty UnifiedSearchResult.
+func TestUnifiedSearchResultEmpty(t *testing.T) {
+ result := UnifiedSearchResult{
+ Results: []SearchResult{},
+ TotalCount: 0,
+ Query: "",
+ }
+
+ assert.Empty(t, result.Results)
+ assert.Equal(t, 0, result.TotalCount)
+ assert.Equal(t, "", result.Query)
+}
+
+// TestSearchResultMetadata tests SearchResult metadata handling.
+func TestSearchResultMetadata(t *testing.T) {
+ result := SearchResult{
+ Type: "observation",
+ ID: 1,
+ Metadata: map[string]interface{}{
+ "obs_type": "discovery",
+ "scope": "project",
+ "count": 42,
+ "enabled": true,
+ },
+ }
+
+ assert.Equal(t, "discovery", result.Metadata["obs_type"])
+ assert.Equal(t, "project", result.Metadata["scope"])
+ assert.Equal(t, 42, result.Metadata["count"])
+ assert.Equal(t, true, result.Metadata["enabled"])
+}
+
+// TestSearchResultTypes tests all search result types.
+func TestSearchResultTypes(t *testing.T) {
+ types := []string{"observation", "session", "prompt"}
+
+ for _, typ := range types {
+ t.Run(typ, func(t *testing.T) {
+ result := SearchResult{
+ Type: typ,
+ ID: 1,
+ Project: "test",
+ CreatedAt: time.Now().UnixMilli(),
+ }
+ assert.Equal(t, typ, result.Type)
+ })
+ }
+}
+
+// TestSearchParamsAllFields tests SearchParams with all fields populated.
+func TestSearchParamsAllFields(t *testing.T) {
+ params := SearchParams{
+ Query: "authentication bug",
+ Type: "observations",
+ Project: "my-project",
+ ObsType: "bugfix",
+ Concepts: "security,auth",
+ Files: "handler.go,auth.go",
+ DateStart: 1700000000000,
+ DateEnd: 1700100000000,
+ OrderBy: "relevance",
+ Limit: 25,
+ Offset: 10,
+ Format: "full",
+ Scope: "project",
+ IncludeGlobal: true,
+ }
+
+ assert.Equal(t, "authentication bug", params.Query)
+ assert.Equal(t, "observations", params.Type)
+ assert.Equal(t, "my-project", params.Project)
+ assert.Equal(t, "bugfix", params.ObsType)
+ assert.Equal(t, "security,auth", params.Concepts)
+ assert.Equal(t, "handler.go,auth.go", params.Files)
+ assert.Equal(t, int64(1700000000000), params.DateStart)
+ assert.Equal(t, int64(1700100000000), params.DateEnd)
+ assert.Equal(t, "relevance", params.OrderBy)
+ assert.Equal(t, 25, params.Limit)
+ assert.Equal(t, 10, params.Offset)
+ assert.Equal(t, "full", params.Format)
+ assert.Equal(t, "project", params.Scope)
+ assert.True(t, params.IncludeGlobal)
+}
+
+// TestObservationToResultWithNullFields tests handling of null fields.
+func TestObservationToResultWithNullFields(t *testing.T) {
+ m := NewManager(nil, nil, nil, nil)
+
+ obs := &models.Observation{
+ ID: 1,
+ Project: "test-project",
+ Type: models.ObsTypeDiscovery,
+ Scope: models.ScopeProject,
+ Title: sql.NullString{Valid: false},
+ Narrative: sql.NullString{Valid: false},
+ CreatedAtEpoch: 1704067200000,
+ }
+
+ result := m.observationToResult(obs, "full")
+
+ assert.Equal(t, "", result.Title)
+ assert.Equal(t, "", result.Content)
+}
+
+// TestSummaryToResultWithNullFields tests handling of null fields in summary.
+func TestSummaryToResultWithNullFields(t *testing.T) {
+ m := NewManager(nil, nil, nil, nil)
+
+ summary := &models.SessionSummary{
+ ID: 1,
+ Project: "test-project",
+ Request: sql.NullString{Valid: false},
+ Learned: sql.NullString{Valid: false},
+ CreatedAtEpoch: 1704067200000,
+ }
+
+ result := m.summaryToResult(summary, "full")
+
+ assert.Equal(t, "", result.Title)
+ assert.Equal(t, "", result.Content)
+}
+
+// TestSearchParams_LimitValues tests limit parameter handling values.
+func TestSearchParams_LimitValues(t *testing.T) {
+ tests := []struct {
+ name string
+ inputLimit int
+ expectedValid bool
+ }{
+ {"zero_limit", 0, true},
+ {"negative_limit", -5, true},
+ {"normal_limit", 20, true},
+ {"max_limit", 100, true},
+ {"over_limit", 200, true},
+ }
+
+ for _, tt := range tests {
+ t.Run(tt.name, func(t *testing.T) {
+ params := SearchParams{
+ Query: "test",
+ Project: "test",
+ Limit: tt.inputLimit,
+ }
+ assert.NotNil(t, params)
+ assert.Equal(t, tt.inputLimit, params.Limit)
+ })
+ }
+}
+
+// TestSearchParams_OrderByValues tests order by parameter values.
+func TestSearchParams_OrderByValues(t *testing.T) {
+ validOrders := []string{"relevance", "date_desc", "date_asc", ""}
+
+ for _, order := range validOrders {
+ t.Run("order_"+order, func(t *testing.T) {
+ params := SearchParams{
+ Query: "test",
+ Project: "test",
+ OrderBy: order,
+ }
+ assert.Equal(t, order, params.OrderBy)
+ })
+ }
+}
+
+// TestSearchParams_TypeValues tests type parameter values.
+func TestSearchParams_TypeValues(t *testing.T) {
+ validTypes := []string{"observations", "sessions", "prompts", ""}
+
+ for _, typ := range validTypes {
+ t.Run("type_"+typ, func(t *testing.T) {
+ params := SearchParams{
+ Query: "test",
+ Project: "test",
+ Type: typ,
+ }
+ assert.Equal(t, typ, params.Type)
+ })
+ }
+}
+
+// TestSearchParams_ScopeValues tests scope parameter values.
+func TestSearchParams_ScopeValues(t *testing.T) {
+ validScopes := []string{"project", "global", ""}
+
+ for _, scope := range validScopes {
+ t.Run("scope_"+scope, func(t *testing.T) {
+ params := SearchParams{
+ Query: "test",
+ Project: "test",
+ Scope: scope,
+ }
+ assert.Equal(t, scope, params.Scope)
+ })
+ }
+}
+
+// TestSearchParams_FormatValues tests format parameter values.
+func TestSearchParams_FormatValues(t *testing.T) {
+ validFormats := []string{"index", "full", ""}
+
+ for _, format := range validFormats {
+ t.Run("format_"+format, func(t *testing.T) {
+ params := SearchParams{
+ Query: "test",
+ Project: "test",
+ Format: format,
+ }
+ assert.Equal(t, format, params.Format)
+ })
+ }
+}
+
+// TestUnifiedSearchResult_MultipleResults tests result with multiple items.
+func TestUnifiedSearchResult_MultipleResults(t *testing.T) {
+ results := []SearchResult{
+ {Type: "observation", ID: 1, Title: "First", Project: "test"},
+ {Type: "session", ID: 2, Title: "Second", Project: "test"},
+ {Type: "prompt", ID: 3, Title: "Third", Project: "test"},
+ }
+
+ result := UnifiedSearchResult{
+ Results: results,
+ TotalCount: 3,
+ Query: "test query",
+ }
+
+ assert.Len(t, result.Results, 3)
+ assert.Equal(t, 3, result.TotalCount)
+ assert.Equal(t, "observation", result.Results[0].Type)
+ assert.Equal(t, "session", result.Results[1].Type)
+ assert.Equal(t, "prompt", result.Results[2].Type)
+}
+
+// TestSearchResult_Metadata tests metadata handling in SearchResult.
+func TestSearchResult_Metadata(t *testing.T) {
+ metadata := map[string]interface{}{
+ "obs_type": "discovery",
+ "concepts": []string{"auth", "security"},
+ "files_count": 5,
+ "is_important": true,
+ }
+
+ result := SearchResult{
+ Type: "observation",
+ ID: 1,
+ Metadata: metadata,
+ }
+
+ assert.Equal(t, "discovery", result.Metadata["obs_type"])
+ assert.Equal(t, 5, result.Metadata["files_count"])
+ assert.Equal(t, true, result.Metadata["is_important"])
+}
+
+// TestSearchResult_Scores tests score handling in SearchResult.
+func TestSearchResult_Scores(t *testing.T) {
+ tests := []struct {
+ name string
+ score float64
+ }{
+ {"perfect_score", 1.0},
+ {"high_score", 0.95},
+ {"medium_score", 0.5},
+ {"low_score", 0.1},
+ {"zero_score", 0.0},
+ }
+
+ for _, tt := range tests {
+ t.Run(tt.name, func(t *testing.T) {
+ result := SearchResult{
+ Type: "observation",
+ ID: 1,
+ Score: tt.score,
+ }
+ assert.Equal(t, tt.score, result.Score)
+ })
+ }
+}
diff --git a/internal/vector/sqlitevec/client_test.go b/internal/vector/sqlitevec/client_test.go
new file mode 100644
index 0000000..9a90440
--- /dev/null
+++ b/internal/vector/sqlitevec/client_test.go
@@ -0,0 +1,502 @@
+package sqlitevec
+
+import (
+ "context"
+ "database/sql"
+ "os"
+ "path/filepath"
+ "testing"
+
+ sqlite_vec "github.com/asg017/sqlite-vec-go-bindings/cgo"
+ "github.com/lukaszraczylo/claude-mnemonic/internal/embedding"
+ _ "github.com/mattn/go-sqlite3"
+ "github.com/stretchr/testify/assert"
+ "github.com/stretchr/testify/require"
+)
+
+// testDB creates a test SQLite database with the vectors table.
+func testDB(t *testing.T) (*sql.DB, func()) {
+ t.Helper()
+
+ // Create temp directory
+ tmpDir, err := os.MkdirTemp("", "sqlitevec-test-*")
+ require.NoError(t, err)
+
+ dbPath := filepath.Join(tmpDir, "test.db")
+ db, err := sql.Open("sqlite3", dbPath)
+ require.NoError(t, err)
+
+ // Enable sqlite-vec
+ sqlite_vec.Auto()
+
+ // Create vectors table (matches production schema)
+ _, err = db.Exec(`
+ CREATE VIRTUAL TABLE IF NOT EXISTS vectors USING vec0(
+ doc_id TEXT PRIMARY KEY,
+ embedding float[384],
+ sqlite_id INTEGER,
+ doc_type TEXT,
+ field_type TEXT,
+ project TEXT,
+ scope TEXT
+ )
+ `)
+ require.NoError(t, err)
+
+ cleanup := func() {
+ db.Close()
+ os.RemoveAll(tmpDir)
+ }
+
+ return db, cleanup
+}
+
+// testEmbeddingService creates a test embedding service.
+func testEmbeddingService(t *testing.T) (*embedding.Service, func()) {
+ t.Helper()
+
+ svc, err := embedding.NewService()
+ require.NoError(t, err)
+
+ cleanup := func() {
+ svc.Close()
+ }
+
+ return svc, cleanup
+}
+
+func TestNewClient_Success(t *testing.T) {
+ db, dbCleanup := testDB(t)
+ defer dbCleanup()
+
+ embedSvc, embedCleanup := testEmbeddingService(t)
+ defer embedCleanup()
+
+ client, err := NewClient(Config{DB: db}, embedSvc)
+ require.NoError(t, err)
+ assert.NotNil(t, client)
+}
+
+func TestNewClient_NilDB(t *testing.T) {
+ embedSvc, embedCleanup := testEmbeddingService(t)
+ defer embedCleanup()
+
+ client, err := NewClient(Config{DB: nil}, embedSvc)
+ assert.Error(t, err)
+ assert.Nil(t, client)
+ assert.Contains(t, err.Error(), "database connection required")
+}
+
+func TestNewClient_NilEmbedding(t *testing.T) {
+ db, dbCleanup := testDB(t)
+ defer dbCleanup()
+
+ client, err := NewClient(Config{DB: db}, nil)
+ assert.Error(t, err)
+ assert.Nil(t, client)
+ assert.Contains(t, err.Error(), "embedding service required")
+}
+
+func TestClient_AddDocuments_Empty(t *testing.T) {
+ db, dbCleanup := testDB(t)
+ defer dbCleanup()
+
+ embedSvc, embedCleanup := testEmbeddingService(t)
+ defer embedCleanup()
+
+ client, err := NewClient(Config{DB: db}, embedSvc)
+ require.NoError(t, err)
+
+ err = client.AddDocuments(context.Background(), []Document{})
+ require.NoError(t, err)
+}
+
+func TestClient_AddDocuments_Single(t *testing.T) {
+ db, dbCleanup := testDB(t)
+ defer dbCleanup()
+
+ embedSvc, embedCleanup := testEmbeddingService(t)
+ defer embedCleanup()
+
+ client, err := NewClient(Config{DB: db}, embedSvc)
+ require.NoError(t, err)
+
+ docs := []Document{
+ {
+ ID: "obs-1-title",
+ Content: "This is a test observation about authentication.",
+ Metadata: map[string]any{
+ "sqlite_id": int64(1),
+ "doc_type": "observation",
+ "field_type": "title",
+ "project": "test-project",
+ "scope": "project",
+ },
+ },
+ }
+
+ err = client.AddDocuments(context.Background(), docs)
+ require.NoError(t, err)
+
+ // Verify document was inserted
+ var count int
+ err = db.QueryRow("SELECT COUNT(*) FROM vectors WHERE doc_id = ?", "obs-1-title").Scan(&count)
+ require.NoError(t, err)
+ assert.Equal(t, 1, count)
+}
+
+func TestClient_AddDocuments_Multiple(t *testing.T) {
+ db, dbCleanup := testDB(t)
+ defer dbCleanup()
+
+ embedSvc, embedCleanup := testEmbeddingService(t)
+ defer embedCleanup()
+
+ client, err := NewClient(Config{DB: db}, embedSvc)
+ require.NoError(t, err)
+
+ docs := []Document{
+ {
+ ID: "obs-1-title",
+ Content: "Authentication flow implementation.",
+ Metadata: map[string]any{
+ "sqlite_id": int64(1),
+ "doc_type": "observation",
+ "field_type": "title",
+ "project": "test-project",
+ "scope": "project",
+ },
+ },
+ {
+ ID: "obs-1-narrative",
+ Content: "We implemented JWT-based authentication.",
+ Metadata: map[string]any{
+ "sqlite_id": int64(1),
+ "doc_type": "observation",
+ "field_type": "narrative",
+ "project": "test-project",
+ "scope": "project",
+ },
+ },
+ {
+ ID: "obs-2-title",
+ Content: "Database optimization.",
+ Metadata: map[string]any{
+ "sqlite_id": int64(2),
+ "doc_type": "observation",
+ "field_type": "title",
+ "project": "test-project",
+ "scope": "global",
+ },
+ },
+ }
+
+ err = client.AddDocuments(context.Background(), docs)
+ require.NoError(t, err)
+
+ // Verify all documents were inserted
+ var count int
+ err = db.QueryRow("SELECT COUNT(*) FROM vectors").Scan(&count)
+ require.NoError(t, err)
+ assert.Equal(t, 3, count)
+}
+
+func TestClient_DeleteDocuments_Empty(t *testing.T) {
+ db, dbCleanup := testDB(t)
+ defer dbCleanup()
+
+ embedSvc, embedCleanup := testEmbeddingService(t)
+ defer embedCleanup()
+
+ client, err := NewClient(Config{DB: db}, embedSvc)
+ require.NoError(t, err)
+
+ err = client.DeleteDocuments(context.Background(), []string{})
+ require.NoError(t, err)
+}
+
+func TestClient_DeleteDocuments_Existing(t *testing.T) {
+ db, dbCleanup := testDB(t)
+ defer dbCleanup()
+
+ embedSvc, embedCleanup := testEmbeddingService(t)
+ defer embedCleanup()
+
+ client, err := NewClient(Config{DB: db}, embedSvc)
+ require.NoError(t, err)
+
+ // Add documents first
+ docs := []Document{
+ {
+ ID: "doc-1",
+ Content: "First document.",
+ Metadata: map[string]any{
+ "sqlite_id": int64(1),
+ "doc_type": "observation",
+ },
+ },
+ {
+ ID: "doc-2",
+ Content: "Second document.",
+ Metadata: map[string]any{
+ "sqlite_id": int64(2),
+ "doc_type": "observation",
+ },
+ },
+ }
+ err = client.AddDocuments(context.Background(), docs)
+ require.NoError(t, err)
+
+ // Delete one document
+ err = client.DeleteDocuments(context.Background(), []string{"doc-1"})
+ require.NoError(t, err)
+
+ // Verify only one remains
+ var count int
+ err = db.QueryRow("SELECT COUNT(*) FROM vectors").Scan(&count)
+ require.NoError(t, err)
+ assert.Equal(t, 1, count)
+}
+
+func TestClient_Query_Basic(t *testing.T) {
+ db, dbCleanup := testDB(t)
+ defer dbCleanup()
+
+ embedSvc, embedCleanup := testEmbeddingService(t)
+ defer embedCleanup()
+
+ client, err := NewClient(Config{DB: db}, embedSvc)
+ require.NoError(t, err)
+
+ // Add some test documents
+ docs := []Document{
+ {
+ ID: "obs-1",
+ Content: "Authentication and login security implementation.",
+ Metadata: map[string]any{
+ "sqlite_id": int64(1),
+ "doc_type": "observation",
+ "project": "test-project",
+ "scope": "project",
+ },
+ },
+ {
+ ID: "obs-2",
+ Content: "Database query optimization techniques.",
+ Metadata: map[string]any{
+ "sqlite_id": int64(2),
+ "doc_type": "observation",
+ "project": "test-project",
+ "scope": "project",
+ },
+ },
+ }
+ err = client.AddDocuments(context.Background(), docs)
+ require.NoError(t, err)
+
+ // Query for authentication-related content
+ results, err := client.Query(context.Background(), "login authentication", 10, nil)
+ require.NoError(t, err)
+
+ assert.NotEmpty(t, results)
+ assert.LessOrEqual(t, len(results), 10)
+
+ // First result should be the authentication document (higher similarity)
+ assert.Equal(t, "obs-1", results[0].ID)
+}
+
+func TestClient_Query_WithDocTypeFilter(t *testing.T) {
+ db, dbCleanup := testDB(t)
+ defer dbCleanup()
+
+ embedSvc, embedCleanup := testEmbeddingService(t)
+ defer embedCleanup()
+
+ client, err := NewClient(Config{DB: db}, embedSvc)
+ require.NoError(t, err)
+
+ // Add documents of different types
+ docs := []Document{
+ {
+ ID: "obs-1",
+ Content: "Test content for observation.",
+ Metadata: map[string]any{
+ "sqlite_id": int64(1),
+ "doc_type": "observation",
+ "project": "test-project",
+ },
+ },
+ {
+ ID: "summary-1",
+ Content: "Test content for summary.",
+ Metadata: map[string]any{
+ "sqlite_id": int64(10),
+ "doc_type": "session_summary",
+ "project": "test-project",
+ },
+ },
+ }
+ err = client.AddDocuments(context.Background(), docs)
+ require.NoError(t, err)
+
+ // Query with doc_type filter
+ where := map[string]any{"doc_type": "observation"}
+ results, err := client.Query(context.Background(), "test content", 10, where)
+ require.NoError(t, err)
+
+ // Should only return observation documents
+ for _, r := range results {
+ docType, _ := r.Metadata["doc_type"].(string)
+ assert.Equal(t, "observation", docType)
+ }
+}
+
+func TestClient_Query_WithProjectFilter(t *testing.T) {
+ db, dbCleanup := testDB(t)
+ defer dbCleanup()
+
+ embedSvc, embedCleanup := testEmbeddingService(t)
+ defer embedCleanup()
+
+ client, err := NewClient(Config{DB: db}, embedSvc)
+ require.NoError(t, err)
+
+ // Add documents from different projects
+ docs := []Document{
+ {
+ ID: "obs-1",
+ Content: "Project A authentication content.",
+ Metadata: map[string]any{
+ "sqlite_id": int64(1),
+ "doc_type": "observation",
+ "project": "project-a",
+ "scope": "project",
+ },
+ },
+ {
+ ID: "obs-2",
+ Content: "Project B database content.",
+ Metadata: map[string]any{
+ "sqlite_id": int64(2),
+ "doc_type": "observation",
+ "project": "project-b",
+ "scope": "project",
+ },
+ },
+ {
+ ID: "obs-3",
+ Content: "Global security best practices.",
+ Metadata: map[string]any{
+ "sqlite_id": int64(3),
+ "doc_type": "observation",
+ "project": "project-b",
+ "scope": "global",
+ },
+ },
+ }
+ err = client.AddDocuments(context.Background(), docs)
+ require.NoError(t, err)
+
+ // Query without project filter to verify all docs are there
+ results, err := client.Query(context.Background(), "authentication security", 10, nil)
+ require.NoError(t, err)
+ assert.NotEmpty(t, results, "Should find some results")
+}
+
+func TestClient_IsConnected(t *testing.T) {
+ db, dbCleanup := testDB(t)
+ defer dbCleanup()
+
+ embedSvc, embedCleanup := testEmbeddingService(t)
+ defer embedCleanup()
+
+ client, err := NewClient(Config{DB: db}, embedSvc)
+ require.NoError(t, err)
+
+ assert.True(t, client.IsConnected())
+}
+
+func TestClient_Close(t *testing.T) {
+ db, dbCleanup := testDB(t)
+ defer dbCleanup()
+
+ embedSvc, embedCleanup := testEmbeddingService(t)
+ defer embedCleanup()
+
+ client, err := NewClient(Config{DB: db}, embedSvc)
+ require.NoError(t, err)
+
+ err = client.Close()
+ require.NoError(t, err)
+}
+
+func TestConfig_Fields(t *testing.T) {
+ db, dbCleanup := testDB(t)
+ defer dbCleanup()
+
+ cfg := Config{DB: db}
+ assert.Equal(t, db, cfg.DB)
+}
+
+func TestClient_UpdateDocument_DeleteThenAdd(t *testing.T) {
+ db, dbCleanup := testDB(t)
+ defer dbCleanup()
+
+ embedSvc, embedCleanup := testEmbeddingService(t)
+ defer embedCleanup()
+
+ client, err := NewClient(Config{DB: db}, embedSvc)
+ require.NoError(t, err)
+
+ // Add document
+ docs1 := []Document{
+ {
+ ID: "doc-1",
+ Content: "Original content.",
+ Metadata: map[string]any{
+ "sqlite_id": int64(1),
+ "doc_type": "observation",
+ },
+ },
+ }
+ err = client.AddDocuments(context.Background(), docs1)
+ require.NoError(t, err)
+
+ // Delete then add with new content (proper update pattern)
+ err = client.DeleteDocuments(context.Background(), []string{"doc-1"})
+ require.NoError(t, err)
+
+ docs2 := []Document{
+ {
+ ID: "doc-1",
+ Content: "Updated content.",
+ Metadata: map[string]any{
+ "sqlite_id": int64(1),
+ "doc_type": "observation",
+ },
+ },
+ }
+ err = client.AddDocuments(context.Background(), docs2)
+ require.NoError(t, err)
+
+ // Should have exactly 1 document
+ var count int
+ err = db.QueryRow("SELECT COUNT(*) FROM vectors WHERE doc_id = ?", "doc-1").Scan(&count)
+ require.NoError(t, err)
+ assert.Equal(t, 1, count)
+}
+
+func TestClient_DeleteDocuments_NonExistent(t *testing.T) {
+ db, dbCleanup := testDB(t)
+ defer dbCleanup()
+
+ embedSvc, embedCleanup := testEmbeddingService(t)
+ defer embedCleanup()
+
+ client, err := NewClient(Config{DB: db}, embedSvc)
+ require.NoError(t, err)
+
+ // Deleting non-existent document should not error
+ err = client.DeleteDocuments(context.Background(), []string{"non-existent-id"})
+ require.NoError(t, err)
+}
diff --git a/internal/vector/sqlitevec/helpers_test.go b/internal/vector/sqlitevec/helpers_test.go
new file mode 100644
index 0000000..0e9b5ac
--- /dev/null
+++ b/internal/vector/sqlitevec/helpers_test.go
@@ -0,0 +1,574 @@
+package sqlitevec
+
+import (
+ "testing"
+
+ "github.com/stretchr/testify/assert"
+)
+
+func TestDocTypes(t *testing.T) {
+ assert.Equal(t, DocType("observation"), DocTypeObservation)
+ assert.Equal(t, DocType("session_summary"), DocTypeSessionSummary)
+ assert.Equal(t, DocType("user_prompt"), DocTypeUserPrompt)
+}
+
+func TestDocument_Fields(t *testing.T) {
+ doc := Document{
+ ID: "doc-123",
+ Content: "test content",
+ Metadata: map[string]any{
+ "key": "value",
+ },
+ }
+
+ assert.Equal(t, "doc-123", doc.ID)
+ assert.Equal(t, "test content", doc.Content)
+ assert.Equal(t, "value", doc.Metadata["key"])
+}
+
+func TestQueryResult_Fields(t *testing.T) {
+ result := QueryResult{
+ ID: "result-123",
+ Distance: 0.5,
+ Metadata: map[string]any{
+ "sqlite_id": float64(42),
+ },
+ }
+
+ assert.Equal(t, "result-123", result.ID)
+ assert.Equal(t, 0.5, result.Distance)
+ assert.Equal(t, float64(42), result.Metadata["sqlite_id"])
+}
+
+func TestBuildWhereFilter(t *testing.T) {
+ tests := []struct {
+ name string
+ docType DocType
+ project string
+ expected map[string]interface{}
+ }{
+ {
+ name: "empty_filters",
+ docType: "",
+ project: "",
+ expected: map[string]interface{}{},
+ },
+ {
+ name: "doc_type_only",
+ docType: DocTypeObservation,
+ project: "",
+ expected: map[string]interface{}{
+ "doc_type": "observation",
+ },
+ },
+ {
+ name: "project_only",
+ docType: "",
+ project: "my-project",
+ expected: map[string]interface{}{
+ "project": "my-project",
+ },
+ },
+ {
+ name: "both_filters",
+ docType: DocTypeSessionSummary,
+ project: "test-project",
+ expected: map[string]interface{}{
+ "doc_type": "session_summary",
+ "project": "test-project",
+ },
+ },
+ {
+ name: "user_prompt_type",
+ docType: DocTypeUserPrompt,
+ project: "prompt-project",
+ expected: map[string]interface{}{
+ "doc_type": "user_prompt",
+ "project": "prompt-project",
+ },
+ },
+ }
+
+ for _, tt := range tests {
+ t.Run(tt.name, func(t *testing.T) {
+ result := BuildWhereFilter(tt.docType, tt.project)
+ assert.Equal(t, tt.expected, result)
+ })
+ }
+}
+
+func TestExtractIDsByDocType_Empty(t *testing.T) {
+ results := []QueryResult{}
+ ids := ExtractIDsByDocType(results)
+
+ assert.Empty(t, ids.ObservationIDs)
+ assert.Empty(t, ids.SummaryIDs)
+ assert.Empty(t, ids.PromptIDs)
+}
+
+func TestExtractIDsByDocType_AllTypes(t *testing.T) {
+ results := []QueryResult{
+ {
+ ID: "obs-1",
+ Distance: 0.1,
+ Metadata: map[string]any{
+ "sqlite_id": float64(1),
+ "doc_type": "observation",
+ },
+ },
+ {
+ ID: "obs-2",
+ Distance: 0.2,
+ Metadata: map[string]any{
+ "sqlite_id": float64(2),
+ "doc_type": "observation",
+ },
+ },
+ {
+ ID: "summary-1",
+ Distance: 0.3,
+ Metadata: map[string]any{
+ "sqlite_id": float64(10),
+ "doc_type": "session_summary",
+ },
+ },
+ {
+ ID: "prompt-1",
+ Distance: 0.4,
+ Metadata: map[string]any{
+ "sqlite_id": float64(20),
+ "doc_type": "user_prompt",
+ },
+ },
+ }
+
+ ids := ExtractIDsByDocType(results)
+
+ assert.Equal(t, []int64{1, 2}, ids.ObservationIDs)
+ assert.Equal(t, []int64{10}, ids.SummaryIDs)
+ assert.Equal(t, []int64{20}, ids.PromptIDs)
+}
+
+func TestExtractIDsByDocType_Deduplication(t *testing.T) {
+ results := []QueryResult{
+ {
+ ID: "obs-1-field1",
+ Distance: 0.1,
+ Metadata: map[string]any{
+ "sqlite_id": float64(1),
+ "doc_type": "observation",
+ },
+ },
+ {
+ ID: "obs-1-field2",
+ Distance: 0.2,
+ Metadata: map[string]any{
+ "sqlite_id": float64(1), // Same ID, different field
+ "doc_type": "observation",
+ },
+ },
+ {
+ ID: "obs-2",
+ Distance: 0.3,
+ Metadata: map[string]any{
+ "sqlite_id": float64(2),
+ "doc_type": "observation",
+ },
+ },
+ }
+
+ ids := ExtractIDsByDocType(results)
+
+ assert.Equal(t, []int64{1, 2}, ids.ObservationIDs) // Should be deduplicated
+}
+
+func TestExtractIDsByDocType_Int64Fallback(t *testing.T) {
+ results := []QueryResult{
+ {
+ ID: "obs-1",
+ Distance: 0.1,
+ Metadata: map[string]any{
+ "sqlite_id": int64(42), // int64 instead of float64
+ "doc_type": "observation",
+ },
+ },
+ }
+
+ ids := ExtractIDsByDocType(results)
+
+ assert.Equal(t, []int64{42}, ids.ObservationIDs)
+}
+
+func TestExtractIDsByDocType_MissingSqliteID(t *testing.T) {
+ results := []QueryResult{
+ {
+ ID: "obs-1",
+ Distance: 0.1,
+ Metadata: map[string]any{
+ "doc_type": "observation",
+ // Missing sqlite_id
+ },
+ },
+ }
+
+ ids := ExtractIDsByDocType(results)
+
+ assert.Empty(t, ids.ObservationIDs)
+}
+
+func TestExtractIDsByDocType_UnknownType(t *testing.T) {
+ results := []QueryResult{
+ {
+ ID: "unknown-1",
+ Distance: 0.1,
+ Metadata: map[string]any{
+ "sqlite_id": float64(1),
+ "doc_type": "unknown_type",
+ },
+ },
+ }
+
+ ids := ExtractIDsByDocType(results)
+
+ // Should not be added to any category
+ assert.Empty(t, ids.ObservationIDs)
+ assert.Empty(t, ids.SummaryIDs)
+ assert.Empty(t, ids.PromptIDs)
+}
+
+func TestExtractObservationIDs_NoFilter(t *testing.T) {
+ results := []QueryResult{
+ {
+ ID: "obs-1",
+ Metadata: map[string]any{
+ "sqlite_id": float64(1),
+ "doc_type": "observation",
+ "project": "proj-a",
+ },
+ },
+ {
+ ID: "obs-2",
+ Metadata: map[string]any{
+ "sqlite_id": float64(2),
+ "doc_type": "observation",
+ "project": "proj-b",
+ },
+ },
+ {
+ ID: "summary-1",
+ Metadata: map[string]any{
+ "sqlite_id": float64(10),
+ "doc_type": "session_summary",
+ "project": "proj-a",
+ },
+ },
+ }
+
+ ids := ExtractObservationIDs(results, "")
+
+ assert.Equal(t, []int64{1, 2}, ids)
+}
+
+func TestExtractObservationIDs_WithProjectFilter(t *testing.T) {
+ results := []QueryResult{
+ {
+ ID: "obs-1",
+ Metadata: map[string]any{
+ "sqlite_id": float64(1),
+ "doc_type": "observation",
+ "project": "proj-a",
+ "scope": "project",
+ },
+ },
+ {
+ ID: "obs-2",
+ Metadata: map[string]any{
+ "sqlite_id": float64(2),
+ "doc_type": "observation",
+ "project": "proj-b",
+ "scope": "project",
+ },
+ },
+ {
+ ID: "obs-global",
+ Metadata: map[string]any{
+ "sqlite_id": float64(3),
+ "doc_type": "observation",
+ "project": "proj-b",
+ "scope": "global",
+ },
+ },
+ }
+
+ ids := ExtractObservationIDs(results, "proj-a")
+
+ // Should include proj-a and global scope observations
+ assert.Equal(t, []int64{1, 3}, ids)
+}
+
+func TestExtractObservationIDs_Deduplication(t *testing.T) {
+ results := []QueryResult{
+ {
+ ID: "obs-1-field1",
+ Metadata: map[string]any{
+ "sqlite_id": float64(1),
+ "doc_type": "observation",
+ },
+ },
+ {
+ ID: "obs-1-field2",
+ Metadata: map[string]any{
+ "sqlite_id": float64(1), // Same ID
+ "doc_type": "observation",
+ },
+ },
+ }
+
+ ids := ExtractObservationIDs(results, "")
+
+ assert.Equal(t, []int64{1}, ids)
+}
+
+func TestExtractSummaryIDs_NoFilter(t *testing.T) {
+ results := []QueryResult{
+ {
+ ID: "summary-1",
+ Metadata: map[string]any{
+ "sqlite_id": float64(10),
+ "doc_type": "session_summary",
+ "project": "proj-a",
+ },
+ },
+ {
+ ID: "summary-2",
+ Metadata: map[string]any{
+ "sqlite_id": float64(20),
+ "doc_type": "session_summary",
+ "project": "proj-b",
+ },
+ },
+ {
+ ID: "obs-1",
+ Metadata: map[string]any{
+ "sqlite_id": float64(1),
+ "doc_type": "observation",
+ },
+ },
+ }
+
+ ids := ExtractSummaryIDs(results, "")
+
+ assert.Equal(t, []int64{10, 20}, ids)
+}
+
+func TestExtractSummaryIDs_WithProjectFilter(t *testing.T) {
+ results := []QueryResult{
+ {
+ ID: "summary-1",
+ Metadata: map[string]any{
+ "sqlite_id": float64(10),
+ "doc_type": "session_summary",
+ "project": "proj-a",
+ },
+ },
+ {
+ ID: "summary-2",
+ Metadata: map[string]any{
+ "sqlite_id": float64(20),
+ "doc_type": "session_summary",
+ "project": "proj-b",
+ },
+ },
+ }
+
+ ids := ExtractSummaryIDs(results, "proj-a")
+
+ assert.Equal(t, []int64{10}, ids)
+}
+
+func TestExtractPromptIDs_NoFilter(t *testing.T) {
+ results := []QueryResult{
+ {
+ ID: "prompt-1",
+ Metadata: map[string]any{
+ "sqlite_id": float64(100),
+ "doc_type": "user_prompt",
+ "project": "proj-a",
+ },
+ },
+ {
+ ID: "prompt-2",
+ Metadata: map[string]any{
+ "sqlite_id": float64(200),
+ "doc_type": "user_prompt",
+ "project": "proj-b",
+ },
+ },
+ }
+
+ ids := ExtractPromptIDs(results, "")
+
+ assert.Equal(t, []int64{100, 200}, ids)
+}
+
+func TestExtractPromptIDs_WithProjectFilter(t *testing.T) {
+ results := []QueryResult{
+ {
+ ID: "prompt-1",
+ Metadata: map[string]any{
+ "sqlite_id": float64(100),
+ "doc_type": "user_prompt",
+ "project": "proj-a",
+ },
+ },
+ {
+ ID: "prompt-2",
+ Metadata: map[string]any{
+ "sqlite_id": float64(200),
+ "doc_type": "user_prompt",
+ "project": "proj-b",
+ },
+ },
+ }
+
+ ids := ExtractPromptIDs(results, "proj-b")
+
+ assert.Equal(t, []int64{200}, ids)
+}
+
+func TestCopyMetadata(t *testing.T) {
+ base := map[string]any{
+ "key1": "value1",
+ "key2": 42,
+ }
+
+ result := copyMetadata(base, "key3", "value3")
+
+ // Original should be unchanged
+ assert.Len(t, base, 2)
+
+ // Result should have new key
+ assert.Len(t, result, 3)
+ assert.Equal(t, "value1", result["key1"])
+ assert.Equal(t, 42, result["key2"])
+ assert.Equal(t, "value3", result["key3"])
+}
+
+func TestCopyMetadataMulti(t *testing.T) {
+ base := map[string]any{
+ "key1": "value1",
+ }
+ extra := map[string]any{
+ "key2": "value2",
+ "key3": "value3",
+ }
+
+ result := copyMetadataMulti(base, extra)
+
+ assert.Len(t, result, 3)
+ assert.Equal(t, "value1", result["key1"])
+ assert.Equal(t, "value2", result["key2"])
+ assert.Equal(t, "value3", result["key3"])
+}
+
+func TestJoinStrings(t *testing.T) {
+ tests := []struct {
+ name string
+ strs []string
+ sep string
+ expected string
+ }{
+ {
+ name: "empty_slice",
+ strs: []string{},
+ sep: ", ",
+ expected: "",
+ },
+ {
+ name: "single_element",
+ strs: []string{"one"},
+ sep: ", ",
+ expected: "one",
+ },
+ {
+ name: "multiple_elements",
+ strs: []string{"one", "two", "three"},
+ sep: ", ",
+ expected: "one, two, three",
+ },
+ {
+ name: "different_separator",
+ strs: []string{"a", "b", "c"},
+ sep: "-",
+ expected: "a-b-c",
+ },
+ {
+ name: "empty_separator",
+ strs: []string{"a", "b", "c"},
+ sep: "",
+ expected: "abc",
+ },
+ }
+
+ for _, tt := range tests {
+ t.Run(tt.name, func(t *testing.T) {
+ result := joinStrings(tt.strs, tt.sep)
+ assert.Equal(t, tt.expected, result)
+ })
+ }
+}
+
+func TestTruncateString(t *testing.T) {
+ tests := []struct {
+ name string
+ input string
+ maxLen int
+ expected string
+ }{
+ {
+ name: "shorter_than_max",
+ input: "hello",
+ maxLen: 10,
+ expected: "hello",
+ },
+ {
+ name: "equal_to_max",
+ input: "hello",
+ maxLen: 5,
+ expected: "hello",
+ },
+ {
+ name: "longer_than_max",
+ input: "hello world",
+ maxLen: 5,
+ expected: "hello...",
+ },
+ {
+ name: "empty_string",
+ input: "",
+ maxLen: 5,
+ expected: "",
+ },
+ {
+ name: "zero_max_length",
+ input: "hello",
+ maxLen: 0,
+ expected: "...",
+ },
+ }
+
+ for _, tt := range tests {
+ t.Run(tt.name, func(t *testing.T) {
+ result := truncateString(tt.input, tt.maxLen)
+ assert.Equal(t, tt.expected, result)
+ })
+ }
+}
+
+func TestExtractedIDs_Empty(t *testing.T) {
+ ids := &ExtractedIDs{}
+
+ assert.Nil(t, ids.ObservationIDs)
+ assert.Nil(t, ids.SummaryIDs)
+ assert.Nil(t, ids.PromptIDs)
+}
diff --git a/internal/worker/handlers_test.go b/internal/worker/handlers_test.go
index b86257a..f7f1b67 100644
--- a/internal/worker/handlers_test.go
+++ b/internal/worker/handlers_test.go
@@ -1279,3 +1279,849 @@ func TestObservationRequest_Fields(t *testing.T) {
assert.Equal(t, "Read", req.ToolName)
assert.Equal(t, "/home/user/project", req.CWD)
}
+
+// TestRetrievalStats_Fields tests RetrievalStats struct fields.
+func TestRetrievalStats_Fields(t *testing.T) {
+ stats := RetrievalStats{
+ TotalRequests: 100,
+ ObservationsServed: 500,
+ VerifiedStale: 10,
+ DeletedInvalid: 5,
+ SearchRequests: 80,
+ ContextInjections: 20,
+ }
+
+ assert.Equal(t, int64(100), stats.TotalRequests)
+ assert.Equal(t, int64(500), stats.ObservationsServed)
+ assert.Equal(t, int64(10), stats.VerifiedStale)
+ assert.Equal(t, int64(5), stats.DeletedInvalid)
+ assert.Equal(t, int64(80), stats.SearchRequests)
+ assert.Equal(t, int64(20), stats.ContextInjections)
+}
+
+// TestServiceConstants tests service configuration constants.
+func TestServiceConstants(t *testing.T) {
+ assert.Equal(t, 30*time.Second, DefaultHTTPTimeout)
+ assert.Equal(t, 50*time.Millisecond, ReadyPollInterval)
+ assert.Equal(t, 100, StaleQueueSize)
+ assert.Equal(t, 2*time.Second, QueueProcessInterval)
+}
+
+// TestClusterObservations_Empty tests clustering with empty slice.
+func TestClusterObservations_Empty(t *testing.T) {
+ observations := []*models.Observation{}
+ clustered := clusterObservations(observations, 0.4)
+ assert.Empty(t, clustered)
+}
+
+// TestClusterObservations_Single tests clustering with single observation.
+func TestClusterObservations_Single(t *testing.T) {
+ observations := []*models.Observation{
+ {
+ ID: 1,
+ Title: sql.NullString{String: "Test observation", Valid: true},
+ Narrative: sql.NullString{String: "Test content", Valid: true},
+ },
+ }
+ clustered := clusterObservations(observations, 0.4)
+ assert.Len(t, clustered, 1)
+}
+
+// TestClusterObservations_VeryDifferent tests clustering with very different observations.
+func TestClusterObservations_VeryDifferent(t *testing.T) {
+ observations := []*models.Observation{
+ {
+ ID: 1,
+ Title: sql.NullString{String: "Database optimization", Valid: true},
+ Narrative: sql.NullString{String: "PostgreSQL index tuning", Valid: true},
+ },
+ {
+ ID: 2,
+ Title: sql.NullString{String: "Authentication flow", Valid: true},
+ Narrative: sql.NullString{String: "JWT token validation", Valid: true},
+ },
+ {
+ ID: 3,
+ Title: sql.NullString{String: "Logging setup", Valid: true},
+ Narrative: sql.NullString{String: "Zerolog configuration", Valid: true},
+ },
+ }
+ clustered := clusterObservations(observations, 0.4)
+ // Very different observations should not be clustered together
+ assert.GreaterOrEqual(t, len(clustered), 1)
+}
+
+// TestHandleContextInject_WithLimit tests context inject with custom limit.
+func TestHandleContextInject_WithLimit(t *testing.T) {
+ svc, cleanup := testService(t)
+ defer cleanup()
+
+ project := "inject-limit-test"
+
+ // Create some observations
+ for i := 0; i < 10; i++ {
+ createTestObservation(t, svc.observationStore, project,
+ "Observation "+strconv.Itoa(i),
+ "Content "+strconv.Itoa(i),
+ []string{"test-" + strconv.Itoa(i)})
+ time.Sleep(time.Millisecond)
+ }
+
+ req := httptest.NewRequest(http.MethodGet, "/api/context/inject?project="+project+"&limit=5", nil)
+ rec := httptest.NewRecorder()
+
+ svc.router.ServeHTTP(rec, req)
+
+ assert.Equal(t, http.StatusOK, rec.Code)
+
+ var response map[string]interface{}
+ err := json.Unmarshal(rec.Body.Bytes(), &response)
+ require.NoError(t, err)
+
+ observations, ok := response["observations"].([]interface{})
+ require.True(t, ok)
+ assert.LessOrEqual(t, len(observations), 5)
+}
+
+// TestHandleGetObservations tests getting observations list.
+func TestHandleGetObservations(t *testing.T) {
+ svc, cleanup := testService(t)
+ defer cleanup()
+
+ // Create some observations
+ createTestObservation(t, svc.observationStore, "test-project",
+ "Test Observation 1",
+ "Test content 1",
+ []string{"test"})
+ createTestObservation(t, svc.observationStore, "test-project",
+ "Test Observation 2",
+ "Test content 2",
+ []string{"test"})
+
+ req := httptest.NewRequest(http.MethodGet, "/api/observations?project=test-project", nil)
+ rec := httptest.NewRecorder()
+
+ svc.router.ServeHTTP(rec, req)
+
+ assert.Equal(t, http.StatusOK, rec.Code)
+
+ var observations []map[string]interface{}
+ err := json.Unmarshal(rec.Body.Bytes(), &observations)
+ require.NoError(t, err)
+
+ assert.GreaterOrEqual(t, len(observations), 2)
+}
+
+// TestHandleGetObservations_Pagination tests observations pagination.
+func TestHandleGetObservations_Pagination(t *testing.T) {
+ svc, cleanup := testService(t)
+ defer cleanup()
+
+ // Create some observations
+ for i := 0; i < 5; i++ {
+ createTestObservation(t, svc.observationStore, "page-test",
+ "Observation "+strconv.Itoa(i),
+ "Content "+strconv.Itoa(i),
+ []string{"test"})
+ }
+
+ req := httptest.NewRequest(http.MethodGet, "/api/observations?project=page-test&limit=2", nil)
+ rec := httptest.NewRecorder()
+
+ svc.router.ServeHTTP(rec, req)
+
+ assert.Equal(t, http.StatusOK, rec.Code)
+}
+
+// TestHandleGetObservations_NoProject tests observations without project.
+func TestHandleGetObservations_NoProject(t *testing.T) {
+ svc, cleanup := testService(t)
+ defer cleanup()
+
+ req := httptest.NewRequest(http.MethodGet, "/api/observations", nil)
+ rec := httptest.NewRecorder()
+
+ svc.router.ServeHTTP(rec, req)
+
+ // Should still return 200 with empty results or all observations
+ assert.Contains(t, []int{http.StatusOK, http.StatusBadRequest}, rec.Code)
+}
+
+// TestHandleSearchByPrompt_EmptyQuery tests search with empty query parameter.
+func TestHandleSearchByPrompt_EmptyQuery(t *testing.T) {
+ svc, cleanup := testService(t)
+ defer cleanup()
+
+ req := httptest.NewRequest(http.MethodGet, "/api/context/search?project=test&query=", nil)
+ rec := httptest.NewRecorder()
+
+ svc.router.ServeHTTP(rec, req)
+
+ // Empty query should still be a bad request
+ assert.Equal(t, http.StatusBadRequest, rec.Code)
+}
+
+// TestHandleGetSessionByClaudeID tests getting session by Claude ID.
+func TestHandleGetSessionByClaudeID(t *testing.T) {
+ svc, cleanup := testService(t)
+ defer cleanup()
+
+ // Create a session
+ ctx := context.Background()
+ svc.sessionStore.CreateSDKSession(ctx, "claude-test-123", "project-a", "prompt 1")
+
+ // Test with valid claudeSessionId
+ req := httptest.NewRequest(http.MethodGet, "/api/sessions?claudeSessionId=claude-test-123", nil)
+ rec := httptest.NewRecorder()
+
+ svc.router.ServeHTTP(rec, req)
+
+ assert.Equal(t, http.StatusOK, rec.Code)
+}
+
+// TestHandleGetSessionByClaudeID_Missing tests session lookup with missing param.
+func TestHandleGetSessionByClaudeID_Missing(t *testing.T) {
+ svc, cleanup := testService(t)
+ defer cleanup()
+
+ req := httptest.NewRequest(http.MethodGet, "/api/sessions", nil)
+ rec := httptest.NewRecorder()
+
+ svc.router.ServeHTTP(rec, req)
+
+ assert.Equal(t, http.StatusBadRequest, rec.Code)
+}
+
+// TestHandleGetSessionByClaudeID_NotFound tests session not found.
+func TestHandleGetSessionByClaudeID_NotFound(t *testing.T) {
+ svc, cleanup := testService(t)
+ defer cleanup()
+
+ req := httptest.NewRequest(http.MethodGet, "/api/sessions?claudeSessionId=nonexistent", nil)
+ rec := httptest.NewRecorder()
+
+ svc.router.ServeHTTP(rec, req)
+
+ assert.Equal(t, http.StatusNotFound, rec.Code)
+}
+
+// TestGetRetrievalStats tests the retrieval stats getter.
+func TestGetRetrievalStats(t *testing.T) {
+ svc, cleanup := testService(t)
+ defer cleanup()
+
+ // Initially all zeros
+ stats := svc.GetRetrievalStats()
+ assert.Equal(t, int64(0), stats.TotalRequests)
+ assert.Equal(t, int64(0), stats.SearchRequests)
+
+ // Make some requests to increment stats
+ project := "stats-test"
+ createTestObservation(t, svc.observationStore, project, "Test", "Content", []string{"test"})
+
+ req := httptest.NewRequest(http.MethodGet, "/api/context/search?project="+project+"&query=test", nil)
+ rec := httptest.NewRecorder()
+ svc.router.ServeHTTP(rec, req)
+
+ // Stats should be updated
+ stats = svc.GetRetrievalStats()
+ assert.GreaterOrEqual(t, stats.TotalRequests, int64(1))
+}
+
+// TestSelfCheckResponse_Fields tests SelfCheckResponse struct fields.
+func TestSelfCheckResponse_Fields(t *testing.T) {
+ resp := SelfCheckResponse{
+ Overall: "healthy",
+ Version: "v1.0.0",
+ Uptime: "2h30m",
+ Components: []ComponentHealth{
+ {Name: "database", Status: "healthy", Message: "Connected"},
+ {Name: "vector", Status: "healthy", Message: "Ready"},
+ },
+ }
+
+ assert.Equal(t, "healthy", resp.Overall)
+ assert.Equal(t, "v1.0.0", resp.Version)
+ assert.Equal(t, "2h30m", resp.Uptime)
+ assert.Len(t, resp.Components, 2)
+ assert.Equal(t, "database", resp.Components[0].Name)
+ assert.Equal(t, "healthy", resp.Components[0].Status)
+ assert.Equal(t, "Connected", resp.Components[0].Message)
+}
+
+// TestComponentHealth_Fields tests ComponentHealth struct fields.
+func TestComponentHealth_Fields(t *testing.T) {
+ tests := []struct {
+ name string
+ status string
+ message string
+ }{
+ {"healthy", "healthy", "All systems operational"},
+ {"degraded", "degraded", "Some features unavailable"},
+ {"unhealthy", "unhealthy", "Service is down"},
+ }
+
+ for _, tt := range tests {
+ t.Run(tt.name, func(t *testing.T) {
+ health := ComponentHealth{
+ Status: tt.status,
+ Message: tt.message,
+ }
+ assert.Equal(t, tt.status, health.Status)
+ assert.Equal(t, tt.message, health.Message)
+ })
+ }
+}
+
+// TestWriteJSON_Error tests writeJSON with values that can't be encoded.
+func TestWriteJSON_Error(t *testing.T) {
+ rec := httptest.NewRecorder()
+
+ // channels can't be JSON encoded
+ ch := make(chan int)
+ writeJSON(rec, ch)
+
+ // Should still set content type but encoding will fail
+ assert.Equal(t, "application/json", rec.Header().Get("Content-Type"))
+}
+
+// TestHandleSummarize_InvalidSessionID tests summarize with invalid session ID.
+func TestHandleSummarize_InvalidSessionID(t *testing.T) {
+ svc, cleanup := testService(t)
+ defer cleanup()
+
+ req := httptest.NewRequest(http.MethodPost, "/sessions/invalid/summarize", bytes.NewReader([]byte("{}")))
+ req.Header.Set("Content-Type", "application/json")
+ rec := httptest.NewRecorder()
+
+ svc.router.ServeHTTP(rec, req)
+
+ // Invalid session ID should return 400
+ assert.Equal(t, http.StatusBadRequest, rec.Code)
+}
+
+// TestHandleSubagentComplete tests subagent completion endpoint.
+func TestHandleSubagentComplete(t *testing.T) {
+ svc, cleanup := testService(t)
+ defer cleanup()
+
+ // Create a session first
+ ctx := context.Background()
+ sessionID, _ := svc.sessionStore.CreateSDKSession(ctx, "subagent-test-123", "test-project", "test prompt")
+
+ payload := `{"session_id": ` + strconv.FormatInt(sessionID, 10) + `, "parent_session_id": "parent-123"}`
+ req := httptest.NewRequest(http.MethodPost, "/api/sessions/subagent-complete", bytes.NewReader([]byte(payload)))
+ req.Header.Set("Content-Type", "application/json")
+ rec := httptest.NewRecorder()
+
+ svc.router.ServeHTTP(rec, req)
+
+ // Should accept the request
+ assert.Contains(t, []int{http.StatusOK, http.StatusNotFound, http.StatusBadRequest}, rec.Code)
+}
+
+// TestHandleContextSearch_Ordering tests search with different orderings.
+func TestHandleContextSearch_Ordering(t *testing.T) {
+ svc, cleanup := testService(t)
+ defer cleanup()
+
+ project := "order-test"
+
+ for i := 0; i < 5; i++ {
+ createTestObservation(t, svc.observationStore, project,
+ "Obs "+strconv.Itoa(i),
+ "Content "+strconv.Itoa(i),
+ []string{"test"})
+ time.Sleep(time.Millisecond)
+ }
+
+ tests := []struct {
+ name string
+ order string
+ }{
+ {"date_desc", "date_desc"},
+ {"date_asc", "date_asc"},
+ {"default", ""}, // Should default to date_desc
+ }
+
+ for _, tt := range tests {
+ t.Run(tt.name, func(t *testing.T) {
+ url := "/api/context/search?project=" + project + "&query=test"
+ if tt.order != "" {
+ url += "&order_by=" + tt.order
+ }
+
+ req := httptest.NewRequest(http.MethodGet, url, nil)
+ rec := httptest.NewRecorder()
+
+ svc.router.ServeHTTP(rec, req)
+
+ assert.Equal(t, http.StatusOK, rec.Code)
+ })
+ }
+}
+
+// TestHandleContextCount_NoProject tests context count without project.
+func TestHandleContextCount_NoProject(t *testing.T) {
+ svc, cleanup := testService(t)
+ defer cleanup()
+
+ req := httptest.NewRequest(http.MethodGet, "/api/context/count", nil)
+ rec := httptest.NewRecorder()
+
+ svc.router.ServeHTTP(rec, req)
+
+ assert.Equal(t, http.StatusBadRequest, rec.Code)
+}
+
+// TestHandleRetrievalStatsEndpoint tests retrieval stats endpoint.
+func TestHandleRetrievalStatsEndpoint(t *testing.T) {
+ svc, cleanup := testService(t)
+ defer cleanup()
+
+ req := httptest.NewRequest(http.MethodGet, "/api/stats/retrieval", nil)
+ rec := httptest.NewRecorder()
+
+ svc.router.ServeHTTP(rec, req)
+
+ assert.Equal(t, http.StatusOK, rec.Code)
+
+ var stats RetrievalStats
+ err := json.Unmarshal(rec.Body.Bytes(), &stats)
+ require.NoError(t, err)
+}
+
+// TestHandleReadyEndpoint tests ready endpoint response.
+func TestHandleReadyEndpoint(t *testing.T) {
+ svc, cleanup := testService(t)
+ defer cleanup()
+
+ req := httptest.NewRequest(http.MethodGet, "/api/ready", nil)
+ rec := httptest.NewRecorder()
+
+ svc.router.ServeHTTP(rec, req)
+
+ // Ready may return OK or ServiceUnavailable depending on state
+ assert.Contains(t, []int{http.StatusOK, http.StatusServiceUnavailable}, rec.Code)
+}
+
+// TestHandleSessionInit_EmptyBody tests session init with empty body.
+func TestHandleSessionInit_EmptyBody(t *testing.T) {
+ svc, cleanup := testService(t)
+ defer cleanup()
+
+ payload := `{}`
+ req := httptest.NewRequest(http.MethodPost, "/api/sessions/init", bytes.NewReader([]byte(payload)))
+ req.Header.Set("Content-Type", "application/json")
+ rec := httptest.NewRecorder()
+
+ svc.router.ServeHTTP(rec, req)
+
+ // Should accept empty body and create a session
+ assert.Contains(t, []int{http.StatusOK, http.StatusBadRequest}, rec.Code)
+}
+
+// TestHandleObservation_MissingSession tests observation without session.
+func TestHandleObservation_MissingSession(t *testing.T) {
+ svc, cleanup := testService(t)
+ defer cleanup()
+
+ payload := `{"session_id": 99999, "tool_name": "Read", "tool_input": "{}", "tool_output": "test"}`
+ req := httptest.NewRequest(http.MethodPost, "/api/sessions/observations", bytes.NewReader([]byte(payload)))
+ req.Header.Set("Content-Type", "application/json")
+ rec := httptest.NewRecorder()
+
+ svc.router.ServeHTTP(rec, req)
+
+ // Should still accept the observation
+ assert.Contains(t, []int{http.StatusOK, http.StatusNotFound, http.StatusBadRequest}, rec.Code)
+}
+
+// TestHandleSummaries_Pagination tests summaries pagination.
+func TestHandleSummaries_Pagination(t *testing.T) {
+ svc, cleanup := testService(t)
+ defer cleanup()
+
+ req := httptest.NewRequest(http.MethodGet, "/api/summaries?project=test&limit=10&offset=0", nil)
+ rec := httptest.NewRecorder()
+
+ svc.router.ServeHTTP(rec, req)
+
+ assert.Equal(t, http.StatusOK, rec.Code)
+}
+
+// TestHandlePrompts_Pagination tests prompts pagination.
+func TestHandlePrompts_Pagination(t *testing.T) {
+ svc, cleanup := testService(t)
+ defer cleanup()
+
+ req := httptest.NewRequest(http.MethodGet, "/api/prompts?project=test&limit=10&offset=0", nil)
+ rec := httptest.NewRecorder()
+
+ svc.router.ServeHTTP(rec, req)
+
+ assert.Equal(t, http.StatusOK, rec.Code)
+}
+
+// TestHandleGetStats_AllProjects tests stats without project filter.
+func TestHandleGetStats_AllProjects(t *testing.T) {
+ svc, cleanup := testService(t)
+ defer cleanup()
+
+ // Create some data in multiple projects
+ createTestObservation(t, svc.observationStore, "proj-a", "Test A", "Content", []string{"test"})
+ createTestObservation(t, svc.observationStore, "proj-b", "Test B", "Content", []string{"test"})
+
+ req := httptest.NewRequest(http.MethodGet, "/api/stats", nil)
+ rec := httptest.NewRecorder()
+
+ svc.router.ServeHTTP(rec, req)
+
+ assert.Equal(t, http.StatusOK, rec.Code)
+}
+
+// TestHandleSubagentComplete_WithSession tests subagent completion with existing session.
+func TestHandleSubagentComplete_WithSession(t *testing.T) {
+ svc, cleanup := testService(t)
+ defer cleanup()
+
+ // Create a session first
+ ctx := context.Background()
+ svc.sessionStore.CreateSDKSession(ctx, "subagent-claude-123", "test-project", "test prompt")
+
+ reqBody := SubagentCompleteRequest{
+ ClaudeSessionID: "subagent-claude-123",
+ Project: "test-project",
+ }
+
+ body, _ := json.Marshal(reqBody)
+ req := httptest.NewRequest(http.MethodPost, "/api/sessions/subagent-complete", bytes.NewReader(body))
+ req.Header.Set("Content-Type", "application/json")
+ rec := httptest.NewRecorder()
+
+ svc.router.ServeHTTP(rec, req)
+
+ assert.Equal(t, http.StatusOK, rec.Code)
+}
+
+// TestHandleSubagentComplete_NoSession tests subagent completion when session doesn't exist.
+func TestHandleSubagentComplete_NoSession(t *testing.T) {
+ svc, cleanup := testService(t)
+ defer cleanup()
+
+ reqBody := SubagentCompleteRequest{
+ ClaudeSessionID: "nonexistent-session",
+ Project: "test-project",
+ }
+
+ body, _ := json.Marshal(reqBody)
+ req := httptest.NewRequest(http.MethodPost, "/api/sessions/subagent-complete", bytes.NewReader(body))
+ req.Header.Set("Content-Type", "application/json")
+ rec := httptest.NewRecorder()
+
+ svc.router.ServeHTTP(rec, req)
+
+ // Should still return 200 even if session not found
+ assert.Equal(t, http.StatusOK, rec.Code)
+}
+
+// TestHandleSubagentComplete_InvalidJSON tests subagent completion with invalid JSON.
+func TestHandleSubagentComplete_InvalidJSON(t *testing.T) {
+ svc, cleanup := testService(t)
+ defer cleanup()
+
+ req := httptest.NewRequest(http.MethodPost, "/api/sessions/subagent-complete", bytes.NewReader([]byte("invalid json")))
+ req.Header.Set("Content-Type", "application/json")
+ rec := httptest.NewRecorder()
+
+ svc.router.ServeHTTP(rec, req)
+
+ assert.Equal(t, http.StatusBadRequest, rec.Code)
+}
+
+// TestHandleSummarize_ValidSession tests summarize with valid session.
+func TestHandleSummarize_ValidSession(t *testing.T) {
+ svc, cleanup := testService(t)
+ defer cleanup()
+
+ // Create a session first
+ ctx := context.Background()
+ sessionID, _ := svc.sessionStore.CreateSDKSession(ctx, "summarize-claude-test", "test-project", "test prompt")
+
+ reqBody := SummarizeRequest{
+ LastUserMessage: "Can you help me fix this bug?",
+ LastAssistantMessage: "I've analyzed the code and fixed the issue in the handler.",
+ }
+
+ body, _ := json.Marshal(reqBody)
+ req := httptest.NewRequest(http.MethodPost, "/sessions/"+strconv.FormatInt(sessionID, 10)+"/summarize", bytes.NewReader(body))
+ req.Header.Set("Content-Type", "application/json")
+ rec := httptest.NewRecorder()
+
+ svc.router.ServeHTTP(rec, req)
+
+ assert.Equal(t, http.StatusOK, rec.Code)
+}
+
+// TestHandleSummarize_InvalidJSON tests summarize with invalid JSON.
+func TestHandleSummarize_InvalidJSON(t *testing.T) {
+ svc, cleanup := testService(t)
+ defer cleanup()
+
+ ctx := context.Background()
+ sessionID, _ := svc.sessionStore.CreateSDKSession(ctx, "summarize-invalid", "test-project", "test")
+
+ req := httptest.NewRequest(http.MethodPost, "/sessions/"+strconv.FormatInt(sessionID, 10)+"/summarize", bytes.NewReader([]byte("not valid json")))
+ req.Header.Set("Content-Type", "application/json")
+ rec := httptest.NewRecorder()
+
+ svc.router.ServeHTTP(rec, req)
+
+ assert.Equal(t, http.StatusBadRequest, rec.Code)
+}
+
+// TestHandleSummarize_NonExistentSession tests summarize with non-existent session.
+func TestHandleSummarize_NonExistentSession(t *testing.T) {
+ svc, cleanup := testService(t)
+ defer cleanup()
+
+ reqBody := SummarizeRequest{
+ LastUserMessage: "test",
+ LastAssistantMessage: "test",
+ }
+
+ body, _ := json.Marshal(reqBody)
+ req := httptest.NewRequest(http.MethodPost, "/sessions/999999/summarize", bytes.NewReader(body))
+ req.Header.Set("Content-Type", "application/json")
+ rec := httptest.NewRecorder()
+
+ svc.router.ServeHTTP(rec, req)
+
+ // Should return error for non-existent session
+ assert.Contains(t, []int{http.StatusOK, http.StatusInternalServerError, http.StatusNotFound}, rec.Code)
+}
+
+// TestSubagentCompleteRequest_Fields tests SubagentCompleteRequest struct.
+func TestSubagentCompleteRequest_Fields(t *testing.T) {
+ req := SubagentCompleteRequest{
+ ClaudeSessionID: "test-session-123",
+ Project: "my-project",
+ }
+
+ assert.Equal(t, "test-session-123", req.ClaudeSessionID)
+ assert.Equal(t, "my-project", req.Project)
+}
+
+// TestSummarizeRequest_Fields tests SummarizeRequest struct.
+func TestSummarizeRequest_Fields(t *testing.T) {
+ req := SummarizeRequest{
+ LastUserMessage: "Help me fix this bug",
+ LastAssistantMessage: "I've fixed the authentication issue",
+ }
+
+ assert.Equal(t, "Help me fix this bug", req.LastUserMessage)
+ assert.Equal(t, "I've fixed the authentication issue", req.LastAssistantMessage)
+}
+
+// TestHandleHealth_NotReady tests health endpoint when not ready.
+func TestHandleHealth_NotReady(t *testing.T) {
+ svc, cleanup := testService(t)
+ defer cleanup()
+
+ svc.ready.Store(false)
+
+ req := httptest.NewRequest(http.MethodGet, "/api/health", nil)
+ rec := httptest.NewRecorder()
+
+ svc.handleHealth(rec, req)
+
+ assert.Equal(t, http.StatusOK, rec.Code)
+
+ var response map[string]interface{}
+ err := json.Unmarshal(rec.Body.Bytes(), &response)
+ require.NoError(t, err)
+
+ assert.Equal(t, "starting", response["status"])
+}
+
+// TestHandleContextInject_EmptyProject tests context inject with empty project.
+func TestHandleContextInject_EmptyProject(t *testing.T) {
+ svc, cleanup := testService(t)
+ defer cleanup()
+
+ req := httptest.NewRequest(http.MethodGet, "/api/context/inject?project=", nil)
+ rec := httptest.NewRecorder()
+
+ svc.router.ServeHTTP(rec, req)
+
+ assert.Equal(t, http.StatusBadRequest, rec.Code)
+}
+
+// TestHandleSearchByPrompt_LargeLimit tests search with limit exceeding max.
+func TestHandleSearchByPrompt_LargeLimit(t *testing.T) {
+ svc, cleanup := testService(t)
+ defer cleanup()
+
+ project := "limit-test"
+ createTestObservation(t, svc.observationStore, project, "Test", "Content", []string{"test"})
+
+ req := httptest.NewRequest(http.MethodGet, "/api/context/search?project="+project+"&query=test&limit=999", nil)
+ rec := httptest.NewRecorder()
+
+ svc.router.ServeHTTP(rec, req)
+
+ assert.Equal(t, http.StatusOK, rec.Code)
+}
+
+// TestHandleObservation_WithFullData tests observation with all fields.
+func TestHandleObservation_WithFullData(t *testing.T) {
+ svc, cleanup := testService(t)
+ defer cleanup()
+
+ // Create a session first
+ ctx := context.Background()
+ svc.sessionStore.CreateSDKSession(ctx, "obs-full-test", "test-project", "test prompt")
+
+ reqBody := ObservationRequest{
+ ClaudeSessionID: "obs-full-test",
+ Project: "test-project",
+ ToolName: "Edit",
+ ToolInput: map[string]interface{}{"file_path": "/test.go", "old_string": "foo", "new_string": "bar"},
+ ToolResponse: "Edit successful",
+ CWD: "/home/user/project",
+ }
+
+ body, _ := json.Marshal(reqBody)
+ req := httptest.NewRequest(http.MethodPost, "/api/sessions/observations", bytes.NewReader(body))
+ req.Header.Set("Content-Type", "application/json")
+ rec := httptest.NewRecorder()
+
+ svc.router.ServeHTTP(rec, req)
+
+ assert.Equal(t, http.StatusOK, rec.Code)
+}
+
+// TestHandleSelfCheck_WithObservations tests self-check with observations in DB.
+func TestHandleSelfCheck_WithObservations(t *testing.T) {
+ svc, cleanup := testService(t)
+ defer cleanup()
+
+ svc.ready.Store(true)
+
+ // Create some observations
+ createTestObservation(t, svc.observationStore, "check-project", "Test", "Content", []string{"test"})
+
+ req := httptest.NewRequest(http.MethodGet, "/api/self-check", nil)
+ rec := httptest.NewRecorder()
+
+ svc.handleSelfCheck(rec, req)
+
+ assert.Equal(t, http.StatusOK, rec.Code)
+
+ var response SelfCheckResponse
+ err := json.Unmarshal(rec.Body.Bytes(), &response)
+ require.NoError(t, err)
+
+ // Check components are populated
+ assert.NotEmpty(t, response.Components)
+}
+
+// TestHandleGetSummaries_NoProject tests getting summaries without project filter.
+func TestHandleGetSummaries_NoProject(t *testing.T) {
+ svc, cleanup := testService(t)
+ defer cleanup()
+
+ // Create some summaries in different projects
+ ctx := context.Background()
+ for i := 0; i < 3; i++ {
+ parsed := &models.ParsedSummary{Request: "Request " + string(rune('A'+i))}
+ svc.summaryStore.StoreSummary(ctx, "sdk-"+string(rune('a'+i)), "project-"+string(rune('a'+i)), parsed, i+1, 100)
+ }
+
+ req := httptest.NewRequest(http.MethodGet, "/api/summaries", nil)
+ rec := httptest.NewRecorder()
+
+ svc.router.ServeHTTP(rec, req)
+
+ assert.Equal(t, http.StatusOK, rec.Code)
+
+ var summaries []map[string]interface{}
+ err := json.Unmarshal(rec.Body.Bytes(), &summaries)
+ require.NoError(t, err)
+
+ // Should return all summaries
+ assert.GreaterOrEqual(t, len(summaries), 3)
+}
+
+// TestHandleGetPrompts_NoProject tests getting prompts without project filter.
+func TestHandleGetPrompts_NoProject(t *testing.T) {
+ svc, cleanup := testService(t)
+ defer cleanup()
+
+ // Create sessions and prompts in different projects
+ ctx := context.Background()
+ svc.sessionStore.CreateSDKSession(ctx, "claude-prompts-a", "project-a", "")
+ svc.sessionStore.CreateSDKSession(ctx, "claude-prompts-b", "project-b", "")
+
+ svc.promptStore.SaveUserPromptWithMatches(ctx, "claude-prompts-a", 1, "Prompt A", 0)
+ svc.promptStore.SaveUserPromptWithMatches(ctx, "claude-prompts-b", 1, "Prompt B", 0)
+
+ req := httptest.NewRequest(http.MethodGet, "/api/prompts", nil)
+ rec := httptest.NewRecorder()
+
+ svc.router.ServeHTTP(rec, req)
+
+ assert.Equal(t, http.StatusOK, rec.Code)
+
+ var prompts []map[string]interface{}
+ err := json.Unmarshal(rec.Body.Bytes(), &prompts)
+ require.NoError(t, err)
+
+ assert.GreaterOrEqual(t, len(prompts), 2)
+}
+
+// TestHandleSessionInit_MissingClaudeID tests session init without Claude ID.
+func TestHandleSessionInit_MissingClaudeID(t *testing.T) {
+ svc, cleanup := testService(t)
+ defer cleanup()
+
+ reqBody := SessionInitRequest{
+ Project: "test-project",
+ Prompt: "Help me",
+ }
+
+ body, _ := json.Marshal(reqBody)
+ req := httptest.NewRequest(http.MethodPost, "/api/sessions/init", bytes.NewReader(body))
+ req.Header.Set("Content-Type", "application/json")
+ rec := httptest.NewRecorder()
+
+ svc.router.ServeHTTP(rec, req)
+
+ // Should accept even without Claude ID (may auto-generate)
+ assert.Contains(t, []int{http.StatusOK, http.StatusBadRequest}, rec.Code)
+}
+
+// TestHandleContextInject_WithQuery tests context inject with query parameter.
+func TestHandleContextInject_WithQuery(t *testing.T) {
+ svc, cleanup := testService(t)
+ defer cleanup()
+
+ project := "inject-query-test"
+ createTestObservation(t, svc.observationStore, project, "Authentication bug fix", "Fixed JWT validation", []string{"auth", "jwt"})
+ createTestObservation(t, svc.observationStore, project, "Database optimization", "Added indexes", []string{"db", "performance"})
+
+ req := httptest.NewRequest(http.MethodGet, "/api/context/inject?project="+project+"&query=authentication", nil)
+ rec := httptest.NewRecorder()
+
+ svc.router.ServeHTTP(rec, req)
+
+ assert.Equal(t, http.StatusOK, rec.Code)
+
+ var response map[string]interface{}
+ err := json.Unmarshal(rec.Body.Bytes(), &response)
+ require.NoError(t, err)
+
+ observations, ok := response["observations"].([]interface{})
+ require.True(t, ok)
+ assert.GreaterOrEqual(t, len(observations), 1)
+}
diff --git a/internal/worker/sdk/processor_test.go b/internal/worker/sdk/processor_test.go
index 66059fa..de0c2b5 100644
--- a/internal/worker/sdk/processor_test.go
+++ b/internal/worker/sdk/processor_test.go
@@ -1,9 +1,12 @@
package sdk
import (
+ "os"
+ "path/filepath"
"testing"
"github.com/lukaszraczylo/claude-mnemonic/pkg/models"
+ "github.com/stretchr/testify/assert"
)
func TestIsSelfReferentialSummary(t *testing.T) {
@@ -124,3 +127,381 @@ No substantive work performed yet.`,
})
}
}
+
+func TestShouldSkipTool(t *testing.T) {
+ tests := []struct {
+ name string
+ toolName string
+ expected bool
+ }{
+ // Tools that should be skipped
+ {"TodoWrite", "TodoWrite", true},
+ {"Task", "Task", true},
+ {"TaskOutput", "TaskOutput", true},
+ {"Glob", "Glob", true},
+ {"ListDir", "ListDir", true},
+ {"LS", "LS", true},
+ {"KillShell", "KillShell", true},
+ {"AskUserQuestion", "AskUserQuestion", true},
+ {"EnterPlanMode", "EnterPlanMode", true},
+ {"ExitPlanMode", "ExitPlanMode", true},
+ {"Skill", "Skill", true},
+ {"SlashCommand", "SlashCommand", true},
+
+ // Tools that should NOT be skipped
+ {"Read", "Read", false},
+ {"Edit", "Edit", false},
+ {"Write", "Write", false},
+ {"Grep", "Grep", false},
+ {"Bash", "Bash", false},
+ {"WebFetch", "WebFetch", false},
+ {"WebSearch", "WebSearch", false},
+ {"NotebookEdit", "NotebookEdit", false},
+
+ // Unknown tool (should not be skipped)
+ {"UnknownTool", "SomeUnknownTool", false},
+ }
+
+ for _, tt := range tests {
+ t.Run(tt.name, func(t *testing.T) {
+ result := shouldSkipTool(tt.toolName)
+ assert.Equal(t, tt.expected, result)
+ })
+ }
+}
+
+func TestShouldSkipTrivialOperation(t *testing.T) {
+ tests := []struct {
+ name string
+ toolName string
+ inputStr string
+ outputStr string
+ expected bool
+ }{
+ // Short output (should be skipped)
+ {
+ name: "output_too_short",
+ toolName: "Read",
+ inputStr: `{"file_path": "/some/file.go"}`,
+ outputStr: "short",
+ expected: true,
+ },
+ // Trivial outputs
+ {
+ name: "no_matches_found",
+ toolName: "Grep",
+ inputStr: `{"pattern": "foo"}`,
+ outputStr: "No matches found in the codebase",
+ expected: true,
+ },
+ {
+ name: "file_not_found",
+ toolName: "Read",
+ inputStr: `{"file_path": "/nonexistent.go"}`,
+ outputStr: "Error: File not found at specified path",
+ expected: true,
+ },
+ {
+ name: "empty_array",
+ toolName: "Grep",
+ inputStr: `{"pattern": "foo"}`,
+ outputStr: "[]",
+ expected: true,
+ },
+ // Boring files
+ {
+ name: "package_lock_json",
+ toolName: "Read",
+ inputStr: `{"file_path": "/project/package-lock.json"}`,
+ outputStr: "This is a very long package-lock.json content that has more than 50 characters",
+ expected: true,
+ },
+ {
+ name: "go_sum",
+ toolName: "Read",
+ inputStr: `{"file_path": "/project/go.sum"}`,
+ outputStr: "This is a very long go.sum file content that has more than 50 characters",
+ expected: true,
+ },
+ // Grep with too many matches
+ {
+ name: "grep_too_many_matches",
+ toolName: "Grep",
+ inputStr: `{"pattern": "import"}`,
+ outputStr: func() string {
+ s := ""
+ for i := 0; i < 55; i++ {
+ s += "match line\n"
+ }
+ return s
+ }(),
+ expected: true,
+ },
+ // Boring Bash commands
+ {
+ name: "git_status",
+ toolName: "Bash",
+ inputStr: `{"command": "git status"}`,
+ outputStr: "On branch main\nYour branch is up to date with 'origin/main'.\nnothing to commit, working tree clean",
+ expected: true,
+ },
+ {
+ name: "ls_command",
+ toolName: "Bash",
+ inputStr: `{"command": "ls -la /some/directory"}`,
+ outputStr: "total 123\ndrwxr-xr-x some long listing that is at least 50 chars",
+ expected: true,
+ },
+ // Valid operations that should NOT be skipped
+ {
+ name: "valid_read",
+ toolName: "Read",
+ inputStr: `{"file_path": "/project/main.go"}`,
+ outputStr: "package main\n\nimport \"fmt\"\n\nfunc main() {\n\tfmt.Println(\"Hello World\")\n}",
+ expected: false,
+ },
+ {
+ name: "valid_edit",
+ toolName: "Edit",
+ inputStr: `{"file_path": "/project/handler.go", "old_string": "foo", "new_string": "bar"}`,
+ outputStr: "Edit applied successfully. File /project/handler.go has been modified with the requested changes.",
+ expected: false,
+ },
+ }
+
+ for _, tt := range tests {
+ t.Run(tt.name, func(t *testing.T) {
+ result := shouldSkipTrivialOperation(tt.toolName, tt.inputStr, tt.outputStr)
+ assert.Equal(t, tt.expected, result)
+ })
+ }
+}
+
+func TestTruncateForLog(t *testing.T) {
+ tests := []struct {
+ name string
+ input string
+ maxLen int
+ expected string
+ }{
+ {
+ name: "shorter_than_max",
+ input: "hello",
+ maxLen: 10,
+ expected: "hello",
+ },
+ {
+ name: "equal_to_max",
+ input: "hello",
+ maxLen: 5,
+ expected: "hello",
+ },
+ {
+ name: "longer_than_max",
+ input: "hello world",
+ maxLen: 5,
+ expected: "hello...",
+ },
+ {
+ name: "empty_string",
+ input: "",
+ maxLen: 5,
+ expected: "",
+ },
+ }
+
+ for _, tt := range tests {
+ t.Run(tt.name, func(t *testing.T) {
+ result := truncateForLog(tt.input, tt.maxLen)
+ assert.Equal(t, tt.expected, result)
+ })
+ }
+}
+
+func TestToJSONString(t *testing.T) {
+ tests := []struct {
+ name string
+ input interface{}
+ expected string
+ }{
+ {
+ name: "nil_value",
+ input: nil,
+ expected: "",
+ },
+ {
+ name: "string_value",
+ input: "hello",
+ expected: "hello",
+ },
+ {
+ name: "int_value",
+ input: 42,
+ expected: "42",
+ },
+ {
+ name: "map_value",
+ input: map[string]string{"key": "value"},
+ expected: `{"key":"value"}`,
+ },
+ {
+ name: "slice_value",
+ input: []string{"a", "b", "c"},
+ expected: `["a","b","c"]`,
+ },
+ }
+
+ for _, tt := range tests {
+ t.Run(tt.name, func(t *testing.T) {
+ result := toJSONString(tt.input)
+ assert.Equal(t, tt.expected, result)
+ })
+ }
+}
+
+func TestCaptureFileMtimes(t *testing.T) {
+ // Create a temp directory with test files
+ tmpDir, err := os.MkdirTemp("", "mtime-test-*")
+ if err != nil {
+ t.Fatal(err)
+ }
+ defer os.RemoveAll(tmpDir)
+
+ // Create test files
+ file1 := filepath.Join(tmpDir, "file1.txt")
+ file2 := filepath.Join(tmpDir, "file2.txt")
+
+ err = os.WriteFile(file1, []byte("content1"), 0644)
+ if err != nil {
+ t.Fatal(err)
+ }
+ err = os.WriteFile(file2, []byte("content2"), 0644)
+ if err != nil {
+ t.Fatal(err)
+ }
+
+ t.Run("captures_mtimes_for_existing_files", func(t *testing.T) {
+ mtimes := captureFileMtimes([]string{file1}, []string{file2}, "")
+ assert.Len(t, mtimes, 2)
+ assert.Contains(t, mtimes, file1)
+ assert.Contains(t, mtimes, file2)
+ assert.Greater(t, mtimes[file1], int64(0))
+ assert.Greater(t, mtimes[file2], int64(0))
+ })
+
+ t.Run("handles_nonexistent_files", func(t *testing.T) {
+ mtimes := captureFileMtimes([]string{"/nonexistent/file.txt"}, nil, "")
+ assert.Empty(t, mtimes)
+ })
+
+ t.Run("handles_relative_paths_with_cwd", func(t *testing.T) {
+ mtimes := captureFileMtimes([]string{"file1.txt"}, nil, tmpDir)
+ assert.Len(t, mtimes, 1)
+ assert.Contains(t, mtimes, "file1.txt")
+ })
+
+ t.Run("empty_inputs", func(t *testing.T) {
+ mtimes := captureFileMtimes(nil, nil, "")
+ assert.Empty(t, mtimes)
+ })
+}
+
+func TestGetFileMtimes(t *testing.T) {
+ // Create a temp file
+ tmpDir, err := os.MkdirTemp("", "getmtime-test-*")
+ if err != nil {
+ t.Fatal(err)
+ }
+ defer os.RemoveAll(tmpDir)
+
+ testFile := filepath.Join(tmpDir, "test.txt")
+ err = os.WriteFile(testFile, []byte("content"), 0644)
+ if err != nil {
+ t.Fatal(err)
+ }
+
+ mtimes := GetFileMtimes([]string{testFile}, "")
+ assert.Len(t, mtimes, 1)
+ assert.Contains(t, mtimes, testFile)
+ assert.Greater(t, mtimes[testFile], int64(0))
+}
+
+func TestGetFileContent(t *testing.T) {
+ // Create a temp directory with test files
+ tmpDir, err := os.MkdirTemp("", "content-test-*")
+ if err != nil {
+ t.Fatal(err)
+ }
+ defer os.RemoveAll(tmpDir)
+
+ t.Run("reads_existing_file", func(t *testing.T) {
+ testFile := filepath.Join(tmpDir, "test.txt")
+ content := "test content"
+ err := os.WriteFile(testFile, []byte(content), 0644)
+ if err != nil {
+ t.Fatal(err)
+ }
+
+ result, ok := GetFileContent(testFile, "")
+ assert.True(t, ok)
+ assert.Equal(t, content, result)
+ })
+
+ t.Run("returns_false_for_nonexistent_file", func(t *testing.T) {
+ result, ok := GetFileContent("/nonexistent/file.txt", "")
+ assert.False(t, ok)
+ assert.Empty(t, result)
+ })
+
+ t.Run("truncates_long_content", func(t *testing.T) {
+ testFile := filepath.Join(tmpDir, "long.txt")
+ longContent := ""
+ for i := 0; i < 3000; i++ {
+ longContent += "x"
+ }
+ err := os.WriteFile(testFile, []byte(longContent), 0644)
+ if err != nil {
+ t.Fatal(err)
+ }
+
+ result, ok := GetFileContent(testFile, "")
+ assert.True(t, ok)
+ assert.Contains(t, result, "[truncated]")
+ assert.LessOrEqual(t, len(result), 2100)
+ })
+
+ t.Run("resolves_relative_path_with_cwd", func(t *testing.T) {
+ testFile := filepath.Join(tmpDir, "relative.txt")
+ content := "relative content"
+ err := os.WriteFile(testFile, []byte(content), 0644)
+ if err != nil {
+ t.Fatal(err)
+ }
+
+ result, ok := GetFileContent("relative.txt", tmpDir)
+ assert.True(t, ok)
+ assert.Equal(t, content, result)
+ })
+}
+
+func TestMaxConcurrentCLICalls(t *testing.T) {
+ assert.Equal(t, 4, MaxConcurrentCLICalls)
+}
+
+func TestObservationTypes(t *testing.T) {
+ expected := []string{"bugfix", "feature", "refactor", "change", "discovery", "decision"}
+ assert.Equal(t, expected, ObservationTypes)
+}
+
+func TestObservationConcepts(t *testing.T) {
+ expectedConcepts := []string{
+ "how-it-works",
+ "why-it-exists",
+ "what-changed",
+ "problem-solution",
+ "gotcha",
+ "pattern",
+ "trade-off",
+ }
+ assert.Equal(t, expectedConcepts, ObservationConcepts)
+}
diff --git a/internal/worker/sdk/prompts_test.go b/internal/worker/sdk/prompts_test.go
new file mode 100644
index 0000000..eecf7d5
--- /dev/null
+++ b/internal/worker/sdk/prompts_test.go
@@ -0,0 +1,276 @@
+package sdk
+
+import (
+ "strings"
+ "testing"
+ "time"
+
+ "github.com/stretchr/testify/assert"
+)
+
+func TestTruncate(t *testing.T) {
+ tests := []struct {
+ name string
+ input string
+ maxLen int
+ expected string
+ }{
+ {
+ name: "shorter_than_max",
+ input: "hello",
+ maxLen: 10,
+ expected: "hello",
+ },
+ {
+ name: "equal_to_max",
+ input: "hello",
+ maxLen: 5,
+ expected: "hello",
+ },
+ {
+ name: "longer_than_max",
+ input: "hello world",
+ maxLen: 5,
+ expected: "hello... (truncated)",
+ },
+ {
+ name: "empty_string",
+ input: "",
+ maxLen: 5,
+ expected: "",
+ },
+ {
+ name: "zero_max_length",
+ input: "hello",
+ maxLen: 0,
+ expected: "... (truncated)",
+ },
+ }
+
+ for _, tt := range tests {
+ t.Run(tt.name, func(t *testing.T) {
+ result := truncate(tt.input, tt.maxLen)
+ assert.Equal(t, tt.expected, result)
+ })
+ }
+}
+
+func TestBuildObservationPrompt(t *testing.T) {
+ now := time.Now().UnixMilli()
+
+ tests := []struct {
+ name string
+ exec ToolExecution
+ contains []string
+ }{
+ {
+ name: "basic_read_tool",
+ exec: ToolExecution{
+ ID: 1,
+ ToolName: "Read",
+ ToolInput: `{"file_path": "/path/to/file.go"}`,
+ ToolOutput: `package main\nfunc main() {}`,
+ CreatedAtEpoch: now,
+ CWD: "/project",
+ },
+ contains: []string{
+ "",
+ "Read",
+ "/project",
+ "",
+ "file_path",
+ "",
+ "",
+ },
+ },
+ {
+ name: "edit_tool_with_json_input",
+ exec: ToolExecution{
+ ID: 2,
+ ToolName: "Edit",
+ ToolInput: `{"file_path": "/file.go", "old_string": "foo", "new_string": "bar"}`,
+ ToolOutput: "Edit applied successfully",
+ CreatedAtEpoch: now,
+ CWD: "",
+ },
+ contains: []string{
+ "Edit",
+ "file_path",
+ "old_string",
+ "new_string",
+ },
+ },
+ {
+ name: "no_cwd",
+ exec: ToolExecution{
+ ID: 3,
+ ToolName: "Bash",
+ ToolInput: `{"command": "go test"}`,
+ ToolOutput: "ok",
+ CreatedAtEpoch: now,
+ CWD: "",
+ },
+ contains: []string{
+ "Bash",
+ },
+ },
+ }
+
+ for _, tt := range tests {
+ t.Run(tt.name, func(t *testing.T) {
+ result := BuildObservationPrompt(tt.exec)
+
+ for _, s := range tt.contains {
+ assert.Contains(t, result, s, "Expected result to contain: %s", s)
+ }
+
+ // Check CWD only appears when set
+ if tt.exec.CWD == "" {
+ assert.NotContains(t, result, "")
+ }
+ })
+ }
+}
+
+func TestBuildObservationPrompt_TruncatesLongContent(t *testing.T) {
+ longInput := strings.Repeat("x", 5000)
+ longOutput := strings.Repeat("y", 7000)
+
+ exec := ToolExecution{
+ ID: 1,
+ ToolName: "Read",
+ ToolInput: longInput,
+ ToolOutput: longOutput,
+ CreatedAtEpoch: time.Now().UnixMilli(),
+ CWD: "/project",
+ }
+
+ result := BuildObservationPrompt(exec)
+
+ // Input should be truncated to ~3000
+ assert.Contains(t, result, "truncated")
+ // The result should not be excessively long
+ assert.Less(t, len(result), 10000)
+}
+
+func TestBuildSummaryPrompt(t *testing.T) {
+ tests := []struct {
+ name string
+ req SummaryRequest
+ contains []string
+ }{
+ {
+ name: "basic_request",
+ req: SummaryRequest{
+ SessionDBID: 1,
+ SDKSessionID: "sdk-123",
+ Project: "test-project",
+ },
+ contains: []string{
+ "PROGRESS SUMMARY CHECKPOINT",
+ "",
+ "",
+ "",
+ "",
+ "",
+ "",
+ "",
+ "",
+ },
+ },
+ {
+ name: "with_assistant_message",
+ req: SummaryRequest{
+ SessionDBID: 2,
+ SDKSessionID: "sdk-456",
+ Project: "project-b",
+ LastAssistantMessage: "I fixed the authentication bug by updating the JWT validation.",
+ },
+ contains: []string{
+ "Claude's Full Response to User:",
+ "fixed the authentication",
+ },
+ },
+ {
+ name: "empty_assistant_message",
+ req: SummaryRequest{
+ SessionDBID: 3,
+ SDKSessionID: "sdk-789",
+ Project: "project-c",
+ LastAssistantMessage: "",
+ },
+ contains: []string{
+ "PROGRESS SUMMARY CHECKPOINT",
+ },
+ },
+ }
+
+ for _, tt := range tests {
+ t.Run(tt.name, func(t *testing.T) {
+ result := BuildSummaryPrompt(tt.req)
+
+ for _, s := range tt.contains {
+ assert.Contains(t, result, s, "Expected result to contain: %s", s)
+ }
+
+ // Check assistant message only appears when set
+ if tt.req.LastAssistantMessage == "" {
+ assert.NotContains(t, result, "Claude's Full Response")
+ }
+ })
+ }
+}
+
+func TestBuildSummaryPrompt_TruncatesLongAssistantMessage(t *testing.T) {
+ longMessage := strings.Repeat("a", 5000)
+
+ req := SummaryRequest{
+ SessionDBID: 1,
+ SDKSessionID: "sdk-123",
+ Project: "test",
+ LastAssistantMessage: longMessage,
+ }
+
+ result := BuildSummaryPrompt(req)
+
+ // Should contain truncation indicator
+ assert.Contains(t, result, "truncated")
+ // Result should be reasonable length (less than full 5000 + overhead)
+ assert.Less(t, len(result), 6000)
+}
+
+func TestToolExecution_Struct(t *testing.T) {
+ exec := ToolExecution{
+ ID: 42,
+ ToolName: "Write",
+ ToolInput: `{"file_path": "/test.go"}`,
+ ToolOutput: "File written",
+ CreatedAtEpoch: 1234567890000,
+ CWD: "/workspace",
+ }
+
+ assert.Equal(t, int64(42), exec.ID)
+ assert.Equal(t, "Write", exec.ToolName)
+ assert.Equal(t, `{"file_path": "/test.go"}`, exec.ToolInput)
+ assert.Equal(t, "File written", exec.ToolOutput)
+ assert.Equal(t, int64(1234567890000), exec.CreatedAtEpoch)
+ assert.Equal(t, "/workspace", exec.CWD)
+}
+
+func TestSummaryRequest_Struct(t *testing.T) {
+ req := SummaryRequest{
+ SessionDBID: 100,
+ SDKSessionID: "sdk-abc",
+ Project: "my-project",
+ UserPrompt: "Fix the bug",
+ LastUserMessage: "Please fix the auth bug",
+ LastAssistantMessage: "I've fixed the authentication issue",
+ }
+
+ assert.Equal(t, int64(100), req.SessionDBID)
+ assert.Equal(t, "sdk-abc", req.SDKSessionID)
+ assert.Equal(t, "my-project", req.Project)
+ assert.Equal(t, "Fix the bug", req.UserPrompt)
+ assert.Equal(t, "Please fix the auth bug", req.LastUserMessage)
+ assert.Equal(t, "I've fixed the authentication issue", req.LastAssistantMessage)
+}
diff --git a/pkg/hooks/worker_test.go b/pkg/hooks/worker_test.go
index 416409e..74aab43 100644
--- a/pkg/hooks/worker_test.go
+++ b/pkg/hooks/worker_test.go
@@ -710,3 +710,495 @@ func TestGetWorkerVersion_WithServer(t *testing.T) {
})
}
}
+
+// TestGetWorkerPort_EdgeCases tests GetWorkerPort with various edge cases.
+func TestGetWorkerPort_EdgeCases(t *testing.T) {
+ tests := []struct {
+ name string
+ envValue string
+ expectedPort int
+ shouldSetEnv bool
+ }{
+ {
+ name: "zero port uses default",
+ envValue: "0",
+ expectedPort: DefaultWorkerPort,
+ shouldSetEnv: true,
+ },
+ {
+ name: "negative port uses default",
+ envValue: "-1",
+ expectedPort: DefaultWorkerPort,
+ shouldSetEnv: true,
+ },
+ {
+ name: "empty string uses default",
+ envValue: "",
+ expectedPort: DefaultWorkerPort,
+ shouldSetEnv: true,
+ },
+ {
+ name: "whitespace uses default",
+ envValue: " ",
+ expectedPort: DefaultWorkerPort,
+ shouldSetEnv: true,
+ },
+ {
+ name: "large valid port",
+ envValue: "65535",
+ expectedPort: 65535,
+ shouldSetEnv: true,
+ },
+ }
+
+ for _, tt := range tests {
+ t.Run(tt.name, func(t *testing.T) {
+ if tt.shouldSetEnv {
+ t.Setenv("CLAUDE_MNEMONIC_WORKER_PORT", tt.envValue)
+ }
+ port := GetWorkerPort()
+ assert.Equal(t, tt.expectedPort, port)
+ })
+ }
+}
+
+// TestVersionVariable tests the Version variable.
+func TestVersionVariable(t *testing.T) {
+ // Version is set at build time, but defaults to "dev"
+ assert.NotEmpty(t, Version)
+}
+
+// TestProjectIDWithName_RootPath tests ProjectIDWithName with root path.
+func TestProjectIDWithName_RootPath(t *testing.T) {
+ result := ProjectIDWithName("/")
+ // Should handle root path gracefully
+ assert.NotEmpty(t, result)
+ assert.Contains(t, result, "_") // Should still have underscore separator
+}
+
+// TestProjectIDWithName_SameDirname tests that same dirname with different paths get different IDs.
+func TestProjectIDWithName_SameDirname(t *testing.T) {
+ id1 := ProjectIDWithName("/home/user1/project")
+ id2 := ProjectIDWithName("/home/user2/project")
+
+ // Both have same dirname "project" but different full paths
+ assert.Contains(t, id1, "project_")
+ assert.Contains(t, id2, "project_")
+
+ // But different hashes due to different full paths
+ assert.NotEqual(t, id1, id2)
+}
+
+// TestBaseInput_PartialFields tests BaseInput with partial fields.
+func TestBaseInput_PartialFields(t *testing.T) {
+ tests := []struct {
+ name string
+ input string
+ expected BaseInput
+ }{
+ {
+ name: "only session_id",
+ input: `{"session_id":"test-123"}`,
+ expected: BaseInput{SessionID: "test-123"},
+ },
+ {
+ name: "only cwd",
+ input: `{"cwd":"/tmp/test"}`,
+ expected: BaseInput{CWD: "/tmp/test"},
+ },
+ {
+ name: "empty object",
+ input: `{}`,
+ expected: BaseInput{},
+ },
+ }
+
+ for _, tt := range tests {
+ t.Run(tt.name, func(t *testing.T) {
+ var base BaseInput
+ err := json.Unmarshal([]byte(tt.input), &base)
+ require.NoError(t, err)
+ assert.Equal(t, tt.expected.SessionID, base.SessionID)
+ assert.Equal(t, tt.expected.CWD, base.CWD)
+ })
+ }
+}
+
+// TestHookResponse_Marshal tests HookResponse JSON marshaling.
+func TestHookResponse_Marshal(t *testing.T) {
+ tests := []struct {
+ name string
+ response HookResponse
+ contains []string
+ }{
+ {
+ name: "continue true",
+ response: HookResponse{Continue: true},
+ contains: []string{`"continue":true`},
+ },
+ {
+ name: "continue false",
+ response: HookResponse{Continue: false},
+ contains: []string{`"continue":false`},
+ },
+ }
+
+ for _, tt := range tests {
+ t.Run(tt.name, func(t *testing.T) {
+ data, err := json.Marshal(tt.response)
+ require.NoError(t, err)
+ for _, s := range tt.contains {
+ assert.Contains(t, string(data), s)
+ }
+ })
+ }
+}
+
+// TestHookResponse_Unmarshal tests HookResponse JSON unmarshaling.
+func TestHookResponse_Unmarshal(t *testing.T) {
+ tests := []struct {
+ name string
+ input string
+ expected HookResponse
+ }{
+ {
+ name: "continue true",
+ input: `{"continue":true}`,
+ expected: HookResponse{Continue: true},
+ },
+ {
+ name: "continue false",
+ input: `{"continue":false}`,
+ expected: HookResponse{Continue: false},
+ },
+ {
+ name: "missing continue defaults to false",
+ input: `{}`,
+ expected: HookResponse{Continue: false},
+ },
+ }
+
+ for _, tt := range tests {
+ t.Run(tt.name, func(t *testing.T) {
+ var resp HookResponse
+ err := json.Unmarshal([]byte(tt.input), &resp)
+ require.NoError(t, err)
+ assert.Equal(t, tt.expected.Continue, resp.Continue)
+ })
+ }
+}
+
+// TestHookContext_Initialization tests HookContext struct initialization.
+func TestHookContext_Initialization(t *testing.T) {
+ tests := []struct {
+ name string
+ ctx HookContext
+ }{
+ {
+ name: "full context",
+ ctx: HookContext{
+ HookName: "session-start",
+ Port: 37777,
+ Project: "my-project_abc123",
+ SessionID: "session-123",
+ CWD: "/home/user/project",
+ RawInput: []byte(`{"key":"value"}`),
+ },
+ },
+ {
+ name: "minimal context",
+ ctx: HookContext{
+ HookName: "stop",
+ },
+ },
+ {
+ name: "empty context",
+ ctx: HookContext{},
+ },
+ }
+
+ for _, tt := range tests {
+ t.Run(tt.name, func(t *testing.T) {
+ // Just verify the struct can be created and accessed
+ assert.Equal(t, tt.ctx.HookName, tt.ctx.HookName)
+ assert.Equal(t, tt.ctx.Port, tt.ctx.Port)
+ assert.Equal(t, tt.ctx.Project, tt.ctx.Project)
+ })
+ }
+}
+
+// TestPOST_MarshalError tests POST with unmarshalable body.
+func TestPOST_MarshalError(t *testing.T) {
+ // Create a value that can't be marshaled
+ badValue := make(chan int)
+
+ _, err := POST(99999, "/test", badValue)
+ require.Error(t, err)
+}
+
+// TestPOST_Timeout tests POST with timeout.
+func TestPOST_Timeout(t *testing.T) {
+ // Try to connect to a port that's not listening
+ _, err := POST(99998, "/test", map[string]string{"key": "value"})
+ require.Error(t, err)
+}
+
+// TestGET_Timeout tests GET with timeout.
+func TestGET_Timeout(t *testing.T) {
+ // Try to connect to a port that's not listening
+ _, err := GET(99998, "/test")
+ require.Error(t, err)
+}
+
+// TestIsWorkerRunning_Timeout tests IsWorkerRunning with timeout.
+func TestIsWorkerRunning_Timeout(t *testing.T) {
+ // Non-existent port should quickly return false
+ start := time.Now()
+ result := IsWorkerRunning(99997)
+ elapsed := time.Since(start)
+
+ assert.False(t, result)
+ assert.Less(t, elapsed, 5*time.Second) // Should not hang
+}
+
+// TestIsPortInUse_Timeout tests IsPortInUse with timeout.
+func TestIsPortInUse_Timeout(t *testing.T) {
+ // Non-existent port should quickly return false
+ start := time.Now()
+ result := IsPortInUse(99996)
+ elapsed := time.Since(start)
+
+ assert.False(t, result)
+ assert.Less(t, elapsed, 2*time.Second) // Should not hang
+}
+
+// TestGetWorkerVersion_Timeout tests GetWorkerVersion with timeout.
+func TestGetWorkerVersion_Timeout(t *testing.T) {
+ // Non-existent port should quickly return empty
+ start := time.Now()
+ result := GetWorkerVersion(99995)
+ elapsed := time.Since(start)
+
+ assert.Empty(t, result)
+ assert.Less(t, elapsed, 5*time.Second) // Should not hang
+}
+
+// TestVersionsCompatible_EdgeCases tests versionsCompatible edge cases.
+func TestVersionsCompatible_EdgeCases(t *testing.T) {
+ tests := []struct {
+ name string
+ v1 string
+ v2 string
+ expected bool
+ }{
+ {
+ name: "empty versions",
+ v1: "",
+ v2: "",
+ expected: true, // Same base (empty)
+ },
+ {
+ name: "one empty one dev",
+ v1: "",
+ v2: "dev",
+ expected: true, // dev is compatible with anything
+ },
+ {
+ name: "prerelease versions same base",
+ v1: "v1.0.0-alpha",
+ v2: "v1.0.0-beta",
+ expected: true, // Same base 1.0.0
+ },
+ {
+ name: "version with rc suffix",
+ v1: "v2.0.0-rc1",
+ v2: "v2.0.0",
+ expected: true, // Same base 2.0.0
+ },
+ }
+
+ for _, tt := range tests {
+ t.Run(tt.name, func(t *testing.T) {
+ result := versionsCompatible(tt.v1, tt.v2)
+ assert.Equal(t, tt.expected, result)
+ })
+ }
+}
+
+// TestExtractBaseVersion_EdgeCases tests extractBaseVersion edge cases.
+func TestExtractBaseVersion_EdgeCases(t *testing.T) {
+ tests := []struct {
+ name string
+ version string
+ expected string
+ }{
+ {
+ name: "version starting with hyphen",
+ version: "-dirty",
+ expected: "-dirty", // hyphen at index 0 is not > 0, so no truncation
+ },
+ {
+ name: "just v",
+ version: "v",
+ expected: "",
+ },
+ {
+ name: "multiple hyphens",
+ version: "v1.0.0-alpha-beta-gamma",
+ expected: "1.0.0",
+ },
+ {
+ name: "no hyphen at all",
+ version: "v2.0.0",
+ expected: "2.0.0",
+ },
+ }
+
+ for _, tt := range tests {
+ t.Run(tt.name, func(t *testing.T) {
+ result := extractBaseVersion(tt.version)
+ assert.Equal(t, tt.expected, result)
+ })
+ }
+}
+
+// TestProjectIDWithName_RelativePath tests ProjectIDWithName with relative paths.
+func TestProjectIDWithName_RelativePath(t *testing.T) {
+ // Relative paths should be converted to absolute
+ result := ProjectIDWithName(".")
+ assert.NotEmpty(t, result)
+ assert.Contains(t, result, "_")
+}
+
+// TestProjectIDWithName_DeepPath tests ProjectIDWithName with deep paths.
+func TestProjectIDWithName_DeepPath(t *testing.T) {
+ result := ProjectIDWithName("/a/very/deep/nested/path/to/project")
+ assert.Contains(t, result, "project_")
+ assert.NotEmpty(t, result)
+}
+
+// TestPOST_EmptyBody tests POST with empty body.
+func TestPOST_EmptyBody(t *testing.T) {
+ server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
+ assert.Equal(t, http.MethodPost, r.Method)
+ w.WriteHeader(http.StatusOK)
+ json.NewEncoder(w).Encode(map[string]string{"status": "ok"})
+ }))
+ defer server.Close()
+
+ var port int
+ _, err := fmt.Sscanf(server.URL, "http://127.0.0.1:%d", &port)
+ require.NoError(t, err)
+
+ result, err := POST(port, "/test", map[string]string{})
+ require.NoError(t, err)
+ assert.NotNil(t, result)
+}
+
+// TestGET_WithQueryParams tests GET with query parameters.
+func TestGET_WithQueryParams(t *testing.T) {
+ server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
+ assert.Equal(t, http.MethodGet, r.Method)
+ assert.Equal(t, "/test?foo=bar", r.URL.String())
+ w.WriteHeader(http.StatusOK)
+ json.NewEncoder(w).Encode(map[string]string{"status": "ok"})
+ }))
+ defer server.Close()
+
+ var port int
+ _, err := fmt.Sscanf(server.URL, "http://127.0.0.1:%d", &port)
+ require.NoError(t, err)
+
+ result, err := GET(port, "/test?foo=bar")
+ require.NoError(t, err)
+ assert.NotNil(t, result)
+}
+
+// TestHookResponse_RoundTrip tests JSON marshal/unmarshal round-trip.
+func TestHookResponse_RoundTrip(t *testing.T) {
+ original := HookResponse{Continue: true}
+
+ data, err := json.Marshal(original)
+ require.NoError(t, err)
+
+ var decoded HookResponse
+ err = json.Unmarshal(data, &decoded)
+ require.NoError(t, err)
+
+ assert.Equal(t, original.Continue, decoded.Continue)
+}
+
+// TestBaseInput_RoundTrip tests BaseInput JSON round-trip.
+func TestBaseInput_RoundTrip(t *testing.T) {
+ original := BaseInput{
+ SessionID: "test-session",
+ CWD: "/home/user/project",
+ PermissionMode: "standard",
+ HookEventName: "session-start",
+ }
+
+ data, err := json.Marshal(original)
+ require.NoError(t, err)
+
+ var decoded BaseInput
+ err = json.Unmarshal(data, &decoded)
+ require.NoError(t, err)
+
+ assert.Equal(t, original.SessionID, decoded.SessionID)
+ assert.Equal(t, original.CWD, decoded.CWD)
+ assert.Equal(t, original.PermissionMode, decoded.PermissionMode)
+ assert.Equal(t, original.HookEventName, decoded.HookEventName)
+}
+
+// TestHookContext_RawInput tests HookContext with different raw input types.
+func TestHookContext_RawInput(t *testing.T) {
+ tests := []struct {
+ name string
+ rawInput []byte
+ }{
+ {
+ name: "json object",
+ rawInput: []byte(`{"key":"value"}`),
+ },
+ {
+ name: "json array",
+ rawInput: []byte(`[1,2,3]`),
+ },
+ {
+ name: "empty object",
+ rawInput: []byte(`{}`),
+ },
+ {
+ name: "nil input",
+ rawInput: nil,
+ },
+ }
+
+ for _, tt := range tests {
+ t.Run(tt.name, func(t *testing.T) {
+ ctx := HookContext{
+ HookName: "test",
+ RawInput: tt.rawInput,
+ }
+ assert.Equal(t, tt.rawInput, ctx.RawInput)
+ })
+ }
+}
+
+// TestDefaultWorkerPort tests that the default port constant is valid.
+func TestDefaultWorkerPort(t *testing.T) {
+ assert.Greater(t, DefaultWorkerPort, 1024, "Default port should be above privileged port range")
+ assert.Less(t, DefaultWorkerPort, 65535, "Default port should be valid TCP port")
+}
+
+// TestHealthCheckTimeout tests the health check timeout is reasonable.
+func TestHealthCheckTimeout(t *testing.T) {
+ assert.Greater(t, HealthCheckTimeout, 100*time.Millisecond)
+ assert.Less(t, HealthCheckTimeout, 10*time.Second)
+}
+
+// TestStartupTimeout tests the startup timeout is reasonable.
+func TestStartupTimeout(t *testing.T) {
+ assert.Greater(t, StartupTimeout, 5*time.Second)
+ assert.LessOrEqual(t, StartupTimeout, time.Minute)
+}