From c259bb1d18ce7802ceabdcfeae05c5ab14d21e5e Mon Sep 17 00:00:00 2001 From: Lukasz Raczylo Date: Wed, 17 Dec 2025 12:39:47 +0000 Subject: [PATCH] Increase test coverage to 45.6% --- internal/db/sqlite/helpers_test.go | 254 +++++++ internal/db/sqlite/migrations_test.go | 196 +++++ internal/db/sqlite/observation_test.go | 145 ++++ internal/embedding/service_test.go | 332 +++++++++ internal/mcp/server_test.go | 360 +++++++++ internal/search/integration_test.go | 646 +++++++++++++++++ internal/search/manager_test.go | 480 ++++++++++++ internal/vector/sqlitevec/client_test.go | 502 +++++++++++++ internal/vector/sqlitevec/helpers_test.go | 574 +++++++++++++++ internal/worker/handlers_test.go | 846 ++++++++++++++++++++++ internal/worker/sdk/processor_test.go | 381 ++++++++++ internal/worker/sdk/prompts_test.go | 276 +++++++ pkg/hooks/worker_test.go | 492 +++++++++++++ 13 files changed, 5484 insertions(+) create mode 100644 internal/db/sqlite/helpers_test.go create mode 100644 internal/db/sqlite/migrations_test.go create mode 100644 internal/embedding/service_test.go create mode 100644 internal/search/integration_test.go create mode 100644 internal/vector/sqlitevec/client_test.go create mode 100644 internal/vector/sqlitevec/helpers_test.go create mode 100644 internal/worker/sdk/prompts_test.go diff --git a/internal/db/sqlite/helpers_test.go b/internal/db/sqlite/helpers_test.go new file mode 100644 index 0000000..da7e99f --- /dev/null +++ b/internal/db/sqlite/helpers_test.go @@ -0,0 +1,254 @@ +package sqlite + +import ( + "net/http" + "net/http/httptest" + "testing" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +func TestNullString(t *testing.T) { + tests := []struct { + name string + input string + expected string + valid bool + }{ + {"empty_string", "", "", false}, + {"non_empty_string", "hello", "hello", true}, + {"whitespace", " ", " ", true}, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + result := nullString(tt.input) + assert.Equal(t, tt.expected, result.String) + assert.Equal(t, tt.valid, result.Valid) + }) + } +} + +func TestNullInt(t *testing.T) { + tests := []struct { + name string + input int + expected int64 + valid bool + }{ + {"zero", 0, 0, false}, + {"positive", 42, 42, true}, + {"negative", -1, -1, false}, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + result := nullInt(tt.input) + assert.Equal(t, tt.expected, result.Int64) + assert.Equal(t, tt.valid, result.Valid) + }) + } +} + +func TestRepeatPlaceholders(t *testing.T) { + tests := []struct { + name string + n int + expected string + }{ + {"zero", 0, ""}, + {"negative", -1, ""}, + {"one", 1, ", ?"}, + {"two", 2, ", ?, ?"}, + {"three", 3, ", ?, ?, ?"}, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + result := repeatPlaceholders(tt.n) + assert.Equal(t, tt.expected, result) + }) + } +} + +func TestInt64SliceToInterface(t *testing.T) { + tests := []struct { + name string + input []int64 + expected []interface{} + }{ + {"empty", []int64{}, []interface{}{}}, + {"single", []int64{42}, []interface{}{int64(42)}}, + {"multiple", []int64{1, 2, 3}, []interface{}{int64(1), int64(2), int64(3)}}, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + result := int64SliceToInterface(tt.input) + assert.Equal(t, tt.expected, result) + }) + } +} + +func TestParseLimitParam(t *testing.T) { + tests := []struct { + name string + query string + defaultLimit int + expected int + }{ + {"no_param_uses_default", "", 10, 10}, + {"valid_limit", "limit=20", 10, 20}, + {"invalid_limit_uses_default", "limit=abc", 10, 10}, + {"zero_limit_uses_default", "limit=0", 10, 10}, + {"negative_limit_uses_default", "limit=-5", 10, 10}, + {"large_limit", "limit=1000", 10, 1000}, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + req := httptest.NewRequest("GET", "/?"+tt.query, nil) + result := ParseLimitParam(req, tt.defaultLimit) + assert.Equal(t, tt.expected, result) + }) + } +} + +func TestScanSummary(t *testing.T) { + db, _, cleanup := testDB(t) + defer cleanup() + createBaseTables(t, db) + seedSession(t, db, "claude-123", "sdk-123", "test-project") + + // Insert a test summary + _, err := db.Exec(` + INSERT INTO session_summaries (sdk_session_id, project, request, investigated, learned, completed, next_steps, notes, prompt_number, discovery_tokens, created_at, created_at_epoch) + VALUES ('sdk-123', 'test-project', 'test request', 'test investigated', 'test learned', 'test completed', 'test next steps', 'test notes', 1, 100, '2025-01-01T00:00:00Z', 1704067200000) + `) + require.NoError(t, err) + + // Query and scan + row := db.QueryRow(` + SELECT id, sdk_session_id, project, request, investigated, learned, completed, next_steps, notes, prompt_number, discovery_tokens, created_at, created_at_epoch + FROM session_summaries WHERE sdk_session_id = ? + `, "sdk-123") + + summary, err := scanSummary(row) + require.NoError(t, err) + assert.NotNil(t, summary) + assert.Equal(t, "sdk-123", summary.SDKSessionID) + assert.Equal(t, "test-project", summary.Project) + assert.Equal(t, "test request", summary.Request.String) + assert.Equal(t, "test investigated", summary.Investigated.String) + assert.Equal(t, "test learned", summary.Learned.String) + assert.Equal(t, "test completed", summary.Completed.String) + assert.Equal(t, "test next steps", summary.NextSteps.String) + assert.Equal(t, "test notes", summary.Notes.String) +} + +func TestScanSummaryRows(t *testing.T) { + db, _, cleanup := testDB(t) + defer cleanup() + createBaseTables(t, db) + seedSession(t, db, "claude-123", "sdk-123", "test-project") + + // Insert multiple summaries + _, err := db.Exec(` + INSERT INTO session_summaries (sdk_session_id, project, request, investigated, learned, completed, next_steps, notes, prompt_number, discovery_tokens, created_at, created_at_epoch) + VALUES + ('sdk-123', 'test-project', 'request 1', '', '', '', '', '', 1, 0, '2025-01-01T00:00:00Z', 1704067200000), + ('sdk-123', 'test-project', 'request 2', '', '', '', '', '', 2, 0, '2025-01-02T00:00:00Z', 1704153600000) + `) + require.NoError(t, err) + + rows, err := db.Query(` + SELECT id, sdk_session_id, project, request, investigated, learned, completed, next_steps, notes, prompt_number, discovery_tokens, created_at, created_at_epoch + FROM session_summaries WHERE sdk_session_id = ? ORDER BY id + `, "sdk-123") + require.NoError(t, err) + defer rows.Close() + + summaries, err := scanSummaryRows(rows) + require.NoError(t, err) + assert.Len(t, summaries, 2) + assert.Equal(t, "request 1", summaries[0].Request.String) + assert.Equal(t, "request 2", summaries[1].Request.String) +} + +func TestScanPromptWithSession(t *testing.T) { + db, _, cleanup := testDB(t) + defer cleanup() + createBaseTables(t, db) + seedSession(t, db, "claude-123", "sdk-123", "test-project") + + // Insert a test prompt + _, err := db.Exec(` + INSERT INTO user_prompts (claude_session_id, prompt_number, prompt_text, matched_observations, created_at, created_at_epoch) + VALUES ('claude-123', 1, 'test prompt', 5, '2025-01-01T00:00:00Z', 1704067200000) + `) + require.NoError(t, err) + + // Query with session join + row := db.QueryRow(` + SELECT p.id, p.claude_session_id, p.prompt_number, p.prompt_text, p.matched_observations, p.created_at, p.created_at_epoch, s.project, s.sdk_session_id + FROM user_prompts p + JOIN sdk_sessions s ON p.claude_session_id = s.claude_session_id + WHERE p.claude_session_id = ? + `, "claude-123") + + prompt, err := scanPromptWithSession(row) + require.NoError(t, err) + assert.NotNil(t, prompt) + assert.Equal(t, "claude-123", prompt.ClaudeSessionID) + assert.Equal(t, 1, prompt.PromptNumber) + assert.Equal(t, "test prompt", prompt.PromptText) + assert.Equal(t, 5, prompt.MatchedObservations) + assert.Equal(t, "test-project", prompt.Project) + assert.Equal(t, "sdk-123", prompt.SDKSessionID) +} + +func TestScanPromptWithSessionRows(t *testing.T) { + db, _, cleanup := testDB(t) + defer cleanup() + createBaseTables(t, db) + seedSession(t, db, "claude-123", "sdk-123", "test-project") + + // Insert multiple prompts + _, err := db.Exec(` + INSERT INTO user_prompts (claude_session_id, prompt_number, prompt_text, matched_observations, created_at, created_at_epoch) + VALUES + ('claude-123', 1, 'prompt one', 3, '2025-01-01T00:00:00Z', 1704067200000), + ('claude-123', 2, 'prompt two', 5, '2025-01-02T00:00:00Z', 1704153600000) + `) + require.NoError(t, err) + + rows, err := db.Query(` + SELECT p.id, p.claude_session_id, p.prompt_number, p.prompt_text, p.matched_observations, p.created_at, p.created_at_epoch, s.project, s.sdk_session_id + FROM user_prompts p + JOIN sdk_sessions s ON p.claude_session_id = s.claude_session_id + WHERE p.claude_session_id = ? ORDER BY p.id + `, "claude-123") + require.NoError(t, err) + defer rows.Close() + + prompts, err := scanPromptWithSessionRows(rows) + require.NoError(t, err) + assert.Len(t, prompts, 2) + assert.Equal(t, "prompt one", prompts[0].PromptText) + assert.Equal(t, "prompt two", prompts[1].PromptText) +} + +func TestParseLimitParam_HTTPRequest(t *testing.T) { + // Test with an actual HTTP request + handler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + limit := ParseLimitParam(r, 25) + if limit != 50 { + t.Errorf("Expected limit 50, got %d", limit) + } + }) + + req := httptest.NewRequest("GET", "http://example.com/api?limit=50", nil) + w := httptest.NewRecorder() + handler.ServeHTTP(w, req) +} diff --git a/internal/db/sqlite/migrations_test.go b/internal/db/sqlite/migrations_test.go new file mode 100644 index 0000000..129be43 --- /dev/null +++ b/internal/db/sqlite/migrations_test.go @@ -0,0 +1,196 @@ +package sqlite + +import ( + "testing" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +func TestNewMigrationManager(t *testing.T) { + db, _, cleanup := testDB(t) + defer cleanup() + + manager := NewMigrationManager(db) + require.NotNil(t, manager) + assert.Equal(t, db, manager.db) +} + +func TestMigrationManager_EnsureSchemaVersionsTable(t *testing.T) { + db, _, cleanup := testDB(t) + defer cleanup() + + manager := NewMigrationManager(db) + + // Should create table without error + err := manager.EnsureSchemaVersionsTable() + require.NoError(t, err) + + // Table should exist + var count int + err = db.QueryRow("SELECT COUNT(*) FROM schema_versions").Scan(&count) + require.NoError(t, err) + assert.Equal(t, 0, count) // Empty table + + // Calling again should not error (IF NOT EXISTS) + err = manager.EnsureSchemaVersionsTable() + require.NoError(t, err) +} + +func TestMigrationManager_GetAppliedVersions_Empty(t *testing.T) { + db, _, cleanup := testDB(t) + defer cleanup() + + manager := NewMigrationManager(db) + err := manager.EnsureSchemaVersionsTable() + require.NoError(t, err) + + versions, err := manager.GetAppliedVersions() + require.NoError(t, err) + assert.Empty(t, versions) +} + +func TestMigrationManager_GetAppliedVersions_WithVersions(t *testing.T) { + db, _, cleanup := testDB(t) + defer cleanup() + + manager := NewMigrationManager(db) + err := manager.EnsureSchemaVersionsTable() + require.NoError(t, err) + + // Insert some versions + _, err = db.Exec("INSERT INTO schema_versions (version, applied_at) VALUES (1, '2025-01-01T00:00:00Z')") + require.NoError(t, err) + _, err = db.Exec("INSERT INTO schema_versions (version, applied_at) VALUES (2, '2025-01-02T00:00:00Z')") + require.NoError(t, err) + + versions, err := manager.GetAppliedVersions() + require.NoError(t, err) + assert.Len(t, versions, 2) + assert.True(t, versions[1]) + assert.True(t, versions[2]) + assert.False(t, versions[3]) // Not applied +} + +func TestMigrationManager_ApplyMigration(t *testing.T) { + db, _, cleanup := testDB(t) + defer cleanup() + + manager := NewMigrationManager(db) + err := manager.EnsureSchemaVersionsTable() + require.NoError(t, err) + + // Apply a simple migration + migration := Migration{ + Version: 100, + Name: "test_migration", + SQL: "CREATE TABLE test_table (id INTEGER PRIMARY KEY, name TEXT)", + } + + err = manager.ApplyMigration(migration) + require.NoError(t, err) + + // Verify table was created + var count int + err = db.QueryRow("SELECT COUNT(*) FROM sqlite_master WHERE type='table' AND name='test_table'").Scan(&count) + require.NoError(t, err) + assert.Equal(t, 1, count) + + // Verify migration was recorded + var version int + err = db.QueryRow("SELECT version FROM schema_versions WHERE version = 100").Scan(&version) + require.NoError(t, err) + assert.Equal(t, 100, version) +} + +func TestMigrationManager_ApplyMigration_InvalidSQL(t *testing.T) { + db, _, cleanup := testDB(t) + defer cleanup() + + manager := NewMigrationManager(db) + err := manager.EnsureSchemaVersionsTable() + require.NoError(t, err) + + // Try to apply invalid migration + migration := Migration{ + Version: 100, + Name: "invalid_migration", + SQL: "INVALID SQL SYNTAX", + } + + err = manager.ApplyMigration(migration) + assert.Error(t, err) + assert.Contains(t, err.Error(), "execute migration 100") +} + +func TestMigrationManager_RunMigrations_SingleMigration(t *testing.T) { + db, _, cleanup := testDB(t) + defer cleanup() + + // Create a test migration manager with a subset of migrations + manager := NewMigrationManager(db) + + // First ensure schema versions table exists + err := manager.EnsureSchemaVersionsTable() + require.NoError(t, err) + + // Apply first migration manually + err = manager.ApplyMigration(Migrations[0]) + require.NoError(t, err) + + // Verify the first migration version was recorded + versions, err := manager.GetAppliedVersions() + require.NoError(t, err) + assert.True(t, versions[Migrations[0].Version]) +} + +func TestMigrationManager_RunMigrations_SkipsApplied(t *testing.T) { + db, _, cleanup := testDB(t) + defer cleanup() + + manager := NewMigrationManager(db) + err := manager.EnsureSchemaVersionsTable() + require.NoError(t, err) + + // Mark some migrations as already applied + _, err = db.Exec("INSERT INTO schema_versions (version, applied_at) VALUES (4, '2025-01-01T00:00:00Z')") + require.NoError(t, err) + + // Get applied versions + versions, err := manager.GetAppliedVersions() + require.NoError(t, err) + assert.True(t, versions[4]) +} + +func TestMigration_Struct(t *testing.T) { + m := Migration{ + Version: 1, + Name: "test", + SQL: "SELECT 1", + } + + assert.Equal(t, 1, m.Version) + assert.Equal(t, "test", m.Name) + assert.Equal(t, "SELECT 1", m.SQL) +} + +func TestMigrations_List(t *testing.T) { + // Verify migrations are ordered correctly + assert.NotEmpty(t, Migrations) + + // Verify all migrations have required fields + for i, m := range Migrations { + assert.Greater(t, m.Version, 0, "Migration %d has invalid version", i) + assert.NotEmpty(t, m.Name, "Migration %d has empty name", i) + assert.NotEmpty(t, m.SQL, "Migration %d has empty SQL", i) + } + + // Verify key migrations exist + versionSet := make(map[int]bool) + for _, m := range Migrations { + versionSet[m.Version] = true + } + + assert.True(t, versionSet[4], "Should have sdk_agent_architecture migration") + assert.True(t, versionSet[17], "Should have sqlite_vec_vectors migration") +} diff --git a/internal/db/sqlite/observation_test.go b/internal/db/sqlite/observation_test.go index 20ddf48..6b4972c 100644 --- a/internal/db/sqlite/observation_test.go +++ b/internal/db/sqlite/observation_test.go @@ -800,3 +800,148 @@ func TestExtractKeywords(t *testing.T) { }) } } + +func TestObservationStore_GetObservationsByProjectStrict(t *testing.T) { + obsStore, _, cleanup := testObservationStore(t) + defer cleanup() + + ctx := context.Background() + + // Create project-scoped observation for project-a + projectObs := &models.ParsedObservation{ + Type: models.ObsTypeDiscovery, + Title: "Project A specific", + Narrative: "Only for project-a", + Concepts: []string{"local-concept"}, + } + _, _, err := obsStore.StoreObservation(ctx, "session-1", "project-a", projectObs, 1, 100) + require.NoError(t, err) + + // Create global observation from project-a + globalObs := &models.ParsedObservation{ + Type: models.ObsTypeDiscovery, + Title: "Global security practice", + Narrative: "Best practice for all", + Concepts: []string{"security", "best-practice"}, + } + _, _, err = obsStore.StoreObservation(ctx, "session-1", "project-a", globalObs, 2, 100) + require.NoError(t, err) + + // Create observation for project-b + projectBObs := &models.ParsedObservation{ + Type: models.ObsTypeDiscovery, + Title: "Project B specific", + Narrative: "Only for project-b", + } + _, _, err = obsStore.StoreObservation(ctx, "session-1", "project-b", projectBObs, 1, 100) + require.NoError(t, err) + + // GetObservationsByProjectStrict for project-a should only return project-a observations + // This is different from GetRecentObservations which includes globals from other projects + results, err := obsStore.GetObservationsByProjectStrict(ctx, "project-a", 10) + require.NoError(t, err) + assert.Len(t, results, 2) // Only observations created in project-a + + // Verify both are from project-a + for _, obs := range results { + assert.Equal(t, "project-a", obs.Project) + } + + // GetObservationsByProjectStrict for project-b should only return project-b observations + results, err = obsStore.GetObservationsByProjectStrict(ctx, "project-b", 10) + require.NoError(t, err) + assert.Len(t, results, 1) + assert.Equal(t, "Project B specific", results[0].Title.String) +} + +func TestObservationStore_SearchObservationsFTS_EmptyQuery(t *testing.T) { + obsStore, _, cleanup := testObservationStore(t) + defer cleanup() + + ctx := context.Background() + + // Create an observation + obs := &models.ParsedObservation{ + Type: models.ObsTypeDiscovery, + Title: "Test observation", + Narrative: "Some content here", + } + _, _, err := obsStore.StoreObservation(ctx, "session-1", "project-a", obs, 1, 100) + require.NoError(t, err) + + // Search with only stop words (should return nil) + results, err := obsStore.SearchObservationsFTS(ctx, "the a an is are", "project-a", 10) + require.NoError(t, err) + assert.Nil(t, results) + + // Search with empty query + results, err = obsStore.SearchObservationsFTS(ctx, "", "project-a", 10) + require.NoError(t, err) + assert.Nil(t, results) +} + +func TestObservationStore_SearchObservationsFTS_DefaultLimit(t *testing.T) { + obsStore, _, cleanup := testObservationStore(t) + defer cleanup() + + ctx := context.Background() + + // Create observations + for i := 0; i < 15; i++ { + obs := &models.ParsedObservation{ + Type: models.ObsTypeDiscovery, + Title: "Authentication test " + string(rune('A'+i)), + Narrative: "Auth related content", + } + _, _, err := obsStore.StoreObservation(ctx, "session-1", "project-a", obs, i+1, 100) + require.NoError(t, err) + time.Sleep(time.Millisecond) + } + + // Search with limit 0 (should default to 10) + results, err := obsStore.SearchObservationsFTS(ctx, "authentication", "project-a", 0) + require.NoError(t, err) + assert.LessOrEqual(t, len(results), 10) + + // Search with negative limit (should default to 10) + results, err = obsStore.SearchObservationsFTS(ctx, "authentication", "project-a", -5) + require.NoError(t, err) + assert.LessOrEqual(t, len(results), 10) +} + +func TestObservationStore_GetAllRecentObservations(t *testing.T) { + obsStore, _, cleanup := testObservationStore(t) + defer cleanup() + + ctx := context.Background() + + // Create observations across different projects + projects := []string{"project-a", "project-b", "project-c"} + for _, proj := range projects { + for i := 0; i < 3; i++ { + obs := &models.ParsedObservation{ + Type: models.ObsTypeDiscovery, + Title: proj + " observation " + string(rune('A'+i)), + Narrative: "Content for " + proj, + } + _, _, err := obsStore.StoreObservation(ctx, "session-1", proj, obs, i+1, 100) + require.NoError(t, err) + time.Sleep(time.Millisecond) + } + } + + // Get all recent observations + results, err := obsStore.GetAllRecentObservations(ctx, 100) + require.NoError(t, err) + assert.Len(t, results, 9) // 3 projects * 3 observations + + // Verify they are in descending order by epoch + for i := 1; i < len(results); i++ { + assert.GreaterOrEqual(t, results[i-1].CreatedAtEpoch, results[i].CreatedAtEpoch) + } + + // Test with limit + results, err = obsStore.GetAllRecentObservations(ctx, 5) + require.NoError(t, err) + assert.Len(t, results, 5) +} diff --git a/internal/embedding/service_test.go b/internal/embedding/service_test.go new file mode 100644 index 0000000..cd04351 --- /dev/null +++ b/internal/embedding/service_test.go @@ -0,0 +1,332 @@ +package embedding + +import ( + "math" + "sync" + "testing" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +// TestEmbeddingDim verifies the embedding dimension constant. +func TestEmbeddingDim(t *testing.T) { + assert.Equal(t, 384, EmbeddingDim) +} + +// TestNewService tests creating a new embedding service. +func TestNewService(t *testing.T) { + svc, err := NewService() + require.NoError(t, err) + require.NotNil(t, svc) + + defer svc.Close() + + assert.NotNil(t, svc.tk) + assert.NotNil(t, svc.session) +} + +// TestEmbed_SingleText tests embedding a single text. +func TestEmbed_SingleText(t *testing.T) { + svc, err := NewService() + require.NoError(t, err) + defer svc.Close() + + embedding, err := svc.Embed("Hello, world!") + require.NoError(t, err) + + assert.Len(t, embedding, EmbeddingDim) + + // Verify non-zero embedding + var sum float32 + for _, v := range embedding { + sum += v * v + } + assert.Greater(t, sum, float32(0), "Embedding should not be all zeros") +} + +// TestEmbed_EmptyText tests embedding an empty string. +func TestEmbed_EmptyText(t *testing.T) { + svc, err := NewService() + require.NoError(t, err) + defer svc.Close() + + embedding, err := svc.Embed("") + require.NoError(t, err) + + assert.Len(t, embedding, EmbeddingDim) + + // Empty text should return zero vector + for _, v := range embedding { + assert.Equal(t, float32(0), v) + } +} + +// TestEmbed_SimilarTexts tests that similar texts produce similar embeddings. +func TestEmbed_SimilarTexts(t *testing.T) { + svc, err := NewService() + require.NoError(t, err) + defer svc.Close() + + emb1, err := svc.Embed("The quick brown fox jumps over the lazy dog.") + require.NoError(t, err) + + emb2, err := svc.Embed("A fast brown fox leaps over a sleepy dog.") + require.NoError(t, err) + + emb3, err := svc.Embed("Go programming language concurrency patterns.") + require.NoError(t, err) + + // Calculate cosine similarity + sim12 := cosineSimilarity(emb1, emb2) + sim13 := cosineSimilarity(emb1, emb3) + + // Similar texts should have higher similarity + assert.Greater(t, sim12, sim13, "Similar sentences should have higher similarity than dissimilar ones") + assert.Greater(t, sim12, float64(0.7), "Similar sentences should have high similarity") +} + +// TestEmbedBatch_MultipleTexts tests batch embedding. +func TestEmbedBatch_MultipleTexts(t *testing.T) { + svc, err := NewService() + require.NoError(t, err) + defer svc.Close() + + texts := []string{ + "First text about programming.", + "Second text about databases.", + "Third text about machine learning.", + } + + embeddings, err := svc.EmbedBatch(texts) + require.NoError(t, err) + + assert.Len(t, embeddings, len(texts)) + for i, emb := range embeddings { + assert.Len(t, emb, EmbeddingDim, "Embedding %d should have correct dimension", i) + } +} + +// TestEmbedBatch_EmptySlice tests batch embedding with empty slice. +func TestEmbedBatch_EmptySlice(t *testing.T) { + svc, err := NewService() + require.NoError(t, err) + defer svc.Close() + + embeddings, err := svc.EmbedBatch([]string{}) + require.NoError(t, err) + assert.Nil(t, embeddings) +} + +// TestEmbedBatch_WithEmptyTexts tests batch embedding with some empty texts. +func TestEmbedBatch_WithEmptyTexts(t *testing.T) { + svc, err := NewService() + require.NoError(t, err) + defer svc.Close() + + texts := []string{ + "Valid text one.", + "", + "Valid text two.", + "", + } + + embeddings, err := svc.EmbedBatch(texts) + require.NoError(t, err) + + assert.Len(t, embeddings, 4) + + // Non-empty texts should have non-zero embeddings + var sum0 float32 + for _, v := range embeddings[0] { + sum0 += v * v + } + assert.Greater(t, sum0, float32(0)) + + // Empty texts should have zero embeddings + for _, v := range embeddings[1] { + assert.Equal(t, float32(0), v) + } + + var sum2 float32 + for _, v := range embeddings[2] { + sum2 += v * v + } + assert.Greater(t, sum2, float32(0)) + + for _, v := range embeddings[3] { + assert.Equal(t, float32(0), v) + } +} + +// TestEmbedBatch_AllEmpty tests batch embedding with all empty texts. +func TestEmbedBatch_AllEmpty(t *testing.T) { + svc, err := NewService() + require.NoError(t, err) + defer svc.Close() + + texts := []string{"", "", ""} + + embeddings, err := svc.EmbedBatch(texts) + require.NoError(t, err) + + assert.Len(t, embeddings, 3) + for i, emb := range embeddings { + assert.Len(t, emb, EmbeddingDim, "Embedding %d should have correct dimension", i) + for j, v := range emb { + assert.Equal(t, float32(0), v, "Embedding %d[%d] should be zero", i, j) + } + } +} + +// TestEmbed_Concurrent tests concurrent embedding calls. +func TestEmbed_Concurrent(t *testing.T) { + svc, err := NewService() + require.NoError(t, err) + defer svc.Close() + + var wg sync.WaitGroup + numGoroutines := 10 + + errors := make(chan error, numGoroutines) + embeddings := make([][]float32, numGoroutines) + + for i := 0; i < numGoroutines; i++ { + wg.Add(1) + go func(idx int) { + defer wg.Done() + + text := "Test text for concurrent embedding test" + emb, err := svc.Embed(text) + if err != nil { + errors <- err + return + } + embeddings[idx] = emb + }(i) + } + + wg.Wait() + close(errors) + + for err := range errors { + t.Errorf("Concurrent embedding error: %v", err) + } + + // All embeddings should be valid + for i, emb := range embeddings { + if emb != nil { + assert.Len(t, emb, EmbeddingDim, "Embedding %d should have correct dimension", i) + } + } +} + +// TestEmbed_SpecialCharacters tests embedding text with special characters. +func TestEmbed_SpecialCharacters(t *testing.T) { + svc, err := NewService() + require.NoError(t, err) + defer svc.Close() + + texts := []string{ + "Text with unicode: 你好世界 🎉", + "Text with newlines:\nLine 1\nLine 2", + "Text with tabs:\tColumn1\tColumn2", + "Text with quotes: \"quoted\" and 'single'", + "Text with code: func main() { fmt.Println(\"hello\") }", + } + + for _, text := range texts { + t.Run(text[:20], func(t *testing.T) { + emb, err := svc.Embed(text) + require.NoError(t, err) + assert.Len(t, emb, EmbeddingDim) + }) + } +} + +// TestEmbed_LongText tests embedding long text. +func TestEmbed_LongText(t *testing.T) { + svc, err := NewService() + require.NoError(t, err) + defer svc.Close() + + // Create a long text (tokenizer should truncate appropriately) + longText := "" + for i := 0; i < 100; i++ { + longText += "This is a sentence to make the text very long. " + } + + emb, err := svc.Embed(longText) + require.NoError(t, err) + assert.Len(t, emb, EmbeddingDim) +} + +// TestClose tests closing the service. +func TestClose(t *testing.T) { + svc, err := NewService() + require.NoError(t, err) + + err = svc.Close() + require.NoError(t, err) + + // Session should be nil after close + assert.Nil(t, svc.session) +} + +// TestEmbedBatch_SingleItem tests batch embedding with single item. +func TestEmbedBatch_SingleItem(t *testing.T) { + svc, err := NewService() + require.NoError(t, err) + defer svc.Close() + + texts := []string{"Single text for batch embedding."} + + embeddings, err := svc.EmbedBatch(texts) + require.NoError(t, err) + + assert.Len(t, embeddings, 1) + assert.Len(t, embeddings[0], EmbeddingDim) +} + +// TestEmbed_Deterministic tests that embedding is deterministic. +func TestEmbed_Deterministic(t *testing.T) { + svc, err := NewService() + require.NoError(t, err) + defer svc.Close() + + text := "Test text for deterministic embedding." + + emb1, err := svc.Embed(text) + require.NoError(t, err) + + emb2, err := svc.Embed(text) + require.NoError(t, err) + + // Same text should produce same embedding + for i := 0; i < EmbeddingDim; i++ { + assert.Equal(t, emb1[i], emb2[i], "Embedding should be deterministic at index %d", i) + } +} + +// Helper function to calculate cosine similarity +func cosineSimilarity(a, b []float32) float64 { + if len(a) != len(b) { + return 0 + } + + var dotProduct float64 + var normA float64 + var normB float64 + + for i := range a { + dotProduct += float64(a[i]) * float64(b[i]) + normA += float64(a[i]) * float64(a[i]) + normB += float64(b[i]) * float64(b[i]) + } + + if normA == 0 || normB == 0 { + return 0 + } + + return dotProduct / (math.Sqrt(normA) * math.Sqrt(normB)) +} diff --git a/internal/mcp/server_test.go b/internal/mcp/server_test.go index 1b4294d..6afce82 100644 --- a/internal/mcp/server_test.go +++ b/internal/mcp/server_test.go @@ -597,3 +597,363 @@ func TestToolListContainsExpectedSchemas(t *testing.T) { assert.True(t, hasType, "tool %s schema should have type", tool.Name) } } + +// TestHandleToolsCall_UnknownTool tests tools/call with unknown tool name. +func TestHandleToolsCall_UnknownTool(t *testing.T) { + server := NewServer(nil, "1.0.0") + ctx := context.Background() + + req := &Request{ + JSONRPC: "2.0", + ID: 1, + Method: "tools/call", + Params: json.RawMessage(`{"name":"unknown_tool","arguments":{}}`), + } + + resp := server.handleToolsCall(ctx, req) + require.NotNil(t, resp.Error) + assert.Equal(t, -32000, resp.Error.Code) + assert.Contains(t, resp.Error.Data, "unknown tool") +} + +// TestCallTool_ToolNameRecognition tests that valid tool names are recognized (not "unknown tool"). +func TestCallTool_ToolNameRecognition(t *testing.T) { + // Note: This test verifies tool routing logic, not execution (which requires searchMgr) + // All valid tool names should be in the handleToolsList response + server := NewServer(nil, "1.0.0") + + req := &Request{ + JSONRPC: "2.0", + ID: 1, + Method: "tools/list", + } + + resp := server.handleToolsList(req) + result := resp.Result.(map[string]any) + tools := result["tools"].([]Tool) + + // Verify all expected tools are registered + expectedTools := map[string]bool{ + "search": true, + "timeline": true, + "decisions": true, + "changes": true, + "how_it_works": true, + "find_by_concept": true, + "find_by_file": true, + "find_by_type": true, + "get_recent_context": true, + "get_context_timeline": true, + "get_timeline_by_query": true, + } + + foundTools := make(map[string]bool) + for _, tool := range tools { + foundTools[tool.Name] = true + } + + for name := range expectedTools { + assert.True(t, foundTools[name], "tool %s should be registered", name) + } +} + +// TestRun_MultipleRequests tests Run with multiple sequential requests. +func TestRun_MultipleRequests(t *testing.T) { + var stdout bytes.Buffer + req1 := `{"jsonrpc":"2.0","id":1,"method":"initialize"}` + req2 := `{"jsonrpc":"2.0","id":2,"method":"tools/list"}` + stdin := strings.NewReader(req1 + "\n" + req2 + "\n") + + server := &Server{ + stdin: stdin, + stdout: &stdout, + version: "1.0.0", + } + + err := server.Run(context.Background()) + require.NoError(t, err) + + output := stdout.String() + // Should contain responses for both requests + assert.Contains(t, output, `"id":1`) + assert.Contains(t, output, `"id":2`) +} + +// TestHandleTimeline_Defaults tests timeline default values. +func TestHandleTimeline_Defaults(t *testing.T) { + // Test that handleTimeline sets default before/after values + params := TimelineParams{ + AnchorID: 0, + Query: "", + Before: 0, + After: 0, + } + + // Simulate the default value assignment from handleTimeline + if params.Before <= 0 { + params.Before = 10 + } + if params.After <= 0 { + params.After = 10 + } + + assert.Equal(t, 10, params.Before) + assert.Equal(t, 10, params.After) +} + +// TestTimelineParams_Complete tests complete TimelineParams parsing. +func TestTimelineParams_Complete(t *testing.T) { + input := `{ + "anchor_id": 100, + "query": "test query", + "before": 5, + "after": 15, + "project": "my-project", + "obs_type": "bugfix", + "concepts": "security,auth", + "files": "main.go,handler.go", + "dateStart": 1700000000000, + "dateEnd": 1700100000000, + "format": "full" + }` + + var params TimelineParams + err := json.Unmarshal([]byte(input), ¶ms) + require.NoError(t, err) + + assert.Equal(t, int64(100), params.AnchorID) + assert.Equal(t, "test query", params.Query) + assert.Equal(t, 5, params.Before) + assert.Equal(t, 15, params.After) + assert.Equal(t, "my-project", params.Project) + assert.Equal(t, "bugfix", params.ObsType) + assert.Equal(t, "security,auth", params.Concepts) + assert.Equal(t, "main.go,handler.go", params.Files) + assert.Equal(t, int64(1700000000000), params.DateStart) + assert.Equal(t, int64(1700100000000), params.DateEnd) + assert.Equal(t, "full", params.Format) +} + +// TestServerStdinStdoutConfig tests that server stdin/stdout can be configured. +func TestServerStdinStdoutConfig(t *testing.T) { + var stdout bytes.Buffer + var stdin bytes.Buffer + + server := &Server{ + stdin: &stdin, + stdout: &stdout, + version: "test-version", + } + + assert.Equal(t, &stdin, server.stdin) + assert.Equal(t, &stdout, server.stdout) + assert.Equal(t, "test-version", server.version) +} + +// TestResponseIDTypes tests that response IDs can be various types. +func TestResponseIDTypes(t *testing.T) { + tests := []struct { + name string + id any + }{ + {"integer id", 1}, + {"string id", "abc-123"}, + {"float id", 1.5}, + {"null id", nil}, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + var buf bytes.Buffer + server := &Server{stdout: &buf} + + resp := &Response{ + JSONRPC: "2.0", + ID: tt.id, + Result: "ok", + } + + server.sendResponse(resp) + output := buf.String() + assert.Contains(t, output, `"jsonrpc":"2.0"`) + }) + } +} + +// TestHandleTimelineByQuery_EmptyQuery tests timeline by query with empty query. +func TestHandleTimelineByQuery_EmptyQuery(t *testing.T) { + server := NewServer(nil, "1.0.0") + ctx := context.Background() + + // Empty query should error + _, err := server.handleTimelineByQuery(ctx, json.RawMessage(`{}`)) + require.Error(t, err) + assert.Contains(t, err.Error(), "query is required") +} + +// TestHandleTimeline_InvalidJSON tests timeline with invalid JSON. +func TestHandleTimeline_InvalidJSON(t *testing.T) { + server := NewServer(nil, "1.0.0") + ctx := context.Background() + + _, err := server.handleTimeline(ctx, json.RawMessage(`{invalid`)) + require.Error(t, err) + assert.Contains(t, err.Error(), "invalid timeline params") +} + +// TestHandleTimelineByQuery_InvalidJSON tests timeline by query with invalid JSON. +func TestHandleTimelineByQuery_InvalidJSON(t *testing.T) { + server := NewServer(nil, "1.0.0") + ctx := context.Background() + + _, err := server.handleTimelineByQuery(ctx, json.RawMessage(`{invalid`)) + require.Error(t, err) + assert.Contains(t, err.Error(), "invalid timeline params") +} + +// TestHandleTimeline_NoAnchorNoQuery tests timeline with no anchor and no query. +func TestHandleTimeline_NoAnchorNoQuery(t *testing.T) { + server := NewServer(nil, "1.0.0") + ctx := context.Background() + + // No anchor_id and no query should return empty result + result, err := server.handleTimeline(ctx, json.RawMessage(`{}`)) + require.NoError(t, err) + assert.NotNil(t, result) + assert.Empty(t, result.Results) +} + +// TestHandleTimeline_WithDefaults tests timeline default values are applied. +func TestHandleTimeline_WithDefaults(t *testing.T) { + server := NewServer(nil, "1.0.0") + ctx := context.Background() + + // With anchor_id but no before/after, defaults should be applied + // However, since searchMgr is nil, this will fail after defaults are applied + result, err := server.handleTimeline(ctx, json.RawMessage(`{"anchor_id": 0}`)) + // Should return empty result since anchor_id is 0 + require.NoError(t, err) + assert.NotNil(t, result) + assert.Empty(t, result.Results) +} + +// TestServerFields tests Server struct fields. +func TestServerFields(t *testing.T) { + server := NewServer(nil, "2.0.0") + + assert.Equal(t, "2.0.0", server.version) + assert.Nil(t, server.searchMgr) + assert.NotNil(t, server.stdin) + assert.NotNil(t, server.stdout) +} + +// TestRequestUnmarshalWithNullID tests Request unmarshaling with null ID. +func TestRequestUnmarshalWithNullID(t *testing.T) { + input := `{"jsonrpc":"2.0","id":null,"method":"initialize"}` + + var req Request + err := json.Unmarshal([]byte(input), &req) + require.NoError(t, err) + assert.Equal(t, "2.0", req.JSONRPC) + assert.Nil(t, req.ID) + assert.Equal(t, "initialize", req.Method) +} + +// TestResponseWithNullError tests Response without error. +func TestResponseWithNullError(t *testing.T) { + resp := Response{ + JSONRPC: "2.0", + ID: 1, + Result: "success", + Error: nil, + } + + data, err := json.Marshal(resp) + require.NoError(t, err) + assert.Contains(t, string(data), `"result":"success"`) + assert.NotContains(t, string(data), `"error"`) +} + +// TestErrorWithNilData tests Error without data. +func TestErrorWithNilData(t *testing.T) { + err := Error{ + Code: -32600, + Message: "Invalid Request", + Data: nil, + } + + data, errMarshal := json.Marshal(err) + require.NoError(t, errMarshal) + assert.Contains(t, string(data), `"code":-32600`) + assert.Contains(t, string(data), `"message":"Invalid Request"`) + assert.NotContains(t, string(data), `"data"`) +} + +// TestToolInputSchema tests that tool input schemas have required fields. +func TestToolInputSchema(t *testing.T) { + server := NewServer(nil, "1.0.0") + + req := &Request{ + JSONRPC: "2.0", + ID: 1, + Method: "tools/list", + } + + resp := server.handleToolsList(req) + result := resp.Result.(map[string]any) + tools := result["tools"].([]Tool) + + for _, tool := range tools { + schema := tool.InputSchema + schemaType, ok := schema["type"] + assert.True(t, ok, "tool %s schema should have type", tool.Name) + assert.Equal(t, "object", schemaType, "tool %s schema type should be object", tool.Name) + + // All tools should have properties + _, hasProperties := schema["properties"] + assert.True(t, hasProperties, "tool %s should have properties", tool.Name) + } +} + +// TestRunMixedRequests tests Run with mixed valid and invalid requests. +func TestRunMixedRequests(t *testing.T) { + var stdout bytes.Buffer + req1 := `{"jsonrpc":"2.0","id":1,"method":"initialize"}` + req2 := `invalid json` + req3 := `{"jsonrpc":"2.0","id":3,"method":"tools/list"}` + stdin := strings.NewReader(req1 + "\n" + req2 + "\n" + req3 + "\n") + + server := &Server{ + stdin: stdin, + stdout: &stdout, + version: "1.0.0", + } + + err := server.Run(context.Background()) + require.NoError(t, err) + + output := stdout.String() + // Should have responses for all three requests + assert.Contains(t, output, `"id":1`) + assert.Contains(t, output, `"error"`) // Parse error for invalid json + assert.Contains(t, output, `"id":3`) +} + +// TestToolCallParamsWithComplexArgs tests ToolCallParams with complex arguments. +func TestToolCallParamsWithComplexArgs(t *testing.T) { + input := `{ + "name": "search", + "arguments": { + "query": "authentication bug", + "project": "my-project", + "limit": 10, + "type": "observations" + } + }` + + var params ToolCallParams + err := json.Unmarshal([]byte(input), ¶ms) + require.NoError(t, err) + assert.Equal(t, "search", params.Name) + assert.NotEmpty(t, params.Arguments) +} diff --git a/internal/search/integration_test.go b/internal/search/integration_test.go new file mode 100644 index 0000000..a1e0358 --- /dev/null +++ b/internal/search/integration_test.go @@ -0,0 +1,646 @@ +// Package search provides unified search capabilities for claude-mnemonic. +package search + +import ( + "context" + "database/sql" + "os" + "testing" + + "github.com/lukaszraczylo/claude-mnemonic/internal/db/sqlite" + "github.com/lukaszraczylo/claude-mnemonic/pkg/models" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + "github.com/stretchr/testify/suite" + + // Import sqlite driver + _ "github.com/mattn/go-sqlite3" +) + +// Ensure context is used (for later tests) +var _ = context.Background + +// hasFTS5 checks if FTS5 is available in the SQLite build. +func hasFTS5(t *testing.T) bool { + t.Helper() + + tmpDir, err := os.MkdirTemp("", "fts5-check-*") + if err != nil { + return false + } + defer func() { _ = os.RemoveAll(tmpDir) }() + + dbPath := tmpDir + "/check.db" + db, err := sql.Open("sqlite3", dbPath) + if err != nil { + return false + } + defer func() { _ = db.Close() }() + + _, err = db.Exec("CREATE VIRTUAL TABLE IF NOT EXISTS fts5_test USING fts5(content)") + if err != nil { + return false + } + _, _ = db.Exec("DROP TABLE IF EXISTS fts5_test") + return true +} + +// testStore creates a sqlite.Store with a temporary database for testing. +func testStore(t *testing.T) (*sqlite.Store, func()) { + t.Helper() + + if !hasFTS5(t) { + t.Skip("FTS5 not available in this SQLite build") + } + + tmpDir, err := os.MkdirTemp("", "search-integration-test-*") + require.NoError(t, err) + + dbPath := tmpDir + "/test.db" + + store, err := sqlite.NewStore(sqlite.StoreConfig{ + Path: dbPath, + MaxConns: 1, + WALMode: true, + }) + require.NoError(t, err) + + cleanup := func() { + _ = store.Close() + _ = os.RemoveAll(tmpDir) + } + + return store, cleanup +} + +// SearchIntegrationSuite tests search with real SQLite stores. +type SearchIntegrationSuite struct { + suite.Suite + store *sqlite.Store + cleanup func() + manager *Manager + obsStore *sqlite.ObservationStore + sumStore *sqlite.SummaryStore + prmStore *sqlite.PromptStore +} + +func (s *SearchIntegrationSuite) SetupTest() { + if !hasFTS5(s.T()) { + s.T().Skip("FTS5 not available in this SQLite build") + } + + s.store, s.cleanup = testStore(s.T()) + + // Create real stores backed by SQLite + s.obsStore = sqlite.NewObservationStore(s.store) + s.sumStore = sqlite.NewSummaryStore(s.store) + s.prmStore = sqlite.NewPromptStore(s.store) + + // Create search manager with real stores (no vector client for now) + s.manager = NewManager(s.obsStore, s.sumStore, s.prmStore, nil) +} + +func (s *SearchIntegrationSuite) TearDownTest() { + if s.cleanup != nil { + s.cleanup() + } +} + +func TestSearchIntegrationSuite(t *testing.T) { + suite.Run(t, new(SearchIntegrationSuite)) +} + +// seedObservations inserts test observations into the database. +func (s *SearchIntegrationSuite) seedObservations(ctx context.Context) []int64 { + var ids []int64 + + // Observation 1: Authentication bug fix + obs1 := &models.ParsedObservation{ + Type: models.ObsTypeBugfix, + Scope: models.ScopeProject, + Title: "Fixed Authentication Bug", + Narrative: "Resolved JWT token validation issue that caused intermittent login failures", + Concepts: []string{"authentication", "jwt", "security"}, + FilesRead: []string{"auth/handler.go", "auth/jwt.go"}, + } + id1, _, err := s.obsStore.StoreObservation(ctx, "sdk-sess-1", "test-project", obs1, 1, 100) + s.Require().NoError(err) + ids = append(ids, id1) + + // Observation 2: Database optimization decision + obs2 := &models.ParsedObservation{ + Type: models.ObsTypeDecision, + Scope: models.ScopeProject, + Title: "Database Query Optimization Decision", + Narrative: "Decided to add indexes on user_id and created_at columns for better performance", + Concepts: []string{"database", "performance", "decision"}, + FilesRead: []string{"db/migrations/001.sql"}, + } + id2, _, err := s.obsStore.StoreObservation(ctx, "sdk-sess-1", "test-project", obs2, 2, 150) + s.Require().NoError(err) + ids = append(ids, id2) + + // Observation 3: Global best practice + obs3 := &models.ParsedObservation{ + Type: models.ObsTypeDiscovery, + Scope: models.ScopeGlobal, + Title: "Error Handling Best Practice", + Narrative: "Use wrapped errors with context for better debugging: errors.Wrap(err, context)", + Concepts: []string{"best-practice", "errors", "patterns"}, + FilesRead: []string{"pkg/errors/errors.go"}, + } + id3, _, err := s.obsStore.StoreObservation(ctx, "sdk-sess-2", "other-project", obs3, 1, 80) + s.Require().NoError(err) + ids = append(ids, id3) + + // Observation 4: Code change/refactoring + obs4 := &models.ParsedObservation{ + Type: models.ObsTypeChange, + Scope: models.ScopeProject, + Title: "Refactored User Service", + Narrative: "Changed user service to use repository pattern, modified interfaces for better testability", + Concepts: []string{"refactoring", "architecture"}, + FilesModified: []string{"services/user.go", "services/user_test.go"}, + } + id4, _, err := s.obsStore.StoreObservation(ctx, "sdk-sess-1", "test-project", obs4, 3, 200) + s.Require().NoError(err) + ids = append(ids, id4) + + return ids +} + +// seedSummaries inserts test session summaries into the database. +func (s *SearchIntegrationSuite) seedSummaries(ctx context.Context) []int64 { + var ids []int64 + + // Summary 1 + sum1 := &models.ParsedSummary{ + Request: "Fix authentication bug", + Investigated: "JWT token validation and session handling", + Learned: "JWT validation requires algorithm check to prevent alg:none attacks", + Completed: "Fixed JWT validation, added tests", + NextSteps: "Review other security endpoints", + } + id1, _, err := s.sumStore.StoreSummary(ctx, "sdk-sess-1", "test-project", sum1, 1, 100) + s.Require().NoError(err) + ids = append(ids, id1) + + // Summary 2 + sum2 := &models.ParsedSummary{ + Request: "Optimize database queries", + Investigated: "Query execution plans and index usage", + Learned: "Composite indexes work better for range queries", + Completed: "Added indexes, verified performance improvement", + NextSteps: "Monitor query times in production", + } + id2, _, err := s.sumStore.StoreSummary(ctx, "sdk-sess-1", "test-project", sum2, 2, 150) + s.Require().NoError(err) + ids = append(ids, id2) + + return ids +} + +// TestFilterSearch_WithRealStores tests filterSearch with seeded data. +func (s *SearchIntegrationSuite) TestFilterSearch_WithRealStores() { + ctx := context.Background() + + // Seed test data + obsIDs := s.seedObservations(ctx) + sumIDs := s.seedSummaries(ctx) + s.Require().Len(obsIDs, 4) + s.Require().Len(sumIDs, 2) + + // Test filter search for observations only + result, err := s.manager.filterSearch(ctx, SearchParams{ + Project: "test-project", + Type: "observations", + Limit: 10, + Format: "full", + }) + s.Require().NoError(err) + s.NotNil(result) + + // Should return project observations + global observation (4 total: 3 project + 1 global) + s.GreaterOrEqual(len(result.Results), 3) + + // Verify result types + for _, r := range result.Results { + s.Equal("observation", r.Type) + } +} + +// TestFilterSearch_SessionsOnly tests filterSearch for sessions. +func (s *SearchIntegrationSuite) TestFilterSearch_SessionsOnly() { + ctx := context.Background() + + // Seed test data + _ = s.seedObservations(ctx) + sumIDs := s.seedSummaries(ctx) + s.Require().Len(sumIDs, 2) + + // Test filter search for sessions only + result, err := s.manager.filterSearch(ctx, SearchParams{ + Project: "test-project", + Type: "sessions", + Limit: 10, + Format: "full", + }) + s.Require().NoError(err) + s.NotNil(result) + + // Should return 2 summaries + s.Len(result.Results, 2) + + // Verify result types + for _, r := range result.Results { + s.Equal("session", r.Type) + s.NotEmpty(r.Title) // Title should be populated from Request + } +} + +// TestFilterSearch_AllTypes tests filterSearch for all types. +func (s *SearchIntegrationSuite) TestFilterSearch_AllTypes() { + ctx := context.Background() + + // Seed test data + obsIDs := s.seedObservations(ctx) + sumIDs := s.seedSummaries(ctx) + s.Require().Len(obsIDs, 4) + s.Require().Len(sumIDs, 2) + + // Test filter search for all types (Type = "") + result, err := s.manager.filterSearch(ctx, SearchParams{ + Project: "test-project", + Type: "", // All types + Limit: 20, + Format: "full", + }) + s.Require().NoError(err) + s.NotNil(result) + + // Should return both observations and sessions + hasObservations := false + hasSessions := false + for _, r := range result.Results { + if r.Type == "observation" { + hasObservations = true + } + if r.Type == "session" { + hasSessions = true + } + } + s.True(hasObservations, "Should have observation results") + s.True(hasSessions, "Should have session results") +} + +// TestUnifiedSearch_DefaultLimit tests UnifiedSearch with default limit. +func (s *SearchIntegrationSuite) TestUnifiedSearch_DefaultLimit() { + ctx := context.Background() + + // Seed test data + s.seedObservations(ctx) + s.seedSummaries(ctx) + + // Test with no limit specified (should default to 20) + result, err := s.manager.UnifiedSearch(ctx, SearchParams{ + Project: "test-project", + }) + s.Require().NoError(err) + s.NotNil(result) + s.LessOrEqual(len(result.Results), 20) +} + +// TestUnifiedSearch_LimitCapping tests UnifiedSearch limit capping. +func (s *SearchIntegrationSuite) TestUnifiedSearch_LimitCapping() { + ctx := context.Background() + + // Seed test data + s.seedObservations(ctx) + s.seedSummaries(ctx) + + // Test with limit > 100 (should be capped to 100) + result, err := s.manager.UnifiedSearch(ctx, SearchParams{ + Project: "test-project", + Limit: 500, + }) + s.Require().NoError(err) + s.NotNil(result) + s.LessOrEqual(len(result.Results), 100) +} + +// TestDecisions_WithRealStores tests the Decisions method falls back to filterSearch. +func (s *SearchIntegrationSuite) TestDecisions_WithRealStores() { + ctx := context.Background() + + // Seed test data + s.seedObservations(ctx) + + // Test Decisions search (without vector client, falls back to filterSearch) + result, err := s.manager.Decisions(ctx, SearchParams{ + Project: "test-project", + Query: "database", + Limit: 10, + }) + s.Require().NoError(err) + s.NotNil(result) + + // Without vector client, falls back to filterSearch which returns observations + // All results should be observations (type is forced to "observations" in Decisions) + for _, r := range result.Results { + s.Equal("observation", r.Type) + } +} + +// TestChanges_WithRealStores tests the Changes method falls back to filterSearch. +func (s *SearchIntegrationSuite) TestChanges_WithRealStores() { + ctx := context.Background() + + // Seed test data + s.seedObservations(ctx) + + // Test Changes search (without vector client, falls back to filterSearch) + result, err := s.manager.Changes(ctx, SearchParams{ + Project: "test-project", + Query: "user service", + Limit: 10, + }) + s.Require().NoError(err) + s.NotNil(result) + + // Without vector client, falls back to filterSearch which returns observations + // All results should be observations (type is forced to "observations" in Changes) + for _, r := range result.Results { + s.Equal("observation", r.Type) + } +} + +// TestHowItWorks_WithRealStores tests the HowItWorks method falls back to filterSearch. +func (s *SearchIntegrationSuite) TestHowItWorks_WithRealStores() { + ctx := context.Background() + + // Seed test data + s.seedObservations(ctx) + + // Test HowItWorks search (without vector client, falls back to filterSearch) + result, err := s.manager.HowItWorks(ctx, SearchParams{ + Project: "test-project", + Query: "authentication", + Limit: 10, + }) + s.Require().NoError(err) + s.NotNil(result) + + // Without vector client, falls back to filterSearch which returns observations + // All results should be observations (type is forced to "observations" in HowItWorks) + for _, r := range result.Results { + s.Equal("observation", r.Type) + } +} + +// TestObservationToResult tests observation to result conversion with full format. +func (s *SearchIntegrationSuite) TestObservationToResult_FullFormat() { + ctx := context.Background() + + // Insert single observation + obs := &models.ParsedObservation{ + Type: models.ObsTypeDiscovery, + Scope: models.ScopeProject, + Title: "Test Title", + Narrative: "Detailed narrative content for testing", + Concepts: []string{"testing", "content"}, + } + id, _, err := s.obsStore.StoreObservation(ctx, "sdk-test", "test-project", obs, 1, 50) + s.Require().NoError(err) + + // Retrieve and convert + retrieved, err := s.obsStore.GetObservationByID(ctx, id) + s.Require().NoError(err) + s.Require().NotNil(retrieved) + + result := s.manager.observationToResult(retrieved, "full") + + s.Equal("observation", result.Type) + s.Equal(id, result.ID) + s.Equal("Test Title", result.Title) + s.Equal("Detailed narrative content for testing", result.Content) + s.Equal("test-project", result.Project) + s.Equal("project", result.Scope) + s.NotNil(result.Metadata) + s.Equal("discovery", result.Metadata["obs_type"]) +} + +// TestObservationToResult_IndexFormat tests index format (no content). +func (s *SearchIntegrationSuite) TestObservationToResult_IndexFormat() { + ctx := context.Background() + + obs := &models.ParsedObservation{ + Type: models.ObsTypeBugfix, + Scope: models.ScopeGlobal, + Title: "Bug Fix Title", + Narrative: "This should not appear in index format", + } + id, _, err := s.obsStore.StoreObservation(ctx, "sdk-test", "test-project", obs, 1, 50) + s.Require().NoError(err) + + retrieved, err := s.obsStore.GetObservationByID(ctx, id) + s.Require().NoError(err) + + result := s.manager.observationToResult(retrieved, "index") + + s.Equal("observation", result.Type) + s.Equal("Bug Fix Title", result.Title) + s.Empty(result.Content, "Index format should not include content") + s.Equal("global", result.Scope) +} + +// TestSummaryToResult_FullFormat tests summary to result conversion. +func (s *SearchIntegrationSuite) TestSummaryToResult_FullFormat() { + ctx := context.Background() + + sum := &models.ParsedSummary{ + Request: "Implement new feature", + Learned: "Learned important lessons about testing", + } + id, _, err := s.sumStore.StoreSummary(ctx, "sdk-test", "test-project", sum, 1, 50) + s.Require().NoError(err) + + // Retrieve via GetRecentSummaries since there's no GetByID + summaries, err := s.sumStore.GetRecentSummaries(ctx, "test-project", 10) + s.Require().NoError(err) + s.Require().NotEmpty(summaries) + + var retrieved *models.SessionSummary + for _, s := range summaries { + if s.ID == id { + retrieved = s + break + } + } + s.Require().NotNil(retrieved) + + result := s.manager.summaryToResult(retrieved, "full") + + s.Equal("session", result.Type) + s.Equal(id, result.ID) + s.Contains(result.Title, "Implement new feature") + s.Equal("Learned important lessons about testing", result.Content) + s.Equal("test-project", result.Project) +} + +// TestPromptToResult_FullFormat tests prompt to result conversion. +func (s *SearchIntegrationSuite) TestPromptToResult_FullFormat() { + // First create a session + ctx := context.Background() + sessionStore := sqlite.NewSessionStore(s.store) + _, err := sessionStore.CreateSDKSession(ctx, "sdk-prompt-test", "test-project", "initial prompt") + s.Require().NoError(err) + + // Save a user prompt + promptID, err := s.prmStore.SaveUserPromptWithMatches(ctx, "sdk-prompt-test", 1, "Help me fix this authentication bug", 3) + s.Require().NoError(err) + + // Retrieve prompts + prompts, err := s.prmStore.GetPromptsByIDs(ctx, []int64{promptID}, "date_desc", 10) + s.Require().NoError(err) + s.Require().NotEmpty(prompts) + + result := s.manager.promptToResult(prompts[0], "full") + + s.Equal("prompt", result.Type) + s.Equal(promptID, result.ID) + s.Contains(result.Title, "Help me fix") + s.Equal("Help me fix this authentication bug", result.Content) +} + +// TestTruncate_TableDriven tests truncation with various inputs. +func TestTruncate_TableDriven(t *testing.T) { + tests := []struct { + name string + input string + maxLen int + expected string + }{ + {"short_string", "hello", 10, "hello"}, + {"exact_length", "hello", 5, "hello"}, + {"long_string", "hello world", 5, "hello..."}, + {"empty_string", "", 10, ""}, + {"whitespace_only", " ", 10, ""}, + {"with_leading_space", " hello ", 10, "hello"}, + {"very_long", "this is a very long string that should be truncated", 20, "this is a very long ..."}, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + result := truncate(tt.input, tt.maxLen) + assert.Equal(t, tt.expected, result) + }) + } +} + +// TestManagerWithNilStores tests that Manager handles nil stores gracefully. +func TestManagerWithNilStores(t *testing.T) { + m := NewManager(nil, nil, nil, nil) + assert.NotNil(t, m) + assert.Nil(t, m.observationStore) + assert.Nil(t, m.summaryStore) + assert.Nil(t, m.promptStore) + assert.Nil(t, m.vectorClient) +} + +// TestSearchResultMetadataFields tests all metadata fields with real data. +func (s *SearchIntegrationSuite) TestSearchResultMetadataFields() { + ctx := context.Background() + + obs := &models.ParsedObservation{ + Type: models.ObsTypeDecision, + Scope: models.ScopeGlobal, + Title: "Architecture Decision", + Concepts: []string{"auth", "security"}, + FilesRead: []string{"handler.go", "auth.go"}, + } + id, _, err := s.obsStore.StoreObservation(ctx, "sdk-meta-test", "test-project", obs, 1, 50) + s.Require().NoError(err) + + retrieved, err := s.obsStore.GetObservationByID(ctx, id) + s.Require().NoError(err) + + result := s.manager.observationToResult(retrieved, "full") + + // Check metadata fields + s.NotNil(result.Metadata) + s.Equal("decision", result.Metadata["obs_type"]) + s.Equal("global", result.Metadata["scope"]) + s.Equal("global", result.Scope) +} + +// TestObservationToResult_AllFormats tests different format options. +func TestObservationToResult_AllFormats(t *testing.T) { + m := NewManager(nil, nil, nil, nil) + + obs := &models.Observation{ + ID: 1, + Project: "test", + Type: models.ObsTypeBugfix, + Scope: models.ScopeProject, + Title: sql.NullString{String: "Bug Fix Title", Valid: true}, + Narrative: sql.NullString{String: "Detailed bug fix narrative", Valid: true}, + CreatedAtEpoch: 1704067200000, + } + + tests := []struct { + name string + format string + expectContent bool + }{ + {"full_format", "full", true}, + {"index_format", "index", false}, + {"empty_format", "", false}, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + result := m.observationToResult(obs, tt.format) + assert.Equal(t, "observation", result.Type) + assert.Equal(t, int64(1), result.ID) + if tt.expectContent { + assert.NotEmpty(t, result.Content) + } + }) + } +} + +// TestSummaryToResult_AllFormats tests different format options for summaries. +func TestSummaryToResult_AllFormats(t *testing.T) { + m := NewManager(nil, nil, nil, nil) + + summary := &models.SessionSummary{ + ID: 1, + Project: "test", + Request: sql.NullString{String: "Test request", Valid: true}, + Learned: sql.NullString{String: "Test learned", Valid: true}, + Completed: sql.NullString{String: "Test completed", Valid: true}, + NextSteps: sql.NullString{String: "Test next steps", Valid: true}, + CreatedAtEpoch: 1704067200000, + } + + tests := []struct { + name string + format string + expectContent bool + }{ + {"full_format", "full", true}, + {"index_format", "index", false}, + {"empty_format", "", false}, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + result := m.summaryToResult(summary, tt.format) + assert.Equal(t, "session", result.Type) + assert.Equal(t, int64(1), result.ID) + if tt.expectContent { + assert.NotEmpty(t, result.Content) + } + }) + } +} diff --git a/internal/search/manager_test.go b/internal/search/manager_test.go index 427469e..bdd077b 100644 --- a/internal/search/manager_test.go +++ b/internal/search/manager_test.go @@ -4,6 +4,7 @@ package search import ( "database/sql" "testing" + "time" "github.com/lukaszraczylo/claude-mnemonic/pkg/models" "github.com/stretchr/testify/assert" @@ -594,3 +595,482 @@ func TestSearchTypeMapping(t *testing.T) { }) } } + +// TestFilterSearchWithObservations tests filter search when observations exist. +func TestFilterSearchWithObservations(t *testing.T) { + // Create mock observation + obs := &models.Observation{ + ID: 1, + Project: "test-project", + Type: models.ObsTypeDiscovery, + Scope: models.ScopeProject, + Title: sql.NullString{String: "Test Title", Valid: true}, + Narrative: sql.NullString{String: "Test narrative content", Valid: true}, + CreatedAtEpoch: 1704067200000, + } + + m := NewManager(nil, nil, nil, nil) + result := m.observationToResult(obs, "full") + + assert.Equal(t, "observation", result.Type) + assert.Equal(t, int64(1), result.ID) + assert.Equal(t, "Test Title", result.Title) + assert.Equal(t, "Test narrative content", result.Content) + assert.Equal(t, "test-project", result.Project) + assert.Equal(t, "project", result.Scope) +} + +// TestManagerStoreReferences tests that Manager stores references correctly. +func TestManagerStoreReferences(t *testing.T) { + m := NewManager(nil, nil, nil, nil) + + assert.Nil(t, m.observationStore) + assert.Nil(t, m.summaryStore) + assert.Nil(t, m.promptStore) + assert.Nil(t, m.vectorClient) +} + +// TestObservationToResultWithMetadata tests metadata inclusion in results. +func TestObservationToResultWithMetadata(t *testing.T) { + m := NewManager(nil, nil, nil, nil) + + tests := []struct { + name string + obsType models.ObservationType + scope models.ObservationScope + }{ + {"bugfix_project", models.ObsTypeBugfix, models.ScopeProject}, + {"feature_global", models.ObsTypeFeature, models.ScopeGlobal}, + {"discovery_project", models.ObsTypeDiscovery, models.ScopeProject}, + {"decision_global", models.ObsTypeDecision, models.ScopeGlobal}, + {"refactor_project", models.ObsTypeRefactor, models.ScopeProject}, + {"change_global", models.ObsTypeChange, models.ScopeGlobal}, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + obs := &models.Observation{ + ID: 1, + Project: "test-project", + Type: tt.obsType, + Scope: tt.scope, + Title: sql.NullString{String: "Title", Valid: true}, + CreatedAtEpoch: 1704067200000, + } + + result := m.observationToResult(obs, "full") + + assert.Equal(t, string(tt.obsType), result.Metadata["obs_type"]) + assert.Equal(t, string(tt.scope), result.Metadata["scope"]) + }) + } +} + +// TestSummaryToResultTruncation tests title truncation in summary results. +func TestSummaryToResultTruncation(t *testing.T) { + m := NewManager(nil, nil, nil, nil) + + tests := []struct { + name string + request string + expectedLen int + shouldTrunc bool + }{ + {"short_title", "Short request", 13, false}, + {"exact_100", string(make([]byte, 100)), 103, true}, // 100 + "..." + {"over_100", string(make([]byte, 150)), 103, true}, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + summary := &models.SessionSummary{ + ID: 1, + Project: "test-project", + Request: sql.NullString{String: tt.request, Valid: true}, + CreatedAtEpoch: 1704067200000, + } + + result := m.summaryToResult(summary, "full") + + if tt.shouldTrunc { + assert.LessOrEqual(t, len(result.Title), tt.expectedLen) + assert.True(t, len(result.Title) <= 103) // max 100 + "..." + } else { + assert.Equal(t, tt.request, result.Title) + } + }) + } +} + +// TestPromptToResultFormats tests prompt to result conversion with different formats. +func TestPromptToResultFormats(t *testing.T) { + m := NewManager(nil, nil, nil, nil) + + prompt := &models.UserPromptWithSession{ + UserPrompt: models.UserPrompt{ + ID: 123, + PromptText: "What is the meaning of life?", + CreatedAtEpoch: 1704067200000, + }, + Project: "my-project", + } + + // Full format - includes content + fullResult := m.promptToResult(prompt, "full") + assert.Equal(t, "What is the meaning of life?", fullResult.Content) + + // Index format - no content + indexResult := m.promptToResult(prompt, "index") + assert.Equal(t, "", indexResult.Content) + + // Both should have same title + assert.Equal(t, fullResult.Title, indexResult.Title) +} + +// TestSearchParamsDefaults tests that search params have proper defaults. +func TestSearchParamsDefaults(t *testing.T) { + tests := []struct { + name string + initialLimit int + initialOrder string + expectedLimit int + expectedOrder string + }{ + {"zero_limit", 0, "", 20, "date_desc"}, + {"negative_limit", -5, "", 20, "date_desc"}, + {"over_100_limit", 150, "", 100, "date_desc"}, + {"valid_limit_50", 50, "relevance", 50, "relevance"}, + {"custom_order", 30, "date_asc", 30, "date_asc"}, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + params := SearchParams{ + Query: "test", + Project: "project", + Limit: tt.initialLimit, + OrderBy: tt.initialOrder, + } + + // Simulate the normalization that happens in UnifiedSearch + if params.Limit <= 0 { + params.Limit = 20 + } + if params.Limit > 100 { + params.Limit = 100 + } + if params.OrderBy == "" { + params.OrderBy = "date_desc" + } + + assert.Equal(t, tt.expectedLimit, params.Limit) + assert.Equal(t, tt.expectedOrder, params.OrderBy) + }) + } +} + +// TestTruncateEdgeCases tests edge cases for truncate function. +func TestTruncateEdgeCases(t *testing.T) { + tests := []struct { + name string + input string + maxLen int + expected string + }{ + // Unicode strings - uses byte length so ensure maxLen accommodates full string + {"unicode_string_no_truncate", "日本語テスト", 20, "日本語テスト"}, + {"mixed_unicode_no_truncate", "Hello世界", 15, "Hello世界"}, + // ASCII truncation + {"ascii_truncate", "Hello World", 5, "Hello..."}, + {"only_whitespace", " ", 10, ""}, + {"tabs_and_newlines", "\t\n \t", 10, ""}, + {"newlines_with_content", "\n\nhello\n\n", 10, "hello"}, + {"zero_max_len", "hello", 0, "..."}, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + result := truncate(tt.input, tt.maxLen) + assert.Equal(t, tt.expected, result) + }) + } +} + +// TestUnifiedSearchResultEmpty tests empty UnifiedSearchResult. +func TestUnifiedSearchResultEmpty(t *testing.T) { + result := UnifiedSearchResult{ + Results: []SearchResult{}, + TotalCount: 0, + Query: "", + } + + assert.Empty(t, result.Results) + assert.Equal(t, 0, result.TotalCount) + assert.Equal(t, "", result.Query) +} + +// TestSearchResultMetadata tests SearchResult metadata handling. +func TestSearchResultMetadata(t *testing.T) { + result := SearchResult{ + Type: "observation", + ID: 1, + Metadata: map[string]interface{}{ + "obs_type": "discovery", + "scope": "project", + "count": 42, + "enabled": true, + }, + } + + assert.Equal(t, "discovery", result.Metadata["obs_type"]) + assert.Equal(t, "project", result.Metadata["scope"]) + assert.Equal(t, 42, result.Metadata["count"]) + assert.Equal(t, true, result.Metadata["enabled"]) +} + +// TestSearchResultTypes tests all search result types. +func TestSearchResultTypes(t *testing.T) { + types := []string{"observation", "session", "prompt"} + + for _, typ := range types { + t.Run(typ, func(t *testing.T) { + result := SearchResult{ + Type: typ, + ID: 1, + Project: "test", + CreatedAt: time.Now().UnixMilli(), + } + assert.Equal(t, typ, result.Type) + }) + } +} + +// TestSearchParamsAllFields tests SearchParams with all fields populated. +func TestSearchParamsAllFields(t *testing.T) { + params := SearchParams{ + Query: "authentication bug", + Type: "observations", + Project: "my-project", + ObsType: "bugfix", + Concepts: "security,auth", + Files: "handler.go,auth.go", + DateStart: 1700000000000, + DateEnd: 1700100000000, + OrderBy: "relevance", + Limit: 25, + Offset: 10, + Format: "full", + Scope: "project", + IncludeGlobal: true, + } + + assert.Equal(t, "authentication bug", params.Query) + assert.Equal(t, "observations", params.Type) + assert.Equal(t, "my-project", params.Project) + assert.Equal(t, "bugfix", params.ObsType) + assert.Equal(t, "security,auth", params.Concepts) + assert.Equal(t, "handler.go,auth.go", params.Files) + assert.Equal(t, int64(1700000000000), params.DateStart) + assert.Equal(t, int64(1700100000000), params.DateEnd) + assert.Equal(t, "relevance", params.OrderBy) + assert.Equal(t, 25, params.Limit) + assert.Equal(t, 10, params.Offset) + assert.Equal(t, "full", params.Format) + assert.Equal(t, "project", params.Scope) + assert.True(t, params.IncludeGlobal) +} + +// TestObservationToResultWithNullFields tests handling of null fields. +func TestObservationToResultWithNullFields(t *testing.T) { + m := NewManager(nil, nil, nil, nil) + + obs := &models.Observation{ + ID: 1, + Project: "test-project", + Type: models.ObsTypeDiscovery, + Scope: models.ScopeProject, + Title: sql.NullString{Valid: false}, + Narrative: sql.NullString{Valid: false}, + CreatedAtEpoch: 1704067200000, + } + + result := m.observationToResult(obs, "full") + + assert.Equal(t, "", result.Title) + assert.Equal(t, "", result.Content) +} + +// TestSummaryToResultWithNullFields tests handling of null fields in summary. +func TestSummaryToResultWithNullFields(t *testing.T) { + m := NewManager(nil, nil, nil, nil) + + summary := &models.SessionSummary{ + ID: 1, + Project: "test-project", + Request: sql.NullString{Valid: false}, + Learned: sql.NullString{Valid: false}, + CreatedAtEpoch: 1704067200000, + } + + result := m.summaryToResult(summary, "full") + + assert.Equal(t, "", result.Title) + assert.Equal(t, "", result.Content) +} + +// TestSearchParams_LimitValues tests limit parameter handling values. +func TestSearchParams_LimitValues(t *testing.T) { + tests := []struct { + name string + inputLimit int + expectedValid bool + }{ + {"zero_limit", 0, true}, + {"negative_limit", -5, true}, + {"normal_limit", 20, true}, + {"max_limit", 100, true}, + {"over_limit", 200, true}, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + params := SearchParams{ + Query: "test", + Project: "test", + Limit: tt.inputLimit, + } + assert.NotNil(t, params) + assert.Equal(t, tt.inputLimit, params.Limit) + }) + } +} + +// TestSearchParams_OrderByValues tests order by parameter values. +func TestSearchParams_OrderByValues(t *testing.T) { + validOrders := []string{"relevance", "date_desc", "date_asc", ""} + + for _, order := range validOrders { + t.Run("order_"+order, func(t *testing.T) { + params := SearchParams{ + Query: "test", + Project: "test", + OrderBy: order, + } + assert.Equal(t, order, params.OrderBy) + }) + } +} + +// TestSearchParams_TypeValues tests type parameter values. +func TestSearchParams_TypeValues(t *testing.T) { + validTypes := []string{"observations", "sessions", "prompts", ""} + + for _, typ := range validTypes { + t.Run("type_"+typ, func(t *testing.T) { + params := SearchParams{ + Query: "test", + Project: "test", + Type: typ, + } + assert.Equal(t, typ, params.Type) + }) + } +} + +// TestSearchParams_ScopeValues tests scope parameter values. +func TestSearchParams_ScopeValues(t *testing.T) { + validScopes := []string{"project", "global", ""} + + for _, scope := range validScopes { + t.Run("scope_"+scope, func(t *testing.T) { + params := SearchParams{ + Query: "test", + Project: "test", + Scope: scope, + } + assert.Equal(t, scope, params.Scope) + }) + } +} + +// TestSearchParams_FormatValues tests format parameter values. +func TestSearchParams_FormatValues(t *testing.T) { + validFormats := []string{"index", "full", ""} + + for _, format := range validFormats { + t.Run("format_"+format, func(t *testing.T) { + params := SearchParams{ + Query: "test", + Project: "test", + Format: format, + } + assert.Equal(t, format, params.Format) + }) + } +} + +// TestUnifiedSearchResult_MultipleResults tests result with multiple items. +func TestUnifiedSearchResult_MultipleResults(t *testing.T) { + results := []SearchResult{ + {Type: "observation", ID: 1, Title: "First", Project: "test"}, + {Type: "session", ID: 2, Title: "Second", Project: "test"}, + {Type: "prompt", ID: 3, Title: "Third", Project: "test"}, + } + + result := UnifiedSearchResult{ + Results: results, + TotalCount: 3, + Query: "test query", + } + + assert.Len(t, result.Results, 3) + assert.Equal(t, 3, result.TotalCount) + assert.Equal(t, "observation", result.Results[0].Type) + assert.Equal(t, "session", result.Results[1].Type) + assert.Equal(t, "prompt", result.Results[2].Type) +} + +// TestSearchResult_Metadata tests metadata handling in SearchResult. +func TestSearchResult_Metadata(t *testing.T) { + metadata := map[string]interface{}{ + "obs_type": "discovery", + "concepts": []string{"auth", "security"}, + "files_count": 5, + "is_important": true, + } + + result := SearchResult{ + Type: "observation", + ID: 1, + Metadata: metadata, + } + + assert.Equal(t, "discovery", result.Metadata["obs_type"]) + assert.Equal(t, 5, result.Metadata["files_count"]) + assert.Equal(t, true, result.Metadata["is_important"]) +} + +// TestSearchResult_Scores tests score handling in SearchResult. +func TestSearchResult_Scores(t *testing.T) { + tests := []struct { + name string + score float64 + }{ + {"perfect_score", 1.0}, + {"high_score", 0.95}, + {"medium_score", 0.5}, + {"low_score", 0.1}, + {"zero_score", 0.0}, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + result := SearchResult{ + Type: "observation", + ID: 1, + Score: tt.score, + } + assert.Equal(t, tt.score, result.Score) + }) + } +} diff --git a/internal/vector/sqlitevec/client_test.go b/internal/vector/sqlitevec/client_test.go new file mode 100644 index 0000000..9a90440 --- /dev/null +++ b/internal/vector/sqlitevec/client_test.go @@ -0,0 +1,502 @@ +package sqlitevec + +import ( + "context" + "database/sql" + "os" + "path/filepath" + "testing" + + sqlite_vec "github.com/asg017/sqlite-vec-go-bindings/cgo" + "github.com/lukaszraczylo/claude-mnemonic/internal/embedding" + _ "github.com/mattn/go-sqlite3" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +// testDB creates a test SQLite database with the vectors table. +func testDB(t *testing.T) (*sql.DB, func()) { + t.Helper() + + // Create temp directory + tmpDir, err := os.MkdirTemp("", "sqlitevec-test-*") + require.NoError(t, err) + + dbPath := filepath.Join(tmpDir, "test.db") + db, err := sql.Open("sqlite3", dbPath) + require.NoError(t, err) + + // Enable sqlite-vec + sqlite_vec.Auto() + + // Create vectors table (matches production schema) + _, err = db.Exec(` + CREATE VIRTUAL TABLE IF NOT EXISTS vectors USING vec0( + doc_id TEXT PRIMARY KEY, + embedding float[384], + sqlite_id INTEGER, + doc_type TEXT, + field_type TEXT, + project TEXT, + scope TEXT + ) + `) + require.NoError(t, err) + + cleanup := func() { + db.Close() + os.RemoveAll(tmpDir) + } + + return db, cleanup +} + +// testEmbeddingService creates a test embedding service. +func testEmbeddingService(t *testing.T) (*embedding.Service, func()) { + t.Helper() + + svc, err := embedding.NewService() + require.NoError(t, err) + + cleanup := func() { + svc.Close() + } + + return svc, cleanup +} + +func TestNewClient_Success(t *testing.T) { + db, dbCleanup := testDB(t) + defer dbCleanup() + + embedSvc, embedCleanup := testEmbeddingService(t) + defer embedCleanup() + + client, err := NewClient(Config{DB: db}, embedSvc) + require.NoError(t, err) + assert.NotNil(t, client) +} + +func TestNewClient_NilDB(t *testing.T) { + embedSvc, embedCleanup := testEmbeddingService(t) + defer embedCleanup() + + client, err := NewClient(Config{DB: nil}, embedSvc) + assert.Error(t, err) + assert.Nil(t, client) + assert.Contains(t, err.Error(), "database connection required") +} + +func TestNewClient_NilEmbedding(t *testing.T) { + db, dbCleanup := testDB(t) + defer dbCleanup() + + client, err := NewClient(Config{DB: db}, nil) + assert.Error(t, err) + assert.Nil(t, client) + assert.Contains(t, err.Error(), "embedding service required") +} + +func TestClient_AddDocuments_Empty(t *testing.T) { + db, dbCleanup := testDB(t) + defer dbCleanup() + + embedSvc, embedCleanup := testEmbeddingService(t) + defer embedCleanup() + + client, err := NewClient(Config{DB: db}, embedSvc) + require.NoError(t, err) + + err = client.AddDocuments(context.Background(), []Document{}) + require.NoError(t, err) +} + +func TestClient_AddDocuments_Single(t *testing.T) { + db, dbCleanup := testDB(t) + defer dbCleanup() + + embedSvc, embedCleanup := testEmbeddingService(t) + defer embedCleanup() + + client, err := NewClient(Config{DB: db}, embedSvc) + require.NoError(t, err) + + docs := []Document{ + { + ID: "obs-1-title", + Content: "This is a test observation about authentication.", + Metadata: map[string]any{ + "sqlite_id": int64(1), + "doc_type": "observation", + "field_type": "title", + "project": "test-project", + "scope": "project", + }, + }, + } + + err = client.AddDocuments(context.Background(), docs) + require.NoError(t, err) + + // Verify document was inserted + var count int + err = db.QueryRow("SELECT COUNT(*) FROM vectors WHERE doc_id = ?", "obs-1-title").Scan(&count) + require.NoError(t, err) + assert.Equal(t, 1, count) +} + +func TestClient_AddDocuments_Multiple(t *testing.T) { + db, dbCleanup := testDB(t) + defer dbCleanup() + + embedSvc, embedCleanup := testEmbeddingService(t) + defer embedCleanup() + + client, err := NewClient(Config{DB: db}, embedSvc) + require.NoError(t, err) + + docs := []Document{ + { + ID: "obs-1-title", + Content: "Authentication flow implementation.", + Metadata: map[string]any{ + "sqlite_id": int64(1), + "doc_type": "observation", + "field_type": "title", + "project": "test-project", + "scope": "project", + }, + }, + { + ID: "obs-1-narrative", + Content: "We implemented JWT-based authentication.", + Metadata: map[string]any{ + "sqlite_id": int64(1), + "doc_type": "observation", + "field_type": "narrative", + "project": "test-project", + "scope": "project", + }, + }, + { + ID: "obs-2-title", + Content: "Database optimization.", + Metadata: map[string]any{ + "sqlite_id": int64(2), + "doc_type": "observation", + "field_type": "title", + "project": "test-project", + "scope": "global", + }, + }, + } + + err = client.AddDocuments(context.Background(), docs) + require.NoError(t, err) + + // Verify all documents were inserted + var count int + err = db.QueryRow("SELECT COUNT(*) FROM vectors").Scan(&count) + require.NoError(t, err) + assert.Equal(t, 3, count) +} + +func TestClient_DeleteDocuments_Empty(t *testing.T) { + db, dbCleanup := testDB(t) + defer dbCleanup() + + embedSvc, embedCleanup := testEmbeddingService(t) + defer embedCleanup() + + client, err := NewClient(Config{DB: db}, embedSvc) + require.NoError(t, err) + + err = client.DeleteDocuments(context.Background(), []string{}) + require.NoError(t, err) +} + +func TestClient_DeleteDocuments_Existing(t *testing.T) { + db, dbCleanup := testDB(t) + defer dbCleanup() + + embedSvc, embedCleanup := testEmbeddingService(t) + defer embedCleanup() + + client, err := NewClient(Config{DB: db}, embedSvc) + require.NoError(t, err) + + // Add documents first + docs := []Document{ + { + ID: "doc-1", + Content: "First document.", + Metadata: map[string]any{ + "sqlite_id": int64(1), + "doc_type": "observation", + }, + }, + { + ID: "doc-2", + Content: "Second document.", + Metadata: map[string]any{ + "sqlite_id": int64(2), + "doc_type": "observation", + }, + }, + } + err = client.AddDocuments(context.Background(), docs) + require.NoError(t, err) + + // Delete one document + err = client.DeleteDocuments(context.Background(), []string{"doc-1"}) + require.NoError(t, err) + + // Verify only one remains + var count int + err = db.QueryRow("SELECT COUNT(*) FROM vectors").Scan(&count) + require.NoError(t, err) + assert.Equal(t, 1, count) +} + +func TestClient_Query_Basic(t *testing.T) { + db, dbCleanup := testDB(t) + defer dbCleanup() + + embedSvc, embedCleanup := testEmbeddingService(t) + defer embedCleanup() + + client, err := NewClient(Config{DB: db}, embedSvc) + require.NoError(t, err) + + // Add some test documents + docs := []Document{ + { + ID: "obs-1", + Content: "Authentication and login security implementation.", + Metadata: map[string]any{ + "sqlite_id": int64(1), + "doc_type": "observation", + "project": "test-project", + "scope": "project", + }, + }, + { + ID: "obs-2", + Content: "Database query optimization techniques.", + Metadata: map[string]any{ + "sqlite_id": int64(2), + "doc_type": "observation", + "project": "test-project", + "scope": "project", + }, + }, + } + err = client.AddDocuments(context.Background(), docs) + require.NoError(t, err) + + // Query for authentication-related content + results, err := client.Query(context.Background(), "login authentication", 10, nil) + require.NoError(t, err) + + assert.NotEmpty(t, results) + assert.LessOrEqual(t, len(results), 10) + + // First result should be the authentication document (higher similarity) + assert.Equal(t, "obs-1", results[0].ID) +} + +func TestClient_Query_WithDocTypeFilter(t *testing.T) { + db, dbCleanup := testDB(t) + defer dbCleanup() + + embedSvc, embedCleanup := testEmbeddingService(t) + defer embedCleanup() + + client, err := NewClient(Config{DB: db}, embedSvc) + require.NoError(t, err) + + // Add documents of different types + docs := []Document{ + { + ID: "obs-1", + Content: "Test content for observation.", + Metadata: map[string]any{ + "sqlite_id": int64(1), + "doc_type": "observation", + "project": "test-project", + }, + }, + { + ID: "summary-1", + Content: "Test content for summary.", + Metadata: map[string]any{ + "sqlite_id": int64(10), + "doc_type": "session_summary", + "project": "test-project", + }, + }, + } + err = client.AddDocuments(context.Background(), docs) + require.NoError(t, err) + + // Query with doc_type filter + where := map[string]any{"doc_type": "observation"} + results, err := client.Query(context.Background(), "test content", 10, where) + require.NoError(t, err) + + // Should only return observation documents + for _, r := range results { + docType, _ := r.Metadata["doc_type"].(string) + assert.Equal(t, "observation", docType) + } +} + +func TestClient_Query_WithProjectFilter(t *testing.T) { + db, dbCleanup := testDB(t) + defer dbCleanup() + + embedSvc, embedCleanup := testEmbeddingService(t) + defer embedCleanup() + + client, err := NewClient(Config{DB: db}, embedSvc) + require.NoError(t, err) + + // Add documents from different projects + docs := []Document{ + { + ID: "obs-1", + Content: "Project A authentication content.", + Metadata: map[string]any{ + "sqlite_id": int64(1), + "doc_type": "observation", + "project": "project-a", + "scope": "project", + }, + }, + { + ID: "obs-2", + Content: "Project B database content.", + Metadata: map[string]any{ + "sqlite_id": int64(2), + "doc_type": "observation", + "project": "project-b", + "scope": "project", + }, + }, + { + ID: "obs-3", + Content: "Global security best practices.", + Metadata: map[string]any{ + "sqlite_id": int64(3), + "doc_type": "observation", + "project": "project-b", + "scope": "global", + }, + }, + } + err = client.AddDocuments(context.Background(), docs) + require.NoError(t, err) + + // Query without project filter to verify all docs are there + results, err := client.Query(context.Background(), "authentication security", 10, nil) + require.NoError(t, err) + assert.NotEmpty(t, results, "Should find some results") +} + +func TestClient_IsConnected(t *testing.T) { + db, dbCleanup := testDB(t) + defer dbCleanup() + + embedSvc, embedCleanup := testEmbeddingService(t) + defer embedCleanup() + + client, err := NewClient(Config{DB: db}, embedSvc) + require.NoError(t, err) + + assert.True(t, client.IsConnected()) +} + +func TestClient_Close(t *testing.T) { + db, dbCleanup := testDB(t) + defer dbCleanup() + + embedSvc, embedCleanup := testEmbeddingService(t) + defer embedCleanup() + + client, err := NewClient(Config{DB: db}, embedSvc) + require.NoError(t, err) + + err = client.Close() + require.NoError(t, err) +} + +func TestConfig_Fields(t *testing.T) { + db, dbCleanup := testDB(t) + defer dbCleanup() + + cfg := Config{DB: db} + assert.Equal(t, db, cfg.DB) +} + +func TestClient_UpdateDocument_DeleteThenAdd(t *testing.T) { + db, dbCleanup := testDB(t) + defer dbCleanup() + + embedSvc, embedCleanup := testEmbeddingService(t) + defer embedCleanup() + + client, err := NewClient(Config{DB: db}, embedSvc) + require.NoError(t, err) + + // Add document + docs1 := []Document{ + { + ID: "doc-1", + Content: "Original content.", + Metadata: map[string]any{ + "sqlite_id": int64(1), + "doc_type": "observation", + }, + }, + } + err = client.AddDocuments(context.Background(), docs1) + require.NoError(t, err) + + // Delete then add with new content (proper update pattern) + err = client.DeleteDocuments(context.Background(), []string{"doc-1"}) + require.NoError(t, err) + + docs2 := []Document{ + { + ID: "doc-1", + Content: "Updated content.", + Metadata: map[string]any{ + "sqlite_id": int64(1), + "doc_type": "observation", + }, + }, + } + err = client.AddDocuments(context.Background(), docs2) + require.NoError(t, err) + + // Should have exactly 1 document + var count int + err = db.QueryRow("SELECT COUNT(*) FROM vectors WHERE doc_id = ?", "doc-1").Scan(&count) + require.NoError(t, err) + assert.Equal(t, 1, count) +} + +func TestClient_DeleteDocuments_NonExistent(t *testing.T) { + db, dbCleanup := testDB(t) + defer dbCleanup() + + embedSvc, embedCleanup := testEmbeddingService(t) + defer embedCleanup() + + client, err := NewClient(Config{DB: db}, embedSvc) + require.NoError(t, err) + + // Deleting non-existent document should not error + err = client.DeleteDocuments(context.Background(), []string{"non-existent-id"}) + require.NoError(t, err) +} diff --git a/internal/vector/sqlitevec/helpers_test.go b/internal/vector/sqlitevec/helpers_test.go new file mode 100644 index 0000000..0e9b5ac --- /dev/null +++ b/internal/vector/sqlitevec/helpers_test.go @@ -0,0 +1,574 @@ +package sqlitevec + +import ( + "testing" + + "github.com/stretchr/testify/assert" +) + +func TestDocTypes(t *testing.T) { + assert.Equal(t, DocType("observation"), DocTypeObservation) + assert.Equal(t, DocType("session_summary"), DocTypeSessionSummary) + assert.Equal(t, DocType("user_prompt"), DocTypeUserPrompt) +} + +func TestDocument_Fields(t *testing.T) { + doc := Document{ + ID: "doc-123", + Content: "test content", + Metadata: map[string]any{ + "key": "value", + }, + } + + assert.Equal(t, "doc-123", doc.ID) + assert.Equal(t, "test content", doc.Content) + assert.Equal(t, "value", doc.Metadata["key"]) +} + +func TestQueryResult_Fields(t *testing.T) { + result := QueryResult{ + ID: "result-123", + Distance: 0.5, + Metadata: map[string]any{ + "sqlite_id": float64(42), + }, + } + + assert.Equal(t, "result-123", result.ID) + assert.Equal(t, 0.5, result.Distance) + assert.Equal(t, float64(42), result.Metadata["sqlite_id"]) +} + +func TestBuildWhereFilter(t *testing.T) { + tests := []struct { + name string + docType DocType + project string + expected map[string]interface{} + }{ + { + name: "empty_filters", + docType: "", + project: "", + expected: map[string]interface{}{}, + }, + { + name: "doc_type_only", + docType: DocTypeObservation, + project: "", + expected: map[string]interface{}{ + "doc_type": "observation", + }, + }, + { + name: "project_only", + docType: "", + project: "my-project", + expected: map[string]interface{}{ + "project": "my-project", + }, + }, + { + name: "both_filters", + docType: DocTypeSessionSummary, + project: "test-project", + expected: map[string]interface{}{ + "doc_type": "session_summary", + "project": "test-project", + }, + }, + { + name: "user_prompt_type", + docType: DocTypeUserPrompt, + project: "prompt-project", + expected: map[string]interface{}{ + "doc_type": "user_prompt", + "project": "prompt-project", + }, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + result := BuildWhereFilter(tt.docType, tt.project) + assert.Equal(t, tt.expected, result) + }) + } +} + +func TestExtractIDsByDocType_Empty(t *testing.T) { + results := []QueryResult{} + ids := ExtractIDsByDocType(results) + + assert.Empty(t, ids.ObservationIDs) + assert.Empty(t, ids.SummaryIDs) + assert.Empty(t, ids.PromptIDs) +} + +func TestExtractIDsByDocType_AllTypes(t *testing.T) { + results := []QueryResult{ + { + ID: "obs-1", + Distance: 0.1, + Metadata: map[string]any{ + "sqlite_id": float64(1), + "doc_type": "observation", + }, + }, + { + ID: "obs-2", + Distance: 0.2, + Metadata: map[string]any{ + "sqlite_id": float64(2), + "doc_type": "observation", + }, + }, + { + ID: "summary-1", + Distance: 0.3, + Metadata: map[string]any{ + "sqlite_id": float64(10), + "doc_type": "session_summary", + }, + }, + { + ID: "prompt-1", + Distance: 0.4, + Metadata: map[string]any{ + "sqlite_id": float64(20), + "doc_type": "user_prompt", + }, + }, + } + + ids := ExtractIDsByDocType(results) + + assert.Equal(t, []int64{1, 2}, ids.ObservationIDs) + assert.Equal(t, []int64{10}, ids.SummaryIDs) + assert.Equal(t, []int64{20}, ids.PromptIDs) +} + +func TestExtractIDsByDocType_Deduplication(t *testing.T) { + results := []QueryResult{ + { + ID: "obs-1-field1", + Distance: 0.1, + Metadata: map[string]any{ + "sqlite_id": float64(1), + "doc_type": "observation", + }, + }, + { + ID: "obs-1-field2", + Distance: 0.2, + Metadata: map[string]any{ + "sqlite_id": float64(1), // Same ID, different field + "doc_type": "observation", + }, + }, + { + ID: "obs-2", + Distance: 0.3, + Metadata: map[string]any{ + "sqlite_id": float64(2), + "doc_type": "observation", + }, + }, + } + + ids := ExtractIDsByDocType(results) + + assert.Equal(t, []int64{1, 2}, ids.ObservationIDs) // Should be deduplicated +} + +func TestExtractIDsByDocType_Int64Fallback(t *testing.T) { + results := []QueryResult{ + { + ID: "obs-1", + Distance: 0.1, + Metadata: map[string]any{ + "sqlite_id": int64(42), // int64 instead of float64 + "doc_type": "observation", + }, + }, + } + + ids := ExtractIDsByDocType(results) + + assert.Equal(t, []int64{42}, ids.ObservationIDs) +} + +func TestExtractIDsByDocType_MissingSqliteID(t *testing.T) { + results := []QueryResult{ + { + ID: "obs-1", + Distance: 0.1, + Metadata: map[string]any{ + "doc_type": "observation", + // Missing sqlite_id + }, + }, + } + + ids := ExtractIDsByDocType(results) + + assert.Empty(t, ids.ObservationIDs) +} + +func TestExtractIDsByDocType_UnknownType(t *testing.T) { + results := []QueryResult{ + { + ID: "unknown-1", + Distance: 0.1, + Metadata: map[string]any{ + "sqlite_id": float64(1), + "doc_type": "unknown_type", + }, + }, + } + + ids := ExtractIDsByDocType(results) + + // Should not be added to any category + assert.Empty(t, ids.ObservationIDs) + assert.Empty(t, ids.SummaryIDs) + assert.Empty(t, ids.PromptIDs) +} + +func TestExtractObservationIDs_NoFilter(t *testing.T) { + results := []QueryResult{ + { + ID: "obs-1", + Metadata: map[string]any{ + "sqlite_id": float64(1), + "doc_type": "observation", + "project": "proj-a", + }, + }, + { + ID: "obs-2", + Metadata: map[string]any{ + "sqlite_id": float64(2), + "doc_type": "observation", + "project": "proj-b", + }, + }, + { + ID: "summary-1", + Metadata: map[string]any{ + "sqlite_id": float64(10), + "doc_type": "session_summary", + "project": "proj-a", + }, + }, + } + + ids := ExtractObservationIDs(results, "") + + assert.Equal(t, []int64{1, 2}, ids) +} + +func TestExtractObservationIDs_WithProjectFilter(t *testing.T) { + results := []QueryResult{ + { + ID: "obs-1", + Metadata: map[string]any{ + "sqlite_id": float64(1), + "doc_type": "observation", + "project": "proj-a", + "scope": "project", + }, + }, + { + ID: "obs-2", + Metadata: map[string]any{ + "sqlite_id": float64(2), + "doc_type": "observation", + "project": "proj-b", + "scope": "project", + }, + }, + { + ID: "obs-global", + Metadata: map[string]any{ + "sqlite_id": float64(3), + "doc_type": "observation", + "project": "proj-b", + "scope": "global", + }, + }, + } + + ids := ExtractObservationIDs(results, "proj-a") + + // Should include proj-a and global scope observations + assert.Equal(t, []int64{1, 3}, ids) +} + +func TestExtractObservationIDs_Deduplication(t *testing.T) { + results := []QueryResult{ + { + ID: "obs-1-field1", + Metadata: map[string]any{ + "sqlite_id": float64(1), + "doc_type": "observation", + }, + }, + { + ID: "obs-1-field2", + Metadata: map[string]any{ + "sqlite_id": float64(1), // Same ID + "doc_type": "observation", + }, + }, + } + + ids := ExtractObservationIDs(results, "") + + assert.Equal(t, []int64{1}, ids) +} + +func TestExtractSummaryIDs_NoFilter(t *testing.T) { + results := []QueryResult{ + { + ID: "summary-1", + Metadata: map[string]any{ + "sqlite_id": float64(10), + "doc_type": "session_summary", + "project": "proj-a", + }, + }, + { + ID: "summary-2", + Metadata: map[string]any{ + "sqlite_id": float64(20), + "doc_type": "session_summary", + "project": "proj-b", + }, + }, + { + ID: "obs-1", + Metadata: map[string]any{ + "sqlite_id": float64(1), + "doc_type": "observation", + }, + }, + } + + ids := ExtractSummaryIDs(results, "") + + assert.Equal(t, []int64{10, 20}, ids) +} + +func TestExtractSummaryIDs_WithProjectFilter(t *testing.T) { + results := []QueryResult{ + { + ID: "summary-1", + Metadata: map[string]any{ + "sqlite_id": float64(10), + "doc_type": "session_summary", + "project": "proj-a", + }, + }, + { + ID: "summary-2", + Metadata: map[string]any{ + "sqlite_id": float64(20), + "doc_type": "session_summary", + "project": "proj-b", + }, + }, + } + + ids := ExtractSummaryIDs(results, "proj-a") + + assert.Equal(t, []int64{10}, ids) +} + +func TestExtractPromptIDs_NoFilter(t *testing.T) { + results := []QueryResult{ + { + ID: "prompt-1", + Metadata: map[string]any{ + "sqlite_id": float64(100), + "doc_type": "user_prompt", + "project": "proj-a", + }, + }, + { + ID: "prompt-2", + Metadata: map[string]any{ + "sqlite_id": float64(200), + "doc_type": "user_prompt", + "project": "proj-b", + }, + }, + } + + ids := ExtractPromptIDs(results, "") + + assert.Equal(t, []int64{100, 200}, ids) +} + +func TestExtractPromptIDs_WithProjectFilter(t *testing.T) { + results := []QueryResult{ + { + ID: "prompt-1", + Metadata: map[string]any{ + "sqlite_id": float64(100), + "doc_type": "user_prompt", + "project": "proj-a", + }, + }, + { + ID: "prompt-2", + Metadata: map[string]any{ + "sqlite_id": float64(200), + "doc_type": "user_prompt", + "project": "proj-b", + }, + }, + } + + ids := ExtractPromptIDs(results, "proj-b") + + assert.Equal(t, []int64{200}, ids) +} + +func TestCopyMetadata(t *testing.T) { + base := map[string]any{ + "key1": "value1", + "key2": 42, + } + + result := copyMetadata(base, "key3", "value3") + + // Original should be unchanged + assert.Len(t, base, 2) + + // Result should have new key + assert.Len(t, result, 3) + assert.Equal(t, "value1", result["key1"]) + assert.Equal(t, 42, result["key2"]) + assert.Equal(t, "value3", result["key3"]) +} + +func TestCopyMetadataMulti(t *testing.T) { + base := map[string]any{ + "key1": "value1", + } + extra := map[string]any{ + "key2": "value2", + "key3": "value3", + } + + result := copyMetadataMulti(base, extra) + + assert.Len(t, result, 3) + assert.Equal(t, "value1", result["key1"]) + assert.Equal(t, "value2", result["key2"]) + assert.Equal(t, "value3", result["key3"]) +} + +func TestJoinStrings(t *testing.T) { + tests := []struct { + name string + strs []string + sep string + expected string + }{ + { + name: "empty_slice", + strs: []string{}, + sep: ", ", + expected: "", + }, + { + name: "single_element", + strs: []string{"one"}, + sep: ", ", + expected: "one", + }, + { + name: "multiple_elements", + strs: []string{"one", "two", "three"}, + sep: ", ", + expected: "one, two, three", + }, + { + name: "different_separator", + strs: []string{"a", "b", "c"}, + sep: "-", + expected: "a-b-c", + }, + { + name: "empty_separator", + strs: []string{"a", "b", "c"}, + sep: "", + expected: "abc", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + result := joinStrings(tt.strs, tt.sep) + assert.Equal(t, tt.expected, result) + }) + } +} + +func TestTruncateString(t *testing.T) { + tests := []struct { + name string + input string + maxLen int + expected string + }{ + { + name: "shorter_than_max", + input: "hello", + maxLen: 10, + expected: "hello", + }, + { + name: "equal_to_max", + input: "hello", + maxLen: 5, + expected: "hello", + }, + { + name: "longer_than_max", + input: "hello world", + maxLen: 5, + expected: "hello...", + }, + { + name: "empty_string", + input: "", + maxLen: 5, + expected: "", + }, + { + name: "zero_max_length", + input: "hello", + maxLen: 0, + expected: "...", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + result := truncateString(tt.input, tt.maxLen) + assert.Equal(t, tt.expected, result) + }) + } +} + +func TestExtractedIDs_Empty(t *testing.T) { + ids := &ExtractedIDs{} + + assert.Nil(t, ids.ObservationIDs) + assert.Nil(t, ids.SummaryIDs) + assert.Nil(t, ids.PromptIDs) +} diff --git a/internal/worker/handlers_test.go b/internal/worker/handlers_test.go index b86257a..f7f1b67 100644 --- a/internal/worker/handlers_test.go +++ b/internal/worker/handlers_test.go @@ -1279,3 +1279,849 @@ func TestObservationRequest_Fields(t *testing.T) { assert.Equal(t, "Read", req.ToolName) assert.Equal(t, "/home/user/project", req.CWD) } + +// TestRetrievalStats_Fields tests RetrievalStats struct fields. +func TestRetrievalStats_Fields(t *testing.T) { + stats := RetrievalStats{ + TotalRequests: 100, + ObservationsServed: 500, + VerifiedStale: 10, + DeletedInvalid: 5, + SearchRequests: 80, + ContextInjections: 20, + } + + assert.Equal(t, int64(100), stats.TotalRequests) + assert.Equal(t, int64(500), stats.ObservationsServed) + assert.Equal(t, int64(10), stats.VerifiedStale) + assert.Equal(t, int64(5), stats.DeletedInvalid) + assert.Equal(t, int64(80), stats.SearchRequests) + assert.Equal(t, int64(20), stats.ContextInjections) +} + +// TestServiceConstants tests service configuration constants. +func TestServiceConstants(t *testing.T) { + assert.Equal(t, 30*time.Second, DefaultHTTPTimeout) + assert.Equal(t, 50*time.Millisecond, ReadyPollInterval) + assert.Equal(t, 100, StaleQueueSize) + assert.Equal(t, 2*time.Second, QueueProcessInterval) +} + +// TestClusterObservations_Empty tests clustering with empty slice. +func TestClusterObservations_Empty(t *testing.T) { + observations := []*models.Observation{} + clustered := clusterObservations(observations, 0.4) + assert.Empty(t, clustered) +} + +// TestClusterObservations_Single tests clustering with single observation. +func TestClusterObservations_Single(t *testing.T) { + observations := []*models.Observation{ + { + ID: 1, + Title: sql.NullString{String: "Test observation", Valid: true}, + Narrative: sql.NullString{String: "Test content", Valid: true}, + }, + } + clustered := clusterObservations(observations, 0.4) + assert.Len(t, clustered, 1) +} + +// TestClusterObservations_VeryDifferent tests clustering with very different observations. +func TestClusterObservations_VeryDifferent(t *testing.T) { + observations := []*models.Observation{ + { + ID: 1, + Title: sql.NullString{String: "Database optimization", Valid: true}, + Narrative: sql.NullString{String: "PostgreSQL index tuning", Valid: true}, + }, + { + ID: 2, + Title: sql.NullString{String: "Authentication flow", Valid: true}, + Narrative: sql.NullString{String: "JWT token validation", Valid: true}, + }, + { + ID: 3, + Title: sql.NullString{String: "Logging setup", Valid: true}, + Narrative: sql.NullString{String: "Zerolog configuration", Valid: true}, + }, + } + clustered := clusterObservations(observations, 0.4) + // Very different observations should not be clustered together + assert.GreaterOrEqual(t, len(clustered), 1) +} + +// TestHandleContextInject_WithLimit tests context inject with custom limit. +func TestHandleContextInject_WithLimit(t *testing.T) { + svc, cleanup := testService(t) + defer cleanup() + + project := "inject-limit-test" + + // Create some observations + for i := 0; i < 10; i++ { + createTestObservation(t, svc.observationStore, project, + "Observation "+strconv.Itoa(i), + "Content "+strconv.Itoa(i), + []string{"test-" + strconv.Itoa(i)}) + time.Sleep(time.Millisecond) + } + + req := httptest.NewRequest(http.MethodGet, "/api/context/inject?project="+project+"&limit=5", nil) + rec := httptest.NewRecorder() + + svc.router.ServeHTTP(rec, req) + + assert.Equal(t, http.StatusOK, rec.Code) + + var response map[string]interface{} + err := json.Unmarshal(rec.Body.Bytes(), &response) + require.NoError(t, err) + + observations, ok := response["observations"].([]interface{}) + require.True(t, ok) + assert.LessOrEqual(t, len(observations), 5) +} + +// TestHandleGetObservations tests getting observations list. +func TestHandleGetObservations(t *testing.T) { + svc, cleanup := testService(t) + defer cleanup() + + // Create some observations + createTestObservation(t, svc.observationStore, "test-project", + "Test Observation 1", + "Test content 1", + []string{"test"}) + createTestObservation(t, svc.observationStore, "test-project", + "Test Observation 2", + "Test content 2", + []string{"test"}) + + req := httptest.NewRequest(http.MethodGet, "/api/observations?project=test-project", nil) + rec := httptest.NewRecorder() + + svc.router.ServeHTTP(rec, req) + + assert.Equal(t, http.StatusOK, rec.Code) + + var observations []map[string]interface{} + err := json.Unmarshal(rec.Body.Bytes(), &observations) + require.NoError(t, err) + + assert.GreaterOrEqual(t, len(observations), 2) +} + +// TestHandleGetObservations_Pagination tests observations pagination. +func TestHandleGetObservations_Pagination(t *testing.T) { + svc, cleanup := testService(t) + defer cleanup() + + // Create some observations + for i := 0; i < 5; i++ { + createTestObservation(t, svc.observationStore, "page-test", + "Observation "+strconv.Itoa(i), + "Content "+strconv.Itoa(i), + []string{"test"}) + } + + req := httptest.NewRequest(http.MethodGet, "/api/observations?project=page-test&limit=2", nil) + rec := httptest.NewRecorder() + + svc.router.ServeHTTP(rec, req) + + assert.Equal(t, http.StatusOK, rec.Code) +} + +// TestHandleGetObservations_NoProject tests observations without project. +func TestHandleGetObservations_NoProject(t *testing.T) { + svc, cleanup := testService(t) + defer cleanup() + + req := httptest.NewRequest(http.MethodGet, "/api/observations", nil) + rec := httptest.NewRecorder() + + svc.router.ServeHTTP(rec, req) + + // Should still return 200 with empty results or all observations + assert.Contains(t, []int{http.StatusOK, http.StatusBadRequest}, rec.Code) +} + +// TestHandleSearchByPrompt_EmptyQuery tests search with empty query parameter. +func TestHandleSearchByPrompt_EmptyQuery(t *testing.T) { + svc, cleanup := testService(t) + defer cleanup() + + req := httptest.NewRequest(http.MethodGet, "/api/context/search?project=test&query=", nil) + rec := httptest.NewRecorder() + + svc.router.ServeHTTP(rec, req) + + // Empty query should still be a bad request + assert.Equal(t, http.StatusBadRequest, rec.Code) +} + +// TestHandleGetSessionByClaudeID tests getting session by Claude ID. +func TestHandleGetSessionByClaudeID(t *testing.T) { + svc, cleanup := testService(t) + defer cleanup() + + // Create a session + ctx := context.Background() + svc.sessionStore.CreateSDKSession(ctx, "claude-test-123", "project-a", "prompt 1") + + // Test with valid claudeSessionId + req := httptest.NewRequest(http.MethodGet, "/api/sessions?claudeSessionId=claude-test-123", nil) + rec := httptest.NewRecorder() + + svc.router.ServeHTTP(rec, req) + + assert.Equal(t, http.StatusOK, rec.Code) +} + +// TestHandleGetSessionByClaudeID_Missing tests session lookup with missing param. +func TestHandleGetSessionByClaudeID_Missing(t *testing.T) { + svc, cleanup := testService(t) + defer cleanup() + + req := httptest.NewRequest(http.MethodGet, "/api/sessions", nil) + rec := httptest.NewRecorder() + + svc.router.ServeHTTP(rec, req) + + assert.Equal(t, http.StatusBadRequest, rec.Code) +} + +// TestHandleGetSessionByClaudeID_NotFound tests session not found. +func TestHandleGetSessionByClaudeID_NotFound(t *testing.T) { + svc, cleanup := testService(t) + defer cleanup() + + req := httptest.NewRequest(http.MethodGet, "/api/sessions?claudeSessionId=nonexistent", nil) + rec := httptest.NewRecorder() + + svc.router.ServeHTTP(rec, req) + + assert.Equal(t, http.StatusNotFound, rec.Code) +} + +// TestGetRetrievalStats tests the retrieval stats getter. +func TestGetRetrievalStats(t *testing.T) { + svc, cleanup := testService(t) + defer cleanup() + + // Initially all zeros + stats := svc.GetRetrievalStats() + assert.Equal(t, int64(0), stats.TotalRequests) + assert.Equal(t, int64(0), stats.SearchRequests) + + // Make some requests to increment stats + project := "stats-test" + createTestObservation(t, svc.observationStore, project, "Test", "Content", []string{"test"}) + + req := httptest.NewRequest(http.MethodGet, "/api/context/search?project="+project+"&query=test", nil) + rec := httptest.NewRecorder() + svc.router.ServeHTTP(rec, req) + + // Stats should be updated + stats = svc.GetRetrievalStats() + assert.GreaterOrEqual(t, stats.TotalRequests, int64(1)) +} + +// TestSelfCheckResponse_Fields tests SelfCheckResponse struct fields. +func TestSelfCheckResponse_Fields(t *testing.T) { + resp := SelfCheckResponse{ + Overall: "healthy", + Version: "v1.0.0", + Uptime: "2h30m", + Components: []ComponentHealth{ + {Name: "database", Status: "healthy", Message: "Connected"}, + {Name: "vector", Status: "healthy", Message: "Ready"}, + }, + } + + assert.Equal(t, "healthy", resp.Overall) + assert.Equal(t, "v1.0.0", resp.Version) + assert.Equal(t, "2h30m", resp.Uptime) + assert.Len(t, resp.Components, 2) + assert.Equal(t, "database", resp.Components[0].Name) + assert.Equal(t, "healthy", resp.Components[0].Status) + assert.Equal(t, "Connected", resp.Components[0].Message) +} + +// TestComponentHealth_Fields tests ComponentHealth struct fields. +func TestComponentHealth_Fields(t *testing.T) { + tests := []struct { + name string + status string + message string + }{ + {"healthy", "healthy", "All systems operational"}, + {"degraded", "degraded", "Some features unavailable"}, + {"unhealthy", "unhealthy", "Service is down"}, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + health := ComponentHealth{ + Status: tt.status, + Message: tt.message, + } + assert.Equal(t, tt.status, health.Status) + assert.Equal(t, tt.message, health.Message) + }) + } +} + +// TestWriteJSON_Error tests writeJSON with values that can't be encoded. +func TestWriteJSON_Error(t *testing.T) { + rec := httptest.NewRecorder() + + // channels can't be JSON encoded + ch := make(chan int) + writeJSON(rec, ch) + + // Should still set content type but encoding will fail + assert.Equal(t, "application/json", rec.Header().Get("Content-Type")) +} + +// TestHandleSummarize_InvalidSessionID tests summarize with invalid session ID. +func TestHandleSummarize_InvalidSessionID(t *testing.T) { + svc, cleanup := testService(t) + defer cleanup() + + req := httptest.NewRequest(http.MethodPost, "/sessions/invalid/summarize", bytes.NewReader([]byte("{}"))) + req.Header.Set("Content-Type", "application/json") + rec := httptest.NewRecorder() + + svc.router.ServeHTTP(rec, req) + + // Invalid session ID should return 400 + assert.Equal(t, http.StatusBadRequest, rec.Code) +} + +// TestHandleSubagentComplete tests subagent completion endpoint. +func TestHandleSubagentComplete(t *testing.T) { + svc, cleanup := testService(t) + defer cleanup() + + // Create a session first + ctx := context.Background() + sessionID, _ := svc.sessionStore.CreateSDKSession(ctx, "subagent-test-123", "test-project", "test prompt") + + payload := `{"session_id": ` + strconv.FormatInt(sessionID, 10) + `, "parent_session_id": "parent-123"}` + req := httptest.NewRequest(http.MethodPost, "/api/sessions/subagent-complete", bytes.NewReader([]byte(payload))) + req.Header.Set("Content-Type", "application/json") + rec := httptest.NewRecorder() + + svc.router.ServeHTTP(rec, req) + + // Should accept the request + assert.Contains(t, []int{http.StatusOK, http.StatusNotFound, http.StatusBadRequest}, rec.Code) +} + +// TestHandleContextSearch_Ordering tests search with different orderings. +func TestHandleContextSearch_Ordering(t *testing.T) { + svc, cleanup := testService(t) + defer cleanup() + + project := "order-test" + + for i := 0; i < 5; i++ { + createTestObservation(t, svc.observationStore, project, + "Obs "+strconv.Itoa(i), + "Content "+strconv.Itoa(i), + []string{"test"}) + time.Sleep(time.Millisecond) + } + + tests := []struct { + name string + order string + }{ + {"date_desc", "date_desc"}, + {"date_asc", "date_asc"}, + {"default", ""}, // Should default to date_desc + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + url := "/api/context/search?project=" + project + "&query=test" + if tt.order != "" { + url += "&order_by=" + tt.order + } + + req := httptest.NewRequest(http.MethodGet, url, nil) + rec := httptest.NewRecorder() + + svc.router.ServeHTTP(rec, req) + + assert.Equal(t, http.StatusOK, rec.Code) + }) + } +} + +// TestHandleContextCount_NoProject tests context count without project. +func TestHandleContextCount_NoProject(t *testing.T) { + svc, cleanup := testService(t) + defer cleanup() + + req := httptest.NewRequest(http.MethodGet, "/api/context/count", nil) + rec := httptest.NewRecorder() + + svc.router.ServeHTTP(rec, req) + + assert.Equal(t, http.StatusBadRequest, rec.Code) +} + +// TestHandleRetrievalStatsEndpoint tests retrieval stats endpoint. +func TestHandleRetrievalStatsEndpoint(t *testing.T) { + svc, cleanup := testService(t) + defer cleanup() + + req := httptest.NewRequest(http.MethodGet, "/api/stats/retrieval", nil) + rec := httptest.NewRecorder() + + svc.router.ServeHTTP(rec, req) + + assert.Equal(t, http.StatusOK, rec.Code) + + var stats RetrievalStats + err := json.Unmarshal(rec.Body.Bytes(), &stats) + require.NoError(t, err) +} + +// TestHandleReadyEndpoint tests ready endpoint response. +func TestHandleReadyEndpoint(t *testing.T) { + svc, cleanup := testService(t) + defer cleanup() + + req := httptest.NewRequest(http.MethodGet, "/api/ready", nil) + rec := httptest.NewRecorder() + + svc.router.ServeHTTP(rec, req) + + // Ready may return OK or ServiceUnavailable depending on state + assert.Contains(t, []int{http.StatusOK, http.StatusServiceUnavailable}, rec.Code) +} + +// TestHandleSessionInit_EmptyBody tests session init with empty body. +func TestHandleSessionInit_EmptyBody(t *testing.T) { + svc, cleanup := testService(t) + defer cleanup() + + payload := `{}` + req := httptest.NewRequest(http.MethodPost, "/api/sessions/init", bytes.NewReader([]byte(payload))) + req.Header.Set("Content-Type", "application/json") + rec := httptest.NewRecorder() + + svc.router.ServeHTTP(rec, req) + + // Should accept empty body and create a session + assert.Contains(t, []int{http.StatusOK, http.StatusBadRequest}, rec.Code) +} + +// TestHandleObservation_MissingSession tests observation without session. +func TestHandleObservation_MissingSession(t *testing.T) { + svc, cleanup := testService(t) + defer cleanup() + + payload := `{"session_id": 99999, "tool_name": "Read", "tool_input": "{}", "tool_output": "test"}` + req := httptest.NewRequest(http.MethodPost, "/api/sessions/observations", bytes.NewReader([]byte(payload))) + req.Header.Set("Content-Type", "application/json") + rec := httptest.NewRecorder() + + svc.router.ServeHTTP(rec, req) + + // Should still accept the observation + assert.Contains(t, []int{http.StatusOK, http.StatusNotFound, http.StatusBadRequest}, rec.Code) +} + +// TestHandleSummaries_Pagination tests summaries pagination. +func TestHandleSummaries_Pagination(t *testing.T) { + svc, cleanup := testService(t) + defer cleanup() + + req := httptest.NewRequest(http.MethodGet, "/api/summaries?project=test&limit=10&offset=0", nil) + rec := httptest.NewRecorder() + + svc.router.ServeHTTP(rec, req) + + assert.Equal(t, http.StatusOK, rec.Code) +} + +// TestHandlePrompts_Pagination tests prompts pagination. +func TestHandlePrompts_Pagination(t *testing.T) { + svc, cleanup := testService(t) + defer cleanup() + + req := httptest.NewRequest(http.MethodGet, "/api/prompts?project=test&limit=10&offset=0", nil) + rec := httptest.NewRecorder() + + svc.router.ServeHTTP(rec, req) + + assert.Equal(t, http.StatusOK, rec.Code) +} + +// TestHandleGetStats_AllProjects tests stats without project filter. +func TestHandleGetStats_AllProjects(t *testing.T) { + svc, cleanup := testService(t) + defer cleanup() + + // Create some data in multiple projects + createTestObservation(t, svc.observationStore, "proj-a", "Test A", "Content", []string{"test"}) + createTestObservation(t, svc.observationStore, "proj-b", "Test B", "Content", []string{"test"}) + + req := httptest.NewRequest(http.MethodGet, "/api/stats", nil) + rec := httptest.NewRecorder() + + svc.router.ServeHTTP(rec, req) + + assert.Equal(t, http.StatusOK, rec.Code) +} + +// TestHandleSubagentComplete_WithSession tests subagent completion with existing session. +func TestHandleSubagentComplete_WithSession(t *testing.T) { + svc, cleanup := testService(t) + defer cleanup() + + // Create a session first + ctx := context.Background() + svc.sessionStore.CreateSDKSession(ctx, "subagent-claude-123", "test-project", "test prompt") + + reqBody := SubagentCompleteRequest{ + ClaudeSessionID: "subagent-claude-123", + Project: "test-project", + } + + body, _ := json.Marshal(reqBody) + req := httptest.NewRequest(http.MethodPost, "/api/sessions/subagent-complete", bytes.NewReader(body)) + req.Header.Set("Content-Type", "application/json") + rec := httptest.NewRecorder() + + svc.router.ServeHTTP(rec, req) + + assert.Equal(t, http.StatusOK, rec.Code) +} + +// TestHandleSubagentComplete_NoSession tests subagent completion when session doesn't exist. +func TestHandleSubagentComplete_NoSession(t *testing.T) { + svc, cleanup := testService(t) + defer cleanup() + + reqBody := SubagentCompleteRequest{ + ClaudeSessionID: "nonexistent-session", + Project: "test-project", + } + + body, _ := json.Marshal(reqBody) + req := httptest.NewRequest(http.MethodPost, "/api/sessions/subagent-complete", bytes.NewReader(body)) + req.Header.Set("Content-Type", "application/json") + rec := httptest.NewRecorder() + + svc.router.ServeHTTP(rec, req) + + // Should still return 200 even if session not found + assert.Equal(t, http.StatusOK, rec.Code) +} + +// TestHandleSubagentComplete_InvalidJSON tests subagent completion with invalid JSON. +func TestHandleSubagentComplete_InvalidJSON(t *testing.T) { + svc, cleanup := testService(t) + defer cleanup() + + req := httptest.NewRequest(http.MethodPost, "/api/sessions/subagent-complete", bytes.NewReader([]byte("invalid json"))) + req.Header.Set("Content-Type", "application/json") + rec := httptest.NewRecorder() + + svc.router.ServeHTTP(rec, req) + + assert.Equal(t, http.StatusBadRequest, rec.Code) +} + +// TestHandleSummarize_ValidSession tests summarize with valid session. +func TestHandleSummarize_ValidSession(t *testing.T) { + svc, cleanup := testService(t) + defer cleanup() + + // Create a session first + ctx := context.Background() + sessionID, _ := svc.sessionStore.CreateSDKSession(ctx, "summarize-claude-test", "test-project", "test prompt") + + reqBody := SummarizeRequest{ + LastUserMessage: "Can you help me fix this bug?", + LastAssistantMessage: "I've analyzed the code and fixed the issue in the handler.", + } + + body, _ := json.Marshal(reqBody) + req := httptest.NewRequest(http.MethodPost, "/sessions/"+strconv.FormatInt(sessionID, 10)+"/summarize", bytes.NewReader(body)) + req.Header.Set("Content-Type", "application/json") + rec := httptest.NewRecorder() + + svc.router.ServeHTTP(rec, req) + + assert.Equal(t, http.StatusOK, rec.Code) +} + +// TestHandleSummarize_InvalidJSON tests summarize with invalid JSON. +func TestHandleSummarize_InvalidJSON(t *testing.T) { + svc, cleanup := testService(t) + defer cleanup() + + ctx := context.Background() + sessionID, _ := svc.sessionStore.CreateSDKSession(ctx, "summarize-invalid", "test-project", "test") + + req := httptest.NewRequest(http.MethodPost, "/sessions/"+strconv.FormatInt(sessionID, 10)+"/summarize", bytes.NewReader([]byte("not valid json"))) + req.Header.Set("Content-Type", "application/json") + rec := httptest.NewRecorder() + + svc.router.ServeHTTP(rec, req) + + assert.Equal(t, http.StatusBadRequest, rec.Code) +} + +// TestHandleSummarize_NonExistentSession tests summarize with non-existent session. +func TestHandleSummarize_NonExistentSession(t *testing.T) { + svc, cleanup := testService(t) + defer cleanup() + + reqBody := SummarizeRequest{ + LastUserMessage: "test", + LastAssistantMessage: "test", + } + + body, _ := json.Marshal(reqBody) + req := httptest.NewRequest(http.MethodPost, "/sessions/999999/summarize", bytes.NewReader(body)) + req.Header.Set("Content-Type", "application/json") + rec := httptest.NewRecorder() + + svc.router.ServeHTTP(rec, req) + + // Should return error for non-existent session + assert.Contains(t, []int{http.StatusOK, http.StatusInternalServerError, http.StatusNotFound}, rec.Code) +} + +// TestSubagentCompleteRequest_Fields tests SubagentCompleteRequest struct. +func TestSubagentCompleteRequest_Fields(t *testing.T) { + req := SubagentCompleteRequest{ + ClaudeSessionID: "test-session-123", + Project: "my-project", + } + + assert.Equal(t, "test-session-123", req.ClaudeSessionID) + assert.Equal(t, "my-project", req.Project) +} + +// TestSummarizeRequest_Fields tests SummarizeRequest struct. +func TestSummarizeRequest_Fields(t *testing.T) { + req := SummarizeRequest{ + LastUserMessage: "Help me fix this bug", + LastAssistantMessage: "I've fixed the authentication issue", + } + + assert.Equal(t, "Help me fix this bug", req.LastUserMessage) + assert.Equal(t, "I've fixed the authentication issue", req.LastAssistantMessage) +} + +// TestHandleHealth_NotReady tests health endpoint when not ready. +func TestHandleHealth_NotReady(t *testing.T) { + svc, cleanup := testService(t) + defer cleanup() + + svc.ready.Store(false) + + req := httptest.NewRequest(http.MethodGet, "/api/health", nil) + rec := httptest.NewRecorder() + + svc.handleHealth(rec, req) + + assert.Equal(t, http.StatusOK, rec.Code) + + var response map[string]interface{} + err := json.Unmarshal(rec.Body.Bytes(), &response) + require.NoError(t, err) + + assert.Equal(t, "starting", response["status"]) +} + +// TestHandleContextInject_EmptyProject tests context inject with empty project. +func TestHandleContextInject_EmptyProject(t *testing.T) { + svc, cleanup := testService(t) + defer cleanup() + + req := httptest.NewRequest(http.MethodGet, "/api/context/inject?project=", nil) + rec := httptest.NewRecorder() + + svc.router.ServeHTTP(rec, req) + + assert.Equal(t, http.StatusBadRequest, rec.Code) +} + +// TestHandleSearchByPrompt_LargeLimit tests search with limit exceeding max. +func TestHandleSearchByPrompt_LargeLimit(t *testing.T) { + svc, cleanup := testService(t) + defer cleanup() + + project := "limit-test" + createTestObservation(t, svc.observationStore, project, "Test", "Content", []string{"test"}) + + req := httptest.NewRequest(http.MethodGet, "/api/context/search?project="+project+"&query=test&limit=999", nil) + rec := httptest.NewRecorder() + + svc.router.ServeHTTP(rec, req) + + assert.Equal(t, http.StatusOK, rec.Code) +} + +// TestHandleObservation_WithFullData tests observation with all fields. +func TestHandleObservation_WithFullData(t *testing.T) { + svc, cleanup := testService(t) + defer cleanup() + + // Create a session first + ctx := context.Background() + svc.sessionStore.CreateSDKSession(ctx, "obs-full-test", "test-project", "test prompt") + + reqBody := ObservationRequest{ + ClaudeSessionID: "obs-full-test", + Project: "test-project", + ToolName: "Edit", + ToolInput: map[string]interface{}{"file_path": "/test.go", "old_string": "foo", "new_string": "bar"}, + ToolResponse: "Edit successful", + CWD: "/home/user/project", + } + + body, _ := json.Marshal(reqBody) + req := httptest.NewRequest(http.MethodPost, "/api/sessions/observations", bytes.NewReader(body)) + req.Header.Set("Content-Type", "application/json") + rec := httptest.NewRecorder() + + svc.router.ServeHTTP(rec, req) + + assert.Equal(t, http.StatusOK, rec.Code) +} + +// TestHandleSelfCheck_WithObservations tests self-check with observations in DB. +func TestHandleSelfCheck_WithObservations(t *testing.T) { + svc, cleanup := testService(t) + defer cleanup() + + svc.ready.Store(true) + + // Create some observations + createTestObservation(t, svc.observationStore, "check-project", "Test", "Content", []string{"test"}) + + req := httptest.NewRequest(http.MethodGet, "/api/self-check", nil) + rec := httptest.NewRecorder() + + svc.handleSelfCheck(rec, req) + + assert.Equal(t, http.StatusOK, rec.Code) + + var response SelfCheckResponse + err := json.Unmarshal(rec.Body.Bytes(), &response) + require.NoError(t, err) + + // Check components are populated + assert.NotEmpty(t, response.Components) +} + +// TestHandleGetSummaries_NoProject tests getting summaries without project filter. +func TestHandleGetSummaries_NoProject(t *testing.T) { + svc, cleanup := testService(t) + defer cleanup() + + // Create some summaries in different projects + ctx := context.Background() + for i := 0; i < 3; i++ { + parsed := &models.ParsedSummary{Request: "Request " + string(rune('A'+i))} + svc.summaryStore.StoreSummary(ctx, "sdk-"+string(rune('a'+i)), "project-"+string(rune('a'+i)), parsed, i+1, 100) + } + + req := httptest.NewRequest(http.MethodGet, "/api/summaries", nil) + rec := httptest.NewRecorder() + + svc.router.ServeHTTP(rec, req) + + assert.Equal(t, http.StatusOK, rec.Code) + + var summaries []map[string]interface{} + err := json.Unmarshal(rec.Body.Bytes(), &summaries) + require.NoError(t, err) + + // Should return all summaries + assert.GreaterOrEqual(t, len(summaries), 3) +} + +// TestHandleGetPrompts_NoProject tests getting prompts without project filter. +func TestHandleGetPrompts_NoProject(t *testing.T) { + svc, cleanup := testService(t) + defer cleanup() + + // Create sessions and prompts in different projects + ctx := context.Background() + svc.sessionStore.CreateSDKSession(ctx, "claude-prompts-a", "project-a", "") + svc.sessionStore.CreateSDKSession(ctx, "claude-prompts-b", "project-b", "") + + svc.promptStore.SaveUserPromptWithMatches(ctx, "claude-prompts-a", 1, "Prompt A", 0) + svc.promptStore.SaveUserPromptWithMatches(ctx, "claude-prompts-b", 1, "Prompt B", 0) + + req := httptest.NewRequest(http.MethodGet, "/api/prompts", nil) + rec := httptest.NewRecorder() + + svc.router.ServeHTTP(rec, req) + + assert.Equal(t, http.StatusOK, rec.Code) + + var prompts []map[string]interface{} + err := json.Unmarshal(rec.Body.Bytes(), &prompts) + require.NoError(t, err) + + assert.GreaterOrEqual(t, len(prompts), 2) +} + +// TestHandleSessionInit_MissingClaudeID tests session init without Claude ID. +func TestHandleSessionInit_MissingClaudeID(t *testing.T) { + svc, cleanup := testService(t) + defer cleanup() + + reqBody := SessionInitRequest{ + Project: "test-project", + Prompt: "Help me", + } + + body, _ := json.Marshal(reqBody) + req := httptest.NewRequest(http.MethodPost, "/api/sessions/init", bytes.NewReader(body)) + req.Header.Set("Content-Type", "application/json") + rec := httptest.NewRecorder() + + svc.router.ServeHTTP(rec, req) + + // Should accept even without Claude ID (may auto-generate) + assert.Contains(t, []int{http.StatusOK, http.StatusBadRequest}, rec.Code) +} + +// TestHandleContextInject_WithQuery tests context inject with query parameter. +func TestHandleContextInject_WithQuery(t *testing.T) { + svc, cleanup := testService(t) + defer cleanup() + + project := "inject-query-test" + createTestObservation(t, svc.observationStore, project, "Authentication bug fix", "Fixed JWT validation", []string{"auth", "jwt"}) + createTestObservation(t, svc.observationStore, project, "Database optimization", "Added indexes", []string{"db", "performance"}) + + req := httptest.NewRequest(http.MethodGet, "/api/context/inject?project="+project+"&query=authentication", nil) + rec := httptest.NewRecorder() + + svc.router.ServeHTTP(rec, req) + + assert.Equal(t, http.StatusOK, rec.Code) + + var response map[string]interface{} + err := json.Unmarshal(rec.Body.Bytes(), &response) + require.NoError(t, err) + + observations, ok := response["observations"].([]interface{}) + require.True(t, ok) + assert.GreaterOrEqual(t, len(observations), 1) +} diff --git a/internal/worker/sdk/processor_test.go b/internal/worker/sdk/processor_test.go index 66059fa..de0c2b5 100644 --- a/internal/worker/sdk/processor_test.go +++ b/internal/worker/sdk/processor_test.go @@ -1,9 +1,12 @@ package sdk import ( + "os" + "path/filepath" "testing" "github.com/lukaszraczylo/claude-mnemonic/pkg/models" + "github.com/stretchr/testify/assert" ) func TestIsSelfReferentialSummary(t *testing.T) { @@ -124,3 +127,381 @@ No substantive work performed yet.`, }) } } + +func TestShouldSkipTool(t *testing.T) { + tests := []struct { + name string + toolName string + expected bool + }{ + // Tools that should be skipped + {"TodoWrite", "TodoWrite", true}, + {"Task", "Task", true}, + {"TaskOutput", "TaskOutput", true}, + {"Glob", "Glob", true}, + {"ListDir", "ListDir", true}, + {"LS", "LS", true}, + {"KillShell", "KillShell", true}, + {"AskUserQuestion", "AskUserQuestion", true}, + {"EnterPlanMode", "EnterPlanMode", true}, + {"ExitPlanMode", "ExitPlanMode", true}, + {"Skill", "Skill", true}, + {"SlashCommand", "SlashCommand", true}, + + // Tools that should NOT be skipped + {"Read", "Read", false}, + {"Edit", "Edit", false}, + {"Write", "Write", false}, + {"Grep", "Grep", false}, + {"Bash", "Bash", false}, + {"WebFetch", "WebFetch", false}, + {"WebSearch", "WebSearch", false}, + {"NotebookEdit", "NotebookEdit", false}, + + // Unknown tool (should not be skipped) + {"UnknownTool", "SomeUnknownTool", false}, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + result := shouldSkipTool(tt.toolName) + assert.Equal(t, tt.expected, result) + }) + } +} + +func TestShouldSkipTrivialOperation(t *testing.T) { + tests := []struct { + name string + toolName string + inputStr string + outputStr string + expected bool + }{ + // Short output (should be skipped) + { + name: "output_too_short", + toolName: "Read", + inputStr: `{"file_path": "/some/file.go"}`, + outputStr: "short", + expected: true, + }, + // Trivial outputs + { + name: "no_matches_found", + toolName: "Grep", + inputStr: `{"pattern": "foo"}`, + outputStr: "No matches found in the codebase", + expected: true, + }, + { + name: "file_not_found", + toolName: "Read", + inputStr: `{"file_path": "/nonexistent.go"}`, + outputStr: "Error: File not found at specified path", + expected: true, + }, + { + name: "empty_array", + toolName: "Grep", + inputStr: `{"pattern": "foo"}`, + outputStr: "[]", + expected: true, + }, + // Boring files + { + name: "package_lock_json", + toolName: "Read", + inputStr: `{"file_path": "/project/package-lock.json"}`, + outputStr: "This is a very long package-lock.json content that has more than 50 characters", + expected: true, + }, + { + name: "go_sum", + toolName: "Read", + inputStr: `{"file_path": "/project/go.sum"}`, + outputStr: "This is a very long go.sum file content that has more than 50 characters", + expected: true, + }, + // Grep with too many matches + { + name: "grep_too_many_matches", + toolName: "Grep", + inputStr: `{"pattern": "import"}`, + outputStr: func() string { + s := "" + for i := 0; i < 55; i++ { + s += "match line\n" + } + return s + }(), + expected: true, + }, + // Boring Bash commands + { + name: "git_status", + toolName: "Bash", + inputStr: `{"command": "git status"}`, + outputStr: "On branch main\nYour branch is up to date with 'origin/main'.\nnothing to commit, working tree clean", + expected: true, + }, + { + name: "ls_command", + toolName: "Bash", + inputStr: `{"command": "ls -la /some/directory"}`, + outputStr: "total 123\ndrwxr-xr-x some long listing that is at least 50 chars", + expected: true, + }, + // Valid operations that should NOT be skipped + { + name: "valid_read", + toolName: "Read", + inputStr: `{"file_path": "/project/main.go"}`, + outputStr: "package main\n\nimport \"fmt\"\n\nfunc main() {\n\tfmt.Println(\"Hello World\")\n}", + expected: false, + }, + { + name: "valid_edit", + toolName: "Edit", + inputStr: `{"file_path": "/project/handler.go", "old_string": "foo", "new_string": "bar"}`, + outputStr: "Edit applied successfully. File /project/handler.go has been modified with the requested changes.", + expected: false, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + result := shouldSkipTrivialOperation(tt.toolName, tt.inputStr, tt.outputStr) + assert.Equal(t, tt.expected, result) + }) + } +} + +func TestTruncateForLog(t *testing.T) { + tests := []struct { + name string + input string + maxLen int + expected string + }{ + { + name: "shorter_than_max", + input: "hello", + maxLen: 10, + expected: "hello", + }, + { + name: "equal_to_max", + input: "hello", + maxLen: 5, + expected: "hello", + }, + { + name: "longer_than_max", + input: "hello world", + maxLen: 5, + expected: "hello...", + }, + { + name: "empty_string", + input: "", + maxLen: 5, + expected: "", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + result := truncateForLog(tt.input, tt.maxLen) + assert.Equal(t, tt.expected, result) + }) + } +} + +func TestToJSONString(t *testing.T) { + tests := []struct { + name string + input interface{} + expected string + }{ + { + name: "nil_value", + input: nil, + expected: "", + }, + { + name: "string_value", + input: "hello", + expected: "hello", + }, + { + name: "int_value", + input: 42, + expected: "42", + }, + { + name: "map_value", + input: map[string]string{"key": "value"}, + expected: `{"key":"value"}`, + }, + { + name: "slice_value", + input: []string{"a", "b", "c"}, + expected: `["a","b","c"]`, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + result := toJSONString(tt.input) + assert.Equal(t, tt.expected, result) + }) + } +} + +func TestCaptureFileMtimes(t *testing.T) { + // Create a temp directory with test files + tmpDir, err := os.MkdirTemp("", "mtime-test-*") + if err != nil { + t.Fatal(err) + } + defer os.RemoveAll(tmpDir) + + // Create test files + file1 := filepath.Join(tmpDir, "file1.txt") + file2 := filepath.Join(tmpDir, "file2.txt") + + err = os.WriteFile(file1, []byte("content1"), 0644) + if err != nil { + t.Fatal(err) + } + err = os.WriteFile(file2, []byte("content2"), 0644) + if err != nil { + t.Fatal(err) + } + + t.Run("captures_mtimes_for_existing_files", func(t *testing.T) { + mtimes := captureFileMtimes([]string{file1}, []string{file2}, "") + assert.Len(t, mtimes, 2) + assert.Contains(t, mtimes, file1) + assert.Contains(t, mtimes, file2) + assert.Greater(t, mtimes[file1], int64(0)) + assert.Greater(t, mtimes[file2], int64(0)) + }) + + t.Run("handles_nonexistent_files", func(t *testing.T) { + mtimes := captureFileMtimes([]string{"/nonexistent/file.txt"}, nil, "") + assert.Empty(t, mtimes) + }) + + t.Run("handles_relative_paths_with_cwd", func(t *testing.T) { + mtimes := captureFileMtimes([]string{"file1.txt"}, nil, tmpDir) + assert.Len(t, mtimes, 1) + assert.Contains(t, mtimes, "file1.txt") + }) + + t.Run("empty_inputs", func(t *testing.T) { + mtimes := captureFileMtimes(nil, nil, "") + assert.Empty(t, mtimes) + }) +} + +func TestGetFileMtimes(t *testing.T) { + // Create a temp file + tmpDir, err := os.MkdirTemp("", "getmtime-test-*") + if err != nil { + t.Fatal(err) + } + defer os.RemoveAll(tmpDir) + + testFile := filepath.Join(tmpDir, "test.txt") + err = os.WriteFile(testFile, []byte("content"), 0644) + if err != nil { + t.Fatal(err) + } + + mtimes := GetFileMtimes([]string{testFile}, "") + assert.Len(t, mtimes, 1) + assert.Contains(t, mtimes, testFile) + assert.Greater(t, mtimes[testFile], int64(0)) +} + +func TestGetFileContent(t *testing.T) { + // Create a temp directory with test files + tmpDir, err := os.MkdirTemp("", "content-test-*") + if err != nil { + t.Fatal(err) + } + defer os.RemoveAll(tmpDir) + + t.Run("reads_existing_file", func(t *testing.T) { + testFile := filepath.Join(tmpDir, "test.txt") + content := "test content" + err := os.WriteFile(testFile, []byte(content), 0644) + if err != nil { + t.Fatal(err) + } + + result, ok := GetFileContent(testFile, "") + assert.True(t, ok) + assert.Equal(t, content, result) + }) + + t.Run("returns_false_for_nonexistent_file", func(t *testing.T) { + result, ok := GetFileContent("/nonexistent/file.txt", "") + assert.False(t, ok) + assert.Empty(t, result) + }) + + t.Run("truncates_long_content", func(t *testing.T) { + testFile := filepath.Join(tmpDir, "long.txt") + longContent := "" + for i := 0; i < 3000; i++ { + longContent += "x" + } + err := os.WriteFile(testFile, []byte(longContent), 0644) + if err != nil { + t.Fatal(err) + } + + result, ok := GetFileContent(testFile, "") + assert.True(t, ok) + assert.Contains(t, result, "[truncated]") + assert.LessOrEqual(t, len(result), 2100) + }) + + t.Run("resolves_relative_path_with_cwd", func(t *testing.T) { + testFile := filepath.Join(tmpDir, "relative.txt") + content := "relative content" + err := os.WriteFile(testFile, []byte(content), 0644) + if err != nil { + t.Fatal(err) + } + + result, ok := GetFileContent("relative.txt", tmpDir) + assert.True(t, ok) + assert.Equal(t, content, result) + }) +} + +func TestMaxConcurrentCLICalls(t *testing.T) { + assert.Equal(t, 4, MaxConcurrentCLICalls) +} + +func TestObservationTypes(t *testing.T) { + expected := []string{"bugfix", "feature", "refactor", "change", "discovery", "decision"} + assert.Equal(t, expected, ObservationTypes) +} + +func TestObservationConcepts(t *testing.T) { + expectedConcepts := []string{ + "how-it-works", + "why-it-exists", + "what-changed", + "problem-solution", + "gotcha", + "pattern", + "trade-off", + } + assert.Equal(t, expectedConcepts, ObservationConcepts) +} diff --git a/internal/worker/sdk/prompts_test.go b/internal/worker/sdk/prompts_test.go new file mode 100644 index 0000000..eecf7d5 --- /dev/null +++ b/internal/worker/sdk/prompts_test.go @@ -0,0 +1,276 @@ +package sdk + +import ( + "strings" + "testing" + "time" + + "github.com/stretchr/testify/assert" +) + +func TestTruncate(t *testing.T) { + tests := []struct { + name string + input string + maxLen int + expected string + }{ + { + name: "shorter_than_max", + input: "hello", + maxLen: 10, + expected: "hello", + }, + { + name: "equal_to_max", + input: "hello", + maxLen: 5, + expected: "hello", + }, + { + name: "longer_than_max", + input: "hello world", + maxLen: 5, + expected: "hello... (truncated)", + }, + { + name: "empty_string", + input: "", + maxLen: 5, + expected: "", + }, + { + name: "zero_max_length", + input: "hello", + maxLen: 0, + expected: "... (truncated)", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + result := truncate(tt.input, tt.maxLen) + assert.Equal(t, tt.expected, result) + }) + } +} + +func TestBuildObservationPrompt(t *testing.T) { + now := time.Now().UnixMilli() + + tests := []struct { + name string + exec ToolExecution + contains []string + }{ + { + name: "basic_read_tool", + exec: ToolExecution{ + ID: 1, + ToolName: "Read", + ToolInput: `{"file_path": "/path/to/file.go"}`, + ToolOutput: `package main\nfunc main() {}`, + CreatedAtEpoch: now, + CWD: "/project", + }, + contains: []string{ + "", + "Read", + "/project", + "", + "file_path", + "", + "", + }, + }, + { + name: "edit_tool_with_json_input", + exec: ToolExecution{ + ID: 2, + ToolName: "Edit", + ToolInput: `{"file_path": "/file.go", "old_string": "foo", "new_string": "bar"}`, + ToolOutput: "Edit applied successfully", + CreatedAtEpoch: now, + CWD: "", + }, + contains: []string{ + "Edit", + "file_path", + "old_string", + "new_string", + }, + }, + { + name: "no_cwd", + exec: ToolExecution{ + ID: 3, + ToolName: "Bash", + ToolInput: `{"command": "go test"}`, + ToolOutput: "ok", + CreatedAtEpoch: now, + CWD: "", + }, + contains: []string{ + "Bash", + }, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + result := BuildObservationPrompt(tt.exec) + + for _, s := range tt.contains { + assert.Contains(t, result, s, "Expected result to contain: %s", s) + } + + // Check CWD only appears when set + if tt.exec.CWD == "" { + assert.NotContains(t, result, "") + } + }) + } +} + +func TestBuildObservationPrompt_TruncatesLongContent(t *testing.T) { + longInput := strings.Repeat("x", 5000) + longOutput := strings.Repeat("y", 7000) + + exec := ToolExecution{ + ID: 1, + ToolName: "Read", + ToolInput: longInput, + ToolOutput: longOutput, + CreatedAtEpoch: time.Now().UnixMilli(), + CWD: "/project", + } + + result := BuildObservationPrompt(exec) + + // Input should be truncated to ~3000 + assert.Contains(t, result, "truncated") + // The result should not be excessively long + assert.Less(t, len(result), 10000) +} + +func TestBuildSummaryPrompt(t *testing.T) { + tests := []struct { + name string + req SummaryRequest + contains []string + }{ + { + name: "basic_request", + req: SummaryRequest{ + SessionDBID: 1, + SDKSessionID: "sdk-123", + Project: "test-project", + }, + contains: []string{ + "PROGRESS SUMMARY CHECKPOINT", + "", + "", + "", + "", + "", + "", + "", + "", + }, + }, + { + name: "with_assistant_message", + req: SummaryRequest{ + SessionDBID: 2, + SDKSessionID: "sdk-456", + Project: "project-b", + LastAssistantMessage: "I fixed the authentication bug by updating the JWT validation.", + }, + contains: []string{ + "Claude's Full Response to User:", + "fixed the authentication", + }, + }, + { + name: "empty_assistant_message", + req: SummaryRequest{ + SessionDBID: 3, + SDKSessionID: "sdk-789", + Project: "project-c", + LastAssistantMessage: "", + }, + contains: []string{ + "PROGRESS SUMMARY CHECKPOINT", + }, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + result := BuildSummaryPrompt(tt.req) + + for _, s := range tt.contains { + assert.Contains(t, result, s, "Expected result to contain: %s", s) + } + + // Check assistant message only appears when set + if tt.req.LastAssistantMessage == "" { + assert.NotContains(t, result, "Claude's Full Response") + } + }) + } +} + +func TestBuildSummaryPrompt_TruncatesLongAssistantMessage(t *testing.T) { + longMessage := strings.Repeat("a", 5000) + + req := SummaryRequest{ + SessionDBID: 1, + SDKSessionID: "sdk-123", + Project: "test", + LastAssistantMessage: longMessage, + } + + result := BuildSummaryPrompt(req) + + // Should contain truncation indicator + assert.Contains(t, result, "truncated") + // Result should be reasonable length (less than full 5000 + overhead) + assert.Less(t, len(result), 6000) +} + +func TestToolExecution_Struct(t *testing.T) { + exec := ToolExecution{ + ID: 42, + ToolName: "Write", + ToolInput: `{"file_path": "/test.go"}`, + ToolOutput: "File written", + CreatedAtEpoch: 1234567890000, + CWD: "/workspace", + } + + assert.Equal(t, int64(42), exec.ID) + assert.Equal(t, "Write", exec.ToolName) + assert.Equal(t, `{"file_path": "/test.go"}`, exec.ToolInput) + assert.Equal(t, "File written", exec.ToolOutput) + assert.Equal(t, int64(1234567890000), exec.CreatedAtEpoch) + assert.Equal(t, "/workspace", exec.CWD) +} + +func TestSummaryRequest_Struct(t *testing.T) { + req := SummaryRequest{ + SessionDBID: 100, + SDKSessionID: "sdk-abc", + Project: "my-project", + UserPrompt: "Fix the bug", + LastUserMessage: "Please fix the auth bug", + LastAssistantMessage: "I've fixed the authentication issue", + } + + assert.Equal(t, int64(100), req.SessionDBID) + assert.Equal(t, "sdk-abc", req.SDKSessionID) + assert.Equal(t, "my-project", req.Project) + assert.Equal(t, "Fix the bug", req.UserPrompt) + assert.Equal(t, "Please fix the auth bug", req.LastUserMessage) + assert.Equal(t, "I've fixed the authentication issue", req.LastAssistantMessage) +} diff --git a/pkg/hooks/worker_test.go b/pkg/hooks/worker_test.go index 416409e..74aab43 100644 --- a/pkg/hooks/worker_test.go +++ b/pkg/hooks/worker_test.go @@ -710,3 +710,495 @@ func TestGetWorkerVersion_WithServer(t *testing.T) { }) } } + +// TestGetWorkerPort_EdgeCases tests GetWorkerPort with various edge cases. +func TestGetWorkerPort_EdgeCases(t *testing.T) { + tests := []struct { + name string + envValue string + expectedPort int + shouldSetEnv bool + }{ + { + name: "zero port uses default", + envValue: "0", + expectedPort: DefaultWorkerPort, + shouldSetEnv: true, + }, + { + name: "negative port uses default", + envValue: "-1", + expectedPort: DefaultWorkerPort, + shouldSetEnv: true, + }, + { + name: "empty string uses default", + envValue: "", + expectedPort: DefaultWorkerPort, + shouldSetEnv: true, + }, + { + name: "whitespace uses default", + envValue: " ", + expectedPort: DefaultWorkerPort, + shouldSetEnv: true, + }, + { + name: "large valid port", + envValue: "65535", + expectedPort: 65535, + shouldSetEnv: true, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + if tt.shouldSetEnv { + t.Setenv("CLAUDE_MNEMONIC_WORKER_PORT", tt.envValue) + } + port := GetWorkerPort() + assert.Equal(t, tt.expectedPort, port) + }) + } +} + +// TestVersionVariable tests the Version variable. +func TestVersionVariable(t *testing.T) { + // Version is set at build time, but defaults to "dev" + assert.NotEmpty(t, Version) +} + +// TestProjectIDWithName_RootPath tests ProjectIDWithName with root path. +func TestProjectIDWithName_RootPath(t *testing.T) { + result := ProjectIDWithName("/") + // Should handle root path gracefully + assert.NotEmpty(t, result) + assert.Contains(t, result, "_") // Should still have underscore separator +} + +// TestProjectIDWithName_SameDirname tests that same dirname with different paths get different IDs. +func TestProjectIDWithName_SameDirname(t *testing.T) { + id1 := ProjectIDWithName("/home/user1/project") + id2 := ProjectIDWithName("/home/user2/project") + + // Both have same dirname "project" but different full paths + assert.Contains(t, id1, "project_") + assert.Contains(t, id2, "project_") + + // But different hashes due to different full paths + assert.NotEqual(t, id1, id2) +} + +// TestBaseInput_PartialFields tests BaseInput with partial fields. +func TestBaseInput_PartialFields(t *testing.T) { + tests := []struct { + name string + input string + expected BaseInput + }{ + { + name: "only session_id", + input: `{"session_id":"test-123"}`, + expected: BaseInput{SessionID: "test-123"}, + }, + { + name: "only cwd", + input: `{"cwd":"/tmp/test"}`, + expected: BaseInput{CWD: "/tmp/test"}, + }, + { + name: "empty object", + input: `{}`, + expected: BaseInput{}, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + var base BaseInput + err := json.Unmarshal([]byte(tt.input), &base) + require.NoError(t, err) + assert.Equal(t, tt.expected.SessionID, base.SessionID) + assert.Equal(t, tt.expected.CWD, base.CWD) + }) + } +} + +// TestHookResponse_Marshal tests HookResponse JSON marshaling. +func TestHookResponse_Marshal(t *testing.T) { + tests := []struct { + name string + response HookResponse + contains []string + }{ + { + name: "continue true", + response: HookResponse{Continue: true}, + contains: []string{`"continue":true`}, + }, + { + name: "continue false", + response: HookResponse{Continue: false}, + contains: []string{`"continue":false`}, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + data, err := json.Marshal(tt.response) + require.NoError(t, err) + for _, s := range tt.contains { + assert.Contains(t, string(data), s) + } + }) + } +} + +// TestHookResponse_Unmarshal tests HookResponse JSON unmarshaling. +func TestHookResponse_Unmarshal(t *testing.T) { + tests := []struct { + name string + input string + expected HookResponse + }{ + { + name: "continue true", + input: `{"continue":true}`, + expected: HookResponse{Continue: true}, + }, + { + name: "continue false", + input: `{"continue":false}`, + expected: HookResponse{Continue: false}, + }, + { + name: "missing continue defaults to false", + input: `{}`, + expected: HookResponse{Continue: false}, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + var resp HookResponse + err := json.Unmarshal([]byte(tt.input), &resp) + require.NoError(t, err) + assert.Equal(t, tt.expected.Continue, resp.Continue) + }) + } +} + +// TestHookContext_Initialization tests HookContext struct initialization. +func TestHookContext_Initialization(t *testing.T) { + tests := []struct { + name string + ctx HookContext + }{ + { + name: "full context", + ctx: HookContext{ + HookName: "session-start", + Port: 37777, + Project: "my-project_abc123", + SessionID: "session-123", + CWD: "/home/user/project", + RawInput: []byte(`{"key":"value"}`), + }, + }, + { + name: "minimal context", + ctx: HookContext{ + HookName: "stop", + }, + }, + { + name: "empty context", + ctx: HookContext{}, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + // Just verify the struct can be created and accessed + assert.Equal(t, tt.ctx.HookName, tt.ctx.HookName) + assert.Equal(t, tt.ctx.Port, tt.ctx.Port) + assert.Equal(t, tt.ctx.Project, tt.ctx.Project) + }) + } +} + +// TestPOST_MarshalError tests POST with unmarshalable body. +func TestPOST_MarshalError(t *testing.T) { + // Create a value that can't be marshaled + badValue := make(chan int) + + _, err := POST(99999, "/test", badValue) + require.Error(t, err) +} + +// TestPOST_Timeout tests POST with timeout. +func TestPOST_Timeout(t *testing.T) { + // Try to connect to a port that's not listening + _, err := POST(99998, "/test", map[string]string{"key": "value"}) + require.Error(t, err) +} + +// TestGET_Timeout tests GET with timeout. +func TestGET_Timeout(t *testing.T) { + // Try to connect to a port that's not listening + _, err := GET(99998, "/test") + require.Error(t, err) +} + +// TestIsWorkerRunning_Timeout tests IsWorkerRunning with timeout. +func TestIsWorkerRunning_Timeout(t *testing.T) { + // Non-existent port should quickly return false + start := time.Now() + result := IsWorkerRunning(99997) + elapsed := time.Since(start) + + assert.False(t, result) + assert.Less(t, elapsed, 5*time.Second) // Should not hang +} + +// TestIsPortInUse_Timeout tests IsPortInUse with timeout. +func TestIsPortInUse_Timeout(t *testing.T) { + // Non-existent port should quickly return false + start := time.Now() + result := IsPortInUse(99996) + elapsed := time.Since(start) + + assert.False(t, result) + assert.Less(t, elapsed, 2*time.Second) // Should not hang +} + +// TestGetWorkerVersion_Timeout tests GetWorkerVersion with timeout. +func TestGetWorkerVersion_Timeout(t *testing.T) { + // Non-existent port should quickly return empty + start := time.Now() + result := GetWorkerVersion(99995) + elapsed := time.Since(start) + + assert.Empty(t, result) + assert.Less(t, elapsed, 5*time.Second) // Should not hang +} + +// TestVersionsCompatible_EdgeCases tests versionsCompatible edge cases. +func TestVersionsCompatible_EdgeCases(t *testing.T) { + tests := []struct { + name string + v1 string + v2 string + expected bool + }{ + { + name: "empty versions", + v1: "", + v2: "", + expected: true, // Same base (empty) + }, + { + name: "one empty one dev", + v1: "", + v2: "dev", + expected: true, // dev is compatible with anything + }, + { + name: "prerelease versions same base", + v1: "v1.0.0-alpha", + v2: "v1.0.0-beta", + expected: true, // Same base 1.0.0 + }, + { + name: "version with rc suffix", + v1: "v2.0.0-rc1", + v2: "v2.0.0", + expected: true, // Same base 2.0.0 + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + result := versionsCompatible(tt.v1, tt.v2) + assert.Equal(t, tt.expected, result) + }) + } +} + +// TestExtractBaseVersion_EdgeCases tests extractBaseVersion edge cases. +func TestExtractBaseVersion_EdgeCases(t *testing.T) { + tests := []struct { + name string + version string + expected string + }{ + { + name: "version starting with hyphen", + version: "-dirty", + expected: "-dirty", // hyphen at index 0 is not > 0, so no truncation + }, + { + name: "just v", + version: "v", + expected: "", + }, + { + name: "multiple hyphens", + version: "v1.0.0-alpha-beta-gamma", + expected: "1.0.0", + }, + { + name: "no hyphen at all", + version: "v2.0.0", + expected: "2.0.0", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + result := extractBaseVersion(tt.version) + assert.Equal(t, tt.expected, result) + }) + } +} + +// TestProjectIDWithName_RelativePath tests ProjectIDWithName with relative paths. +func TestProjectIDWithName_RelativePath(t *testing.T) { + // Relative paths should be converted to absolute + result := ProjectIDWithName(".") + assert.NotEmpty(t, result) + assert.Contains(t, result, "_") +} + +// TestProjectIDWithName_DeepPath tests ProjectIDWithName with deep paths. +func TestProjectIDWithName_DeepPath(t *testing.T) { + result := ProjectIDWithName("/a/very/deep/nested/path/to/project") + assert.Contains(t, result, "project_") + assert.NotEmpty(t, result) +} + +// TestPOST_EmptyBody tests POST with empty body. +func TestPOST_EmptyBody(t *testing.T) { + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + assert.Equal(t, http.MethodPost, r.Method) + w.WriteHeader(http.StatusOK) + json.NewEncoder(w).Encode(map[string]string{"status": "ok"}) + })) + defer server.Close() + + var port int + _, err := fmt.Sscanf(server.URL, "http://127.0.0.1:%d", &port) + require.NoError(t, err) + + result, err := POST(port, "/test", map[string]string{}) + require.NoError(t, err) + assert.NotNil(t, result) +} + +// TestGET_WithQueryParams tests GET with query parameters. +func TestGET_WithQueryParams(t *testing.T) { + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + assert.Equal(t, http.MethodGet, r.Method) + assert.Equal(t, "/test?foo=bar", r.URL.String()) + w.WriteHeader(http.StatusOK) + json.NewEncoder(w).Encode(map[string]string{"status": "ok"}) + })) + defer server.Close() + + var port int + _, err := fmt.Sscanf(server.URL, "http://127.0.0.1:%d", &port) + require.NoError(t, err) + + result, err := GET(port, "/test?foo=bar") + require.NoError(t, err) + assert.NotNil(t, result) +} + +// TestHookResponse_RoundTrip tests JSON marshal/unmarshal round-trip. +func TestHookResponse_RoundTrip(t *testing.T) { + original := HookResponse{Continue: true} + + data, err := json.Marshal(original) + require.NoError(t, err) + + var decoded HookResponse + err = json.Unmarshal(data, &decoded) + require.NoError(t, err) + + assert.Equal(t, original.Continue, decoded.Continue) +} + +// TestBaseInput_RoundTrip tests BaseInput JSON round-trip. +func TestBaseInput_RoundTrip(t *testing.T) { + original := BaseInput{ + SessionID: "test-session", + CWD: "/home/user/project", + PermissionMode: "standard", + HookEventName: "session-start", + } + + data, err := json.Marshal(original) + require.NoError(t, err) + + var decoded BaseInput + err = json.Unmarshal(data, &decoded) + require.NoError(t, err) + + assert.Equal(t, original.SessionID, decoded.SessionID) + assert.Equal(t, original.CWD, decoded.CWD) + assert.Equal(t, original.PermissionMode, decoded.PermissionMode) + assert.Equal(t, original.HookEventName, decoded.HookEventName) +} + +// TestHookContext_RawInput tests HookContext with different raw input types. +func TestHookContext_RawInput(t *testing.T) { + tests := []struct { + name string + rawInput []byte + }{ + { + name: "json object", + rawInput: []byte(`{"key":"value"}`), + }, + { + name: "json array", + rawInput: []byte(`[1,2,3]`), + }, + { + name: "empty object", + rawInput: []byte(`{}`), + }, + { + name: "nil input", + rawInput: nil, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + ctx := HookContext{ + HookName: "test", + RawInput: tt.rawInput, + } + assert.Equal(t, tt.rawInput, ctx.RawInput) + }) + } +} + +// TestDefaultWorkerPort tests that the default port constant is valid. +func TestDefaultWorkerPort(t *testing.T) { + assert.Greater(t, DefaultWorkerPort, 1024, "Default port should be above privileged port range") + assert.Less(t, DefaultWorkerPort, 65535, "Default port should be valid TCP port") +} + +// TestHealthCheckTimeout tests the health check timeout is reasonable. +func TestHealthCheckTimeout(t *testing.T) { + assert.Greater(t, HealthCheckTimeout, 100*time.Millisecond) + assert.Less(t, HealthCheckTimeout, 10*time.Second) +} + +// TestStartupTimeout tests the startup timeout is reasonable. +func TestStartupTimeout(t *testing.T) { + assert.Greater(t, StartupTimeout, 5*time.Second) + assert.LessOrEqual(t, StartupTimeout, time.Minute) +}