diff --git a/Makefile b/Makefile index 24f97b3..8304c52 100644 --- a/Makefile +++ b/Makefile @@ -160,14 +160,15 @@ uninstall: stop-worker rm -rf $(HOME)/.claude/plugins/marketplaces/claude-mnemonic @echo "Uninstallation complete!" -# Run tests +# Run tests (with FTS5 support) test: setup-libs - go test -v -race ./... + go test $(BUILD_TAGS) -v -race ./... -# Run tests with coverage -test-coverage: - go test -v -race -coverprofile=coverage.out ./... +# Run tests with coverage (with FTS5 support) +test-coverage: setup-libs + go test $(BUILD_TAGS) -v -race -coverprofile=coverage.out ./... go tool cover -html=coverage.out -o coverage.html + @go tool cover -func=coverage.out | tail -1 # Run benchmarks bench: diff --git a/internal/config/config_test.go b/internal/config/config_test.go new file mode 100644 index 0000000..29731c2 --- /dev/null +++ b/internal/config/config_test.go @@ -0,0 +1,384 @@ +// Package config provides configuration management for claude-mnemonic. +package config + +import ( + "os" + "path/filepath" + "testing" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + "github.com/stretchr/testify/suite" +) + +// ConfigSuite is a test suite for config operations. +type ConfigSuite struct { + suite.Suite + tempDir string + origHomeDir string +} + +func (s *ConfigSuite) SetupTest() { + var err error + s.tempDir, err = os.MkdirTemp("", "config-test-*") + s.Require().NoError(err) + + // Save and override HOME + s.origHomeDir = os.Getenv("HOME") + os.Setenv("HOME", s.tempDir) +} + +func (s *ConfigSuite) TearDownTest() { + os.Setenv("HOME", s.origHomeDir) + os.RemoveAll(s.tempDir) +} + +func TestConfigSuite(t *testing.T) { + suite.Run(t, new(ConfigSuite)) +} + +// TestDefault tests default configuration values. +func (s *ConfigSuite) TestDefault() { + cfg := Default() + + s.Equal(DefaultWorkerPort, cfg.WorkerPort) + s.Equal(DefaultModel, cfg.Model) + s.Equal(4, cfg.MaxConns) + s.Equal(100, cfg.ContextObservations) + s.Equal(25, cfg.ContextFullCount) + s.Equal(10, cfg.ContextSessionCount) + s.True(cfg.ContextShowReadTokens) + s.True(cfg.ContextShowWorkTokens) + s.Equal("narrative", cfg.ContextFullField) + s.True(cfg.ContextShowLastSummary) + s.Equal(DefaultObservationTypes, cfg.ContextObsTypes) + s.Equal(DefaultObservationConcepts, cfg.ContextObsConcepts) +} + +// TestDataDir tests data directory path. +func (s *ConfigSuite) TestDataDir() { + dir := DataDir() + s.Contains(dir, ".claude-mnemonic") +} + +// TestDBPath tests database path. +func (s *ConfigSuite) TestDBPath() { + path := DBPath() + s.Contains(path, "claude-mnemonic.db") +} + +// TestSettingsPath tests settings file path. +func (s *ConfigSuite) TestSettingsPath() { + path := SettingsPath() + s.Contains(path, "settings.json") +} + +// TestEnsureDataDir tests data directory creation. +func (s *ConfigSuite) TestEnsureDataDir() { + err := EnsureDataDir() + s.NoError(err) + + dir := DataDir() + info, err := os.Stat(dir) + s.NoError(err) + s.True(info.IsDir()) +} + +// TestEnsureSettings tests settings file creation. +func (s *ConfigSuite) TestEnsureSettings() { + // First ensure data dir exists + err := EnsureDataDir() + s.NoError(err) + + // Ensure settings creates default file + err = EnsureSettings() + s.NoError(err) + + path := SettingsPath() + info, err := os.Stat(path) + s.NoError(err) + s.False(info.IsDir()) + + // Second call should not error (file exists) + err = EnsureSettings() + s.NoError(err) +} + +// TestEnsureAll tests full initialization. +func (s *ConfigSuite) TestEnsureAll() { + err := EnsureAll() + s.NoError(err) + + // Verify dir and settings exist + _, err = os.Stat(DataDir()) + s.NoError(err) + _, err = os.Stat(SettingsPath()) + s.NoError(err) +} + +// TestLoad_TableDriven tests configuration loading with various scenarios. +func (s *ConfigSuite) TestLoad_TableDriven() { + tests := []struct { + name string + settingsJSON string + expectedPort int + expectedModel string + expectedObsObs int + }{ + { + name: "no settings file", + settingsJSON: "", + expectedPort: DefaultWorkerPort, + expectedModel: DefaultModel, + expectedObsObs: 100, + }, + { + name: "custom port", + settingsJSON: `{"CLAUDE_MNEMONIC_WORKER_PORT": 38888}`, + expectedPort: 38888, + expectedModel: DefaultModel, + expectedObsObs: 100, + }, + { + name: "custom model", + settingsJSON: `{"CLAUDE_MNEMONIC_MODEL": "sonnet"}`, + expectedPort: DefaultWorkerPort, + expectedModel: "sonnet", + expectedObsObs: 100, + }, + { + name: "custom observations", + settingsJSON: `{"CLAUDE_MNEMONIC_CONTEXT_OBSERVATIONS": 200}`, + expectedPort: DefaultWorkerPort, + expectedModel: DefaultModel, + expectedObsObs: 200, + }, + { + name: "multiple settings", + settingsJSON: `{"CLAUDE_MNEMONIC_WORKER_PORT": 39999, "CLAUDE_MNEMONIC_MODEL": "opus", "CLAUDE_MNEMONIC_CONTEXT_OBSERVATIONS": 50}`, + expectedPort: 39999, + expectedModel: "opus", + expectedObsObs: 50, + }, + { + name: "invalid JSON returns defaults", + settingsJSON: `{invalid}`, + expectedPort: DefaultWorkerPort, + expectedModel: DefaultModel, + expectedObsObs: 100, + }, + } + + for _, tt := range tests { + s.Run(tt.name, func() { + // Create fresh temp dir + tempDir, err := os.MkdirTemp("", "config-test-*") + s.Require().NoError(err) + defer os.RemoveAll(tempDir) + + os.Setenv("HOME", tempDir) + + // Create data dir + err = os.MkdirAll(filepath.Join(tempDir, ".claude-mnemonic"), 0750) + s.Require().NoError(err) + + if tt.settingsJSON != "" { + err := os.WriteFile( + filepath.Join(tempDir, ".claude-mnemonic", "settings.json"), + []byte(tt.settingsJSON), + 0600, + ) + s.Require().NoError(err) + } + + cfg, err := Load() + s.NoError(err) + s.NotNil(cfg) + s.Equal(tt.expectedPort, cfg.WorkerPort) + s.Equal(tt.expectedModel, cfg.Model) + s.Equal(tt.expectedObsObs, cfg.ContextObservations) + }) + } +} + +// TestGetWorkerPort_TableDriven tests worker port retrieval with various scenarios. +func TestGetWorkerPort_TableDriven(t *testing.T) { + tests := []struct { + name string + envValue string + wantPort int + setEnv bool + }{ + { + name: "no env, use default", + envValue: "", + wantPort: DefaultWorkerPort, + setEnv: false, + }, + { + name: "env set to valid port", + envValue: "38888", + wantPort: 38888, + setEnv: true, + }, + { + name: "env set to invalid value", + envValue: "invalid", + wantPort: DefaultWorkerPort, + setEnv: true, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + // Save original env + origEnv := os.Getenv("CLAUDE_MNEMONIC_WORKER_PORT") + defer os.Setenv("CLAUDE_MNEMONIC_WORKER_PORT", origEnv) + + if tt.setEnv { + os.Setenv("CLAUDE_MNEMONIC_WORKER_PORT", tt.envValue) + } else { + os.Unsetenv("CLAUDE_MNEMONIC_WORKER_PORT") + } + + // We can't easily test GetWorkerPort since it uses Get() which caches + // So we test the env parsing logic directly + if tt.setEnv && tt.envValue != "" { + if tt.wantPort != DefaultWorkerPort { + assert.Equal(t, tt.envValue, os.Getenv("CLAUDE_MNEMONIC_WORKER_PORT")) + } + } + }) + } +} + +// TestSplitTrim tests the splitTrim helper function. +func TestSplitTrim(t *testing.T) { + tests := []struct { + name string + input string + expected []string + }{ + { + name: "empty string", + input: "", + expected: []string{}, + }, + { + name: "single value", + input: "bugfix", + expected: []string{"bugfix"}, + }, + { + name: "multiple values", + input: "bugfix,feature,refactor", + expected: []string{"bugfix", "feature", "refactor"}, + }, + { + name: "values with spaces", + input: " bugfix , feature , refactor ", + expected: []string{"bugfix", "feature", "refactor"}, + }, + { + name: "empty values filtered", + input: "bugfix,,feature,,", + expected: []string{"bugfix", "feature"}, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + result := splitTrim(tt.input) + assert.Equal(t, tt.expected, result) + }) + } +} + +// TestDefaultObservationTypes tests default observation types. +func TestDefaultObservationTypes(t *testing.T) { + expected := []string{ + "bugfix", "feature", "refactor", "change", "discovery", "decision", + } + assert.Equal(t, expected, DefaultObservationTypes) +} + +// TestDefaultObservationConcepts tests default observation concepts. +func TestDefaultObservationConcepts(t *testing.T) { + expected := []string{ + "how-it-works", "why-it-exists", "what-changed", + "problem-solution", "gotcha", "pattern", "trade-off", + } + assert.Equal(t, expected, DefaultObservationConcepts) +} + +// TestCriticalConcepts tests critical concepts list. +func TestCriticalConcepts(t *testing.T) { + expected := []string{ + "gotcha", "pattern", "problem-solution", "trade-off", + } + assert.Equal(t, expected, CriticalConcepts) +} + +// TestLoad_ClaudeCodePath tests claude code path loading. +func TestLoad_ClaudeCodePath(t *testing.T) { + // Create temp dir + tempDir, err := os.MkdirTemp("", "config-test-*") + require.NoError(t, err) + defer os.RemoveAll(tempDir) + + origHome := os.Getenv("HOME") + os.Setenv("HOME", tempDir) + defer os.Setenv("HOME", origHome) + + // Create data dir and settings + err = os.MkdirAll(filepath.Join(tempDir, ".claude-mnemonic"), 0750) + require.NoError(t, err) + + settingsJSON := `{"CLAUDE_CODE_PATH": "/usr/local/bin/claude"}` + err = os.WriteFile( + filepath.Join(tempDir, ".claude-mnemonic", "settings.json"), + []byte(settingsJSON), + 0600, + ) + require.NoError(t, err) + + cfg, err := Load() + require.NoError(t, err) + assert.Equal(t, "/usr/local/bin/claude", cfg.ClaudeCodePath) +} + +// TestLoad_ContextSettings tests context-related settings loading. +func TestLoad_ContextSettings(t *testing.T) { + // Create temp dir + tempDir, err := os.MkdirTemp("", "config-test-*") + require.NoError(t, err) + defer os.RemoveAll(tempDir) + + origHome := os.Getenv("HOME") + os.Setenv("HOME", tempDir) + defer os.Setenv("HOME", origHome) + + // Create data dir and settings + err = os.MkdirAll(filepath.Join(tempDir, ".claude-mnemonic"), 0750) + require.NoError(t, err) + + settingsJSON := `{ + "CLAUDE_MNEMONIC_CONTEXT_FULL_COUNT": 50, + "CLAUDE_MNEMONIC_CONTEXT_SESSION_COUNT": 20, + "CLAUDE_MNEMONIC_CONTEXT_OBS_TYPES": "bugfix,feature", + "CLAUDE_MNEMONIC_CONTEXT_OBS_CONCEPTS": "security,performance" + }` + err = os.WriteFile( + filepath.Join(tempDir, ".claude-mnemonic", "settings.json"), + []byte(settingsJSON), + 0600, + ) + require.NoError(t, err) + + cfg, err := Load() + require.NoError(t, err) + assert.Equal(t, 50, cfg.ContextFullCount) + assert.Equal(t, 20, cfg.ContextSessionCount) + assert.Equal(t, []string{"bugfix", "feature"}, cfg.ContextObsTypes) + assert.Equal(t, []string{"security", "performance"}, cfg.ContextObsConcepts) +} diff --git a/internal/db/sqlite/observation_test.go b/internal/db/sqlite/observation_test.go index 5d11a5d..20ddf48 100644 --- a/internal/db/sqlite/observation_test.go +++ b/internal/db/sqlite/observation_test.go @@ -9,8 +9,22 @@ import ( "github.com/lukaszraczylo/claude-mnemonic/pkg/models" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" + "github.com/stretchr/testify/suite" ) +// testObservationStoreBasic creates an ObservationStore with base tables (no FTS5). +func testObservationStoreBasic(t *testing.T) (*ObservationStore, *Store, func()) { + t.Helper() + + db, _, cleanup := testDB(t) + createBaseTables(t, db) + + store := newStoreFromDB(db) + obsStore := NewObservationStore(store) + + return obsStore, store, cleanup +} + // testObservationStore creates an ObservationStore with a test database including FTS5. func testObservationStore(t *testing.T) (*ObservationStore, *Store, func()) { t.Helper() @@ -24,6 +38,420 @@ func testObservationStore(t *testing.T) (*ObservationStore, *Store, func()) { return obsStore, store, cleanup } +// ObservationStoreSuite is a test suite for ObservationStore operations. +type ObservationStoreSuite struct { + suite.Suite + obsStore *ObservationStore + store *Store + cleanup func() +} + +func (s *ObservationStoreSuite) SetupTest() { + s.obsStore, s.store, s.cleanup = testObservationStoreBasic(s.T()) +} + +func (s *ObservationStoreSuite) TearDownTest() { + if s.cleanup != nil { + s.cleanup() + } +} + +func TestObservationStoreSuite(t *testing.T) { + suite.Run(t, new(ObservationStoreSuite)) +} + +// TestStoreObservation_TableDriven tests observation storage with various scenarios. +func (s *ObservationStoreSuite) TestStoreObservation_TableDriven() { + ctx := context.Background() + + tests := []struct { + name string + sdkSessionID string + project string + obs *models.ParsedObservation + promptNum int + tokens int64 + wantErr bool + }{ + { + name: "basic discovery observation", + sdkSessionID: "session-basic", + project: "project-a", + obs: &models.ParsedObservation{ + Type: models.ObsTypeDiscovery, + Title: "Test Title", + Subtitle: "Test Subtitle", + Narrative: "Test narrative content", + Facts: []string{"Fact 1", "Fact 2"}, + Concepts: []string{"testing", "golang"}, + }, + promptNum: 1, + tokens: 100, + wantErr: false, + }, + { + name: "bugfix observation", + sdkSessionID: "session-bugfix", + project: "project-b", + obs: &models.ParsedObservation{ + Type: models.ObsTypeBugfix, + Title: "Fixed null pointer", + Narrative: "Fixed null pointer exception in handler", + FilesModified: []string{"handler.go"}, + }, + promptNum: 2, + tokens: 50, + wantErr: false, + }, + { + name: "global scope observation", + sdkSessionID: "session-global", + project: "project-c", + obs: &models.ParsedObservation{ + Type: models.ObsTypeDiscovery, + Title: "Security best practice", + Narrative: "Always validate user input", + Concepts: []string{"security", "best-practice"}, + }, + promptNum: 1, + tokens: 75, + wantErr: false, + }, + { + name: "observation with files", + sdkSessionID: "session-files", + project: "project-d", + obs: &models.ParsedObservation{ + Type: models.ObsTypeFeature, + Title: "Added authentication", + Narrative: "Implemented JWT authentication", + FilesRead: []string{"config.go", "auth.go"}, + FilesModified: []string{"handler.go", "middleware.go"}, + FileMtimes: map[string]int64{"handler.go": 1234567890, "middleware.go": 1234567891}, + }, + promptNum: 3, + tokens: 200, + wantErr: false, + }, + { + name: "minimal observation", + sdkSessionID: "session-minimal", + project: "project-e", + obs: &models.ParsedObservation{ + Type: models.ObsTypeChange, + }, + promptNum: 0, + tokens: 0, + wantErr: false, + }, + } + + for _, tt := range tests { + s.Run(tt.name, func() { + id, epoch, err := s.obsStore.StoreObservation(ctx, tt.sdkSessionID, tt.project, tt.obs, tt.promptNum, tt.tokens) + if tt.wantErr { + s.Error(err) + return + } + + s.NoError(err) + s.Greater(id, int64(0)) + s.Greater(epoch, int64(0)) + + // Retrieve and verify + retrieved, err := s.obsStore.GetObservationByID(ctx, id) + s.NoError(err) + s.NotNil(retrieved) + s.Equal(id, retrieved.ID) + s.Equal(tt.project, retrieved.Project) + s.Equal(tt.obs.Type, retrieved.Type) + }) + } +} + +// TestGetObservationByID_NotFound tests retrieval of non-existent observation. +func (s *ObservationStoreSuite) TestGetObservationByID_NotFound() { + ctx := context.Background() + + obs, err := s.obsStore.GetObservationByID(ctx, 99999) + s.NoError(err) + s.Nil(obs) +} + +// TestGetRecentObservations_TableDriven tests recent observations retrieval. +func (s *ObservationStoreSuite) TestGetRecentObservations_TableDriven() { + ctx := context.Background() + + // Create 15 observations + for i := 0; i < 15; i++ { + obs := &models.ParsedObservation{ + Type: models.ObsTypeDiscovery, + Title: "Observation " + string(rune('A'+i)), + } + _, _, err := s.obsStore.StoreObservation(ctx, "session-"+string(rune('0'+i)), "project-a", obs, i, 10) + s.NoError(err) + time.Sleep(time.Millisecond) // Ensure different timestamps + } + + tests := []struct { + name string + project string + limit int + wantCount int + }{ + { + name: "limit 5", + project: "project-a", + limit: 5, + wantCount: 5, + }, + { + name: "limit 10", + project: "project-a", + limit: 10, + wantCount: 10, + }, + { + name: "limit higher than count", + project: "project-a", + limit: 50, + wantCount: 15, + }, + { + name: "different project (no results)", + project: "project-b", + limit: 10, + wantCount: 0, + }, + } + + for _, tt := range tests { + s.Run(tt.name, func() { + observations, err := s.obsStore.GetRecentObservations(ctx, tt.project, tt.limit) + s.NoError(err) + s.Len(observations, tt.wantCount) + }) + } +} + +// TestDeleteObservations_TableDriven tests observation deletion. +func (s *ObservationStoreSuite) TestDeleteObservations_TableDriven() { + ctx := context.Background() + + // Create observations + var ids []int64 + for i := 0; i < 5; i++ { + obs := &models.ParsedObservation{ + Type: models.ObsTypeDiscovery, + Title: "To delete " + string(rune('A'+i)), + } + id, _, err := s.obsStore.StoreObservation(ctx, "session-del", "project-del", obs, i, 10) + s.NoError(err) + ids = append(ids, id) + } + + tests := []struct { + name string + toDelete []int64 + wantDeleted int64 + wantRemain int + }{ + { + name: "delete none", + toDelete: []int64{}, + wantDeleted: 0, + wantRemain: 5, + }, + { + name: "delete one", + toDelete: ids[0:1], + wantDeleted: 1, + wantRemain: 4, + }, + { + name: "delete remaining", + toDelete: ids[1:], + wantDeleted: 4, + wantRemain: 0, + }, + } + + for _, tt := range tests { + s.Run(tt.name, func() { + deleted, err := s.obsStore.DeleteObservations(ctx, tt.toDelete) + s.NoError(err) + s.Equal(tt.wantDeleted, deleted) + + remaining, err := s.obsStore.GetAllRecentObservations(ctx, 100) + s.NoError(err) + s.Len(remaining, tt.wantRemain) + }) + } +} + +// TestGetObservationsByIDs tests retrieval by multiple IDs. +func (s *ObservationStoreSuite) TestGetObservationsByIDs() { + ctx := context.Background() + + // Create observations + var ids []int64 + for i := 0; i < 5; i++ { + obs := &models.ParsedObservation{ + Type: models.ObsTypeDiscovery, + Title: "By ID " + string(rune('A'+i)), + } + id, _, err := s.obsStore.StoreObservation(ctx, "session-byid", "project-byid", obs, i, 10) + s.NoError(err) + ids = append(ids, id) + time.Sleep(time.Millisecond) + } + + tests := []struct { + name string + queryIDs []int64 + orderBy string + limit int + wantCount int + }{ + { + name: "empty IDs", + queryIDs: []int64{}, + orderBy: "date_desc", + limit: 10, + wantCount: 0, + }, + { + name: "single ID", + queryIDs: ids[0:1], + orderBy: "date_desc", + limit: 10, + wantCount: 1, + }, + { + name: "all IDs", + queryIDs: ids, + orderBy: "date_desc", + limit: 10, + wantCount: 5, + }, + { + name: "with limit less than IDs", + queryIDs: ids, + orderBy: "date_desc", + limit: 3, + wantCount: 3, + }, + { + name: "ascending order", + queryIDs: ids, + orderBy: "date_asc", + limit: 10, + wantCount: 5, + }, + } + + for _, tt := range tests { + s.Run(tt.name, func() { + observations, err := s.obsStore.GetObservationsByIDs(ctx, tt.queryIDs, tt.orderBy, tt.limit) + if tt.wantCount == 0 { + s.NoError(err) + s.Nil(observations) + } else { + s.NoError(err) + s.Len(observations, tt.wantCount) + } + }) + } +} + +// TestGlobalScope tests global vs project scope. +func (s *ObservationStoreSuite) TestGlobalScope() { + ctx := context.Background() + + // Create project-scoped observation + projectObs := &models.ParsedObservation{ + Type: models.ObsTypeDiscovery, + Title: "Project specific", + Concepts: []string{"project-specific"}, + } + _, _, err := s.obsStore.StoreObservation(ctx, "session-scope", "project-a", projectObs, 1, 10) + s.NoError(err) + + // Create global-scoped observation (security concept triggers global) + globalObs := &models.ParsedObservation{ + Type: models.ObsTypeDiscovery, + Title: "Global security", + Concepts: []string{"security"}, + } + _, _, err = s.obsStore.StoreObservation(ctx, "session-scope", "project-a", globalObs, 2, 10) + s.NoError(err) + + // Project-a should see both + resultsA, err := s.obsStore.GetRecentObservations(ctx, "project-a", 10) + s.NoError(err) + s.Len(resultsA, 2) + + // Project-b should only see global + resultsB, err := s.obsStore.GetRecentObservations(ctx, "project-b", 10) + s.NoError(err) + s.Len(resultsB, 1) + s.Equal("Global security", resultsB[0].Title.String) + s.Equal(models.ScopeGlobal, resultsB[0].Scope) +} + +// TestSetCleanupFunc tests the cleanup function callback. +func (s *ObservationStoreSuite) TestSetCleanupFunc() { + ctx := context.Background() + + var calledWith []int64 + s.obsStore.SetCleanupFunc(func(ctx context.Context, deletedIDs []int64) { + calledWith = deletedIDs + }) + + // Store an observation + obs := &models.ParsedObservation{ + Type: models.ObsTypeDiscovery, + Title: "Test cleanup", + } + _, _, err := s.obsStore.StoreObservation(ctx, "session-cleanup", "project-cleanup", obs, 1, 10) + s.NoError(err) + + // Cleanup should not have been called since nothing was deleted + s.Empty(calledWith) +} + +// TestGetObservationCount tests observation counting. +func (s *ObservationStoreSuite) TestGetObservationCount() { + ctx := context.Background() + + // Create observations for project-a + for i := 0; i < 5; i++ { + obs := &models.ParsedObservation{ + Type: models.ObsTypeDiscovery, + } + _, _, err := s.obsStore.StoreObservation(ctx, "session-count", "project-a", obs, i, 10) + s.NoError(err) + } + + // Create global observation + globalObs := &models.ParsedObservation{ + Type: models.ObsTypeDiscovery, + Concepts: []string{"security"}, + } + _, _, err := s.obsStore.StoreObservation(ctx, "session-count", "project-a", globalObs, 6, 10) + s.NoError(err) + + // Project-a should count 6 (5 project + 1 global) + count, err := s.obsStore.GetObservationCount(ctx, "project-a") + s.NoError(err) + s.Equal(6, count) + + // Project-b should count 1 (only global) + count, err = s.obsStore.GetObservationCount(ctx, "project-b") + s.NoError(err) + s.Equal(1, count) +} + func TestObservationStore_StoreAndRetrieve(t *testing.T) { obsStore, _, cleanup := testObservationStore(t) defer cleanup() diff --git a/internal/db/sqlite/prompt_test.go b/internal/db/sqlite/prompt_test.go index 5fc5f77..6d16093 100644 --- a/internal/db/sqlite/prompt_test.go +++ b/internal/db/sqlite/prompt_test.go @@ -194,3 +194,96 @@ func TestPromptStore_GetPromptsByIDs_EmptyInput(t *testing.T) { require.NoError(t, err) assert.Nil(t, prompts) } + +func TestPromptStore_FindRecentPromptByText(t *testing.T) { + promptStore, store, cleanup := testPromptStore(t) + defer cleanup() + + ctx := context.Background() + + // Create a session + seedSession(t, storeDB(store), "claude-1", "sdk-1", "test-project") + + // Save a prompt + _, err := promptStore.SaveUserPromptWithMatches(ctx, "claude-1", 1, "Help me fix this bug in the code", 0) + require.NoError(t, err) + + // Find the prompt by text - returns (id, promptNumber, found) + id, promptNum, found := promptStore.FindRecentPromptByText(ctx, "claude-1", "Help me fix this bug in the code", 60) + assert.True(t, found, "should find the exact prompt text") + assert.Greater(t, id, int64(0)) + assert.Equal(t, 1, promptNum) + + // Try to find non-existent prompt + _, _, found = promptStore.FindRecentPromptByText(ctx, "claude-1", "This prompt does not exist", 60) + assert.False(t, found, "should not find non-existent prompt") + + // Try with different session + _, _, found = promptStore.FindRecentPromptByText(ctx, "claude-2", "Help me fix this bug in the code", 60) + assert.False(t, found, "should not find prompt for different session") +} + +func TestPromptStore_FindRecentPromptByText_WindowSeconds(t *testing.T) { + promptStore, store, cleanup := testPromptStore(t) + defer cleanup() + + ctx := context.Background() + + // Create a session + seedSession(t, storeDB(store), "claude-1", "sdk-1", "test-project") + + // Save a prompt with an old timestamp + oldEpoch := time.Now().Add(-2 * time.Hour).UnixMilli() + _, err := storeDB(store).Exec(` + INSERT INTO user_prompts (claude_session_id, prompt_number, prompt_text, created_at, created_at_epoch) + VALUES (?, ?, ?, datetime('now'), ?) + `, "claude-1", 1, "Old prompt text", oldEpoch) + require.NoError(t, err) + + // Search within last hour - should not find old prompt + _, _, found := promptStore.FindRecentPromptByText(ctx, "claude-1", "Old prompt text", 3600) + assert.False(t, found, "should not find prompt outside window") + + // Search within last 3 hours - should find old prompt + _, _, found = promptStore.FindRecentPromptByText(ctx, "claude-1", "Old prompt text", 3*3600) + assert.True(t, found, "should find prompt within extended window") +} + +func TestPromptStore_SaveMultiplePrompts(t *testing.T) { + promptStore, store, cleanup := testPromptStore(t) + defer cleanup() + + ctx := context.Background() + + // Create sessions + seedSession(t, storeDB(store), "claude-1", "sdk-1", "project-x") + seedSession(t, storeDB(store), "claude-2", "sdk-2", "project-y") + + tests := []struct { + claudeSessionID string + promptNum int + text string + matches int + }{ + {"claude-1", 1, "First prompt", 5}, + {"claude-1", 2, "Second prompt", 3}, + {"claude-2", 1, "Third prompt", 0}, + {"claude-1", 3, "Fourth prompt", 10}, + } + + for _, tt := range tests { + id, err := promptStore.SaveUserPromptWithMatches(ctx, tt.claudeSessionID, tt.promptNum, tt.text, tt.matches) + require.NoError(t, err) + assert.Greater(t, id, int64(0)) + } + + // Verify counts + var count int + err := storeDB(store).QueryRow("SELECT COUNT(*) FROM user_prompts WHERE claude_session_id = 'claude-1'").Scan(&count) + require.NoError(t, err) + assert.Equal(t, 3, count) + + err = storeDB(store).QueryRow("SELECT COUNT(*) FROM user_prompts WHERE claude_session_id = 'claude-2'").Scan(&count) + require.NoError(t, err) + assert.Equal(t, 1, count) +} diff --git a/internal/db/sqlite/session_test.go b/internal/db/sqlite/session_test.go index c0bf553..b445ed1 100644 --- a/internal/db/sqlite/session_test.go +++ b/internal/db/sqlite/session_test.go @@ -8,13 +8,14 @@ import ( "github.com/lukaszraczylo/claude-mnemonic/pkg/models" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" + "github.com/stretchr/testify/suite" ) func testSessionStore(t *testing.T) (*SessionStore, *Store, func()) { t.Helper() db, _, cleanup := testDB(t) - createAllTables(t, db) + createBaseTables(t, db) // Use base tables without FTS5 for session tests store := newStoreFromDB(db) sessionStore := NewSessionStore(store) @@ -22,6 +23,236 @@ func testSessionStore(t *testing.T) (*SessionStore, *Store, func()) { return sessionStore, store, cleanup } +// SessionStoreSuite is a test suite for SessionStore operations. +type SessionStoreSuite struct { + suite.Suite + sessionStore *SessionStore + store *Store + cleanup func() +} + +func (s *SessionStoreSuite) SetupTest() { + s.sessionStore, s.store, s.cleanup = testSessionStore(s.T()) +} + +func (s *SessionStoreSuite) TearDownTest() { + if s.cleanup != nil { + s.cleanup() + } +} + +func TestSessionStoreSuite(t *testing.T) { + suite.Run(t, new(SessionStoreSuite)) +} + +// TestCreateSDKSession_TableDriven tests session creation with various scenarios. +func (s *SessionStoreSuite) TestCreateSDKSession_TableDriven() { + ctx := context.Background() + + tests := []struct { + name string + claudeSessionID string + project string + userPrompt string + wantErr bool + wantID bool + }{ + { + name: "basic session creation", + claudeSessionID: "claude-basic", + project: "project-a", + userPrompt: "hello world", + wantErr: false, + wantID: true, + }, + { + name: "empty user prompt", + claudeSessionID: "claude-noprompt", + project: "project-b", + userPrompt: "", + wantErr: false, + wantID: true, + }, + { + name: "long project name", + claudeSessionID: "claude-longproj", + project: "/Users/test/Documents/very/long/path/to/some/project/directory", + userPrompt: "test", + wantErr: false, + wantID: true, + }, + { + name: "unicode project name", + claudeSessionID: "claude-unicode", + project: "项目名称-プロジェクト", + userPrompt: "测试 テスト", + wantErr: false, + wantID: true, + }, + { + name: "special characters in prompt", + claudeSessionID: "claude-special", + project: "project-special", + userPrompt: "Fix the bug in file.go:123 with \"quotes\" and 'apostrophes'", + wantErr: false, + wantID: true, + }, + } + + for _, tt := range tests { + s.Run(tt.name, func() { + id, err := s.sessionStore.CreateSDKSession(ctx, tt.claudeSessionID, tt.project, tt.userPrompt) + if tt.wantErr { + s.Error(err) + } else { + s.NoError(err) + if tt.wantID { + s.Greater(id, int64(0)) + } + + // Verify created session + sess, err := s.sessionStore.GetSessionByID(ctx, id) + s.NoError(err) + s.NotNil(sess) + s.Equal(tt.claudeSessionID, sess.ClaudeSessionID) + s.Equal(tt.project, sess.Project) + s.Equal(models.SessionStatusActive, sess.Status) + } + }) + } +} + +// TestIdempotentSession tests that session creation is idempotent. +func (s *SessionStoreSuite) TestIdempotentSession() { + ctx := context.Background() + + // Create initial session + id1, err := s.sessionStore.CreateSDKSession(ctx, "claude-idem", "project-1", "prompt-1") + s.NoError(err) + s.Greater(id1, int64(0)) + + // Create with same claude_session_id - should return same ID + id2, err := s.sessionStore.CreateSDKSession(ctx, "claude-idem", "project-2", "prompt-2") + s.NoError(err) + s.Equal(id1, id2) + + // Verify project was updated + sess, err := s.sessionStore.GetSessionByID(ctx, id1) + s.NoError(err) + s.Equal("project-2", sess.Project) +} + +// TestPromptCounterOperations tests prompt counter increment and retrieval. +func (s *SessionStoreSuite) TestPromptCounterOperations() { + ctx := context.Background() + + tests := []struct { + name string + increments int + expectedCount int + }{ + { + name: "no increments", + increments: 0, + expectedCount: 0, + }, + { + name: "single increment", + increments: 1, + expectedCount: 1, + }, + { + name: "multiple increments", + increments: 5, + expectedCount: 5, + }, + { + name: "many increments", + increments: 100, + expectedCount: 100, + }, + } + + for _, tt := range tests { + s.Run(tt.name, func() { + // Create fresh session for each test + id, err := s.sessionStore.CreateSDKSession(ctx, "claude-counter-"+tt.name, "project", "") + s.NoError(err) + + // Increment specified number of times + var lastCount int + for i := 0; i < tt.increments; i++ { + lastCount, err = s.sessionStore.IncrementPromptCounter(ctx, id) + s.NoError(err) + } + + // Get final count + finalCount, err := s.sessionStore.GetPromptCounter(ctx, id) + s.NoError(err) + s.Equal(tt.expectedCount, finalCount) + + if tt.increments > 0 { + s.Equal(tt.expectedCount, lastCount) + } + }) + } +} + +// TestFindAnySDKSession tests session lookup scenarios. +func (s *SessionStoreSuite) TestFindAnySDKSession_Scenarios() { + ctx := context.Background() + + // Create test sessions + _, err := s.sessionStore.CreateSDKSession(ctx, "session-find-1", "project-a", "") + s.NoError(err) + _, err = s.sessionStore.CreateSDKSession(ctx, "session-find-2", "project-b", "") + s.NoError(err) + + tests := []struct { + name string + claudeSessionID string + wantFound bool + wantProject string + }{ + { + name: "find existing session 1", + claudeSessionID: "session-find-1", + wantFound: true, + wantProject: "project-a", + }, + { + name: "find existing session 2", + claudeSessionID: "session-find-2", + wantFound: true, + wantProject: "project-b", + }, + { + name: "find non-existent session", + claudeSessionID: "session-nonexistent", + wantFound: false, + }, + { + name: "find with empty string", + claudeSessionID: "", + wantFound: false, + }, + } + + for _, tt := range tests { + s.Run(tt.name, func() { + sess, err := s.sessionStore.FindAnySDKSession(ctx, tt.claudeSessionID) + s.NoError(err) // FindAnySDKSession returns nil,nil for not found + + if tt.wantFound { + s.NotNil(sess) + s.Equal(tt.wantProject, sess.Project) + } else { + s.Nil(sess) + } + }) + } +} + func TestSessionStore_CreateSDKSession(t *testing.T) { sessionStore, _, cleanup := testSessionStore(t) defer cleanup() diff --git a/internal/db/sqlite/store_test.go b/internal/db/sqlite/store_test.go new file mode 100644 index 0000000..8f96aa0 --- /dev/null +++ b/internal/db/sqlite/store_test.go @@ -0,0 +1,529 @@ +// Package sqlite provides SQLite database operations for claude-mnemonic. +package sqlite + +import ( + "context" + "database/sql" + "testing" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + "github.com/stretchr/testify/suite" +) + +// StoreSuite is a test suite for Store operations. +type StoreSuite struct { + suite.Suite + db *sql.DB + store *Store + cleanup func() +} + +// SetupTest creates a fresh database before each test. +func (s *StoreSuite) SetupTest() { + s.db, _, s.cleanup = testDB(s.T()) + createBaseTables(s.T(), s.db) + s.store = newStoreFromDB(s.db) +} + +// TearDownTest cleans up after each test. +func (s *StoreSuite) TearDownTest() { + if s.cleanup != nil { + s.cleanup() + } +} + +func TestStoreSuite(t *testing.T) { + suite.Run(t, new(StoreSuite)) +} + +// TestGetStmt tests prepared statement caching. +func (s *StoreSuite) TestGetStmt() { + tests := []struct { + name string + query string + wantErr bool + }{ + { + name: "valid simple query", + query: "SELECT 1", + wantErr: false, + }, + { + name: "valid query with parameter", + query: "SELECT * FROM sdk_sessions WHERE id = ?", + wantErr: false, + }, + { + name: "invalid query syntax", + query: "SELECT * FROM nonexistent_table WHERE", + wantErr: true, + }, + } + + for _, tt := range tests { + s.Run(tt.name, func() { + stmt, err := s.store.GetStmt(tt.query) + if tt.wantErr { + s.Error(err) + s.Nil(stmt) + } else { + s.NoError(err) + s.NotNil(stmt) + + // Second call should return cached statement + stmt2, err := s.store.GetStmt(tt.query) + s.NoError(err) + s.Same(stmt, stmt2) + } + }) + } +} + +// TestExecContext tests query execution. +func (s *StoreSuite) TestExecContext() { + ctx := context.Background() + + tests := []struct { + name string + query string + args []interface{} + wantErr bool + wantAffected int64 + }{ + { + name: "insert session", + query: `INSERT INTO sdk_sessions (claude_session_id, sdk_session_id, project, started_at, started_at_epoch, status) + VALUES (?, ?, ?, datetime('now'), strftime('%s', 'now') * 1000, 'active')`, + args: []interface{}{"claude-1", "sdk-1", "test-project"}, + wantErr: false, + wantAffected: 1, + }, + { + name: "invalid query", + query: "INSERT INTO nonexistent_table VALUES (?)", + args: []interface{}{"test"}, + wantErr: true, + wantAffected: 0, + }, + } + + for _, tt := range tests { + s.Run(tt.name, func() { + result, err := s.store.ExecContext(ctx, tt.query, tt.args...) + if tt.wantErr { + s.Error(err) + } else { + s.NoError(err) + affected, _ := result.RowsAffected() + s.Equal(tt.wantAffected, affected) + } + }) + } +} + +// TestQueryContext tests query execution that returns rows. +func (s *StoreSuite) TestQueryContext() { + ctx := context.Background() + + // Seed data + seedSession(s.T(), s.db, "claude-1", "sdk-1", "project-a") + + tests := []struct { + name string + query string + args []interface{} + wantErr bool + wantRows int + setupFunc func() + assertFunc func(rows *sql.Rows) + }{ + { + name: "query existing session", + query: "SELECT id, project FROM sdk_sessions WHERE claude_session_id = ?", + args: []interface{}{"claude-1"}, + wantErr: false, + wantRows: 1, + }, + { + name: "query non-existent session", + query: "SELECT id, project FROM sdk_sessions WHERE claude_session_id = ?", + args: []interface{}{"nonexistent"}, + wantErr: false, + wantRows: 0, + }, + { + name: "query all sessions", + query: "SELECT id, project FROM sdk_sessions", + args: nil, + wantErr: false, + wantRows: 1, + }, + } + + for _, tt := range tests { + s.Run(tt.name, func() { + rows, err := s.store.QueryContext(ctx, tt.query, tt.args...) + if tt.wantErr { + s.Error(err) + return + } + + s.NoError(err) + defer rows.Close() + + count := 0 + for rows.Next() { + count++ + } + s.Equal(tt.wantRows, count) + }) + } +} + +// TestQueryRowContext tests single row query execution. +func (s *StoreSuite) TestQueryRowContext() { + ctx := context.Background() + + // Seed data + seedSession(s.T(), s.db, "claude-1", "sdk-1", "project-a") + + tests := []struct { + name string + query string + args []interface{} + wantErr bool + }{ + { + name: "query existing session", + query: "SELECT id FROM sdk_sessions WHERE claude_session_id = ?", + args: []interface{}{"claude-1"}, + wantErr: false, + }, + { + name: "query non-existent session", + query: "SELECT id FROM sdk_sessions WHERE claude_session_id = ?", + args: []interface{}{"nonexistent"}, + wantErr: true, // sql.ErrNoRows + }, + } + + for _, tt := range tests { + s.Run(tt.name, func() { + row := s.store.QueryRowContext(ctx, tt.query, tt.args...) + var id int64 + err := row.Scan(&id) + if tt.wantErr { + s.Error(err) + } else { + s.NoError(err) + s.Greater(id, int64(0)) + } + }) + } +} + +// TestPing tests database connection health check. +func (s *StoreSuite) TestPing() { + err := s.store.Ping() + s.NoError(err) +} + +// TestDB tests getting the underlying database connection. +func (s *StoreSuite) TestDB() { + db := s.store.DB() + s.NotNil(db) + s.Same(s.db, db) +} + +// TestClose tests closing the store. +func (s *StoreSuite) TestClose() { + // Create a separate store for close test + db, _, cleanup := testDB(s.T()) + defer cleanup() + + store := newStoreFromDB(db) + + // Cache a statement first + _, err := store.GetStmt("SELECT 1") + s.NoError(err) + + // Close should not error + err = store.Close() + s.NoError(err) + + // Operations after close should fail + err = store.Ping() + s.Error(err) +} + +// TestConcurrentStmtCache tests concurrent access to statement cache. +func (s *StoreSuite) TestConcurrentStmtCache() { + ctx := context.Background() + queries := []string{ + "SELECT 1", + "SELECT 2", + "SELECT id FROM sdk_sessions", + "SELECT project FROM sdk_sessions", + } + + done := make(chan struct{}) + for i := 0; i < 10; i++ { + go func(i int) { + query := queries[i%len(queries)] + _, _ = s.store.GetStmt(query) + _, _ = s.store.ExecContext(ctx, "SELECT 1") + done <- struct{}{} + }(i) + } + + for i := 0; i < 10; i++ { + <-done + } +} + +// HelpersSuite tests helper functions. +type HelpersSuite struct { + suite.Suite +} + +func TestHelpersSuite(t *testing.T) { + suite.Run(t, new(HelpersSuite)) +} + +func (s *HelpersSuite) TestNullString() { + tests := []struct { + name string + input string + wantStr string + wantBool bool + }{ + { + name: "empty string", + input: "", + wantStr: "", + wantBool: false, + }, + { + name: "non-empty string", + input: "test", + wantStr: "test", + wantBool: true, + }, + { + name: "whitespace string", + input: " ", + wantStr: " ", + wantBool: true, + }, + } + + for _, tt := range tests { + s.Run(tt.name, func() { + result := nullString(tt.input) + s.Equal(tt.wantStr, result.String) + s.Equal(tt.wantBool, result.Valid) + }) + } +} + +func (s *HelpersSuite) TestNullInt() { + tests := []struct { + name string + input int + wantInt int64 + wantBool bool + }{ + { + name: "zero", + input: 0, + wantInt: 0, + wantBool: false, + }, + { + name: "negative", + input: -1, + wantInt: -1, + wantBool: false, + }, + { + name: "positive", + input: 42, + wantInt: 42, + wantBool: true, + }, + } + + for _, tt := range tests { + s.Run(tt.name, func() { + result := nullInt(tt.input) + s.Equal(tt.wantInt, result.Int64) + s.Equal(tt.wantBool, result.Valid) + }) + } +} + +func (s *HelpersSuite) TestRepeatPlaceholders() { + tests := []struct { + name string + input int + expected string + }{ + { + name: "zero", + input: 0, + expected: "", + }, + { + name: "negative", + input: -1, + expected: "", + }, + { + name: "one", + input: 1, + expected: ", ?", + }, + { + name: "three", + input: 3, + expected: ", ?, ?, ?", + }, + } + + for _, tt := range tests { + s.Run(tt.name, func() { + result := repeatPlaceholders(tt.input) + s.Equal(tt.expected, result) + }) + } +} + +func (s *HelpersSuite) TestInt64SliceToInterface() { + tests := []struct { + name string + input []int64 + expected int + }{ + { + name: "empty slice", + input: []int64{}, + expected: 0, + }, + { + name: "single element", + input: []int64{42}, + expected: 1, + }, + { + name: "multiple elements", + input: []int64{1, 2, 3, 4, 5}, + expected: 5, + }, + } + + for _, tt := range tests { + s.Run(tt.name, func() { + result := int64SliceToInterface(tt.input) + s.Len(result, tt.expected) + for i, v := range result { + s.Equal(tt.input[i], v) + } + }) + } +} + +// TestBuildGetByIDsQuery tests the shared query builder. +func TestBuildGetByIDsQuery(t *testing.T) { + tests := []struct { + name string + baseQuery string + ids []int64 + orderBy string + limit int + wantQuery string + wantArgs int + }{ + { + name: "single id, no limit, desc order", + baseQuery: "SELECT * FROM test", + ids: []int64{1}, + orderBy: "date_desc", + limit: 0, + wantQuery: "SELECT * FROM test WHERE id IN (?)\n\t\tORDER BY created_at_epoch DESC", + wantArgs: 1, + }, + { + name: "multiple ids with limit and asc order", + baseQuery: "SELECT * FROM test", + ids: []int64{1, 2, 3}, + orderBy: "date_asc", + limit: 10, + wantQuery: "SELECT * FROM test WHERE id IN (?, ?, ?)\n\t\tORDER BY created_at_epoch ASC LIMIT ?", + wantArgs: 4, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + query, args := BuildGetByIDsQuery(tt.baseQuery, tt.ids, tt.orderBy, tt.limit) + assert.Contains(t, query, "WHERE id IN") + assert.Len(t, args, tt.wantArgs) + }) + } +} + +// TestEnsureSessionExists tests session auto-creation. +func TestEnsureSessionExists(t *testing.T) { + db, _, cleanup := testDB(t) + defer cleanup() + createBaseTables(t, db) + + store := newStoreFromDB(db) + ctx := context.Background() + + tests := []struct { + name string + sdkSessionID string + project string + setup func() + wantErr bool + }{ + { + name: "create new session", + sdkSessionID: "sdk-new", + project: "project-a", + wantErr: false, + }, + { + name: "session already exists", + sdkSessionID: "sdk-existing", + project: "project-b", + setup: func() { + seedSession(t, db, "sdk-existing", "sdk-existing", "project-b") + }, + wantErr: false, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + if tt.setup != nil { + tt.setup() + } + + err := EnsureSessionExists(ctx, store, tt.sdkSessionID, tt.project) + if tt.wantErr { + require.Error(t, err) + } else { + require.NoError(t, err) + + // Verify session exists + var id int64 + err := db.QueryRow("SELECT id FROM sdk_sessions WHERE sdk_session_id = ?", tt.sdkSessionID).Scan(&id) + require.NoError(t, err) + assert.Greater(t, id, int64(0)) + } + }) + } +} diff --git a/internal/mcp/server_test.go b/internal/mcp/server_test.go new file mode 100644 index 0000000..1b4294d --- /dev/null +++ b/internal/mcp/server_test.go @@ -0,0 +1,599 @@ +// Package mcp provides the MCP (Model Context Protocol) server for claude-mnemonic. +package mcp + +import ( + "bytes" + "context" + "encoding/json" + "strings" + "testing" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + "github.com/stretchr/testify/suite" +) + +// ServerSuite is a test suite for MCP Server operations. +type ServerSuite struct { + suite.Suite +} + +func TestServerSuite(t *testing.T) { + suite.Run(t, new(ServerSuite)) +} + +// TestNewServer tests server creation. +func (s *ServerSuite) TestNewServer() { + server := NewServer(nil, "1.0.0") + s.NotNil(server) + s.Nil(server.searchMgr) + s.Equal("1.0.0", server.version) +} + +// TestRequest tests Request struct JSON marshaling. +func TestRequest(t *testing.T) { + tests := []struct { + name string + req Request + expected string + }{ + { + name: "initialize request", + req: Request{ + JSONRPC: "2.0", + ID: 1, + Method: "initialize", + }, + expected: `{"jsonrpc":"2.0","id":1,"method":"initialize"}`, + }, + { + name: "tools/list request", + req: Request{ + JSONRPC: "2.0", + ID: "abc", + Method: "tools/list", + }, + expected: `{"jsonrpc":"2.0","id":"abc","method":"tools/list"}`, + }, + { + name: "tools/call with params", + req: Request{ + JSONRPC: "2.0", + ID: 2, + Method: "tools/call", + Params: json.RawMessage(`{"name":"search","arguments":{}}`), + }, + expected: `{"jsonrpc":"2.0","id":2,"method":"tools/call","params":{"name":"search","arguments":{}}}`, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + data, err := json.Marshal(tt.req) + require.NoError(t, err) + assert.JSONEq(t, tt.expected, string(data)) + + // Test unmarshaling + var parsed Request + err = json.Unmarshal(data, &parsed) + require.NoError(t, err) + assert.Equal(t, tt.req.JSONRPC, parsed.JSONRPC) + assert.Equal(t, tt.req.Method, parsed.Method) + }) + } +} + +// TestResponse tests Response struct JSON marshaling. +func TestResponse(t *testing.T) { + tests := []struct { + name string + resp Response + expected string + }{ + { + name: "success response", + resp: Response{ + JSONRPC: "2.0", + ID: 1, + Result: map[string]string{"status": "ok"}, + }, + expected: `{"jsonrpc":"2.0","id":1,"result":{"status":"ok"}}`, + }, + { + name: "error response", + resp: Response{ + JSONRPC: "2.0", + ID: 2, + Error: &Error{ + Code: -32600, + Message: "Invalid Request", + }, + }, + expected: `{"jsonrpc":"2.0","id":2,"error":{"code":-32600,"message":"Invalid Request"}}`, + }, + { + name: "error with data", + resp: Response{ + JSONRPC: "2.0", + ID: 3, + Error: &Error{ + Code: -32602, + Message: "Invalid params", + Data: "missing field", + }, + }, + expected: `{"jsonrpc":"2.0","id":3,"error":{"code":-32602,"message":"Invalid params","data":"missing field"}}`, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + data, err := json.Marshal(tt.resp) + require.NoError(t, err) + assert.JSONEq(t, tt.expected, string(data)) + }) + } +} + +// TestError tests Error struct. +func TestError(t *testing.T) { + tests := []struct { + name string + err Error + expected string + }{ + { + name: "parse error", + err: Error{ + Code: -32700, + Message: "Parse error", + }, + expected: `{"code":-32700,"message":"Parse error"}`, + }, + { + name: "method not found", + err: Error{ + Code: -32601, + Message: "Method not found", + }, + expected: `{"code":-32601,"message":"Method not found"}`, + }, + { + name: "invalid params", + err: Error{ + Code: -32602, + Message: "Invalid params", + Data: "details here", + }, + expected: `{"code":-32602,"message":"Invalid params","data":"details here"}`, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + data, err := json.Marshal(tt.err) + require.NoError(t, err) + assert.JSONEq(t, tt.expected, string(data)) + }) + } +} + +// TestToolCallParams tests ToolCallParams struct. +func TestToolCallParams(t *testing.T) { + tests := []struct { + name string + input string + expected ToolCallParams + }{ + { + name: "search tool call", + input: `{"name":"search","arguments":{"query":"test"}}`, + expected: ToolCallParams{ + Name: "search", + Arguments: json.RawMessage(`{"query":"test"}`), + }, + }, + { + name: "decisions tool call", + input: `{"name":"decisions","arguments":{"query":"auth"}}`, + expected: ToolCallParams{ + Name: "decisions", + Arguments: json.RawMessage(`{"query":"auth"}`), + }, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + var params ToolCallParams + err := json.Unmarshal([]byte(tt.input), ¶ms) + require.NoError(t, err) + assert.Equal(t, tt.expected.Name, params.Name) + }) + } +} + +// TestTool tests Tool struct. +func TestTool(t *testing.T) { + tool := Tool{ + Name: "search", + Description: "Search observations", + InputSchema: map[string]any{ + "type": "object", + "properties": map[string]any{ + "query": map[string]any{"type": "string"}, + }, + }, + } + + data, err := json.Marshal(tool) + require.NoError(t, err) + + var parsed Tool + err = json.Unmarshal(data, &parsed) + require.NoError(t, err) + assert.Equal(t, "search", parsed.Name) + assert.Equal(t, "Search observations", parsed.Description) +} + +// TestTimelineParams tests TimelineParams struct. +func TestTimelineParams(t *testing.T) { + tests := []struct { + name string + input string + expected TimelineParams + }{ + { + name: "with anchor_id", + input: `{"anchor_id":123,"before":5,"after":5}`, + expected: TimelineParams{ + AnchorID: 123, + Before: 5, + After: 5, + }, + }, + { + name: "with query", + input: `{"query":"test query","project":"my-project"}`, + expected: TimelineParams{ + Query: "test query", + Project: "my-project", + }, + }, + { + name: "full params", + input: `{"anchor_id":100,"query":"search","before":10,"after":20,"project":"proj","obs_type":"bugfix","concepts":"security","files":"main.go","dateStart":1234567890,"dateEnd":9876543210,"format":"full"}`, + expected: TimelineParams{ + AnchorID: 100, + Query: "search", + Before: 10, + After: 20, + Project: "proj", + ObsType: "bugfix", + Concepts: "security", + Files: "main.go", + DateStart: 1234567890, + DateEnd: 9876543210, + Format: "full", + }, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + var params TimelineParams + err := json.Unmarshal([]byte(tt.input), ¶ms) + require.NoError(t, err) + assert.Equal(t, tt.expected.AnchorID, params.AnchorID) + assert.Equal(t, tt.expected.Query, params.Query) + assert.Equal(t, tt.expected.Project, params.Project) + }) + } +} + +// TestHandleInitialize tests the initialize handler. +func TestHandleInitialize(t *testing.T) { + server := NewServer(nil, "1.2.3") + + req := &Request{ + JSONRPC: "2.0", + ID: 1, + Method: "initialize", + } + + resp := server.handleInitialize(req) + + assert.Equal(t, "2.0", resp.JSONRPC) + assert.Equal(t, 1, resp.ID) + assert.Nil(t, resp.Error) + assert.NotNil(t, resp.Result) + + result, ok := resp.Result.(map[string]any) + require.True(t, ok) + assert.Equal(t, "2024-11-05", result["protocolVersion"]) + + serverInfo, ok := result["serverInfo"].(map[string]any) + require.True(t, ok) + assert.Equal(t, "claude-mnemonic", serverInfo["name"]) + assert.Equal(t, "1.2.3", serverInfo["version"]) +} + +// TestHandleToolsList tests the tools/list handler. +func TestHandleToolsList(t *testing.T) { + server := NewServer(nil, "1.0.0") + + req := &Request{ + JSONRPC: "2.0", + ID: 1, + Method: "tools/list", + } + + resp := server.handleToolsList(req) + + assert.Equal(t, "2.0", resp.JSONRPC) + assert.Equal(t, 1, resp.ID) + assert.Nil(t, resp.Error) + + result, ok := resp.Result.(map[string]any) + require.True(t, ok) + + tools, ok := result["tools"].([]Tool) + require.True(t, ok) + assert.NotEmpty(t, tools) + + // Verify expected tools are present + toolNames := make(map[string]bool) + for _, tool := range tools { + toolNames[tool.Name] = true + } + + expectedTools := []string{ + "search", "timeline", "decisions", "changes", + "how_it_works", "find_by_concept", "find_by_file", + "find_by_type", "get_recent_context", "get_context_timeline", + "get_timeline_by_query", + } + + for _, name := range expectedTools { + assert.True(t, toolNames[name], "expected tool %s to be present", name) + } +} + +// TestHandleRequest tests request routing. +func TestHandleRequest(t *testing.T) { + server := NewServer(nil, "1.0.0") + ctx := context.Background() + + tests := []struct { + name string + req *Request + expectError bool + errorCode int + errorMessage string + }{ + { + name: "initialize method", + req: &Request{ + JSONRPC: "2.0", + ID: 1, + Method: "initialize", + }, + expectError: false, + }, + { + name: "tools/list method", + req: &Request{ + JSONRPC: "2.0", + ID: 2, + Method: "tools/list", + }, + expectError: false, + }, + { + name: "unknown method", + req: &Request{ + JSONRPC: "2.0", + ID: 3, + Method: "unknown_method", + }, + expectError: true, + errorCode: -32601, + errorMessage: "Method not found", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + resp := server.handleRequest(ctx, tt.req) + + assert.Equal(t, "2.0", resp.JSONRPC) + assert.Equal(t, tt.req.ID, resp.ID) + + if tt.expectError { + require.NotNil(t, resp.Error) + assert.Equal(t, tt.errorCode, resp.Error.Code) + assert.Equal(t, tt.errorMessage, resp.Error.Message) + } else { + assert.Nil(t, resp.Error) + assert.NotNil(t, resp.Result) + } + }) + } +} + +// TestHandleToolsCall_InvalidParams tests tools/call with invalid params. +func TestHandleToolsCall_InvalidParams(t *testing.T) { + server := NewServer(nil, "1.0.0") + ctx := context.Background() + + req := &Request{ + JSONRPC: "2.0", + ID: 1, + Method: "tools/call", + Params: json.RawMessage(`invalid json`), + } + + resp := server.handleToolsCall(ctx, req) + + require.NotNil(t, resp.Error) + assert.Equal(t, -32602, resp.Error.Code) + assert.Equal(t, "Invalid params", resp.Error.Message) +} + +// TestCallTool_UnknownTool tests callTool with unknown tool name. +func TestCallTool_UnknownTool(t *testing.T) { + server := NewServer(nil, "1.0.0") + ctx := context.Background() + + _, err := server.callTool(ctx, "nonexistent_tool", json.RawMessage(`{}`)) + require.Error(t, err) + assert.Contains(t, err.Error(), "unknown tool") +} + +// TestCallTool_InvalidArgs tests callTool with invalid arguments. +func TestCallTool_InvalidArgs(t *testing.T) { + server := NewServer(nil, "1.0.0") + ctx := context.Background() + + _, err := server.callTool(ctx, "search", json.RawMessage(`invalid json`)) + require.Error(t, err) + assert.Contains(t, err.Error(), "invalid arguments") +} + +// TestSendResponse tests response sending. +func TestSendResponse(t *testing.T) { + var buf bytes.Buffer + server := &Server{ + stdout: &buf, + } + + resp := &Response{ + JSONRPC: "2.0", + ID: 1, + Result: map[string]string{"status": "ok"}, + } + + server.sendResponse(resp) + + output := buf.String() + assert.Contains(t, output, `"jsonrpc":"2.0"`) + assert.Contains(t, output, `"id":1`) + assert.Contains(t, output, `"result"`) +} + +// TestSendError tests error response sending. +func TestSendError(t *testing.T) { + var buf bytes.Buffer + server := &Server{ + stdout: &buf, + } + + server.sendError(1, -32700, "Parse error", "details") + + output := buf.String() + assert.Contains(t, output, `"error"`) + assert.Contains(t, output, `-32700`) + assert.Contains(t, output, `"Parse error"`) +} + +// TestRun_ParseError tests Run with invalid JSON input. +func TestRun_ParseError(t *testing.T) { + var stdout bytes.Buffer + stdin := strings.NewReader("invalid json\n") + + server := &Server{ + stdin: stdin, + stdout: &stdout, + } + + err := server.Run(context.Background()) + require.NoError(t, err) + + output := stdout.String() + assert.Contains(t, output, `"error"`) + assert.Contains(t, output, `-32700`) + assert.Contains(t, output, `"Parse error"`) +} + +// TestRun_EmptyLine tests Run skips empty lines. +func TestRun_EmptyLine(t *testing.T) { + var stdout bytes.Buffer + stdin := strings.NewReader("\n\n") + + server := &Server{ + stdin: stdin, + stdout: &stdout, + } + + err := server.Run(context.Background()) + require.NoError(t, err) + + // Should be empty - no responses for empty lines + assert.Empty(t, stdout.String()) +} + +// TestRun_ValidRequest tests Run with a valid request. +func TestRun_ValidRequest(t *testing.T) { + var stdout bytes.Buffer + req := `{"jsonrpc":"2.0","id":1,"method":"initialize"}` + stdin := strings.NewReader(req + "\n") + + server := &Server{ + stdin: stdin, + stdout: &stdout, + version: "1.0.0", + } + + err := server.Run(context.Background()) + require.NoError(t, err) + + output := stdout.String() + assert.Contains(t, output, `"jsonrpc":"2.0"`) + assert.Contains(t, output, `"result"`) + assert.Contains(t, output, `"protocolVersion"`) +} + +// TestJSONRPCErrorCodes tests standard JSON-RPC error codes. +func TestJSONRPCErrorCodes(t *testing.T) { + errorCodes := map[string]int{ + "Parse error": -32700, + "Invalid Request": -32600, + "Method not found": -32601, + "Invalid params": -32602, + "Internal error": -32603, + } + + for msg, code := range errorCodes { + t.Run(msg, func(t *testing.T) { + err := Error{Code: code, Message: msg} + assert.Equal(t, code, err.Code) + assert.Equal(t, msg, err.Message) + }) + } +} + +// TestToolListContainsExpectedSchemas tests that tool schemas are valid. +func TestToolListContainsExpectedSchemas(t *testing.T) { + server := NewServer(nil, "1.0.0") + + req := &Request{ + JSONRPC: "2.0", + ID: 1, + Method: "tools/list", + } + + resp := server.handleToolsList(req) + result := resp.Result.(map[string]any) + tools := result["tools"].([]Tool) + + for _, tool := range tools { + assert.NotEmpty(t, tool.Name) + assert.NotEmpty(t, tool.Description) + assert.NotNil(t, tool.InputSchema) + + // Check schema has type + schema := tool.InputSchema + _, hasType := schema["type"] + assert.True(t, hasType, "tool %s schema should have type", tool.Name) + } +} diff --git a/internal/search/manager_test.go b/internal/search/manager_test.go new file mode 100644 index 0000000..427469e --- /dev/null +++ b/internal/search/manager_test.go @@ -0,0 +1,596 @@ +// Package search provides unified search capabilities for claude-mnemonic. +package search + +import ( + "database/sql" + "testing" + + "github.com/lukaszraczylo/claude-mnemonic/pkg/models" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/suite" +) + +// ManagerSuite is a test suite for search Manager operations. +type ManagerSuite struct { + suite.Suite +} + +func TestManagerSuite(t *testing.T) { + suite.Run(t, new(ManagerSuite)) +} + +// TestNewManager tests manager creation. +func (s *ManagerSuite) TestNewManager() { + // Test with nil stores (valid use case for testing) + m := NewManager(nil, nil, nil, nil) + s.NotNil(m) + s.Nil(m.observationStore) + s.Nil(m.summaryStore) + s.Nil(m.promptStore) + s.Nil(m.vectorClient) +} + +// TestSearchParams tests SearchParams defaults. +func (s *ManagerSuite) TestSearchParams() { + params := SearchParams{ + Query: "test query", + Project: "my-project", + Limit: 10, + } + + s.Equal("test query", params.Query) + s.Equal("my-project", params.Project) + s.Equal(10, params.Limit) + s.Equal("", params.Type) + s.Equal("", params.OrderBy) +} + +// TestSearchResult tests SearchResult struct. +func (s *ManagerSuite) TestSearchResult() { + result := SearchResult{ + Type: "observation", + ID: 123, + Title: "Test Title", + Content: "Test content", + Project: "my-project", + Scope: "project", + CreatedAt: 1704067200000, + Score: 0.95, + Metadata: map[string]interface{}{ + "obs_type": "discovery", + }, + } + + s.Equal("observation", result.Type) + s.Equal(int64(123), result.ID) + s.Equal("Test Title", result.Title) + s.Equal("Test content", result.Content) + s.Equal("my-project", result.Project) + s.Equal("project", result.Scope) + s.Equal(int64(1704067200000), result.CreatedAt) + s.Equal(0.95, result.Score) + s.Equal("discovery", result.Metadata["obs_type"]) +} + +// TestUnifiedSearchResult tests UnifiedSearchResult struct. +func (s *ManagerSuite) TestUnifiedSearchResult() { + result := UnifiedSearchResult{ + Results: []SearchResult{ + {Type: "observation", ID: 1}, + {Type: "session", ID: 2}, + }, + TotalCount: 2, + Query: "test", + } + + s.Len(result.Results, 2) + s.Equal(2, result.TotalCount) + s.Equal("test", result.Query) +} + +// TestTruncate tests the truncate helper function. +func TestTruncate(t *testing.T) { + tests := []struct { + name string + input string + maxLen int + expected string + }{ + { + name: "short string no truncation", + input: "hello", + maxLen: 10, + expected: "hello", + }, + { + name: "exact length no truncation", + input: "hello", + maxLen: 5, + expected: "hello", + }, + { + name: "long string truncated", + input: "hello world this is a long string", + maxLen: 10, + expected: "hello worl...", + }, + { + name: "empty string", + input: "", + maxLen: 10, + expected: "", + }, + { + name: "whitespace trimmed", + input: " hello ", + maxLen: 10, + expected: "hello", + }, + { + name: "whitespace trimmed then truncated", + input: " hello world this is long ", + maxLen: 10, + expected: "hello worl...", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + result := truncate(tt.input, tt.maxLen) + assert.Equal(t, tt.expected, result) + }) + } +} + +// TestObservationToResult tests observation to result conversion. +func TestObservationToResult(t *testing.T) { + m := NewManager(nil, nil, nil, nil) + + tests := []struct { + name string + obs *models.Observation + format string + expected SearchResult + }{ + { + name: "full format with all fields", + obs: &models.Observation{ + ID: 123, + Project: "my-project", + Type: models.ObsTypeDiscovery, + Scope: models.ScopeProject, + Title: sql.NullString{String: "Test Title", Valid: true}, + Narrative: sql.NullString{String: "Full narrative content", Valid: true}, + CreatedAtEpoch: 1704067200000, + }, + format: "full", + expected: SearchResult{ + Type: "observation", + ID: 123, + Title: "Test Title", + Content: "Full narrative content", + Project: "my-project", + Scope: "project", + CreatedAt: 1704067200000, + }, + }, + { + name: "index format no content", + obs: &models.Observation{ + ID: 456, + Project: "other-project", + Type: models.ObsTypeBugfix, + Scope: models.ScopeGlobal, + Title: sql.NullString{String: "Bug Fix", Valid: true}, + Narrative: sql.NullString{String: "Narrative here", Valid: true}, + CreatedAtEpoch: 1704067200000, + }, + format: "index", + expected: SearchResult{ + Type: "observation", + ID: 456, + Title: "Bug Fix", + Content: "", // Not included in index format + Project: "other-project", + Scope: "global", + CreatedAt: 1704067200000, + }, + }, + { + name: "null title", + obs: &models.Observation{ + ID: 789, + Project: "project", + Type: models.ObsTypeFeature, + Scope: models.ScopeProject, + Title: sql.NullString{Valid: false}, + Narrative: sql.NullString{Valid: false}, + CreatedAtEpoch: 1704067200000, + }, + format: "full", + expected: SearchResult{ + Type: "observation", + ID: 789, + Title: "", + Content: "", + Project: "project", + Scope: "project", + CreatedAt: 1704067200000, + }, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + result := m.observationToResult(tt.obs, tt.format) + assert.Equal(t, tt.expected.Type, result.Type) + assert.Equal(t, tt.expected.ID, result.ID) + assert.Equal(t, tt.expected.Title, result.Title) + assert.Equal(t, tt.expected.Content, result.Content) + assert.Equal(t, tt.expected.Project, result.Project) + assert.Equal(t, tt.expected.Scope, result.Scope) + assert.Equal(t, tt.expected.CreatedAt, result.CreatedAt) + }) + } +} + +// TestSummaryToResult tests summary to result conversion. +func TestSummaryToResult(t *testing.T) { + m := NewManager(nil, nil, nil, nil) + + tests := []struct { + name string + summary *models.SessionSummary + format string + expected SearchResult + }{ + { + name: "full format with all fields", + summary: &models.SessionSummary{ + ID: 123, + Project: "my-project", + Request: sql.NullString{String: "Test request", Valid: true}, + Learned: sql.NullString{String: "Learned this content", Valid: true}, + CreatedAtEpoch: 1704067200000, + }, + format: "full", + expected: SearchResult{ + Type: "session", + ID: 123, + Title: "Test request", + Content: "Learned this content", + Project: "my-project", + CreatedAt: 1704067200000, + }, + }, + { + name: "index format no content", + summary: &models.SessionSummary{ + ID: 456, + Project: "other-project", + Request: sql.NullString{String: "Another request", Valid: true}, + Learned: sql.NullString{String: "Some learning", Valid: true}, + CreatedAtEpoch: 1704067200000, + }, + format: "index", + expected: SearchResult{ + Type: "session", + ID: 456, + Title: "Another request", + Content: "", // Not included in index format + Project: "other-project", + CreatedAt: 1704067200000, + }, + }, + { + name: "long title truncated", + summary: &models.SessionSummary{ + ID: 789, + Project: "project", + Request: sql.NullString{String: "This is a very long request that should be truncated because it exceeds the maximum allowed length for titles which is 100 characters", Valid: true}, + Learned: sql.NullString{Valid: false}, + CreatedAtEpoch: 1704067200000, + }, + format: "full", + expected: SearchResult{ + Type: "session", + ID: 789, + Title: "This is a very long request that should be truncated because it exceeds the maximum allowed length f...", + Content: "", + Project: "project", + CreatedAt: 1704067200000, + }, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + result := m.summaryToResult(tt.summary, tt.format) + assert.Equal(t, tt.expected.Type, result.Type) + assert.Equal(t, tt.expected.ID, result.ID) + assert.Equal(t, tt.expected.Title, result.Title) + assert.Equal(t, tt.expected.Content, result.Content) + assert.Equal(t, tt.expected.Project, result.Project) + assert.Equal(t, tt.expected.CreatedAt, result.CreatedAt) + }) + } +} + +// TestPromptToResult tests prompt to result conversion. +func TestPromptToResult(t *testing.T) { + m := NewManager(nil, nil, nil, nil) + + tests := []struct { + name string + prompt *models.UserPromptWithSession + format string + expected SearchResult + }{ + { + name: "full format with content", + prompt: &models.UserPromptWithSession{ + UserPrompt: models.UserPrompt{ + ID: 123, + PromptText: "What is the meaning of life?", + CreatedAtEpoch: 1704067200000, + }, + Project: "my-project", + }, + format: "full", + expected: SearchResult{ + Type: "prompt", + ID: 123, + Title: "What is the meaning of life?", + Content: "What is the meaning of life?", + Project: "my-project", + CreatedAt: 1704067200000, + }, + }, + { + name: "index format no content", + prompt: &models.UserPromptWithSession{ + UserPrompt: models.UserPrompt{ + ID: 456, + PromptText: "Short prompt", + CreatedAtEpoch: 1704067200000, + }, + Project: "other-project", + }, + format: "index", + expected: SearchResult{ + Type: "prompt", + ID: 456, + Title: "Short prompt", + Content: "", + Project: "other-project", + CreatedAt: 1704067200000, + }, + }, + { + name: "long prompt truncated title", + prompt: &models.UserPromptWithSession{ + UserPrompt: models.UserPrompt{ + ID: 789, + PromptText: "This is a very long prompt that should be truncated because it exceeds the maximum allowed length for titles which is 100 characters and it keeps going", + CreatedAtEpoch: 1704067200000, + }, + Project: "project", + }, + format: "full", + expected: SearchResult{ + Type: "prompt", + ID: 789, + Title: "This is a very long prompt that should be truncated because it exceeds the maximum allowed length fo...", + Content: "This is a very long prompt that should be truncated because it exceeds the maximum allowed length for titles which is 100 characters and it keeps going", + Project: "project", + CreatedAt: 1704067200000, + }, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + result := m.promptToResult(tt.prompt, tt.format) + assert.Equal(t, tt.expected.Type, result.Type) + assert.Equal(t, tt.expected.ID, result.ID) + assert.Equal(t, tt.expected.Title, result.Title) + assert.Equal(t, tt.expected.Content, result.Content) + assert.Equal(t, tt.expected.Project, result.Project) + assert.Equal(t, tt.expected.CreatedAt, result.CreatedAt) + }) + } +} + +// TestSearchParamsValidation tests parameter validation in UnifiedSearch. +func TestSearchParamsValidation(t *testing.T) { + tests := []struct { + name string + params SearchParams + expectedLimit int + expectedOrder string + }{ + { + name: "default limit applied", + params: SearchParams{ + Query: "test", + Project: "project", + Limit: 0, + }, + expectedLimit: 20, + expectedOrder: "date_desc", + }, + { + name: "negative limit corrected", + params: SearchParams{ + Query: "test", + Project: "project", + Limit: -5, + }, + expectedLimit: 20, + expectedOrder: "date_desc", + }, + { + name: "limit over 100 capped", + params: SearchParams{ + Query: "test", + Project: "project", + Limit: 200, + }, + expectedLimit: 100, + expectedOrder: "date_desc", + }, + { + name: "custom limit preserved", + params: SearchParams{ + Query: "test", + Project: "project", + Limit: 50, + OrderBy: "relevance", + }, + expectedLimit: 50, + expectedOrder: "relevance", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + // Since we can't easily call UnifiedSearch without stores, + // we verify the expected values through logic + params := tt.params + + // Simulate the validation logic from UnifiedSearch + if params.Limit <= 0 { + params.Limit = 20 + } + if params.Limit > 100 { + params.Limit = 100 + } + if params.OrderBy == "" { + params.OrderBy = "date_desc" + } + + assert.Equal(t, tt.expectedLimit, params.Limit) + assert.Equal(t, tt.expectedOrder, params.OrderBy) + }) + } +} + +// TestDecisionsQueryBoost tests Decisions search query boosting. +func TestDecisionsQueryBoost(t *testing.T) { + tests := []struct { + name string + inputQuery string + expectedQuery string + }{ + { + name: "empty query not boosted", + inputQuery: "", + expectedQuery: "", + }, + { + name: "query boosted with keywords", + inputQuery: "authentication", + expectedQuery: "authentication decision chose architecture", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + params := SearchParams{Query: tt.inputQuery} + // Simulate Decisions boost logic + if params.Query != "" { + params.Query = params.Query + " decision chose architecture" + } + assert.Equal(t, tt.expectedQuery, params.Query) + }) + } +} + +// TestChangesQueryBoost tests Changes search query boosting. +func TestChangesQueryBoost(t *testing.T) { + tests := []struct { + name string + inputQuery string + expectedQuery string + }{ + { + name: "empty query not boosted", + inputQuery: "", + expectedQuery: "", + }, + { + name: "query boosted with keywords", + inputQuery: "handler", + expectedQuery: "handler changed modified refactored", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + params := SearchParams{Query: tt.inputQuery} + // Simulate Changes boost logic + if params.Query != "" { + params.Query = params.Query + " changed modified refactored" + } + assert.Equal(t, tt.expectedQuery, params.Query) + }) + } +} + +// TestHowItWorksQueryBoost tests HowItWorks search query boosting. +func TestHowItWorksQueryBoost(t *testing.T) { + tests := []struct { + name string + inputQuery string + expectedQuery string + }{ + { + name: "empty query not boosted", + inputQuery: "", + expectedQuery: "", + }, + { + name: "query boosted with keywords", + inputQuery: "database", + expectedQuery: "database architecture design pattern implements", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + params := SearchParams{Query: tt.inputQuery} + // Simulate HowItWorks boost logic + if params.Query != "" { + params.Query = params.Query + " architecture design pattern implements" + } + assert.Equal(t, tt.expectedQuery, params.Query) + }) + } +} + +// TestSearchTypeMapping tests type string to doc type mapping. +func TestSearchTypeMapping(t *testing.T) { + tests := []struct { + typeStr string + expected string + }{ + {"observations", "observation"}, + {"sessions", "session_summary"}, + {"prompts", "user_prompt"}, + {"", ""}, // Empty type for all + } + + for _, tt := range tests { + t.Run("type_"+tt.typeStr, func(t *testing.T) { + // This tests the type mapping logic + // Just verify the valid type strings + validTypes := map[string]bool{ + "observations": true, + "sessions": true, + "prompts": true, + "": true, + } + assert.True(t, validTypes[tt.typeStr]) + }) + } +} diff --git a/internal/worker/handlers_test.go b/internal/worker/handlers_test.go index 6602366..b86257a 100644 --- a/internal/worker/handlers_test.go +++ b/internal/worker/handlers_test.go @@ -2,11 +2,13 @@ package worker import ( + "bytes" "context" "database/sql" "encoding/json" "net/http" "net/http/httptest" + "strconv" "testing" "time" @@ -551,3 +553,729 @@ func TestRequireReadyMiddleware_Allows(t *testing.T) { assert.Equal(t, http.StatusOK, rec.Code) assert.Equal(t, "success", rec.Body.String()) } + +func TestHandleGetStats(t *testing.T) { + svc, cleanup := testService(t) + defer cleanup() + + req := httptest.NewRequest(http.MethodGet, "/api/stats", nil) + rec := httptest.NewRecorder() + + svc.handleGetStats(rec, req) + + assert.Equal(t, http.StatusOK, rec.Code) + + var response map[string]interface{} + err := json.Unmarshal(rec.Body.Bytes(), &response) + require.NoError(t, err) + + // Check basic stats fields exist + _, hasUptime := response["uptime"] + assert.True(t, hasUptime) + _, hasReady := response["ready"] + assert.True(t, hasReady) +} + +func TestHandleGetStats_WithProject(t *testing.T) { + svc, cleanup := testService(t) + defer cleanup() + + project := "test-project" + createTestObservation(t, svc.observationStore, project, "Test", "Test content", []string{"test"}) + + req := httptest.NewRequest(http.MethodGet, "/api/stats?project="+project, nil) + rec := httptest.NewRecorder() + + svc.handleGetStats(rec, req) + + assert.Equal(t, http.StatusOK, rec.Code) + + var response map[string]interface{} + err := json.Unmarshal(rec.Body.Bytes(), &response) + require.NoError(t, err) + + // Check project-specific stats + assert.Equal(t, project, response["project"]) + assert.Equal(t, float64(1), response["projectObservations"]) +} + +func TestHandleGetRetrievalStats(t *testing.T) { + svc, cleanup := testService(t) + defer cleanup() + + req := httptest.NewRequest(http.MethodGet, "/api/stats/retrieval", nil) + rec := httptest.NewRecorder() + + svc.handleGetRetrievalStats(rec, req) + + assert.Equal(t, http.StatusOK, rec.Code) + + var response RetrievalStats + err := json.Unmarshal(rec.Body.Bytes(), &response) + require.NoError(t, err) + + // Initially all stats should be 0 + assert.Equal(t, int64(0), response.TotalRequests) +} + +func TestHandleContextCount(t *testing.T) { + svc, cleanup := testService(t) + defer cleanup() + + project := "count-project" + + // Create some observations + for i := 0; i < 5; i++ { + createTestObservation(t, svc.observationStore, project, "Test "+string(rune('A'+i)), "Content", []string{"test"}) + } + + req := httptest.NewRequest(http.MethodGet, "/api/context/count?project="+project, nil) + rec := httptest.NewRecorder() + + svc.handleContextCount(rec, req) + + assert.Equal(t, http.StatusOK, rec.Code) + + var response map[string]interface{} + err := json.Unmarshal(rec.Body.Bytes(), &response) + require.NoError(t, err) + + assert.Equal(t, project, response["project"]) + assert.Equal(t, float64(5), response["count"]) +} + +func TestHandleContextCount_MissingProject(t *testing.T) { + svc, cleanup := testService(t) + defer cleanup() + + req := httptest.NewRequest(http.MethodGet, "/api/context/count", nil) + rec := httptest.NewRecorder() + + svc.handleContextCount(rec, req) + + assert.Equal(t, http.StatusBadRequest, rec.Code) +} + +func TestHandleGetProjects(t *testing.T) { + svc, cleanup := testService(t) + defer cleanup() + + // Create sessions for different projects + ctx := context.Background() + svc.sessionStore.CreateSDKSession(ctx, "session-1", "project-alpha", "") + svc.sessionStore.CreateSDKSession(ctx, "session-2", "project-beta", "") + svc.sessionStore.CreateSDKSession(ctx, "session-3", "project-gamma", "") + + req := httptest.NewRequest(http.MethodGet, "/api/projects", nil) + rec := httptest.NewRecorder() + + svc.handleGetProjects(rec, req) + + assert.Equal(t, http.StatusOK, rec.Code) + + var projects []string + err := json.Unmarshal(rec.Body.Bytes(), &projects) + require.NoError(t, err) + + assert.Len(t, projects, 3) + assert.Contains(t, projects, "project-alpha") + assert.Contains(t, projects, "project-beta") + assert.Contains(t, projects, "project-gamma") +} + +func TestHandleGetTypes(t *testing.T) { + svc, cleanup := testService(t) + defer cleanup() + + req := httptest.NewRequest(http.MethodGet, "/api/types", nil) + rec := httptest.NewRecorder() + + svc.handleGetTypes(rec, req) + + assert.Equal(t, http.StatusOK, rec.Code) + + var response map[string]interface{} + err := json.Unmarshal(rec.Body.Bytes(), &response) + require.NoError(t, err) + + // Check observation types + obsTypes, ok := response["observation_types"].([]interface{}) + require.True(t, ok) + assert.Contains(t, toStringSlice(obsTypes), "bugfix") + assert.Contains(t, toStringSlice(obsTypes), "feature") + + // Check concept types + conceptTypes, ok := response["concept_types"].([]interface{}) + require.True(t, ok) + assert.Contains(t, toStringSlice(conceptTypes), "security") +} + +func toStringSlice(arr []interface{}) []string { + result := make([]string, len(arr)) + for i, v := range arr { + result[i] = v.(string) + } + return result +} + +func TestHandleGetSummaries(t *testing.T) { + svc, cleanup := testService(t) + defer cleanup() + + // Create some summaries + ctx := context.Background() + for i := 0; i < 3; i++ { + parsed := &models.ParsedSummary{ + Request: "Test request " + string(rune('A'+i)), + Completed: "Test completed", + } + sdkSessionID := "sdk-" + string(rune('a'+i)) + _, _, err := svc.summaryStore.StoreSummary(ctx, sdkSessionID, "project-a", parsed, i+1, 100) + require.NoError(t, err) + } + + req := httptest.NewRequest(http.MethodGet, "/api/summaries?project=project-a&limit=10", nil) + rec := httptest.NewRecorder() + + svc.handleGetSummaries(rec, req) + + assert.Equal(t, http.StatusOK, rec.Code) + + var summaries []map[string]interface{} + err := json.Unmarshal(rec.Body.Bytes(), &summaries) + require.NoError(t, err) + + assert.Len(t, summaries, 3) +} + +func TestHandleGetPrompts(t *testing.T) { + svc, cleanup := testService(t) + defer cleanup() + + // Create sessions and prompts + ctx := context.Background() + svc.sessionStore.CreateSDKSession(ctx, "claude-test", "project-x", "") + + // Save prompts + for i := 0; i < 5; i++ { + _, err := svc.promptStore.SaveUserPromptWithMatches(ctx, "claude-test", i+1, "Test prompt "+string(rune('A'+i)), 0) + require.NoError(t, err) + } + + req := httptest.NewRequest(http.MethodGet, "/api/prompts?project=project-x&limit=10", nil) + rec := httptest.NewRecorder() + + svc.handleGetPrompts(rec, req) + + assert.Equal(t, http.StatusOK, rec.Code) + + var prompts []map[string]interface{} + err := json.Unmarshal(rec.Body.Bytes(), &prompts) + require.NoError(t, err) + + assert.Len(t, prompts, 5) +} + +func TestHandleSelfCheck(t *testing.T) { + svc, cleanup := testService(t) + defer cleanup() + + svc.ready.Store(true) + + req := httptest.NewRequest(http.MethodGet, "/api/self-check", nil) + rec := httptest.NewRecorder() + + svc.handleSelfCheck(rec, req) + + assert.Equal(t, http.StatusOK, rec.Code) + + var response SelfCheckResponse + err := json.Unmarshal(rec.Body.Bytes(), &response) + require.NoError(t, err) + + // Overall health should be healthy or degraded (not unhealthy for basic tests) + assert.NotEqual(t, "unhealthy", response.Overall) + assert.NotEmpty(t, response.Version) + assert.NotEmpty(t, response.Uptime) + assert.NotEmpty(t, response.Components) +} + +func TestHandleSelfCheck_NotReady(t *testing.T) { + svc, cleanup := testService(t) + defer cleanup() + + svc.ready.Store(false) + + req := httptest.NewRequest(http.MethodGet, "/api/self-check", nil) + rec := httptest.NewRecorder() + + svc.handleSelfCheck(rec, req) + + assert.Equal(t, http.StatusOK, rec.Code) + + var response SelfCheckResponse + err := json.Unmarshal(rec.Body.Bytes(), &response) + require.NoError(t, err) + + // Should be degraded when not ready + assert.Equal(t, "degraded", response.Overall) +} + +func TestObservationTypesAndConcepts(t *testing.T) { + // Verify observation types + assert.Contains(t, ObservationTypes, "bugfix") + assert.Contains(t, ObservationTypes, "feature") + assert.Contains(t, ObservationTypes, "refactor") + assert.Contains(t, ObservationTypes, "discovery") + assert.Contains(t, ObservationTypes, "decision") + assert.Contains(t, ObservationTypes, "change") + + // Verify concept types + assert.Contains(t, ConceptTypes, "how-it-works") + assert.Contains(t, ConceptTypes, "security") + assert.Contains(t, ConceptTypes, "best-practice") +} + +func TestWriteJSON(t *testing.T) { + rec := httptest.NewRecorder() + + data := map[string]string{"test": "value"} + writeJSON(rec, data) + + assert.Equal(t, "application/json", rec.Header().Get("Content-Type")) + + var result map[string]string + err := json.Unmarshal(rec.Body.Bytes(), &result) + require.NoError(t, err) + assert.Equal(t, "value", result["test"]) +} + +func TestDefaultLimitConstants(t *testing.T) { + assert.Equal(t, 100, DefaultObservationsLimit) + assert.Equal(t, 50, DefaultSummariesLimit) + assert.Equal(t, 100, DefaultPromptsLimit) + assert.Equal(t, 50, DefaultSearchLimit) + assert.Equal(t, 50, DefaultContextLimit) +} + +func TestDuplicatePromptWindowSeconds(t *testing.T) { + assert.Equal(t, 10, DuplicatePromptWindowSeconds) +} + +func TestHandleSessionInit_Success(t *testing.T) { + svc, cleanup := testService(t) + defer cleanup() + + reqBody := SessionInitRequest{ + ClaudeSessionID: "claude-test-123", + Project: "test-project", + Prompt: "Help me fix this bug", + MatchedObservations: 5, + } + + body, _ := json.Marshal(reqBody) + req := httptest.NewRequest(http.MethodPost, "/api/sessions/init", bytes.NewReader(body)) + req.Header.Set("Content-Type", "application/json") + rec := httptest.NewRecorder() + + svc.router.ServeHTTP(rec, req) + + assert.Equal(t, http.StatusOK, rec.Code) + + var response SessionInitResponse + err := json.Unmarshal(rec.Body.Bytes(), &response) + require.NoError(t, err) + + assert.Greater(t, response.SessionDBID, int64(0)) + assert.Equal(t, 1, response.PromptNumber) + assert.False(t, response.Skipped) +} + +func TestHandleSessionInit_InvalidJSON(t *testing.T) { + svc, cleanup := testService(t) + defer cleanup() + + req := httptest.NewRequest(http.MethodPost, "/api/sessions/init", bytes.NewReader([]byte("invalid json"))) + req.Header.Set("Content-Type", "application/json") + rec := httptest.NewRecorder() + + svc.router.ServeHTTP(rec, req) + + assert.Equal(t, http.StatusBadRequest, rec.Code) +} + +func TestHandleSessionInit_PrivatePrompt(t *testing.T) { + svc, cleanup := testService(t) + defer cleanup() + + reqBody := SessionInitRequest{ + ClaudeSessionID: "claude-private", + Project: "test-project", + Prompt: "This is a private prompt", + } + + body, _ := json.Marshal(reqBody) + req := httptest.NewRequest(http.MethodPost, "/api/sessions/init", bytes.NewReader(body)) + req.Header.Set("Content-Type", "application/json") + rec := httptest.NewRecorder() + + svc.router.ServeHTTP(rec, req) + + assert.Equal(t, http.StatusOK, rec.Code) + + var response SessionInitResponse + err := json.Unmarshal(rec.Body.Bytes(), &response) + require.NoError(t, err) + + assert.True(t, response.Skipped) + assert.Equal(t, "private", response.Reason) +} + +func TestHandleSessionInit_DuplicatePrompt(t *testing.T) { + svc, cleanup := testService(t) + defer cleanup() + + reqBody := SessionInitRequest{ + ClaudeSessionID: "claude-dup-test", + Project: "test-project", + Prompt: "Help me fix this specific bug", + } + + body, _ := json.Marshal(reqBody) + + // First request + req1 := httptest.NewRequest(http.MethodPost, "/api/sessions/init", bytes.NewReader(body)) + req1.Header.Set("Content-Type", "application/json") + rec1 := httptest.NewRecorder() + svc.router.ServeHTTP(rec1, req1) + + assert.Equal(t, http.StatusOK, rec1.Code) + var resp1 SessionInitResponse + json.Unmarshal(rec1.Body.Bytes(), &resp1) + + // Second request with same prompt (duplicate) + body2, _ := json.Marshal(reqBody) + req2 := httptest.NewRequest(http.MethodPost, "/api/sessions/init", bytes.NewReader(body2)) + req2.Header.Set("Content-Type", "application/json") + rec2 := httptest.NewRecorder() + svc.router.ServeHTTP(rec2, req2) + + assert.Equal(t, http.StatusOK, rec2.Code) + var resp2 SessionInitResponse + json.Unmarshal(rec2.Body.Bytes(), &resp2) + + // Should return same prompt number (duplicate detected) + assert.Equal(t, resp1.PromptNumber, resp2.PromptNumber) +} + +func TestHandleSessionStart_Success(t *testing.T) { + svc, cleanup := testService(t) + defer cleanup() + + // First create a session + ctx := context.Background() + sessionID, _ := svc.sessionStore.CreateSDKSession(ctx, "claude-start-test", "test-project", "test prompt") + + reqBody := SessionStartRequest{ + UserPrompt: "Help me with something", + PromptNumber: 1, + } + + body, _ := json.Marshal(reqBody) + req := httptest.NewRequest(http.MethodPost, "/sessions/"+strconv.FormatInt(sessionID, 10)+"/init", bytes.NewReader(body)) + req.Header.Set("Content-Type", "application/json") + rec := httptest.NewRecorder() + + svc.router.ServeHTTP(rec, req) + + assert.Equal(t, http.StatusOK, rec.Code) +} + +func TestHandleSessionStart_InvalidID(t *testing.T) { + svc, cleanup := testService(t) + defer cleanup() + + reqBody := SessionStartRequest{ + UserPrompt: "Help me", + PromptNumber: 1, + } + + body, _ := json.Marshal(reqBody) + req := httptest.NewRequest(http.MethodPost, "/sessions/invalid/init", bytes.NewReader(body)) + req.Header.Set("Content-Type", "application/json") + rec := httptest.NewRecorder() + + svc.router.ServeHTTP(rec, req) + + assert.Equal(t, http.StatusBadRequest, rec.Code) +} + +func TestHandleSessionStart_NotFound(t *testing.T) { + svc, cleanup := testService(t) + defer cleanup() + + reqBody := SessionStartRequest{ + UserPrompt: "Help me", + PromptNumber: 1, + } + + body, _ := json.Marshal(reqBody) + req := httptest.NewRequest(http.MethodPost, "/sessions/999999/init", bytes.NewReader(body)) + req.Header.Set("Content-Type", "application/json") + rec := httptest.NewRecorder() + + svc.router.ServeHTTP(rec, req) + + assert.Equal(t, http.StatusNotFound, rec.Code) +} + +func TestHandleSessionStart_InvalidJSON(t *testing.T) { + svc, cleanup := testService(t) + defer cleanup() + + ctx := context.Background() + sessionID, _ := svc.sessionStore.CreateSDKSession(ctx, "claude-json-test", "test-project", "") + + req := httptest.NewRequest(http.MethodPost, "/sessions/"+strconv.FormatInt(sessionID, 10)+"/init", bytes.NewReader([]byte("not json"))) + req.Header.Set("Content-Type", "application/json") + rec := httptest.NewRecorder() + + svc.router.ServeHTTP(rec, req) + + assert.Equal(t, http.StatusBadRequest, rec.Code) +} + +func TestHandleObservation_SessionNotFound(t *testing.T) { + svc, cleanup := testService(t) + defer cleanup() + + reqBody := ObservationRequest{ + ClaudeSessionID: "non-existent-session", + Project: "test-project", + ToolName: "Read", + ToolInput: map[string]string{"path": "/test.go"}, + ToolResponse: "file content", + CWD: "/test", + } + + body, _ := json.Marshal(reqBody) + req := httptest.NewRequest(http.MethodPost, "/api/sessions/observations", bytes.NewReader(body)) + req.Header.Set("Content-Type", "application/json") + rec := httptest.NewRecorder() + + svc.router.ServeHTTP(rec, req) + + // Should return 200 (queues observation) or 404 (session not found) + assert.Contains(t, []int{http.StatusOK, http.StatusNotFound}, rec.Code) +} + +func TestHandleObservation_InvalidJSON(t *testing.T) { + svc, cleanup := testService(t) + defer cleanup() + + req := httptest.NewRequest(http.MethodPost, "/api/sessions/observations", bytes.NewReader([]byte("invalid"))) + req.Header.Set("Content-Type", "application/json") + rec := httptest.NewRecorder() + + svc.router.ServeHTTP(rec, req) + + assert.Equal(t, http.StatusBadRequest, rec.Code) +} + +func TestHandleObservation_WithExistingSession(t *testing.T) { + svc, cleanup := testService(t) + defer cleanup() + + // Create a session first + ctx := context.Background() + svc.sessionStore.CreateSDKSession(ctx, "claude-obs-test", "test-project", "test prompt") + + reqBody := ObservationRequest{ + ClaudeSessionID: "claude-obs-test", + Project: "test-project", + ToolName: "Write", + ToolInput: map[string]string{"path": "/test.go"}, + ToolResponse: "success", + CWD: "/project", + } + + body, _ := json.Marshal(reqBody) + req := httptest.NewRequest(http.MethodPost, "/api/sessions/observations", bytes.NewReader(body)) + req.Header.Set("Content-Type", "application/json") + rec := httptest.NewRecorder() + + svc.router.ServeHTTP(rec, req) + + assert.Equal(t, http.StatusOK, rec.Code) +} + +func TestHandleGetObservations_DefaultLimit(t *testing.T) { + svc, cleanup := testService(t) + defer cleanup() + + // Create more than default limit + for i := 0; i < 120; i++ { + createTestObservation(t, svc.observationStore, "project-limit", + "Test "+strconv.Itoa(i), + "Content "+strconv.Itoa(i), + []string{"test"}) + } + + req := httptest.NewRequest(http.MethodGet, "/api/observations", nil) + rec := httptest.NewRecorder() + + svc.router.ServeHTTP(rec, req) + + assert.Equal(t, http.StatusOK, rec.Code) + + var observations []map[string]interface{} + err := json.Unmarshal(rec.Body.Bytes(), &observations) + require.NoError(t, err) + + // Should return default limit (100) + assert.LessOrEqual(t, len(observations), DefaultObservationsLimit) +} + +func TestHandleGetObservations_FilterByProject(t *testing.T) { + svc, cleanup := testService(t) + defer cleanup() + + // Create observations in different projects + createTestObservation(t, svc.observationStore, "alpha", "Alpha 1", "Content", []string{"test"}) + createTestObservation(t, svc.observationStore, "alpha", "Alpha 2", "Content", []string{"test"}) + createTestObservation(t, svc.observationStore, "beta", "Beta 1", "Content", []string{"test"}) + + req := httptest.NewRequest(http.MethodGet, "/api/observations?project=alpha", nil) + rec := httptest.NewRecorder() + + svc.router.ServeHTTP(rec, req) + + assert.Equal(t, http.StatusOK, rec.Code) + + var observations []map[string]interface{} + err := json.Unmarshal(rec.Body.Bytes(), &observations) + require.NoError(t, err) + + assert.Len(t, observations, 2) +} + +func TestHandleGetObservations_FilterByType(t *testing.T) { + svc, cleanup := testService(t) + defer cleanup() + + // Create observations - createTestObservation creates discovery type + createTestObservation(t, svc.observationStore, "type-test", "Test 1", "Content", []string{"test"}) + createTestObservation(t, svc.observationStore, "type-test", "Test 2", "Content", []string{"test"}) + + req := httptest.NewRequest(http.MethodGet, "/api/observations?type=discovery", nil) + rec := httptest.NewRecorder() + + svc.router.ServeHTTP(rec, req) + + assert.Equal(t, http.StatusOK, rec.Code) +} + +func TestHandleGetSummaries_DefaultLimit(t *testing.T) { + svc, cleanup := testService(t) + defer cleanup() + + ctx := context.Background() + // Create more than default limit + for i := 0; i < 60; i++ { + parsed := &models.ParsedSummary{Request: "Request " + strconv.Itoa(i)} + svc.summaryStore.StoreSummary(ctx, "sdk-"+strconv.Itoa(i), "project-sum", parsed, i+1, 100) + } + + req := httptest.NewRequest(http.MethodGet, "/api/summaries", nil) + rec := httptest.NewRecorder() + + svc.router.ServeHTTP(rec, req) + + assert.Equal(t, http.StatusOK, rec.Code) + + var summaries []map[string]interface{} + err := json.Unmarshal(rec.Body.Bytes(), &summaries) + require.NoError(t, err) + + assert.LessOrEqual(t, len(summaries), DefaultSummariesLimit) +} + +func TestHandleGetPrompts_DefaultLimit(t *testing.T) { + svc, cleanup := testService(t) + defer cleanup() + + ctx := context.Background() + svc.sessionStore.CreateSDKSession(ctx, "claude-prompts", "project-prompts", "") + + // Create more than default limit + for i := 0; i < 120; i++ { + svc.promptStore.SaveUserPromptWithMatches(ctx, "claude-prompts", i+1, "Prompt "+strconv.Itoa(i), 0) + } + + req := httptest.NewRequest(http.MethodGet, "/api/prompts", nil) + rec := httptest.NewRecorder() + + svc.router.ServeHTTP(rec, req) + + assert.Equal(t, http.StatusOK, rec.Code) + + var prompts []map[string]interface{} + err := json.Unmarshal(rec.Body.Bytes(), &prompts) + require.NoError(t, err) + + assert.LessOrEqual(t, len(prompts), DefaultPromptsLimit) +} + +func TestSessionInitRequest_Fields(t *testing.T) { + req := SessionInitRequest{ + ClaudeSessionID: "test-123", + Project: "my-project", + Prompt: "Help me", + MatchedObservations: 10, + } + + assert.Equal(t, "test-123", req.ClaudeSessionID) + assert.Equal(t, "my-project", req.Project) + assert.Equal(t, "Help me", req.Prompt) + assert.Equal(t, 10, req.MatchedObservations) +} + +func TestSessionInitResponse_Fields(t *testing.T) { + resp := SessionInitResponse{ + SessionDBID: 123, + PromptNumber: 5, + Skipped: true, + Reason: "private", + } + + assert.Equal(t, int64(123), resp.SessionDBID) + assert.Equal(t, 5, resp.PromptNumber) + assert.True(t, resp.Skipped) + assert.Equal(t, "private", resp.Reason) +} + +func TestSessionStartRequest_Fields(t *testing.T) { + req := SessionStartRequest{ + UserPrompt: "Help me with code", + PromptNumber: 3, + } + + assert.Equal(t, "Help me with code", req.UserPrompt) + assert.Equal(t, 3, req.PromptNumber) +} + +func TestObservationRequest_Fields(t *testing.T) { + req := ObservationRequest{ + ClaudeSessionID: "session-abc", + Project: "my-project", + ToolName: "Read", + ToolInput: map[string]string{"path": "/file.go"}, + ToolResponse: "file contents", + CWD: "/home/user/project", + } + + assert.Equal(t, "session-abc", req.ClaudeSessionID) + assert.Equal(t, "my-project", req.Project) + assert.Equal(t, "Read", req.ToolName) + assert.Equal(t, "/home/user/project", req.CWD) +} diff --git a/internal/worker/sdk/parser_test.go b/internal/worker/sdk/parser_test.go new file mode 100644 index 0000000..e89b981 --- /dev/null +++ b/internal/worker/sdk/parser_test.go @@ -0,0 +1,537 @@ +package sdk + +import ( + "testing" + + "github.com/lukaszraczylo/claude-mnemonic/pkg/models" + "github.com/stretchr/testify/assert" +) + +func TestParseObservations_SingleObservation(t *testing.T) { + text := `Some text before + +bugfix +Fixed null pointer error +In user service +The service was crashing when user ID was nil + +Added nil check +Added unit test + + +error-handling +debugging + + +user_service.go + + +user_service.go +user_service_test.go + + +Some text after` + + observations := ParseObservations(text, "test-correlation-id") + + assert.Len(t, observations, 1) + obs := observations[0] + assert.Equal(t, models.ObservationType("bugfix"), obs.Type) + assert.Equal(t, "Fixed null pointer error", obs.Title) + assert.Equal(t, "In user service", obs.Subtitle) + assert.Equal(t, "The service was crashing when user ID was nil", obs.Narrative) + assert.Equal(t, []string{"Added nil check", "Added unit test"}, obs.Facts) + assert.Equal(t, []string{"error-handling", "debugging"}, obs.Concepts) + assert.Equal(t, []string{"user_service.go"}, obs.FilesRead) + assert.Equal(t, []string{"user_service.go", "user_service_test.go"}, obs.FilesModified) +} + +func TestParseObservations_MultipleObservations(t *testing.T) { + text := ` + +feature +Added caching +Implemented Redis caching +Added cache layer +caching + + +refactor +Cleaned up code +Removed dead code +Removed unused functions +refactoring + +` + + observations := ParseObservations(text, "test-id") + + assert.Len(t, observations, 2) + assert.Equal(t, models.ObservationType("feature"), observations[0].Type) + assert.Equal(t, "Added caching", observations[0].Title) + assert.Equal(t, models.ObservationType("refactor"), observations[1].Type) + assert.Equal(t, "Cleaned up code", observations[1].Title) +} + +func TestParseObservations_TableDriven(t *testing.T) { + tests := []struct { + name string + input string + expectedCount int + expectedType models.ObservationType + expectedTitle string + checkConcepts []string + }{ + { + name: "valid_bugfix_observation", + input: ` +bugfix +Fixed bug +Details +`, + expectedCount: 1, + expectedType: models.ObsTypeBugfix, + expectedTitle: "Fixed bug", + }, + { + name: "valid_feature_observation", + input: ` +feature +New feature +Added new stuff +`, + expectedCount: 1, + expectedType: models.ObsTypeFeature, + expectedTitle: "New feature", + }, + { + name: "valid_refactor_observation", + input: ` +refactor +Code cleanup +Refactored module +`, + expectedCount: 1, + expectedType: models.ObsTypeRefactor, + expectedTitle: "Code cleanup", + }, + { + name: "valid_change_observation", + input: ` +change +Config update +Changed settings +`, + expectedCount: 1, + expectedType: models.ObsTypeChange, + expectedTitle: "Config update", + }, + { + name: "valid_discovery_observation", + input: ` +discovery +Found pattern +Discovered new pattern +`, + expectedCount: 1, + expectedType: models.ObsTypeDiscovery, + expectedTitle: "Found pattern", + }, + { + name: "valid_decision_observation", + input: ` +decision +Architecture decision +Chose microservices +`, + expectedCount: 1, + expectedType: models.ObsTypeDecision, + expectedTitle: "Architecture decision", + }, + { + name: "invalid_type_defaults_to_change", + input: ` +invalid_type +Some title +Details +`, + expectedCount: 1, + expectedType: models.ObsTypeChange, + expectedTitle: "Some title", + }, + { + name: "missing_type_defaults_to_change", + input: ` +No type specified +Details +`, + expectedCount: 1, + expectedType: models.ObsTypeChange, + expectedTitle: "No type specified", + }, + { + name: "empty_input", + input: "", + expectedCount: 0, + }, + { + name: "no_observation_tags", + input: "Just regular text without any observation", + expectedCount: 0, + }, + { + name: "valid_concepts_filtered", + input: ` +bugfix +Test +Test + +best-practice +invalid-concept +security + +`, + expectedCount: 1, + expectedType: models.ObsTypeBugfix, + checkConcepts: []string{"best-practice", "security"}, + }, + { + name: "type_in_concepts_filtered_out", + input: ` +bugfix +Test +Test + +bugfix +security + +`, + expectedCount: 1, + expectedType: models.ObsTypeBugfix, + checkConcepts: []string{"security"}, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + observations := ParseObservations(tt.input, "test-correlation-id") + + assert.Len(t, observations, tt.expectedCount) + if tt.expectedCount > 0 { + obs := observations[0] + assert.Equal(t, tt.expectedType, obs.Type) + if tt.expectedTitle != "" { + assert.Equal(t, tt.expectedTitle, obs.Title) + } + if tt.checkConcepts != nil { + assert.Equal(t, tt.checkConcepts, obs.Concepts) + } + } + }) + } +} + +func TestParseObservations_AllValidConcepts(t *testing.T) { + // Test all valid concepts are accepted + validConcepts := []string{ + "how-it-works", "why-it-exists", "what-changed", "problem-solution", "gotcha", "pattern", "trade-off", + "best-practice", "anti-pattern", "architecture", "security", "performance", "testing", "debugging", "workflow", "tooling", + "refactoring", "api", "database", "configuration", "error-handling", "caching", "logging", "auth", "validation", + } + + for _, concept := range validConcepts { + t.Run("concept_"+concept, func(t *testing.T) { + input := ` +discovery +Test +Test +` + concept + ` +` + + observations := ParseObservations(input, "test-id") + assert.Len(t, observations, 1) + assert.Contains(t, observations[0].Concepts, concept) + }) + } +} + +func TestParseObservations_ConceptCaseInsensitive(t *testing.T) { + input := ` +discovery +Test +Test + +SECURITY +Best-Practice + caching + +` + + observations := ParseObservations(input, "test-id") + + assert.Len(t, observations, 1) + assert.Equal(t, []string{"security", "best-practice", "caching"}, observations[0].Concepts) +} + +func TestParseSummary_ValidSummary(t *testing.T) { + text := `Some text before + +User asked to fix the bug +Looked at error logs and stack traces +The issue was a race condition +Fixed the race condition with mutex +Add more tests for concurrent access +May need to review similar code elsewhere + +Some text after` + + summary := ParseSummary(text, 123) + + assert.NotNil(t, summary) + assert.Equal(t, "User asked to fix the bug", summary.Request) + assert.Equal(t, "Looked at error logs and stack traces", summary.Investigated) + assert.Equal(t, "The issue was a race condition", summary.Learned) + assert.Equal(t, "Fixed the race condition with mutex", summary.Completed) + assert.Equal(t, "Add more tests for concurrent access", summary.NextSteps) + assert.Equal(t, "May need to review similar code elsewhere", summary.Notes) +} + +func TestParseSummary_TableDriven(t *testing.T) { + tests := []struct { + name string + input string + sessionID int64 + expectNil bool + expectedRequest string + }{ + { + name: "empty_input", + input: "", + sessionID: 1, + expectNil: true, + }, + { + name: "no_summary_tag", + input: "Just some text without summary", + sessionID: 1, + expectNil: true, + }, + { + name: "skip_summary_tag", + input: ``, + sessionID: 1, + expectNil: true, + }, + { + name: "skip_summary_with_different_reason", + input: ``, + sessionID: 2, + expectNil: true, + }, + { + name: "valid_summary_minimal", + input: ` +Test request +`, + sessionID: 3, + expectNil: false, + expectedRequest: "Test request", + }, + { + name: "valid_summary_all_fields", + input: ` +Full request +Full investigated +Full learned +Full completed +Full next steps +Full notes +`, + sessionID: 4, + expectNil: false, + expectedRequest: "Full request", + }, + { + name: "summary_with_empty_fields", + input: ` + + +`, + sessionID: 5, + expectNil: false, + expectedRequest: "", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + summary := ParseSummary(tt.input, tt.sessionID) + + if tt.expectNil { + assert.Nil(t, summary) + } else { + assert.NotNil(t, summary) + assert.Equal(t, tt.expectedRequest, summary.Request) + } + }) + } +} + +func TestParseSummary_SkipSummaryPriority(t *testing.T) { + // skip_summary should take priority over summary block + text := ` + +This should be ignored +` + + summary := ParseSummary(text, 1) + assert.Nil(t, summary) +} + +func TestExtractField_TableDriven(t *testing.T) { + tests := []struct { + name string + content string + fieldName string + expected string + }{ + { + name: "simple_field", + content: "Test Title", + fieldName: "title", + expected: "Test Title", + }, + { + name: "field_with_whitespace", + content: " Test Title ", + fieldName: "title", + expected: "Test Title", + }, + { + name: "field_not_found", + content: "Value", + fieldName: "title", + expected: "", + }, + { + name: "empty_field", + content: "", + fieldName: "title", + expected: "", + }, + { + name: "nested_content", + content: "Nested", + fieldName: "title", + expected: "Nested", + }, + { + name: "field_among_others", + content: "ATargetB", + fieldName: "title", + expected: "Target", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + result := extractField(tt.content, tt.fieldName) + assert.Equal(t, tt.expected, result) + }) + } +} + +func TestExtractArrayElements_TableDriven(t *testing.T) { + tests := []struct { + name string + content string + arrayName string + elementName string + expected []string + }{ + { + name: "simple_array", + content: "OneTwo", + arrayName: "facts", + elementName: "fact", + expected: []string{"One", "Two"}, + }, + { + name: "empty_array", + content: "", + arrayName: "facts", + elementName: "fact", + expected: nil, + }, + { + name: "array_not_found", + content: "Value", + arrayName: "facts", + elementName: "fact", + expected: nil, + }, + { + name: "single_element", + content: "security", + arrayName: "concepts", + elementName: "concept", + expected: []string{"security"}, + }, + { + name: "multiline_array", + content: ` +file1.go +file2.go +file3.go +`, + arrayName: "files", + elementName: "file", + expected: []string{"file1.go", "file2.go", "file3.go"}, + }, + { + name: "whitespace_trimmed", + content: " trimmed ", + arrayName: "items", + elementName: "item", + expected: []string{"trimmed"}, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + result := extractArrayElements(tt.content, tt.arrayName, tt.elementName) + assert.Equal(t, tt.expected, result) + }) + } +} + +func TestValidObsTypes(t *testing.T) { + expected := map[string]bool{ + "bugfix": true, + "feature": true, + "refactor": true, + "change": true, + "discovery": true, + "decision": true, + } + assert.Equal(t, expected, validObsTypes) +} + +func TestValidConcepts(t *testing.T) { + // Verify expected concepts are valid + expectedValid := []string{ + "how-it-works", "why-it-exists", "what-changed", "problem-solution", "gotcha", "pattern", "trade-off", + "best-practice", "anti-pattern", "architecture", "security", "performance", "testing", "debugging", "workflow", "tooling", + "refactoring", "api", "database", "configuration", "error-handling", "caching", "logging", "auth", "validation", + } + + for _, concept := range expectedValid { + assert.True(t, validConcepts[concept], "Expected %s to be valid", concept) + } + + // Verify invalid concepts + invalidConcepts := []string{"random", "invalid", "not-a-concept", "foo", "bar"} + for _, concept := range invalidConcepts { + assert.False(t, validConcepts[concept], "Expected %s to be invalid", concept) + } +} diff --git a/internal/worker/session/manager_test.go b/internal/worker/session/manager_test.go new file mode 100644 index 0000000..6e7f654 --- /dev/null +++ b/internal/worker/session/manager_test.go @@ -0,0 +1,695 @@ +// Package session provides session lifecycle management for claude-mnemonic. +package session + +import ( + "context" + "sync" + "testing" + "time" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/suite" +) + +// ManagerSuite is a test suite for Manager operations. +type ManagerSuite struct { + suite.Suite + manager *Manager +} + +func (s *ManagerSuite) SetupTest() { + // Create manager without real session store (use nil for unit tests) + s.manager = &Manager{ + sessions: make(map[int64]*ActiveSession), + ProcessNotify: make(chan struct{}, 1), + } + // Initialize context for manager + ctx, cancel := context.WithCancel(context.Background()) + s.manager.ctx = ctx + s.manager.cancel = cancel +} + +func (s *ManagerSuite) TearDownTest() { + if s.manager != nil && s.manager.cancel != nil { + s.manager.cancel() + } +} + +func TestManagerSuite(t *testing.T) { + suite.Run(t, new(ManagerSuite)) +} + +// TestActiveSession tests ActiveSession creation and basic operations. +func (s *ManagerSuite) TestActiveSession() { + session := &ActiveSession{ + SessionDBID: 1, + ClaudeSessionID: "claude-123", + SDKSessionID: "sdk-123", + Project: "test-project", + UserPrompt: "Hello", + StartTime: time.Now(), + pendingMessages: make([]PendingMessage, 0), + notify: make(chan struct{}, 1), + } + + s.Equal(int64(1), session.SessionDBID) + s.Equal("claude-123", session.ClaudeSessionID) + s.Equal("sdk-123", session.SDKSessionID) + s.Equal("test-project", session.Project) + s.Equal("Hello", session.UserPrompt) +} + +// TestGetActiveSessionCount tests session counting. +func (s *ManagerSuite) TestGetActiveSessionCount() { + // Initially 0 + s.Equal(0, s.manager.GetActiveSessionCount()) + + // Add sessions directly for testing + s.manager.sessions[1] = &ActiveSession{SessionDBID: 1} + s.manager.sessions[2] = &ActiveSession{SessionDBID: 2} + + s.Equal(2, s.manager.GetActiveSessionCount()) +} + +// TestGetTotalQueueDepth tests queue depth calculation. +func (s *ManagerSuite) TestGetTotalQueueDepth() { + // Initially 0 + s.Equal(0, s.manager.GetTotalQueueDepth()) + + // Add sessions with pending messages + s.manager.sessions[1] = &ActiveSession{ + SessionDBID: 1, + pendingMessages: make([]PendingMessage, 3), + } + s.manager.sessions[2] = &ActiveSession{ + SessionDBID: 2, + pendingMessages: make([]PendingMessage, 5), + } + + s.Equal(8, s.manager.GetTotalQueueDepth()) +} + +// TestIsAnySessionProcessing tests processing status detection. +func (s *ManagerSuite) TestIsAnySessionProcessing() { + // No sessions - not processing + s.False(s.manager.IsAnySessionProcessing()) + + // Session with no pending - not processing + s.manager.sessions[1] = &ActiveSession{ + SessionDBID: 1, + pendingMessages: []PendingMessage{}, + } + s.False(s.manager.IsAnySessionProcessing()) + + // Session with pending - processing + s.manager.sessions[1].pendingMessages = []PendingMessage{{Type: MessageTypeObservation}} + s.True(s.manager.IsAnySessionProcessing()) + + // Clear pending but set generator active + s.manager.sessions[1].pendingMessages = []PendingMessage{} + s.manager.sessions[1].generatorActive.Store(true) + s.True(s.manager.IsAnySessionProcessing()) +} + +// TestGetAllSessions tests retrieving all sessions. +func (s *ManagerSuite) TestGetAllSessions() { + // Empty + sessions := s.manager.GetAllSessions() + s.Empty(sessions) + + // Add sessions + session1 := &ActiveSession{SessionDBID: 1, Project: "project-a"} + session2 := &ActiveSession{SessionDBID: 2, Project: "project-b"} + s.manager.sessions[1] = session1 + s.manager.sessions[2] = session2 + + sessions = s.manager.GetAllSessions() + s.Len(sessions, 2) +} + +// TestDeleteSession tests session deletion. +func (s *ManagerSuite) TestDeleteSession() { + // Create session with context + ctx, cancel := context.WithCancel(context.Background()) + session := &ActiveSession{ + SessionDBID: 1, + Project: "test-project", + StartTime: time.Now(), + pendingMessages: []PendingMessage{}, + ctx: ctx, + cancel: cancel, + } + s.manager.sessions[1] = session + + // Track callback + var deletedID int64 + s.manager.SetOnSessionDeleted(func(id int64) { + deletedID = id + }) + + s.Equal(1, s.manager.GetActiveSessionCount()) + + // Delete + s.manager.DeleteSession(1) + + s.Equal(0, s.manager.GetActiveSessionCount()) + s.Equal(int64(1), deletedID) + + // Double delete should be safe + s.manager.DeleteSession(1) +} + +// TestDrainMessages tests message draining. +func (s *ManagerSuite) TestDrainMessages() { + // No session - nil + messages := s.manager.DrainMessages(999) + s.Nil(messages) + + // Session with messages + session := &ActiveSession{ + SessionDBID: 1, + pendingMessages: []PendingMessage{ + {Type: MessageTypeObservation}, + {Type: MessageTypeSummarize}, + }, + } + s.manager.sessions[1] = session + + messages = s.manager.DrainMessages(1) + s.Len(messages, 2) + + // Queue should be empty now + s.Empty(session.pendingMessages) + + // Drain again - empty + messages = s.manager.DrainMessages(1) + s.Empty(messages) +} + +// TestSetOnSessionCreated tests callback setting. +func (s *ManagerSuite) TestSetOnSessionCreated() { + var calledWith int64 + callback := func(id int64) { + calledWith = id + } + + s.manager.SetOnSessionCreated(callback) + s.NotNil(s.manager.onCreated) + + // Simulate callback + if s.manager.onCreated != nil { + s.manager.onCreated(42) + } + s.Equal(int64(42), calledWith) +} + +// TestSetOnSessionDeleted tests callback setting. +func (s *ManagerSuite) TestSetOnSessionDeleted() { + var calledWith int64 + callback := func(id int64) { + calledWith = id + } + + s.manager.SetOnSessionDeleted(callback) + s.NotNil(s.manager.onDeleted) + + // Simulate callback + if s.manager.onDeleted != nil { + s.manager.onDeleted(42) + } + s.Equal(int64(42), calledWith) +} + +// TestMessageTypes tests message type constants. +func TestMessageTypes(t *testing.T) { + assert.Equal(t, MessageType(0), MessageTypeObservation) + assert.Equal(t, MessageType(1), MessageTypeSummarize) +} + +// TestTimeoutConstants tests timeout constants. +func TestTimeoutConstants(t *testing.T) { + assert.Equal(t, 30*time.Minute, SessionTimeout) + assert.Equal(t, 5*time.Minute, CleanupInterval) +} + +// TestObservationData tests observation data structure. +func TestObservationData(t *testing.T) { + data := ObservationData{ + ToolName: "Read", + ToolInput: map[string]string{"path": "/test/file.go"}, + ToolResponse: "file content", + PromptNumber: 1, + CWD: "/test", + } + + assert.Equal(t, "Read", data.ToolName) + assert.Equal(t, 1, data.PromptNumber) + assert.Equal(t, "/test", data.CWD) +} + +// TestSummarizeData tests summarize data structure. +func TestSummarizeData(t *testing.T) { + data := SummarizeData{ + LastUserMessage: "What did you do?", + LastAssistantMessage: "I completed the task.", + } + + assert.Equal(t, "What did you do?", data.LastUserMessage) + assert.Equal(t, "I completed the task.", data.LastAssistantMessage) +} + +// TestPendingMessage tests pending message structure. +func TestPendingMessage(t *testing.T) { + obsData := &ObservationData{ToolName: "Read"} + msg := PendingMessage{ + Type: MessageTypeObservation, + Observation: obsData, + } + + assert.Equal(t, MessageTypeObservation, msg.Type) + assert.NotNil(t, msg.Observation) + assert.Nil(t, msg.Summarize) + + sumData := &SummarizeData{LastUserMessage: "Test"} + msg2 := PendingMessage{ + Type: MessageTypeSummarize, + Summarize: sumData, + } + + assert.Equal(t, MessageTypeSummarize, msg2.Type) + assert.Nil(t, msg2.Observation) + assert.NotNil(t, msg2.Summarize) +} + +// TestConcurrentSessionAccess tests thread-safe session operations. +func TestConcurrentSessionAccess(t *testing.T) { + manager := &Manager{ + sessions: make(map[int64]*ActiveSession), + ProcessNotify: make(chan struct{}, 1), + } + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + manager.ctx = ctx + manager.cancel = cancel + + var wg sync.WaitGroup + numGoroutines := 100 + + // Concurrent session operations + for i := 0; i < numGoroutines; i++ { + wg.Add(1) + go func(id int64) { + defer wg.Done() + + // Add session + ctx, cancel := context.WithCancel(context.Background()) + manager.mu.Lock() + manager.sessions[id] = &ActiveSession{ + SessionDBID: id, + Project: "test", + StartTime: time.Now(), + ctx: ctx, + cancel: cancel, + } + manager.mu.Unlock() + + // Read operations + _ = manager.GetActiveSessionCount() + _ = manager.GetTotalQueueDepth() + _ = manager.IsAnySessionProcessing() + _ = manager.GetAllSessions() + + // Delete session + manager.DeleteSession(id) + }(int64(i)) + } + + wg.Wait() + + // All sessions should be deleted + assert.Equal(t, 0, manager.GetActiveSessionCount()) +} + +// TestProcessNotifyChannel tests the process notification channel. +func TestProcessNotifyChannel(t *testing.T) { + manager := &Manager{ + sessions: make(map[int64]*ActiveSession), + ProcessNotify: make(chan struct{}, 1), + } + + // Non-blocking send should work + select { + case manager.ProcessNotify <- struct{}{}: + // Success + default: + t.Error("ProcessNotify channel should accept first message") + } + + // Second send should not block (channel is buffered with size 1) + select { + case manager.ProcessNotify <- struct{}{}: + // Full buffer, this is expected behavior + default: + // This is fine - channel is full + } + + // Drain the channel + select { + case <-manager.ProcessNotify: + // Drained + default: + t.Error("Should be able to receive from ProcessNotify") + } +} + +// TestActiveSessionContext tests session context handling. +func TestActiveSessionContext(t *testing.T) { + ctx, cancel := context.WithCancel(context.Background()) + + session := &ActiveSession{ + SessionDBID: 1, + ctx: ctx, + cancel: cancel, + } + + // Context should not be done + select { + case <-session.ctx.Done(): + t.Error("Context should not be done yet") + default: + // Expected + } + + // Cancel context + session.cancel() + + // Context should be done + select { + case <-session.ctx.Done(): + // Expected + default: + t.Error("Context should be done after cancel") + } +} + +// TestGeneratorActive tests the atomic generator active flag. +func TestGeneratorActive(t *testing.T) { + session := &ActiveSession{} + + // Initially false + assert.False(t, session.generatorActive.Load()) + + // Set to true + session.generatorActive.Store(true) + assert.True(t, session.generatorActive.Load()) + + // Set back to false + session.generatorActive.Store(false) + assert.False(t, session.generatorActive.Load()) +} + +// TestTokenAccumulation tests token accumulation fields. +func TestTokenAccumulation(t *testing.T) { + session := &ActiveSession{ + CumulativeInputTokens: 0, + CumulativeOutputTokens: 0, + } + + // Accumulate tokens + session.CumulativeInputTokens += 100 + session.CumulativeOutputTokens += 50 + + assert.Equal(t, int64(100), session.CumulativeInputTokens) + assert.Equal(t, int64(50), session.CumulativeOutputTokens) + + // Add more + session.CumulativeInputTokens += 200 + session.CumulativeOutputTokens += 100 + + assert.Equal(t, int64(300), session.CumulativeInputTokens) + assert.Equal(t, int64(150), session.CumulativeOutputTokens) +} + +// TestShutdownAll tests graceful shutdown of all sessions. +func (s *ManagerSuite) TestShutdownAll() { + // Create multiple sessions + for i := int64(1); i <= 3; i++ { + ctx, cancel := context.WithCancel(context.Background()) + s.manager.sessions[i] = &ActiveSession{ + SessionDBID: i, + Project: "test-project", + StartTime: time.Now(), + pendingMessages: []PendingMessage{}, + ctx: ctx, + cancel: cancel, + } + } + + s.Equal(3, s.manager.GetActiveSessionCount()) + + // Track deleted sessions + var deletedIDs []int64 + s.manager.SetOnSessionDeleted(func(id int64) { + deletedIDs = append(deletedIDs, id) + }) + + // Shutdown all + s.manager.ShutdownAll(context.Background()) + + // All sessions should be deleted + s.Equal(0, s.manager.GetActiveSessionCount()) + s.Len(deletedIDs, 3) +} + +// TestDeleteNonExistentSession tests deleting a session that doesn't exist. +func (s *ManagerSuite) TestDeleteNonExistentSession() { + // Track callback + callbackCalled := false + s.manager.SetOnSessionDeleted(func(id int64) { + callbackCalled = true + }) + + // Delete non-existent session + s.manager.DeleteSession(999) + + // Callback should not be called + s.False(callbackCalled) +} + +// TestLastPromptNumber tests prompt number tracking. +func TestLastPromptNumber(t *testing.T) { + session := &ActiveSession{ + SessionDBID: 1, + LastPromptNumber: 0, + } + + assert.Equal(t, 0, session.LastPromptNumber) + + session.LastPromptNumber = 5 + assert.Equal(t, 5, session.LastPromptNumber) + + session.LastPromptNumber++ + assert.Equal(t, 6, session.LastPromptNumber) +} + +// TestActiveSessionNotifyChannel tests session notification channel. +func TestActiveSessionNotifyChannel(t *testing.T) { + session := &ActiveSession{ + notify: make(chan struct{}, 1), + } + + // Non-blocking send + select { + case session.notify <- struct{}{}: + // Success + default: + t.Error("Should accept first notification") + } + + // Second send should not block + select { + case session.notify <- struct{}{}: + // Full buffer + default: + // Expected - buffer is full + } + + // Drain + select { + case <-session.notify: + // Drained + default: + t.Error("Should receive notification") + } +} + +// TestMessageMutex tests message mutex operations. +func TestMessageMutex(t *testing.T) { + session := &ActiveSession{ + pendingMessages: make([]PendingMessage, 0), + } + + var wg sync.WaitGroup + + // Concurrent message operations + for i := 0; i < 50; i++ { + wg.Add(1) + go func() { + defer wg.Done() + + session.messageMu.Lock() + session.pendingMessages = append(session.pendingMessages, PendingMessage{ + Type: MessageTypeObservation, + }) + session.messageMu.Unlock() + }() + } + + wg.Wait() + + assert.Len(t, session.pendingMessages, 50) +} + +// TestQueueDepthMultipleSessions tests queue depth with multiple sessions. +func (s *ManagerSuite) TestQueueDepthMultipleSessions() { + // Add sessions with varying queue depths + s.manager.sessions[1] = &ActiveSession{ + SessionDBID: 1, + pendingMessages: make([]PendingMessage, 10), + } + s.manager.sessions[2] = &ActiveSession{ + SessionDBID: 2, + pendingMessages: make([]PendingMessage, 0), + } + s.manager.sessions[3] = &ActiveSession{ + SessionDBID: 3, + pendingMessages: make([]PendingMessage, 5), + } + + s.Equal(15, s.manager.GetTotalQueueDepth()) +} + +// TestIsAnySessionProcessing_GeneratorOnly tests processing status with only generator active. +func (s *ManagerSuite) TestIsAnySessionProcessingGeneratorOnly() { + session := &ActiveSession{ + SessionDBID: 1, + pendingMessages: []PendingMessage{}, + } + s.manager.sessions[1] = session + + // No processing initially + s.False(s.manager.IsAnySessionProcessing()) + + // Set generator active + session.generatorActive.Store(true) + s.True(s.manager.IsAnySessionProcessing()) + + // Clear generator + session.generatorActive.Store(false) + s.False(s.manager.IsAnySessionProcessing()) +} + +// TestPendingMessageWithBothTypes tests pending messages with both types. +func TestPendingMessageWithBothTypes(t *testing.T) { + messages := []PendingMessage{ + { + Type: MessageTypeObservation, + Observation: &ObservationData{ToolName: "Read"}, + }, + { + Type: MessageTypeSummarize, + Summarize: &SummarizeData{LastUserMessage: "Test"}, + }, + { + Type: MessageTypeObservation, + Observation: &ObservationData{ToolName: "Write"}, + }, + } + + assert.Len(t, messages, 3) + + // Verify types + assert.Equal(t, MessageTypeObservation, messages[0].Type) + assert.Equal(t, MessageTypeSummarize, messages[1].Type) + assert.Equal(t, MessageTypeObservation, messages[2].Type) + + // Verify data + assert.Equal(t, "Read", messages[0].Observation.ToolName) + assert.Nil(t, messages[0].Summarize) + + assert.Equal(t, "Test", messages[1].Summarize.LastUserMessage) + assert.Nil(t, messages[1].Observation) + + assert.Equal(t, "Write", messages[2].Observation.ToolName) +} + +// TestDrainMessagesPreservesOrder tests that draining preserves message order. +func (s *ManagerSuite) TestDrainMessagesPreservesOrder() { + session := &ActiveSession{ + SessionDBID: 1, + pendingMessages: []PendingMessage{ + {Type: MessageTypeObservation, Observation: &ObservationData{ToolName: "Tool1"}}, + {Type: MessageTypeSummarize, Summarize: &SummarizeData{LastUserMessage: "Msg1"}}, + {Type: MessageTypeObservation, Observation: &ObservationData{ToolName: "Tool2"}}, + }, + } + s.manager.sessions[1] = session + + messages := s.manager.DrainMessages(1) + + s.Len(messages, 3) + s.Equal("Tool1", messages[0].Observation.ToolName) + s.Equal("Msg1", messages[1].Summarize.LastUserMessage) + s.Equal("Tool2", messages[2].Observation.ToolName) +} + +// TestActiveSessionCWD tests CWD field in ObservationData. +func TestActiveSessionCWD(t *testing.T) { + tests := []struct { + name string + cwd string + }{ + {"empty_cwd", ""}, + {"absolute_path", "/home/user/project"}, + {"windows_path", "C:\\Users\\test\\project"}, + {"path_with_spaces", "/home/user/my project"}, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + data := ObservationData{ + ToolName: "Test", + CWD: tt.cwd, + } + assert.Equal(t, tt.cwd, data.CWD) + }) + } +} + +// TestToolInputResponse tests various tool input/response types. +func TestToolInputResponse(t *testing.T) { + tests := []struct { + name string + input interface{} + response interface{} + }{ + {"nil_values", nil, nil}, + {"string_values", "input string", "response string"}, + {"map_values", map[string]string{"key": "value"}, map[string]interface{}{"result": true}}, + {"slice_values", []string{"a", "b"}, []int{1, 2, 3}}, + {"int_values", 42, 100}, + {"bool_values", true, false}, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + data := ObservationData{ + ToolName: "TestTool", + ToolInput: tt.input, + ToolResponse: tt.response, + } + assert.Equal(t, tt.input, data.ToolInput) + assert.Equal(t, tt.response, data.ToolResponse) + }) + } +} diff --git a/internal/worker/sse/broadcaster_test.go b/internal/worker/sse/broadcaster_test.go new file mode 100644 index 0000000..9b69459 --- /dev/null +++ b/internal/worker/sse/broadcaster_test.go @@ -0,0 +1,383 @@ +// Package sse provides Server-Sent Events broadcasting for claude-mnemonic. +package sse + +import ( + "net/http" + "net/http/httptest" + "sync" + "testing" + "time" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + "github.com/stretchr/testify/suite" +) + +// BroadcasterSuite is a test suite for Broadcaster operations. +type BroadcasterSuite struct { + suite.Suite + broadcaster *Broadcaster +} + +func (s *BroadcasterSuite) SetupTest() { + s.broadcaster = NewBroadcaster() +} + +func TestBroadcasterSuite(t *testing.T) { + suite.Run(t, new(BroadcasterSuite)) +} + +// TestNewBroadcaster tests broadcaster creation. +func (s *BroadcasterSuite) TestNewBroadcaster() { + b := NewBroadcaster() + s.NotNil(b) + s.NotNil(b.clients) + s.Equal(0, b.ClientCount()) +} + +// TestClientCount tests client counting. +func (s *BroadcasterSuite) TestClientCount() { + s.Equal(0, s.broadcaster.ClientCount()) +} + +// mockResponseWriter implements http.ResponseWriter and http.Flusher for testing. +type mockResponseWriter struct { + header http.Header + body []byte + statusCode int + mu sync.Mutex +} + +func newMockResponseWriter() *mockResponseWriter { + return &mockResponseWriter{ + header: make(http.Header), + statusCode: http.StatusOK, + } +} + +func (m *mockResponseWriter) Header() http.Header { + return m.header +} + +func (m *mockResponseWriter) Write(data []byte) (int, error) { + m.mu.Lock() + defer m.mu.Unlock() + m.body = append(m.body, data...) + return len(data), nil +} + +func (m *mockResponseWriter) WriteHeader(statusCode int) { + m.statusCode = statusCode +} + +func (m *mockResponseWriter) Flush() { + // No-op for testing +} + +func (m *mockResponseWriter) GetBody() []byte { + m.mu.Lock() + defer m.mu.Unlock() + return m.body +} + +// TestAddClient tests adding clients. +func (s *BroadcasterSuite) TestAddClient() { + w := newMockResponseWriter() + + client, err := s.broadcaster.AddClient(w) + s.NoError(err) + s.NotNil(client) + s.NotEmpty(client.ID) + s.NotNil(client.Done) + s.Equal(1, s.broadcaster.ClientCount()) +} + +// TestAddMultipleClients tests adding multiple clients. +func (s *BroadcasterSuite) TestAddMultipleClients() { + for i := 0; i < 5; i++ { + w := newMockResponseWriter() + _, err := s.broadcaster.AddClient(w) + s.NoError(err) + } + + s.Equal(5, s.broadcaster.ClientCount()) +} + +// TestRemoveClient tests removing clients. +func (s *BroadcasterSuite) TestRemoveClient() { + w := newMockResponseWriter() + client, err := s.broadcaster.AddClient(w) + s.NoError(err) + + s.Equal(1, s.broadcaster.ClientCount()) + + s.broadcaster.RemoveClient(client) + + s.Equal(0, s.broadcaster.ClientCount()) + + // Check that Done channel is closed + select { + case <-client.Done: + // Expected - channel is closed + default: + s.Fail("Done channel should be closed") + } +} + +// TestBroadcast tests broadcasting messages. +func (s *BroadcasterSuite) TestBroadcast() { + w := newMockResponseWriter() + _, err := s.broadcaster.AddClient(w) + s.NoError(err) + + // Broadcast a message + s.broadcaster.Broadcast(map[string]string{"type": "test", "message": "hello"}) + + // Give time for async write + time.Sleep(50 * time.Millisecond) + + body := string(w.GetBody()) + s.Contains(body, "data:") + s.Contains(body, "test") + s.Contains(body, "hello") +} + +// TestBroadcastNoClients tests broadcasting with no clients. +func (s *BroadcasterSuite) TestBroadcastNoClients() { + // Should not panic + s.broadcaster.Broadcast(map[string]string{"type": "test"}) +} + +// TestBroadcastMultipleClients tests broadcasting to multiple clients. +func (s *BroadcasterSuite) TestBroadcastMultipleClients() { + writers := make([]*mockResponseWriter, 3) + for i := 0; i < 3; i++ { + writers[i] = newMockResponseWriter() + _, err := s.broadcaster.AddClient(writers[i]) + s.NoError(err) + } + + // Broadcast + s.broadcaster.Broadcast(map[string]string{"type": "test"}) + + // Give time for async writes + time.Sleep(100 * time.Millisecond) + + // All clients should receive the message + for i, w := range writers { + body := string(w.GetBody()) + s.Contains(body, "data:", "Client %d should receive data", i) + } +} + +// TestClient tests Client structure. +func TestClient(t *testing.T) { + w := newMockResponseWriter() + client := &Client{ + ID: "test-client", + Writer: w, + Flusher: w, + Done: make(chan struct{}), + } + + assert.Equal(t, "test-client", client.ID) + assert.NotNil(t, client.Writer) + assert.NotNil(t, client.Flusher) + assert.NotNil(t, client.Done) + + // Close done channel + close(client.Done) + + select { + case <-client.Done: + // Expected + default: + t.Error("Done channel should be closed") + } +} + +// TestClientUniqueIDs tests that clients get unique IDs. +func TestClientUniqueIDs(t *testing.T) { + b := NewBroadcaster() + ids := make(map[string]bool) + + for i := 0; i < 100; i++ { + w := newMockResponseWriter() + client, err := b.AddClient(w) + require.NoError(t, err) + + // ID should be unique + assert.False(t, ids[client.ID], "ID %s should be unique", client.ID) + ids[client.ID] = true + } +} + +// TestWriteTimeout tests the write timeout constant. +func TestWriteTimeout(t *testing.T) { + assert.Equal(t, 2*time.Second, WriteTimeout) +} + +// TestHandleSSE tests the HandleSSE HTTP handler. +func TestHandleSSE(t *testing.T) { + b := NewBroadcaster() + + // Create a test server + handler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + // Set up context that will be cancelled + ctx := r.Context() + + // Start goroutine to cancel context after short delay + go func() { + time.Sleep(50 * time.Millisecond) + // Request will be cancelled by the test client + }() + + // This will block until context is cancelled + select { + case <-ctx.Done(): + return + case <-time.After(100 * time.Millisecond): + return + } + }) + + _ = handler + _ = b + + // Just verify the handler exists and broadcaster can handle SSE + req := httptest.NewRequest(http.MethodGet, "/events", nil) + rec := httptest.NewRecorder() + + // Can't easily test HandleSSE since it blocks, but we can verify setup + assert.NotNil(t, req) + assert.NotNil(t, rec) +} + +// TestBroadcastJSON tests broadcasting various JSON types. +func TestBroadcastJSON(t *testing.T) { + tests := []struct { + name string + data interface{} + wantErr bool + }{ + { + name: "string map", + data: map[string]string{"key": "value"}, + wantErr: false, + }, + { + name: "int map", + data: map[string]int{"count": 42}, + wantErr: false, + }, + { + name: "nested struct", + data: struct{ Name string }{Name: "test"}, + wantErr: false, + }, + { + name: "array", + data: []string{"a", "b", "c"}, + wantErr: false, + }, + { + name: "interface map", + data: map[string]interface{}{"type": "test", "count": 1, "active": true}, + wantErr: false, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + b := NewBroadcaster() + w := newMockResponseWriter() + _, err := b.AddClient(w) + require.NoError(t, err) + + // Should not panic + b.Broadcast(tt.data) + + time.Sleep(50 * time.Millisecond) + + body := string(w.GetBody()) + assert.Contains(t, body, "data:") + }) + } +} + +// TestConcurrentBroadcast tests concurrent broadcasting. +func TestConcurrentBroadcast(t *testing.T) { + b := NewBroadcaster() + + // Add clients + for i := 0; i < 10; i++ { + w := newMockResponseWriter() + _, err := b.AddClient(w) + require.NoError(t, err) + } + + // Broadcast concurrently + var wg sync.WaitGroup + for i := 0; i < 100; i++ { + wg.Add(1) + go func(i int) { + defer wg.Done() + b.Broadcast(map[string]int{"index": i}) + }(i) + } + + wg.Wait() + + // Should complete without panics + assert.Equal(t, 10, b.ClientCount()) +} + +// TestRemoveNonExistentClient tests removing a non-existent client. +func TestRemoveNonExistentClient(t *testing.T) { + b := NewBroadcaster() + + // Create a client but don't add it + client := &Client{ + ID: "fake-client", + Done: make(chan struct{}), + } + + // Should not panic + b.RemoveClient(client) + + // Done channel should be closed + select { + case <-client.Done: + // Expected + default: + t.Error("Done channel should be closed") + } +} + +// TestBroadcasterConcurrentAddRemove tests concurrent add/remove operations. +func TestBroadcasterConcurrentAddRemove(t *testing.T) { + b := NewBroadcaster() + var wg sync.WaitGroup + + // Concurrent adds + for i := 0; i < 50; i++ { + wg.Add(1) + go func() { + defer wg.Done() + w := newMockResponseWriter() + client, err := b.AddClient(w) + if err == nil { + // Random chance to remove + if time.Now().UnixNano()%2 == 0 { + b.RemoveClient(client) + } + } + }() + } + + wg.Wait() + + // Should not panic and have some clients + count := b.ClientCount() + assert.GreaterOrEqual(t, count, 0) +} diff --git a/pkg/hooks/worker_test.go b/pkg/hooks/worker_test.go index c3ab458..416409e 100644 --- a/pkg/hooks/worker_test.go +++ b/pkg/hooks/worker_test.go @@ -3,9 +3,11 @@ package hooks import ( "encoding/json" + "fmt" "net/http" "net/http/httptest" "testing" + "time" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" @@ -190,3 +192,521 @@ func TestFindWorkerBinary(t *testing.T) { // Result depends on whether worker is installed, so we just check it doesn't panic t.Logf("findWorkerBinary returned: %s", result) } + +// TestVersionsCompatible tests the versionsCompatible function. +func TestVersionsCompatible(t *testing.T) { + tests := []struct { + name string + v1 string + v2 string + expected bool + }{ + { + name: "identical versions", + v1: "v1.0.0", + v2: "v1.0.0", + expected: true, + }, + { + name: "same base different suffix", + v1: "v1.0.0", + v2: "v1.0.0-dirty", + expected: true, + }, + { + name: "same base with commit hash", + v1: "v1.0.0-2-gca711a8", + v2: "v1.0.0-5-gabcdef1-dirty", + expected: true, + }, + { + name: "different base versions", + v1: "v1.0.0", + v2: "v2.0.0", + expected: false, + }, + { + name: "dev version compatible with anything", + v1: "dev", + v2: "v1.0.0", + expected: true, + }, + { + name: "anything compatible with dev", + v1: "v2.0.0-dirty", + v2: "dev", + expected: true, + }, + { + name: "both dev versions", + v1: "dev", + v2: "dev", + expected: true, + }, + { + name: "minor version difference", + v1: "v1.2.0", + v2: "v1.3.0", + expected: false, + }, + { + name: "patch version difference", + v1: "v1.0.1", + v2: "v1.0.2", + expected: false, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + result := versionsCompatible(tt.v1, tt.v2) + assert.Equal(t, tt.expected, result) + }) + } +} + +// TestExtractBaseVersion tests the extractBaseVersion function. +func TestExtractBaseVersion(t *testing.T) { + tests := []struct { + name string + version string + expected string + }{ + { + name: "simple version with v prefix", + version: "v1.0.0", + expected: "1.0.0", + }, + { + name: "version without v prefix", + version: "1.0.0", + expected: "1.0.0", + }, + { + name: "version with commit suffix", + version: "v0.3.5-2-gca711a8", + expected: "0.3.5", + }, + { + name: "version with dirty suffix", + version: "v0.3.5-dirty", + expected: "0.3.5", + }, + { + name: "version with full suffix", + version: "v0.3.5-2-gca711a8-dirty", + expected: "0.3.5", + }, + { + name: "dev version", + version: "dev", + expected: "dev", + }, + { + name: "empty version", + version: "", + expected: "", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + result := extractBaseVersion(tt.version) + assert.Equal(t, tt.expected, result) + }) + } +} + +// TestPOST tests the POST function with a mock server. +func TestPOST(t *testing.T) { + tests := []struct { + name string + serverHandler func(w http.ResponseWriter, r *http.Request) + body interface{} + expectError bool + expectedResult map[string]interface{} + }{ + { + name: "successful POST with JSON response", + serverHandler: func(w http.ResponseWriter, r *http.Request) { + assert.Equal(t, http.MethodPost, r.Method) + assert.Equal(t, "application/json", r.Header.Get("Content-Type")) + w.WriteHeader(http.StatusOK) + json.NewEncoder(w).Encode(map[string]interface{}{"status": "ok"}) + }, + body: map[string]string{"key": "value"}, + expectError: false, + expectedResult: map[string]interface{}{"status": "ok"}, + }, + { + name: "POST with 400 error", + serverHandler: func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(http.StatusBadRequest) + }, + body: map[string]string{"key": "value"}, + expectError: true, + }, + { + name: "POST with 500 error", + serverHandler: func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(http.StatusInternalServerError) + }, + body: map[string]string{"key": "value"}, + expectError: true, + }, + { + name: "POST with non-JSON response", + serverHandler: func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(http.StatusOK) + w.Write([]byte("not json")) + }, + body: map[string]string{"key": "value"}, + expectError: false, + expectedResult: nil, // Non-JSON returns nil + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + server := httptest.NewServer(http.HandlerFunc(tt.serverHandler)) + defer server.Close() + + // Extract port from test server + var port int + _, err := fmt.Sscanf(server.URL, "http://127.0.0.1:%d", &port) + require.NoError(t, err) + + result, err := POST(port, "/test", tt.body) + + if tt.expectError { + assert.Error(t, err) + } else { + assert.NoError(t, err) + if tt.expectedResult != nil { + assert.Equal(t, tt.expectedResult["status"], result["status"]) + } + } + }) + } +} + +// TestGET tests the GET function with a mock server. +func TestGET(t *testing.T) { + tests := []struct { + name string + serverHandler func(w http.ResponseWriter, r *http.Request) + expectError bool + expectedResult map[string]interface{} + }{ + { + name: "successful GET with JSON response", + serverHandler: func(w http.ResponseWriter, r *http.Request) { + assert.Equal(t, http.MethodGet, r.Method) + w.WriteHeader(http.StatusOK) + json.NewEncoder(w).Encode(map[string]interface{}{"data": "test"}) + }, + expectError: false, + expectedResult: map[string]interface{}{"data": "test"}, + }, + { + name: "GET with 404 error", + serverHandler: func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(http.StatusNotFound) + }, + expectError: true, + }, + { + name: "GET with invalid JSON", + serverHandler: func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(http.StatusOK) + w.Write([]byte("not valid json")) + }, + expectError: true, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + server := httptest.NewServer(http.HandlerFunc(tt.serverHandler)) + defer server.Close() + + // Extract port from test server + var port int + _, err := fmt.Sscanf(server.URL, "http://127.0.0.1:%d", &port) + require.NoError(t, err) + + result, err := GET(port, "/test") + + if tt.expectError { + assert.Error(t, err) + } else { + assert.NoError(t, err) + if tt.expectedResult != nil { + assert.Equal(t, tt.expectedResult["data"], result["data"]) + } + } + }) + } +} + +// TestProjectIDWithName_Comprehensive tests ProjectIDWithName more thoroughly. +func TestProjectIDWithName_Comprehensive(t *testing.T) { + tests := []struct { + name string + cwd string + expectedPrefix string + expectedLen int // Expected minimum length (prefix + _ + 6 char hash) + }{ + { + name: "standard project path", + cwd: "/Users/test/projects/my-project", + expectedPrefix: "my-project_", + expectedLen: 17, // "my-project_" + 6 char hash + }, + { + name: "short directory name", + cwd: "/tmp", + expectedPrefix: "tmp_", + expectedLen: 10, // "tmp_" + 6 char hash + }, + { + name: "nested path", + cwd: "/home/user/code/org/repo", + expectedPrefix: "repo_", + expectedLen: 11, // "repo_" + 6 char hash + }, + { + name: "path with special characters", + cwd: "/Users/test/my-special.project", + expectedPrefix: "my-special.project_", + expectedLen: 25, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + result := ProjectIDWithName(tt.cwd) + assert.True(t, len(result) >= tt.expectedLen, "result %s should be at least %d chars", result, tt.expectedLen) + assert.Contains(t, result, tt.expectedPrefix[:len(tt.expectedPrefix)-1]) // Check without trailing underscore + assert.Contains(t, result, "_") + + // Verify hash uniqueness - same path should give same result + result2 := ProjectIDWithName(tt.cwd) + assert.Equal(t, result, result2) + }) + } +} + +// TestProjectIDWithName_Uniqueness tests that different paths produce different IDs. +func TestProjectIDWithName_Uniqueness(t *testing.T) { + paths := []string{ + "/Users/test/project-a", + "/Users/test/project-b", + "/Users/other/project-a", + "/tmp/project-a", + } + + ids := make(map[string]bool) + for _, path := range paths { + id := ProjectIDWithName(path) + assert.False(t, ids[id], "duplicate ID generated for path %s", path) + ids[id] = true + } +} + +// TestHookConstants tests hook-related constants. +func TestHookConstants(t *testing.T) { + assert.Equal(t, 37777, DefaultWorkerPort) + assert.Equal(t, 1*time.Second, HealthCheckTimeout) + assert.Equal(t, 30*time.Second, StartupTimeout) +} + +// TestExitCodes tests exit code constants. +func TestExitCodes(t *testing.T) { + assert.Equal(t, 0, ExitSuccess) + assert.Equal(t, 1, ExitFailure) + assert.Equal(t, 3, ExitUserMessageOnly) +} + +// TestHookResponse tests HookResponse struct. +func TestHookResponse(t *testing.T) { + tests := []struct { + name string + response HookResponse + expected string + }{ + { + name: "continue true", + response: HookResponse{Continue: true}, + expected: `{"continue":true}`, + }, + { + name: "continue false", + response: HookResponse{Continue: false}, + expected: `{"continue":false}`, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + data, err := json.Marshal(tt.response) + require.NoError(t, err) + assert.JSONEq(t, tt.expected, string(data)) + }) + } +} + +// TestBaseInput tests BaseInput struct parsing. +func TestBaseInput(t *testing.T) { + input := `{ + "session_id": "test-session-123", + "cwd": "/Users/test/project", + "permission_mode": "standard", + "hook_event_name": "session-start" + }` + + var base BaseInput + err := json.Unmarshal([]byte(input), &base) + require.NoError(t, err) + + assert.Equal(t, "test-session-123", base.SessionID) + assert.Equal(t, "/Users/test/project", base.CWD) + assert.Equal(t, "standard", base.PermissionMode) + assert.Equal(t, "session-start", base.HookEventName) +} + +// TestHookContext tests HookContext struct. +func TestHookContext(t *testing.T) { + ctx := &HookContext{ + HookName: "session-start", + Port: 37777, + Project: "my-project_abc123", + SessionID: "test-session", + CWD: "/Users/test/project", + RawInput: []byte(`{"key":"value"}`), + } + + assert.Equal(t, "session-start", ctx.HookName) + assert.Equal(t, 37777, ctx.Port) + assert.Equal(t, "my-project_abc123", ctx.Project) + assert.Equal(t, "test-session", ctx.SessionID) + assert.Equal(t, "/Users/test/project", ctx.CWD) + assert.Equal(t, []byte(`{"key":"value"}`), ctx.RawInput) +} + +// TestIsWorkerRunning_WithServer tests IsWorkerRunning with actual server. +func TestIsWorkerRunning_WithServer(t *testing.T) { + tests := []struct { + name string + serverHandler func(w http.ResponseWriter, r *http.Request) + expectedResult bool + }{ + { + name: "healthy worker returns true", + serverHandler: func(w http.ResponseWriter, r *http.Request) { + if r.URL.Path == "/api/health" { + w.WriteHeader(http.StatusOK) + } + }, + expectedResult: true, + }, + { + name: "unhealthy worker returns false", + serverHandler: func(w http.ResponseWriter, r *http.Request) { + if r.URL.Path == "/api/health" { + w.WriteHeader(http.StatusServiceUnavailable) + } + }, + expectedResult: false, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + server := httptest.NewServer(http.HandlerFunc(tt.serverHandler)) + defer server.Close() + + // Extract port - note: test server binds to 127.0.0.1 + var port int + _, err := fmt.Sscanf(server.URL, "http://127.0.0.1:%d", &port) + require.NoError(t, err) + + // The function uses hardcoded 127.0.0.1, which matches httptest + result := IsWorkerRunning(port) + assert.Equal(t, tt.expectedResult, result) + }) + } +} + +// TestIsPortInUse_WithServer tests IsPortInUse with actual server. +func TestIsPortInUse_WithServer(t *testing.T) { + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(http.StatusOK) + })) + defer server.Close() + + // Extract port + var port int + _, err := fmt.Sscanf(server.URL, "http://127.0.0.1:%d", &port) + require.NoError(t, err) + + // Port should be in use + assert.True(t, IsPortInUse(port)) +} + +// TestGetWorkerVersion_WithServer tests GetWorkerVersion with actual server. +func TestGetWorkerVersion_WithServer(t *testing.T) { + tests := []struct { + name string + serverHandler func(w http.ResponseWriter, r *http.Request) + expectedResult string + }{ + { + name: "returns version from server", + serverHandler: func(w http.ResponseWriter, r *http.Request) { + if r.URL.Path == "/api/version" { + w.WriteHeader(http.StatusOK) + json.NewEncoder(w).Encode(map[string]string{"version": "v1.2.3"}) + } + }, + expectedResult: "v1.2.3", + }, + { + name: "returns empty on non-200", + serverHandler: func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(http.StatusNotFound) + }, + expectedResult: "", + }, + { + name: "returns empty on invalid JSON", + serverHandler: func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(http.StatusOK) + w.Write([]byte("not json")) + }, + expectedResult: "", + }, + { + name: "returns empty on missing version field", + serverHandler: func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(http.StatusOK) + json.NewEncoder(w).Encode(map[string]string{"other": "field"}) + }, + expectedResult: "", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + server := httptest.NewServer(http.HandlerFunc(tt.serverHandler)) + defer server.Close() + + var port int + _, err := fmt.Sscanf(server.URL, "http://127.0.0.1:%d", &port) + require.NoError(t, err) + + result := GetWorkerVersion(port) + assert.Equal(t, tt.expectedResult, result) + }) + } +} diff --git a/pkg/models/observation_test.go b/pkg/models/observation_test.go new file mode 100644 index 0000000..50ce357 --- /dev/null +++ b/pkg/models/observation_test.go @@ -0,0 +1,424 @@ +// Package models contains domain models for claude-mnemonic. +package models + +import ( + "database/sql" + "encoding/json" + "testing" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + "github.com/stretchr/testify/suite" +) + +// ObservationSuite is a test suite for Observation operations. +type ObservationSuite struct { + suite.Suite +} + +func TestObservationSuite(t *testing.T) { + suite.Run(t, new(ObservationSuite)) +} + +// TestObservationTypeConstants tests observation type constants. +func (s *ObservationSuite) TestObservationTypeConstants() { + s.Equal(ObservationType("discovery"), ObsTypeDiscovery) + s.Equal(ObservationType("decision"), ObsTypeDecision) + s.Equal(ObservationType("bugfix"), ObsTypeBugfix) + s.Equal(ObservationType("feature"), ObsTypeFeature) + s.Equal(ObservationType("refactor"), ObsTypeRefactor) + s.Equal(ObservationType("change"), ObsTypeChange) +} + +// TestScopeConstants tests scope constants. +func (s *ObservationSuite) TestScopeConstants() { + s.Equal(ObservationScope("project"), ScopeProject) + s.Equal(ObservationScope("global"), ScopeGlobal) +} + +// TestGlobalizableConcepts tests that globalizable concepts are defined. +func (s *ObservationSuite) TestGlobalizableConcepts() { + expected := []string{ + "best-practice", "pattern", "anti-pattern", "architecture", + "security", "performance", "testing", + "debugging", "workflow", "tooling", + } + s.Equal(expected, GlobalizableConcepts) +} + +// TestDetermineScope_TableDriven tests scope determination with various concepts. +func (s *ObservationSuite) TestDetermineScope_TableDriven() { + tests := []struct { + name string + concepts []string + expected ObservationScope + }{ + { + name: "empty concepts - project scope", + concepts: []string{}, + expected: ScopeProject, + }, + { + name: "no globalizable concepts - project scope", + concepts: []string{"how-it-works", "custom-tag"}, + expected: ScopeProject, + }, + { + name: "security concept - global scope", + concepts: []string{"security"}, + expected: ScopeGlobal, + }, + { + name: "best-practice concept - global scope", + concepts: []string{"best-practice"}, + expected: ScopeGlobal, + }, + { + name: "mixed concepts with globalizable - global scope", + concepts: []string{"how-it-works", "security"}, + expected: ScopeGlobal, + }, + { + name: "performance concept - global scope", + concepts: []string{"performance"}, + expected: ScopeGlobal, + }, + { + name: "testing concept - global scope", + concepts: []string{"testing"}, + expected: ScopeGlobal, + }, + { + name: "pattern concept - global scope", + concepts: []string{"pattern"}, + expected: ScopeGlobal, + }, + } + + for _, tt := range tests { + s.Run(tt.name, func() { + result := DetermineScope(tt.concepts) + s.Equal(tt.expected, result) + }) + } +} + +// TestParsedObservation_FileMtimesJSON tests FileMtimes JSON serialization. +func (s *ObservationSuite) TestParsedObservation_FileMtimesJSON() { + obs := &ParsedObservation{ + Type: ObsTypeDiscovery, + Title: "Test", + FileMtimes: map[string]int64{"file1.go": 1234567890, "file2.go": 1234567891}, + } + + // Verify mtimes can be marshaled + data, err := json.Marshal(obs.FileMtimes) + s.NoError(err) + s.Contains(string(data), "file1.go") + s.Contains(string(data), "1234567890") +} + +// TestObservation_CheckStaleness_TableDriven tests staleness checking. +func (s *ObservationSuite) TestObservation_CheckStaleness_TableDriven() { + tests := []struct { + name string + storedMtimes map[string]int64 + currentMtimes map[string]int64 + expectedStale bool + }{ + { + name: "empty stored mtimes - not stale", + storedMtimes: map[string]int64{}, + currentMtimes: map[string]int64{"file.go": 1000}, + expectedStale: false, + }, + { + name: "matching mtimes - not stale", + storedMtimes: map[string]int64{"file.go": 1000}, + currentMtimes: map[string]int64{"file.go": 1000}, + expectedStale: false, + }, + { + name: "file modified - stale", + storedMtimes: map[string]int64{"file.go": 1000}, + currentMtimes: map[string]int64{"file.go": 2000}, + expectedStale: true, + }, + { + name: "file missing from current - not stale (files might not be checked)", + storedMtimes: map[string]int64{"file.go": 1000}, + currentMtimes: map[string]int64{}, + expectedStale: false, // Missing files don't mark as stale per the implementation + }, + { + name: "multiple files, one modified - stale", + storedMtimes: map[string]int64{"file1.go": 1000, "file2.go": 2000}, + currentMtimes: map[string]int64{"file1.go": 1000, "file2.go": 3000}, + expectedStale: true, + }, + { + name: "nil current mtimes - not stale", + storedMtimes: map[string]int64{"file.go": 1000}, + currentMtimes: nil, + expectedStale: false, + }, + } + + for _, tt := range tests { + s.Run(tt.name, func() { + obs := &Observation{ + FileMtimes: tt.storedMtimes, + } + result := obs.CheckStaleness(tt.currentMtimes) + s.Equal(tt.expectedStale, result) + }) + } +} + +// TestObservation_MarshalJSON tests JSON marshaling of Observation. +func (s *ObservationSuite) TestObservation_MarshalJSON() { + obs := &Observation{ + ID: 1, + Project: "test-project", + Type: ObsTypeDiscovery, + Title: sql.NullString{String: "Test Title", Valid: true}, + Scope: ScopeProject, + } + + data, err := json.Marshal(obs) + s.NoError(err) + s.Contains(string(data), `"id":1`) + s.Contains(string(data), `"project":"test-project"`) + s.Contains(string(data), `"type":"discovery"`) +} + +// TestParsedObservation_Fields tests ParsedObservation field access. +func (s *ObservationSuite) TestParsedObservation_Fields() { + obs := &ParsedObservation{ + Type: ObsTypeFeature, + Title: "Add authentication", + Subtitle: "JWT-based auth", + Narrative: "Implemented JWT authentication for API endpoints", + Facts: []string{"Uses RS256 algorithm", "Tokens expire in 24h"}, + Concepts: []string{"security", "auth"}, + FilesRead: []string{"config.go"}, + FilesModified: []string{"handler.go", "middleware.go"}, + FileMtimes: map[string]int64{"handler.go": 1234567890}, + } + + s.Equal(ObsTypeFeature, obs.Type) + s.Equal("Add authentication", obs.Title) + s.Equal("JWT-based auth", obs.Subtitle) + s.Contains(obs.Narrative, "JWT") + s.Len(obs.Facts, 2) + s.Len(obs.Concepts, 2) + s.Len(obs.FilesRead, 1) + s.Len(obs.FilesModified, 2) + s.Len(obs.FileMtimes, 1) +} + +// TestObservation_NullFields tests handling of nullable fields. +func (s *ObservationSuite) TestObservation_NullFields() { + // Test with null fields + obs := &Observation{ + ID: 1, + Project: "test", + Type: ObsTypeDiscovery, + Title: sql.NullString{Valid: false}, + Subtitle: sql.NullString{Valid: false}, + Narrative: sql.NullString{Valid: false}, + } + + s.False(obs.Title.Valid) + s.False(obs.Subtitle.Valid) + s.False(obs.Narrative.Valid) + + // Test with valid fields + obs2 := &Observation{ + ID: 2, + Project: "test", + Type: ObsTypeBugfix, + Title: sql.NullString{String: "Fix bug", Valid: true}, + Subtitle: sql.NullString{String: "Memory leak", Valid: true}, + Narrative: sql.NullString{String: "Fixed memory leak in handler", Valid: true}, + } + + s.True(obs2.Title.Valid) + s.Equal("Fix bug", obs2.Title.String) + s.True(obs2.Subtitle.Valid) + s.Equal("Memory leak", obs2.Subtitle.String) +} + +// TestNewObservation tests observation creation from parsed data. +func TestNewObservation(t *testing.T) { + parsed := &ParsedObservation{ + Type: ObsTypeFeature, + Title: "Add authentication", + Subtitle: "JWT-based", + Narrative: "Implemented JWT auth", + Facts: []string{"Uses RS256"}, + Concepts: []string{"security"}, + FilesRead: []string{"config.go"}, + FilesModified: []string{"handler.go"}, + FileMtimes: map[string]int64{"handler.go": 1234567890}, + } + + obs := NewObservation("sdk-123", "test-project", parsed, 5, 1000) + + assert.Equal(t, "sdk-123", obs.SDKSessionID) + assert.Equal(t, "test-project", obs.Project) + assert.Equal(t, ScopeGlobal, obs.Scope) // security triggers global + assert.Equal(t, ObsTypeFeature, obs.Type) + assert.Equal(t, "Add authentication", obs.Title.String) + assert.True(t, obs.Title.Valid) + assert.Equal(t, int64(5), obs.PromptNumber.Int64) + assert.Equal(t, int64(1000), obs.DiscoveryTokens) + assert.NotEmpty(t, obs.CreatedAt) + assert.Greater(t, obs.CreatedAtEpoch, int64(0)) +} + +// TestParsedObservation_ToStoredObservation tests conversion. +func TestParsedObservation_ToStoredObservation(t *testing.T) { + parsed := &ParsedObservation{ + Type: ObsTypeDiscovery, + Title: "Test Title", + Subtitle: "Test Subtitle", + Narrative: "Test narrative", + Facts: []string{"Fact 1"}, + Concepts: []string{"testing"}, + } + + obs := parsed.ToStoredObservation() + + assert.Equal(t, ObsTypeDiscovery, obs.Type) + assert.Equal(t, "Test Title", obs.Title.String) + assert.True(t, obs.Title.Valid) + assert.Equal(t, "Test Subtitle", obs.Subtitle.String) + assert.True(t, obs.Subtitle.Valid) +} + +// TestJSONStringArray tests JSONStringArray scanning. +func TestJSONStringArray(t *testing.T) { + tests := []struct { + name string + input interface{} + wantErr bool + expected JSONStringArray + }{ + { + name: "nil input", + input: nil, + wantErr: false, + expected: nil, + }, + { + name: "empty string", + input: "", + wantErr: false, + expected: nil, + }, + { + name: "json array string", + input: `["item1", "item2"]`, + wantErr: false, + expected: JSONStringArray{"item1", "item2"}, + }, + { + name: "json array bytes", + input: []byte(`["a", "b", "c"]`), + wantErr: false, + expected: JSONStringArray{"a", "b", "c"}, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + var arr JSONStringArray + err := arr.Scan(tt.input) + if tt.wantErr { + assert.Error(t, err) + } else { + assert.NoError(t, err) + assert.Equal(t, tt.expected, arr) + } + }) + } +} + +// TestJSONInt64Map tests JSONInt64Map scanning. +func TestJSONInt64Map(t *testing.T) { + tests := []struct { + name string + input interface{} + wantErr bool + expected JSONInt64Map + }{ + { + name: "nil input", + input: nil, + wantErr: false, + expected: nil, + }, + { + name: "empty string", + input: "", + wantErr: false, + expected: nil, + }, + { + name: "json map string", + input: `{"file.go": 1234567890}`, + wantErr: false, + expected: JSONInt64Map{"file.go": 1234567890}, + }, + { + name: "json map bytes", + input: []byte(`{"a.go": 100, "b.go": 200}`), + wantErr: false, + expected: JSONInt64Map{"a.go": 100, "b.go": 200}, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + var m JSONInt64Map + err := m.Scan(tt.input) + if tt.wantErr { + assert.Error(t, err) + } else { + assert.NoError(t, err) + assert.Equal(t, tt.expected, m) + } + }) + } +} + +// TestObservation_JSONRoundTrip tests that observations can be marshaled and unmarshaled. +func TestObservation_JSONRoundTrip(t *testing.T) { + original := &Observation{ + ID: 1, + SDKSessionID: "session-123", + Project: "test-project", + Type: ObsTypeDiscovery, + Title: sql.NullString{String: "Test Title", Valid: true}, + Subtitle: sql.NullString{String: "Test Subtitle", Valid: true}, + Narrative: sql.NullString{String: "Test narrative content", Valid: true}, + Scope: ScopeProject, + CreatedAt: "2024-01-01T00:00:00Z", + CreatedAtEpoch: 1704067200000, + } + + // Marshal + data, err := json.Marshal(original) + require.NoError(t, err) + + // Unmarshal into map to check fields + var result map[string]interface{} + err = json.Unmarshal(data, &result) + require.NoError(t, err) + + assert.Equal(t, float64(1), result["id"]) + assert.Equal(t, "test-project", result["project"]) + assert.Equal(t, "discovery", result["type"]) + assert.Equal(t, "Test Title", result["title"]) +} diff --git a/pkg/models/summary_test.go b/pkg/models/summary_test.go new file mode 100644 index 0000000..39c16ef --- /dev/null +++ b/pkg/models/summary_test.go @@ -0,0 +1,267 @@ +// Package models contains domain models for claude-mnemonic. +package models + +import ( + "database/sql" + "encoding/json" + "testing" + "time" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + "github.com/stretchr/testify/suite" +) + +// SummarySuite is a test suite for SessionSummary operations. +type SummarySuite struct { + suite.Suite +} + +func TestSummarySuite(t *testing.T) { + suite.Run(t, new(SummarySuite)) +} + +// TestNewSessionSummary tests summary creation. +func (s *SummarySuite) TestNewSessionSummary() { + parsed := &ParsedSummary{ + Request: "Fix the bug in handler.go", + Investigated: "Looked at error logs", + Learned: "The issue was a race condition", + Completed: "Fixed the race condition", + NextSteps: "Add more tests", + Notes: "Consider adding mutex", + } + + summary := NewSessionSummary("sdk-123", "test-project", parsed, 5, 1000) + + s.NotNil(summary) + s.Equal("sdk-123", summary.SDKSessionID) + s.Equal("test-project", summary.Project) + s.True(summary.Request.Valid) + s.Equal("Fix the bug in handler.go", summary.Request.String) + s.True(summary.Investigated.Valid) + s.True(summary.Learned.Valid) + s.True(summary.Completed.Valid) + s.True(summary.NextSteps.Valid) + s.True(summary.Notes.Valid) + s.True(summary.PromptNumber.Valid) + s.Equal(int64(5), summary.PromptNumber.Int64) + s.Equal(int64(1000), summary.DiscoveryTokens) + s.NotEmpty(summary.CreatedAt) + s.Greater(summary.CreatedAtEpoch, int64(0)) +} + +// TestNewSessionSummary_EmptyFields tests summary creation with empty fields. +func (s *SummarySuite) TestNewSessionSummary_EmptyFields() { + parsed := &ParsedSummary{ + Request: "Test request", + // All other fields empty + } + + summary := NewSessionSummary("sdk-123", "project", parsed, 0, 0) + + s.True(summary.Request.Valid) + s.False(summary.Investigated.Valid) + s.False(summary.Learned.Valid) + s.False(summary.Completed.Valid) + s.False(summary.NextSteps.Valid) + s.False(summary.Notes.Valid) + s.False(summary.PromptNumber.Valid) // 0 is not valid + s.Equal(int64(0), summary.DiscoveryTokens) +} + +// TestSessionSummary_MarshalJSON tests JSON marshaling. +func (s *SummarySuite) TestSessionSummary_MarshalJSON() { + summary := &SessionSummary{ + ID: 1, + SDKSessionID: "sdk-123", + Project: "test-project", + Request: sql.NullString{String: "Test request", Valid: true}, + Investigated: sql.NullString{String: "Test investigation", Valid: true}, + Learned: sql.NullString{Valid: false}, // Invalid - should be omitted + Completed: sql.NullString{String: "Test completion", Valid: true}, + NextSteps: sql.NullString{Valid: false}, + Notes: sql.NullString{String: "Test notes", Valid: true}, + PromptNumber: sql.NullInt64{Int64: 3, Valid: true}, + DiscoveryTokens: 500, + CreatedAt: "2024-01-01T00:00:00Z", + CreatedAtEpoch: 1704067200000, + } + + data, err := json.Marshal(summary) + s.NoError(err) + + // Parse the JSON + var result map[string]interface{} + err = json.Unmarshal(data, &result) + s.NoError(err) + + // Check fields + s.Equal(float64(1), result["id"]) + s.Equal("sdk-123", result["sdk_session_id"]) + s.Equal("test-project", result["project"]) + s.Equal("Test request", result["request"]) + s.Equal("Test investigation", result["investigated"]) + s.Equal("Test completion", result["completed"]) + s.Equal("Test notes", result["notes"]) + s.Equal(float64(3), result["prompt_number"]) + s.Equal(float64(500), result["discovery_tokens"]) + + // Empty fields should be omitted + _, hasLearned := result["learned"] + s.False(hasLearned, "Empty learned should be omitted") + _, hasNextSteps := result["next_steps"] + s.False(hasNextSteps, "Empty next_steps should be omitted") +} + +// TestSessionSummary_MarshalJSON_AllEmpty tests JSON marshaling with all empty optional fields. +func (s *SummarySuite) TestSessionSummary_MarshalJSON_AllEmpty() { + summary := &SessionSummary{ + ID: 1, + SDKSessionID: "sdk-123", + Project: "test-project", + Request: sql.NullString{Valid: false}, + Investigated: sql.NullString{Valid: false}, + Learned: sql.NullString{Valid: false}, + Completed: sql.NullString{Valid: false}, + NextSteps: sql.NullString{Valid: false}, + Notes: sql.NullString{Valid: false}, + PromptNumber: sql.NullInt64{Valid: false}, + DiscoveryTokens: 0, + CreatedAt: "2024-01-01T00:00:00Z", + CreatedAtEpoch: 1704067200000, + } + + data, err := json.Marshal(summary) + s.NoError(err) + + var result map[string]interface{} + err = json.Unmarshal(data, &result) + s.NoError(err) + + // Required fields should be present + s.Equal(float64(1), result["id"]) + s.Equal("sdk-123", result["sdk_session_id"]) + s.Equal("test-project", result["project"]) + + // Optional fields should be empty strings or omitted + request, hasRequest := result["request"] + if hasRequest { + s.Equal("", request) + } +} + +// TestParsedSummary tests ParsedSummary structure. +func (s *SummarySuite) TestParsedSummary() { + parsed := &ParsedSummary{ + Request: "Request text", + Investigated: "Investigation text", + Learned: "Learned text", + Completed: "Completed text", + NextSteps: "Next steps text", + Notes: "Notes text", + } + + s.Equal("Request text", parsed.Request) + s.Equal("Investigation text", parsed.Investigated) + s.Equal("Learned text", parsed.Learned) + s.Equal("Completed text", parsed.Completed) + s.Equal("Next steps text", parsed.NextSteps) + s.Equal("Notes text", parsed.Notes) +} + +// TestSessionSummaryJSON tests the JSON-friendly type. +func (s *SummarySuite) TestSessionSummaryJSON() { + j := SessionSummaryJSON{ + ID: 1, + SDKSessionID: "sdk-123", + Project: "test-project", + Request: "Request", + Investigated: "Investigation", + Learned: "Learned", + Completed: "Completed", + NextSteps: "Next steps", + Notes: "Notes", + PromptNumber: 5, + DiscoveryTokens: 1000, + CreatedAt: "2024-01-01T00:00:00Z", + CreatedAtEpoch: 1704067200000, + } + + s.Equal(int64(1), j.ID) + s.Equal("sdk-123", j.SDKSessionID) + s.Equal("test-project", j.Project) + s.Equal("Request", j.Request) + s.Equal("Investigation", j.Investigated) + s.Equal("Learned", j.Learned) + s.Equal("Completed", j.Completed) + s.Equal("Next steps", j.NextSteps) + s.Equal("Notes", j.Notes) + s.Equal(int64(5), j.PromptNumber) + s.Equal(int64(1000), j.DiscoveryTokens) +} + +// TestSessionSummary_TimestampValidity tests that timestamps are set correctly. +func TestSessionSummary_TimestampValidity(t *testing.T) { + before := time.Now().Add(-time.Second) // Give 1 second buffer + + parsed := &ParsedSummary{Request: "Test"} + summary := NewSessionSummary("sdk-123", "project", parsed, 1, 100) + + after := time.Now().Add(time.Second) // Give 1 second buffer + + // Parse the timestamp + createdAt, err := time.Parse(time.RFC3339, summary.CreatedAt) + require.NoError(t, err) + + // Timestamp should be between before and after (with buffer) + assert.True(t, createdAt.After(before) || createdAt.Equal(before), "created_at should be >= before") + assert.True(t, createdAt.Before(after) || createdAt.Equal(after), "created_at should be <= after") + + // Epoch should also be in range (with buffer) + beforeEpoch := before.UnixMilli() + afterEpoch := after.UnixMilli() + assert.GreaterOrEqual(t, summary.CreatedAtEpoch, beforeEpoch, "epoch should be >= before epoch") + assert.LessOrEqual(t, summary.CreatedAtEpoch, afterEpoch, "epoch should be <= after epoch") +} + +// TestSessionSummary_JSONRoundTrip tests that summaries can be marshaled and unmarshaled. +func TestSessionSummary_JSONRoundTrip(t *testing.T) { + original := &SessionSummary{ + ID: 1, + SDKSessionID: "sdk-123", + Project: "test-project", + Request: sql.NullString{String: "Test request", Valid: true}, + Investigated: sql.NullString{String: "Test investigation", Valid: true}, + Learned: sql.NullString{String: "Test learned", Valid: true}, + Completed: sql.NullString{String: "Test completed", Valid: true}, + NextSteps: sql.NullString{String: "Test next steps", Valid: true}, + Notes: sql.NullString{String: "Test notes", Valid: true}, + PromptNumber: sql.NullInt64{Int64: 5, Valid: true}, + DiscoveryTokens: 1000, + CreatedAt: "2024-01-01T00:00:00Z", + CreatedAtEpoch: 1704067200000, + } + + // Marshal + data, err := json.Marshal(original) + require.NoError(t, err) + + // Unmarshal into JSON type + var result SessionSummaryJSON + err = json.Unmarshal(data, &result) + require.NoError(t, err) + + // Verify + assert.Equal(t, original.ID, result.ID) + assert.Equal(t, original.SDKSessionID, result.SDKSessionID) + assert.Equal(t, original.Project, result.Project) + assert.Equal(t, original.Request.String, result.Request) + assert.Equal(t, original.Investigated.String, result.Investigated) + assert.Equal(t, original.Learned.String, result.Learned) + assert.Equal(t, original.Completed.String, result.Completed) + assert.Equal(t, original.NextSteps.String, result.NextSteps) + assert.Equal(t, original.Notes.String, result.Notes) + assert.Equal(t, original.PromptNumber.Int64, result.PromptNumber) + assert.Equal(t, original.DiscoveryTokens, result.DiscoveryTokens) +}