diff --git a/internal/config/config_test.go b/internal/config/config_test.go index 29731c2..bc8bc4e 100644 --- a/internal/config/config_test.go +++ b/internal/config/config_test.go @@ -347,6 +347,58 @@ func TestLoad_ClaudeCodePath(t *testing.T) { assert.Equal(t, "/usr/local/bin/claude", cfg.ClaudeCodePath) } +// TestGet tests the global config getter. +func TestGet(t *testing.T) { + // Save and restore HOME + origHome := os.Getenv("HOME") + tempDir, err := os.MkdirTemp("", "config-get-test-*") + require.NoError(t, err) + defer func() { + os.Setenv("HOME", origHome) + os.RemoveAll(tempDir) + }() + os.Setenv("HOME", tempDir) + + // Create data dir + err = os.MkdirAll(filepath.Join(tempDir, ".claude-mnemonic"), 0750) + require.NoError(t, err) + + // Get() should return a valid config + cfg := Get() + require.NotNil(t, cfg) + assert.Greater(t, cfg.WorkerPort, 0) + assert.NotEmpty(t, cfg.Model) +} + +// TestGetWorkerPort_WithEnv tests GetWorkerPort with environment variable. +func TestGetWorkerPort_WithEnv(t *testing.T) { + // Save original env + origEnv := os.Getenv("CLAUDE_MNEMONIC_WORKER_PORT") + defer os.Setenv("CLAUDE_MNEMONIC_WORKER_PORT", origEnv) + + // Test with valid port in env + os.Setenv("CLAUDE_MNEMONIC_WORKER_PORT", "45678") + port := GetWorkerPort() + assert.Equal(t, 45678, port) + + // Test with invalid port (should fall back to config) + os.Setenv("CLAUDE_MNEMONIC_WORKER_PORT", "not-a-number") + port = GetWorkerPort() + // Should return from Get().WorkerPort, which is default + assert.Greater(t, port, 0) + + // Test with zero port (should fall back to config) + os.Setenv("CLAUDE_MNEMONIC_WORKER_PORT", "0") + port = GetWorkerPort() + // Zero is invalid, so should use default + assert.Greater(t, port, 0) + + // Test with no env (should use config) + os.Unsetenv("CLAUDE_MNEMONIC_WORKER_PORT") + port = GetWorkerPort() + assert.Greater(t, port, 0) +} + // TestLoad_ContextSettings tests context-related settings loading. func TestLoad_ContextSettings(t *testing.T) { // Create temp dir diff --git a/internal/mcp/server_test.go b/internal/mcp/server_test.go index 6afce82..ef094f4 100644 --- a/internal/mcp/server_test.go +++ b/internal/mcp/server_test.go @@ -957,3 +957,127 @@ func TestToolCallParamsWithComplexArgs(t *testing.T) { assert.Equal(t, "search", params.Name) assert.NotEmpty(t, params.Arguments) } + +// TestCallTool_UnknownToolName tests callTool with various unknown tool names. +func TestCallTool_UnknownToolName(t *testing.T) { + server := NewServer(nil, "1.0.0") + ctx := context.Background() + + unknownTools := []string{ + "invalid_tool", + "nonexistent", + "search_v2", + "timeline_special", + } + + for _, name := range unknownTools { + t.Run(name, func(t *testing.T) { + result, err := server.callTool(ctx, name, json.RawMessage(`{}`)) + assert.Error(t, err) + assert.Empty(t, result) + assert.Contains(t, err.Error(), "unknown tool") + }) + } +} + +// TestTimelineParams_Validation tests TimelineParams struct field validation. +func TestTimelineParams_Validation(t *testing.T) { + tests := []struct { + name string + json string + wantOK bool + }{ + {"valid with anchor_id", `{"anchor_id":123,"before":5,"after":5}`, true}, + {"valid with query only", `{"query":"test query"}`, true}, + {"empty params", `{}`, true}, + {"with all fields", `{"anchor_id":1,"query":"test","before":10,"after":10,"project":"proj","obs_type":"bugfix","format":"full"}`, true}, + {"invalid json", `{invalid`, false}, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + var params TimelineParams + err := json.Unmarshal([]byte(tt.json), ¶ms) + if tt.wantOK { + assert.NoError(t, err) + } else { + assert.Error(t, err) + } + }) + } +} + +// TestHandleToolsCall_UnknownToolNameError tests tools/call with unknown tool returns error. +func TestHandleToolsCall_UnknownToolNameError(t *testing.T) { + server := NewServer(nil, "1.0.0") + ctx := context.Background() + + req := &Request{ + JSONRPC: "2.0", + ID: 1, + Method: "tools/call", + Params: json.RawMessage(`{"name":"very_unknown_tool_name","arguments":{}}`), + } + + resp := server.handleToolsCall(ctx, req) + + // Should get an error response + assert.Equal(t, "2.0", resp.JSONRPC) + assert.Equal(t, 1, resp.ID) + require.NotNil(t, resp.Error) + // Error is "Tool error" with message containing "unknown tool" + assert.True(t, resp.Error.Code != 0) +} + +// TestHandleToolsCall_EmptyParams tests tools/call with empty params. +func TestHandleToolsCall_EmptyParams(t *testing.T) { + server := NewServer(nil, "1.0.0") + ctx := context.Background() + + req := &Request{ + JSONRPC: "2.0", + ID: 1, + Method: "tools/call", + Params: json.RawMessage(`{}`), + } + + resp := server.handleToolsCall(ctx, req) + + // Should error due to missing name + require.NotNil(t, resp.Error) +} + +// TestSendResponse_WithError tests sendResponse with an error response. +func TestSendResponse_WithError(t *testing.T) { + var buf bytes.Buffer + server := &Server{stdout: &buf} + + resp := &Response{ + JSONRPC: "2.0", + ID: 1, + Error: &Error{Code: -32600, Message: "Invalid Request"}, + } + + server.sendResponse(resp) + + output := buf.String() + assert.Contains(t, output, `"error"`) + assert.Contains(t, output, `-32600`) +} + +// TestSendResponse_NilID tests sendResponse with nil ID. +func TestSendResponse_NilID(t *testing.T) { + var buf bytes.Buffer + server := &Server{stdout: &buf} + + resp := &Response{ + JSONRPC: "2.0", + ID: nil, + Result: "notification response", + } + + server.sendResponse(resp) + + output := buf.String() + assert.Contains(t, output, `"id":null`) +} diff --git a/internal/vector/sqlitevec/sync_test.go b/internal/vector/sqlitevec/sync_test.go new file mode 100644 index 0000000..5a9ae63 --- /dev/null +++ b/internal/vector/sqlitevec/sync_test.go @@ -0,0 +1,348 @@ +package sqlitevec + +import ( + "context" + "database/sql" + "testing" + + "github.com/lukaszraczylo/claude-mnemonic/pkg/models" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +// testClient creates a test client for sync tests. +func testClient(t *testing.T) (*Client, func()) { + t.Helper() + + db, dbCleanup := testDB(t) + embedSvc, embedCleanup := testEmbeddingService(t) + + client, err := NewClient(Config{DB: db}, embedSvc) + require.NoError(t, err) + + cleanup := func() { + embedCleanup() + dbCleanup() + } + + return client, cleanup +} + +func TestNewSync(t *testing.T) { + client, cleanup := testClient(t) + defer cleanup() + + sync := NewSync(client) + assert.NotNil(t, sync) +} + +func TestSync_SyncObservation_Empty(t *testing.T) { + client, cleanup := testClient(t) + defer cleanup() + + sync := NewSync(client) + + // Observation with no content should be handled gracefully + obs := &models.Observation{ + ID: 1, + SDKSessionID: "test-session", + Project: "test-project", + Type: models.ObsTypeDiscovery, + } + + err := sync.SyncObservation(context.Background(), obs) + require.NoError(t, err) +} + +func TestSync_SyncObservation_WithContent(t *testing.T) { + client, cleanup := testClient(t) + defer cleanup() + + sync := NewSync(client) + + obs := &models.Observation{ + ID: 1, + SDKSessionID: "test-session", + Project: "test-project", + Type: models.ObsTypeDiscovery, + Scope: models.ScopeProject, + Title: sql.NullString{String: "Authentication bug fix", Valid: true}, + Subtitle: sql.NullString{String: "Fixed JWT validation", Valid: true}, + Narrative: sql.NullString{String: "Fixed the JWT token validation to handle expired tokens correctly.", Valid: true}, + Facts: []string{"JWT tokens expire after 24 hours", "Refresh tokens are used for renewal"}, + Concepts: []string{"authentication", "security"}, + FilesRead: []string{"auth.go"}, + FilesModified: []string{"handler.go"}, + } + + err := sync.SyncObservation(context.Background(), obs) + require.NoError(t, err) + + // Verify documents were added + results, err := client.Query(context.Background(), "authentication", 10, nil) + require.NoError(t, err) + assert.GreaterOrEqual(t, len(results), 1) +} + +func TestSync_SyncObservation_DefaultScope(t *testing.T) { + client, cleanup := testClient(t) + defer cleanup() + + sync := NewSync(client) + + // Observation without explicit scope + obs := &models.Observation{ + ID: 2, + SDKSessionID: "test-session", + Project: "test-project", + Type: models.ObsTypeBugfix, + Narrative: sql.NullString{String: "Fixed a null pointer exception.", Valid: true}, + } + + err := sync.SyncObservation(context.Background(), obs) + require.NoError(t, err) +} + +func TestSync_SyncSummary_Empty(t *testing.T) { + client, cleanup := testClient(t) + defer cleanup() + + sync := NewSync(client) + + // Summary with no content + summary := &models.SessionSummary{ + ID: 1, + SDKSessionID: "test-session", + Project: "test-project", + } + + err := sync.SyncSummary(context.Background(), summary) + require.NoError(t, err) +} + +func TestSync_SyncSummary_WithContent(t *testing.T) { + client, cleanup := testClient(t) + defer cleanup() + + sync := NewSync(client) + + summary := &models.SessionSummary{ + ID: 1, + SDKSessionID: "test-session", + Project: "test-project", + Request: sql.NullString{String: "Help me fix the authentication bug", Valid: true}, + Investigated: sql.NullString{String: "Looked at auth.go and handler.go", Valid: true}, + Learned: sql.NullString{String: "JWT tokens were not being validated properly", Valid: true}, + Completed: sql.NullString{String: "Fixed the JWT validation logic", Valid: true}, + NextSteps: sql.NullString{String: "Add tests for edge cases", Valid: true}, + Notes: sql.NullString{String: "Consider using a library for JWT handling", Valid: true}, + PromptNumber: sql.NullInt64{Int64: 1, Valid: true}, + } + + err := sync.SyncSummary(context.Background(), summary) + require.NoError(t, err) + + // Verify documents were added + results, err := client.Query(context.Background(), "authentication", 10, nil) + require.NoError(t, err) + assert.GreaterOrEqual(t, len(results), 1) +} + +func TestSync_SyncUserPrompt(t *testing.T) { + client, cleanup := testClient(t) + defer cleanup() + + sync := NewSync(client) + + prompt := &models.UserPromptWithSession{ + UserPrompt: models.UserPrompt{ + ID: 1, + PromptNumber: 1, + PromptText: "Help me fix the authentication bug in the login handler", + CreatedAtEpoch: 1234567890, + }, + SDKSessionID: "test-session", + Project: "test-project", + } + + err := sync.SyncUserPrompt(context.Background(), prompt) + require.NoError(t, err) + + // Verify document was added + results, err := client.Query(context.Background(), "authentication", 10, nil) + require.NoError(t, err) + assert.GreaterOrEqual(t, len(results), 1) +} + +func TestSync_DeleteObservations_Empty(t *testing.T) { + client, cleanup := testClient(t) + defer cleanup() + + sync := NewSync(client) + + // Should handle empty list + err := sync.DeleteObservations(context.Background(), []int64{}) + require.NoError(t, err) +} + +func TestSync_DeleteObservations_WithData(t *testing.T) { + client, cleanup := testClient(t) + defer cleanup() + + sync := NewSync(client) + + // First add an observation + obs := &models.Observation{ + ID: 10, + SDKSessionID: "test-session", + Project: "test-project", + Type: models.ObsTypeDiscovery, + Narrative: sql.NullString{String: "This observation should be deleted.", Valid: true}, + Facts: []string{"Fact 1", "Fact 2"}, + } + + err := sync.SyncObservation(context.Background(), obs) + require.NoError(t, err) + + // Then delete it + err = sync.DeleteObservations(context.Background(), []int64{10}) + require.NoError(t, err) +} + +func TestSync_DeleteUserPrompts_Empty(t *testing.T) { + client, cleanup := testClient(t) + defer cleanup() + + sync := NewSync(client) + + // Should handle empty list + err := sync.DeleteUserPrompts(context.Background(), []int64{}) + require.NoError(t, err) +} + +func TestSync_DeleteUserPrompts_WithData(t *testing.T) { + client, cleanup := testClient(t) + defer cleanup() + + sync := NewSync(client) + + // First add a prompt + prompt := &models.UserPromptWithSession{ + UserPrompt: models.UserPrompt{ + ID: 20, + PromptNumber: 1, + PromptText: "This prompt should be deleted.", + CreatedAtEpoch: 1234567890, + }, + SDKSessionID: "test-session", + Project: "test-project", + } + + err := sync.SyncUserPrompt(context.Background(), prompt) + require.NoError(t, err) + + // Then delete it + err = sync.DeleteUserPrompts(context.Background(), []int64{20}) + require.NoError(t, err) +} + +func TestSync_FormatObservationDocs_AllFields(t *testing.T) { + client, cleanup := testClient(t) + defer cleanup() + + sync := NewSync(client) + + obs := &models.Observation{ + ID: 100, + SDKSessionID: "sdk-123", + Project: "my-project", + Type: models.ObsTypeFeature, + Scope: models.ScopeGlobal, + Title: sql.NullString{String: "Feature Title", Valid: true}, + Subtitle: sql.NullString{String: "Feature Subtitle", Valid: true}, + Narrative: sql.NullString{String: "Feature narrative content", Valid: true}, + Facts: []string{"Fact A", "Fact B", "Fact C"}, + Concepts: []string{"api", "performance"}, + FilesRead: []string{"file1.go", "file2.go"}, + FilesModified: []string{"file3.go"}, + } + + docs := sync.formatObservationDocs(obs) + + // Should have 1 narrative + 3 facts = 4 docs + assert.Len(t, docs, 4) + + // Check narrative doc + var narrativeDoc *Document + for i := range docs { + if docs[i].ID == "obs_100_narrative" { + narrativeDoc = &docs[i] + break + } + } + require.NotNil(t, narrativeDoc) + assert.Equal(t, "Feature narrative content", narrativeDoc.Content) + assert.Equal(t, int64(100), narrativeDoc.Metadata["sqlite_id"]) + assert.Equal(t, "observation", narrativeDoc.Metadata["doc_type"]) + assert.Equal(t, "global", narrativeDoc.Metadata["scope"]) + assert.Equal(t, "narrative", narrativeDoc.Metadata["field_type"]) +} + +func TestSync_FormatSummaryDocs_AllFields(t *testing.T) { + client, cleanup := testClient(t) + defer cleanup() + + sync := NewSync(client) + + summary := &models.SessionSummary{ + ID: 200, + SDKSessionID: "sdk-456", + Project: "summary-project", + Request: sql.NullString{String: "Request content", Valid: true}, + Investigated: sql.NullString{String: "Investigated content", Valid: true}, + Learned: sql.NullString{String: "Learned content", Valid: true}, + Completed: sql.NullString{String: "Completed content", Valid: true}, + NextSteps: sql.NullString{String: "Next steps content", Valid: true}, + Notes: sql.NullString{String: "Notes content", Valid: true}, + PromptNumber: sql.NullInt64{Int64: 5, Valid: true}, + } + + docs := sync.formatSummaryDocs(summary) + + // Should have 6 docs (one for each field) + assert.Len(t, docs, 6) + + // Check request doc + var requestDoc *Document + for i := range docs { + if docs[i].ID == "summary_200_request" { + requestDoc = &docs[i] + break + } + } + require.NotNil(t, requestDoc) + assert.Equal(t, "Request content", requestDoc.Content) + assert.Equal(t, int64(200), requestDoc.Metadata["sqlite_id"]) + assert.Equal(t, "session_summary", requestDoc.Metadata["doc_type"]) + assert.Equal(t, int64(5), requestDoc.Metadata["prompt_number"]) +} + +func TestSync_FormatObservationDocs_EmptyScope(t *testing.T) { + client, cleanup := testClient(t) + defer cleanup() + + sync := NewSync(client) + + obs := &models.Observation{ + ID: 300, + SDKSessionID: "sdk-789", + Project: "scope-test", + Type: models.ObsTypeDecision, + Narrative: sql.NullString{String: "Test narrative", Valid: true}, + // Scope intentionally left empty + } + + docs := sync.formatObservationDocs(obs) + assert.Len(t, docs, 1) + assert.Equal(t, "project", docs[0].Metadata["scope"]) +} diff --git a/internal/worker/handlers_test.go b/internal/worker/handlers_test.go index f7f1b67..5a4349d 100644 --- a/internal/worker/handlers_test.go +++ b/internal/worker/handlers_test.go @@ -6,6 +6,7 @@ import ( "context" "database/sql" "encoding/json" + "fmt" "net/http" "net/http/httptest" "strconv" @@ -15,6 +16,7 @@ import ( "github.com/go-chi/chi/v5" "github.com/lukaszraczylo/claude-mnemonic/internal/config" "github.com/lukaszraczylo/claude-mnemonic/internal/db/sqlite" + "github.com/lukaszraczylo/claude-mnemonic/internal/update" "github.com/lukaszraczylo/claude-mnemonic/internal/worker/session" "github.com/lukaszraczylo/claude-mnemonic/internal/worker/sse" "github.com/lukaszraczylo/claude-mnemonic/pkg/models" @@ -45,6 +47,9 @@ func testService(t *testing.T) (*Service, func()) { // Create router router := chi.NewRouter() + // Create test updater + testUpdater := update.New("test-version", t.TempDir()) + svc := &Service{ version: "test-version", config: config.Get(), @@ -55,6 +60,7 @@ func testService(t *testing.T) (*Service, func()) { promptStore: promptStore, sessionManager: sessionManager, sseBroadcaster: sseBroadcaster, + updater: testUpdater, router: router, ctx: ctx, cancel: cancel, @@ -2125,3 +2131,781 @@ func TestHandleContextInject_WithQuery(t *testing.T) { require.True(t, ok) assert.GreaterOrEqual(t, len(observations), 1) } + +// TestHandleUpdateCheck tests the update check endpoint. +func TestHandleUpdateCheck(t *testing.T) { + svc, cleanup := testService(t) + defer cleanup() + + req := httptest.NewRequest(http.MethodGet, "/api/update/check", nil) + rec := httptest.NewRecorder() + + svc.router.ServeHTTP(rec, req) + + // May return 200 or 500 depending on network availability + assert.Contains(t, []int{http.StatusOK, http.StatusInternalServerError}, rec.Code) +} + +// TestHandleUpdateStatus tests the update status endpoint. +func TestHandleUpdateStatus(t *testing.T) { + svc, cleanup := testService(t) + defer cleanup() + + req := httptest.NewRequest(http.MethodGet, "/api/update/status", nil) + rec := httptest.NewRecorder() + + svc.router.ServeHTTP(rec, req) + + assert.Equal(t, http.StatusOK, rec.Code) + + var status update.UpdateStatus + err := json.Unmarshal(rec.Body.Bytes(), &status) + require.NoError(t, err) + + // Default state should be "idle" + assert.Equal(t, "idle", status.State) +} + +// TestHandleUpdateApply tests the update apply endpoint. +func TestHandleUpdateApply(t *testing.T) { + svc, cleanup := testService(t) + defer cleanup() + + req := httptest.NewRequest(http.MethodPost, "/api/update/apply", nil) + rec := httptest.NewRecorder() + + svc.router.ServeHTTP(rec, req) + + // Will return 200 with message or 500 if network check fails + assert.Contains(t, []int{http.StatusOK, http.StatusInternalServerError}, rec.Code) +} + +// TestHandleUpdateRestart_NoUpdate tests restart without applied update. +func TestHandleUpdateRestart_NoUpdate(t *testing.T) { + svc, cleanup := testService(t) + defer cleanup() + + req := httptest.NewRequest(http.MethodPost, "/api/update/restart", nil) + rec := httptest.NewRecorder() + + svc.router.ServeHTTP(rec, req) + + // Should return error since no update has been applied + assert.Equal(t, http.StatusBadRequest, rec.Code) +} + +// TestHandleRestart tests the general restart endpoint. +func TestHandleRestart(t *testing.T) { + svc, cleanup := testService(t) + defer cleanup() + + req := httptest.NewRequest(http.MethodPost, "/api/restart", nil) + rec := httptest.NewRecorder() + + svc.router.ServeHTTP(rec, req) + + // Should return OK with restart message + assert.Equal(t, http.StatusOK, rec.Code) + + var response map[string]interface{} + err := json.Unmarshal(rec.Body.Bytes(), &response) + require.NoError(t, err) + + assert.Equal(t, true, response["success"]) + assert.Equal(t, "Restarting worker...", response["message"]) +} + +// TestSetInitError tests setting init error. +func TestSetInitError(t *testing.T) { + svc, cleanup := testService(t) + defer cleanup() + + // Initially no error + assert.Nil(t, svc.GetInitError()) + + // Set an error + testErr := assert.AnError + svc.setInitError(testErr) + + // Should be set + assert.Equal(t, testErr, svc.GetInitError()) +} + +// TestQueueStaleVerification tests the stale verification queue. +func TestQueueStaleVerification(t *testing.T) { + svc, cleanup := testService(t) + defer cleanup() + + // Queue some verifications - should not panic + svc.queueStaleVerification(1, "/test/path") + svc.queueStaleVerification(2, "/test/path2") + + // Give the goroutine a moment to start + time.Sleep(10 * time.Millisecond) +} + +// TestRecordRetrievalStats tests retrieval stats recording. +func TestRecordRetrievalStats(t *testing.T) { + svc, cleanup := testService(t) + defer cleanup() + + // Initially all zeros + stats := svc.GetRetrievalStats() + assert.Equal(t, int64(0), stats.TotalRequests) + + // Record a search request + svc.recordRetrievalStats(10, 1, 0, true) + stats = svc.GetRetrievalStats() + assert.Equal(t, int64(1), stats.TotalRequests) + assert.Equal(t, int64(10), stats.ObservationsServed) + assert.Equal(t, int64(1), stats.VerifiedStale) + assert.Equal(t, int64(1), stats.SearchRequests) + + // Record a context injection + svc.recordRetrievalStats(5, 0, 1, false) + stats = svc.GetRetrievalStats() + assert.Equal(t, int64(2), stats.TotalRequests) + assert.Equal(t, int64(15), stats.ObservationsServed) + assert.Equal(t, int64(1), stats.DeletedInvalid) + assert.Equal(t, int64(1), stats.ContextInjections) +} + +// TestHandleSelfCheck_WithInitError tests self-check with init error. +func TestHandleSelfCheck_WithInitError(t *testing.T) { + svc, cleanup := testService(t) + defer cleanup() + + svc.ready.Store(false) + svc.setInitError(assert.AnError) + + req := httptest.NewRequest(http.MethodGet, "/api/self-check", nil) + rec := httptest.NewRecorder() + + svc.handleSelfCheck(rec, req) + + assert.Equal(t, http.StatusOK, rec.Code) + + var response SelfCheckResponse + err := json.Unmarshal(rec.Body.Bytes(), &response) + require.NoError(t, err) + + assert.Equal(t, "unhealthy", response.Overall) +} + +// TestHandleHealthEndpoint tests health endpoint via router. +func TestHandleHealthEndpoint(t *testing.T) { + svc, cleanup := testService(t) + defer cleanup() + + req := httptest.NewRequest(http.MethodGet, "/health", nil) + rec := httptest.NewRecorder() + + svc.router.ServeHTTP(rec, req) + + assert.Equal(t, http.StatusOK, rec.Code) +} + +// TestHandleSelfCheckEndpoint tests self-check endpoint via router. +func TestHandleSelfCheckEndpoint(t *testing.T) { + svc, cleanup := testService(t) + defer cleanup() + + req := httptest.NewRequest(http.MethodGet, "/api/selfcheck", nil) + rec := httptest.NewRecorder() + + svc.router.ServeHTTP(rec, req) + + assert.Equal(t, http.StatusOK, rec.Code) +} + +// TestUpdateInfo_Fields tests UpdateInfo struct. +func TestUpdateInfo_Fields(t *testing.T) { + info := update.UpdateInfo{ + Available: true, + CurrentVersion: "v1.0.0", + LatestVersion: "v2.0.0", + ReleaseNotes: "Bug fixes", + } + + assert.True(t, info.Available) + assert.Equal(t, "v1.0.0", info.CurrentVersion) + assert.Equal(t, "v2.0.0", info.LatestVersion) + assert.Equal(t, "Bug fixes", info.ReleaseNotes) +} + +// TestUpdateStatus_Fields tests UpdateStatus struct. +func TestUpdateStatus_Fields(t *testing.T) { + status := update.UpdateStatus{ + State: "downloading", + Progress: 50.0, + Message: "Downloading update...", + } + + assert.Equal(t, "downloading", status.State) + assert.Equal(t, 50.0, status.Progress) + assert.Equal(t, "Downloading update...", status.Message) +} + +// TestHandleObservation_MissingFields tests observation with partial fields. +// The handler accepts requests even with missing fields since it creates sessions on-the-fly. +func TestHandleObservation_MissingFields(t *testing.T) { + svc, cleanup := testService(t) + defer cleanup() + + body := `{"project": "test"}` + req := httptest.NewRequest(http.MethodPost, "/api/sessions/observations", bytes.NewBufferString(body)) + req.Header.Set("Content-Type", "application/json") + rec := httptest.NewRecorder() + + svc.router.ServeHTTP(rec, req) + + // Handler accepts requests with partial fields and creates sessions on-the-fly + assert.Equal(t, http.StatusOK, rec.Code) +} + +// TestHandleGetSummaries_WithProject tests getting summaries filtered by project. +func TestHandleGetSummaries_WithProject(t *testing.T) { + svc, cleanup := testService(t) + defer cleanup() + + req := httptest.NewRequest(http.MethodGet, "/api/summaries?project=test-project&limit=10", nil) + rec := httptest.NewRecorder() + + svc.router.ServeHTTP(rec, req) + + assert.Equal(t, http.StatusOK, rec.Code) +} + +// TestHandleGetPrompts_WithProject tests getting prompts filtered by project. +func TestHandleGetPrompts_WithProject(t *testing.T) { + svc, cleanup := testService(t) + defer cleanup() + + req := httptest.NewRequest(http.MethodGet, "/api/prompts?project=test-project&limit=10", nil) + rec := httptest.NewRecorder() + + svc.router.ServeHTTP(rec, req) + + assert.Equal(t, http.StatusOK, rec.Code) +} + +// TestHandleSearchByPrompt_WithCWD tests search with cwd parameter. +func TestHandleSearchByPrompt_WithCWD(t *testing.T) { + svc, cleanup := testService(t) + defer cleanup() + + project := "search-cwd-test" + createTestObservation(t, svc.observationStore, project, "Test observation", "Content here", []string{"test"}) + + req := httptest.NewRequest(http.MethodGet, "/api/context/search?project="+project+"&query=test&cwd=/tmp", nil) + rec := httptest.NewRecorder() + + svc.router.ServeHTTP(rec, req) + + assert.Equal(t, http.StatusOK, rec.Code) +} + +// TestHandleContextInject_WithLimitAndCWD tests context inject with both limit and cwd. +func TestHandleContextInject_WithLimitAndCWD(t *testing.T) { + svc, cleanup := testService(t) + defer cleanup() + + project := "inject-cwd-test" + createTestObservation(t, svc.observationStore, project, "Test observation", "Content here", []string{"test"}) + + req := httptest.NewRequest(http.MethodGet, "/api/context/inject?project="+project+"&limit=5&cwd=/tmp", nil) + rec := httptest.NewRecorder() + + svc.router.ServeHTTP(rec, req) + + assert.Equal(t, http.StatusOK, rec.Code) +} + +// TestHandleSelfCheck_AllComponents tests self-check response structure. +func TestHandleSelfCheck_AllComponents(t *testing.T) { + svc, cleanup := testService(t) + defer cleanup() + + req := httptest.NewRequest(http.MethodGet, "/api/selfcheck", nil) + rec := httptest.NewRecorder() + + svc.router.ServeHTTP(rec, req) + + assert.Equal(t, http.StatusOK, rec.Code) + + var response SelfCheckResponse + err := json.Unmarshal(rec.Body.Bytes(), &response) + require.NoError(t, err) + + // Verify all expected components are present + componentNames := make(map[string]bool) + for _, comp := range response.Components { + componentNames[comp.Name] = true + } + + expectedComponents := []string{"Worker Service", "SQLite Database"} + for _, name := range expectedComponents { + assert.True(t, componentNames[name], "missing component: %s", name) + } +} + +// TestHandleReady_NotReady tests ready endpoint when not ready. +func TestHandleReady_NotReady(t *testing.T) { + svc, cleanup := testService(t) + defer cleanup() + + svc.ready.Store(false) + + req := httptest.NewRequest(http.MethodGet, "/api/ready", nil) + rec := httptest.NewRecorder() + + svc.router.ServeHTTP(rec, req) + + assert.Equal(t, http.StatusServiceUnavailable, rec.Code) +} + +// TestHandleReady_Ready tests ready endpoint when ready. +func TestHandleReady_Ready(t *testing.T) { + svc, cleanup := testService(t) + defer cleanup() + + req := httptest.NewRequest(http.MethodGet, "/api/ready", nil) + rec := httptest.NewRecorder() + + svc.router.ServeHTTP(rec, req) + + assert.Equal(t, http.StatusOK, rec.Code) +} + +// TestRequireReady_Middleware tests the requireReady middleware. +func TestRequireReady_Middleware(t *testing.T) { + svc, cleanup := testService(t) + defer cleanup() + + // Service is ready by default in tests + req := httptest.NewRequest(http.MethodGet, "/api/observations", nil) + rec := httptest.NewRecorder() + + svc.router.ServeHTTP(rec, req) + + // Should get OK because service is ready + assert.Equal(t, http.StatusOK, rec.Code) +} + +// TestHandleObservation_ValidRequest tests observation with valid request. +func TestHandleObservation_ValidRequest(t *testing.T) { + svc, cleanup := testService(t) + defer cleanup() + + // First create a session + sessionBody := `{"claude_session_id": "obs-test-session", "project": "obs-test-project"}` + sessionReq := httptest.NewRequest(http.MethodPost, "/api/sessions/init", bytes.NewBufferString(sessionBody)) + sessionReq.Header.Set("Content-Type", "application/json") + sessionRec := httptest.NewRecorder() + svc.router.ServeHTTP(sessionRec, sessionReq) + + // Now create an observation + body := `{ + "claude_session_id": "obs-test-session", + "project": "obs-test-project", + "tool_name": "Write", + "tool_input": {"file_path": "/test/file.go"}, + "tool_response": {"success": true}, + "prompt_number": 1, + "cwd": "/test" + }` + req := httptest.NewRequest(http.MethodPost, "/api/sessions/observations", bytes.NewBufferString(body)) + req.Header.Set("Content-Type", "application/json") + rec := httptest.NewRecorder() + + svc.router.ServeHTTP(rec, req) + + // Should accept the observation (may queue for processing) + assert.Contains(t, []int{http.StatusOK, http.StatusAccepted}, rec.Code) +} + +// TestHandleSubagentComplete_ValidRequest tests subagent complete with valid request. +func TestHandleSubagentComplete_ValidRequest(t *testing.T) { + svc, cleanup := testService(t) + defer cleanup() + + // First create a session + sessionBody := `{"claude_session_id": "subagent-test-session", "project": "subagent-test-project"}` + sessionReq := httptest.NewRequest(http.MethodPost, "/api/sessions/init", bytes.NewBufferString(sessionBody)) + sessionReq.Header.Set("Content-Type", "application/json") + sessionRec := httptest.NewRecorder() + svc.router.ServeHTTP(sessionRec, sessionReq) + + // Now complete a subagent + body := `{ + "claude_session_id": "subagent-test-session", + "project": "subagent-test-project", + "tool_name": "Task", + "tool_input": {"description": "test task"}, + "tool_response": {"result": "completed"}, + "prompt_number": 1, + "cwd": "/test" + }` + req := httptest.NewRequest(http.MethodPost, "/api/sessions/subagent-complete", bytes.NewBufferString(body)) + req.Header.Set("Content-Type", "application/json") + rec := httptest.NewRecorder() + + svc.router.ServeHTTP(rec, req) + + // Should accept the request + assert.Contains(t, []int{http.StatusOK, http.StatusAccepted}, rec.Code) +} + +// TestClusterObservations tests the observation clustering function. +func TestClusterObservations(t *testing.T) { + // Create some observations + obs := []*models.Observation{ + {ID: 1, Title: sql.NullString{String: "Test 1", Valid: true}}, + {ID: 2, Title: sql.NullString{String: "Test 2", Valid: true}}, + {ID: 3, Title: sql.NullString{String: "Test 3", Valid: true}}, + } + + // Cluster with default threshold + clustered := clusterObservations(obs, 0.4) + + // Should return at least one observation + assert.NotEmpty(t, clustered) + assert.LessOrEqual(t, len(clustered), len(obs)) +} + +// TestClusterObservations_EmptyInput tests clustering with empty input. +func TestClusterObservations_EmptyInput(t *testing.T) { + clustered := clusterObservations([]*models.Observation{}, 0.4) + assert.Empty(t, clustered) +} + +// TestClusterObservations_NilInput tests clustering with nil input. +func TestClusterObservations_NilInput(t *testing.T) { + clustered := clusterObservations(nil, 0.4) + assert.Empty(t, clustered) +} + +// TestComponentHealth tests ComponentHealth struct. +func TestComponentHealth(t *testing.T) { + comp := ComponentHealth{ + Name: "Test Component", + Status: "healthy", + Message: "All good", + } + + assert.Equal(t, "Test Component", comp.Name) + assert.Equal(t, "healthy", comp.Status) + assert.Equal(t, "All good", comp.Message) +} + +// TestHandleGetObservations_EmptyResult tests get observations with no data. +func TestHandleGetObservations_EmptyResult(t *testing.T) { + svc, cleanup := testService(t) + defer cleanup() + + req := httptest.NewRequest(http.MethodGet, "/api/observations?project=nonexistent-project", nil) + rec := httptest.NewRecorder() + + svc.router.ServeHTTP(rec, req) + + assert.Equal(t, http.StatusOK, rec.Code) + + // Should return empty array, not null + var obs []interface{} + err := json.Unmarshal(rec.Body.Bytes(), &obs) + require.NoError(t, err) + assert.NotNil(t, obs) +} + +// TestHandleGetSummaries_EmptyResult tests get summaries with no data. +func TestHandleGetSummaries_EmptyResult(t *testing.T) { + svc, cleanup := testService(t) + defer cleanup() + + req := httptest.NewRequest(http.MethodGet, "/api/summaries?project=nonexistent-project", nil) + rec := httptest.NewRecorder() + + svc.router.ServeHTTP(rec, req) + + assert.Equal(t, http.StatusOK, rec.Code) +} + +// TestHandleGetPrompts_EmptyResult tests get prompts with no data. +func TestHandleGetPrompts_EmptyResult(t *testing.T) { + svc, cleanup := testService(t) + defer cleanup() + + req := httptest.NewRequest(http.MethodGet, "/api/prompts?project=nonexistent-project", nil) + rec := httptest.NewRecorder() + + svc.router.ServeHTTP(rec, req) + + assert.Equal(t, http.StatusOK, rec.Code) +} + +// TestHandleContextInject_MissingProject tests context inject without project parameter. +func TestHandleContextInject_MissingProject(t *testing.T) { + svc, cleanup := testService(t) + defer cleanup() + + req := httptest.NewRequest(http.MethodGet, "/api/context/inject", nil) + rec := httptest.NewRecorder() + + svc.router.ServeHTTP(rec, req) + + assert.Equal(t, http.StatusBadRequest, rec.Code) + assert.Contains(t, rec.Body.String(), "project required") +} + +// TestHandleContextInject_DefaultCwd tests context inject with default cwd. +func TestHandleContextInject_DefaultCwd(t *testing.T) { + svc, cleanup := testService(t) + defer cleanup() + + project := "inject-default-cwd-test" + createTestObservation(t, svc.observationStore, project, "Test", "Content", []string{"test"}) + + // Request without cwd parameter - should use default "/" + req := httptest.NewRequest(http.MethodGet, "/api/context/inject?project="+project, nil) + rec := httptest.NewRecorder() + + svc.router.ServeHTTP(rec, req) + + assert.Equal(t, http.StatusOK, rec.Code) +} + +// TestHandleContextInject_ConfigLimits tests context inject respects config limits. +func TestHandleContextInject_ConfigLimits(t *testing.T) { + svc, cleanup := testService(t) + defer cleanup() + + // Set custom config values + svc.config.ContextObservations = 5 + svc.config.ContextFullCount = 2 + + project := "inject-config-test" + for i := 0; i < 10; i++ { + createTestObservation(t, svc.observationStore, project, fmt.Sprintf("Test %d", i), "Content", []string{"test"}) + } + + req := httptest.NewRequest(http.MethodGet, "/api/context/inject?project="+project, nil) + rec := httptest.NewRecorder() + + svc.router.ServeHTTP(rec, req) + + assert.Equal(t, http.StatusOK, rec.Code) + + var response map[string]interface{} + err := json.Unmarshal(rec.Body.Bytes(), &response) + require.NoError(t, err) + assert.Equal(t, float64(2), response["full_count"]) +} + +// TestHandleUpdateApply_NoUpdateAvailable tests update apply when no update is available. +func TestHandleUpdateApply_NoUpdateAvailable(t *testing.T) { + svc, cleanup := testService(t) + defer cleanup() + + req := httptest.NewRequest(http.MethodPost, "/api/update/apply", nil) + rec := httptest.NewRecorder() + + svc.router.ServeHTTP(rec, req) + + assert.Equal(t, http.StatusOK, rec.Code) + + var response map[string]interface{} + err := json.Unmarshal(rec.Body.Bytes(), &response) + require.NoError(t, err) + // Update check may succeed or fail - both are valid behaviors + assert.NotNil(t, response) +} + +// TestHandleGetObservations_WithQuery tests observations with query parameter. +func TestHandleGetObservations_WithQuery(t *testing.T) { + svc, cleanup := testService(t) + defer cleanup() + + project := "obs-query-test" + createTestObservation(t, svc.observationStore, project, "Unique Test Title", "Content about testing", []string{"test"}) + + req := httptest.NewRequest(http.MethodGet, "/api/observations?project="+project+"&query=unique", nil) + rec := httptest.NewRecorder() + + svc.router.ServeHTTP(rec, req) + + assert.Equal(t, http.StatusOK, rec.Code) +} + +// TestHandleGetSummaries_WithQuery tests summaries with query parameter. +func TestHandleGetSummaries_WithQuery(t *testing.T) { + svc, cleanup := testService(t) + defer cleanup() + + req := httptest.NewRequest(http.MethodGet, "/api/summaries?project=test-project&query=test", nil) + rec := httptest.NewRecorder() + + svc.router.ServeHTTP(rec, req) + + assert.Equal(t, http.StatusOK, rec.Code) +} + +// TestHandleGetPrompts_WithQuery tests prompts with query parameter. +func TestHandleGetPrompts_WithQuery(t *testing.T) { + svc, cleanup := testService(t) + defer cleanup() + + req := httptest.NewRequest(http.MethodGet, "/api/prompts?project=test-project&query=test", nil) + rec := httptest.NewRecorder() + + svc.router.ServeHTTP(rec, req) + + assert.Equal(t, http.StatusOK, rec.Code) +} + +// TestHandleSearchByPrompt_SpecialCharsQuery tests search with special characters to trigger FTS fallback. +func TestHandleSearchByPrompt_SpecialCharsQuery(t *testing.T) { + svc, cleanup := testService(t) + defer cleanup() + + project := "special-chars-test" + createTestObservation(t, svc.observationStore, project, "Test", "Content", []string{"test"}) + + // Use special characters that might cause FTS issues + req := httptest.NewRequest(http.MethodGet, "/api/context/search?project="+project+"&query=test()+special*", nil) + rec := httptest.NewRecorder() + + svc.router.ServeHTTP(rec, req) + + // Should still return OK (falls back to recent observations) + assert.Equal(t, http.StatusOK, rec.Code) +} + +// TestHandleHealth tests the health endpoint. +func TestHandleHealth(t *testing.T) { + svc, cleanup := testService(t) + defer cleanup() + + req := httptest.NewRequest(http.MethodGet, "/api/health", nil) + rec := httptest.NewRecorder() + + svc.router.ServeHTTP(rec, req) + + assert.Equal(t, http.StatusOK, rec.Code) + + var response map[string]interface{} + err := json.Unmarshal(rec.Body.Bytes(), &response) + require.NoError(t, err) + assert.Equal(t, "ready", response["status"]) +} + +// TestHandleSessionInit_ValidRequest tests session init with valid request. +func TestHandleSessionInit_ValidRequest(t *testing.T) { + svc, cleanup := testService(t) + defer cleanup() + + body := `{"claude_session_id": "init-test-session", "project": "init-test-project", "prompt": "test prompt"}` + req := httptest.NewRequest(http.MethodPost, "/api/sessions/init", bytes.NewBufferString(body)) + req.Header.Set("Content-Type", "application/json") + rec := httptest.NewRecorder() + + svc.router.ServeHTTP(rec, req) + + assert.Equal(t, http.StatusOK, rec.Code) + + var response map[string]interface{} + err := json.Unmarshal(rec.Body.Bytes(), &response) + require.NoError(t, err) + // Response uses camelCase: sessionDbId + assert.Contains(t, response, "sessionDbId") +} + +// TestHandleSummarize tests the summarize endpoint. +func TestHandleSummarize(t *testing.T) { + svc, cleanup := testService(t) + defer cleanup() + + // First create a session with observations + body := `{"claude_session_id": "summarize-test-session", "project": "summarize-test"}` + initReq := httptest.NewRequest(http.MethodPost, "/api/sessions/init", bytes.NewBufferString(body)) + initReq.Header.Set("Content-Type", "application/json") + initRec := httptest.NewRecorder() + svc.router.ServeHTTP(initRec, initReq) + + // Now request summarization - using the session ID path + req := httptest.NewRequest(http.MethodPost, "/api/sessions/summarize-test-session/summarize", nil) + req.Header.Set("Content-Type", "application/json") + rec := httptest.NewRecorder() + + svc.router.ServeHTTP(rec, req) + + // May return OK, 404 (not found), or error depending on endpoint and session state + assert.Contains(t, []int{http.StatusOK, http.StatusNotFound, http.StatusInternalServerError}, rec.Code) +} + +// TestHandleGetObservations_AllProjects tests getting all observations without project filter. +func TestHandleGetObservations_AllProjects(t *testing.T) { + svc, cleanup := testService(t) + defer cleanup() + + // Create observations in different projects + createTestObservation(t, svc.observationStore, "proj-a", "Test A", "Content", []string{"test"}) + createTestObservation(t, svc.observationStore, "proj-b", "Test B", "Content", []string{"test"}) + + // Request without project filter + req := httptest.NewRequest(http.MethodGet, "/api/observations", nil) + rec := httptest.NewRecorder() + + svc.router.ServeHTTP(rec, req) + + assert.Equal(t, http.StatusOK, rec.Code) +} + +// TestHandleGetSummaries_AllProjects tests getting all summaries without project filter. +func TestHandleGetSummaries_AllProjects(t *testing.T) { + svc, cleanup := testService(t) + defer cleanup() + + req := httptest.NewRequest(http.MethodGet, "/api/summaries", nil) + rec := httptest.NewRecorder() + + svc.router.ServeHTTP(rec, req) + + assert.Equal(t, http.StatusOK, rec.Code) +} + +// TestHandleGetPrompts_AllProjects tests getting all prompts without project filter. +func TestHandleGetPrompts_AllProjects(t *testing.T) { + svc, cleanup := testService(t) + defer cleanup() + + req := httptest.NewRequest(http.MethodGet, "/api/prompts", nil) + rec := httptest.NewRecorder() + + svc.router.ServeHTTP(rec, req) + + assert.Equal(t, http.StatusOK, rec.Code) +} + +// TestHandleSessionStart tests the session start endpoint. +func TestHandleSessionStart(t *testing.T) { + svc, cleanup := testService(t) + defer cleanup() + + // First create a session + initBody := `{"claude_session_id": "start-test-session", "project": "start-test"}` + initReq := httptest.NewRequest(http.MethodPost, "/api/sessions/init", bytes.NewBufferString(initBody)) + initReq.Header.Set("Content-Type", "application/json") + initRec := httptest.NewRecorder() + svc.router.ServeHTTP(initRec, initReq) + + // Now start the session + body := `{"sessionId": "start-test-session"}` + req := httptest.NewRequest(http.MethodPost, "/api/sessions/start", bytes.NewBufferString(body)) + req.Header.Set("Content-Type", "application/json") + rec := httptest.NewRecorder() + + svc.router.ServeHTTP(rec, req) + + // May return various status codes depending on session state and endpoint + assert.Contains(t, []int{http.StatusOK, http.StatusBadRequest, http.StatusNotFound, http.StatusInternalServerError}, rec.Code) +} diff --git a/internal/worker/sdk/processor_test.go b/internal/worker/sdk/processor_test.go index de0c2b5..f18e6c5 100644 --- a/internal/worker/sdk/processor_test.go +++ b/internal/worker/sdk/processor_test.go @@ -505,3 +505,472 @@ func TestObservationConcepts(t *testing.T) { } assert.Equal(t, expectedConcepts, ObservationConcepts) } + +// TestProcessorStruct tests processor struct initialization and methods. +func TestProcessorStruct(t *testing.T) { + p := &Processor{ + claudePath: "/path/to/claude", + model: "haiku", + sem: make(chan struct{}, MaxConcurrentCLICalls), + } + + assert.Equal(t, "/path/to/claude", p.claudePath) + assert.Equal(t, "haiku", p.model) + assert.NotNil(t, p.sem) +} + +// TestSetBroadcastFunc tests the broadcast callback setter. +func TestSetBroadcastFunc(t *testing.T) { + p := &Processor{} + + assert.Nil(t, p.broadcastFunc) + + var called bool + var receivedEvent map[string]interface{} + fn := func(event map[string]interface{}) { + called = true + receivedEvent = event + } + + p.SetBroadcastFunc(fn) + assert.NotNil(t, p.broadcastFunc) + + // Test broadcast + p.broadcast(map[string]interface{}{"type": "test"}) + assert.True(t, called) + assert.Equal(t, "test", receivedEvent["type"]) +} + +// TestSetSyncObservationFunc tests the sync observation callback setter. +func TestSetSyncObservationFunc(t *testing.T) { + p := &Processor{} + + assert.Nil(t, p.syncObservationFunc) + + var called bool + fn := func(obs *models.Observation) { + called = true + } + + p.SetSyncObservationFunc(fn) + assert.NotNil(t, p.syncObservationFunc) + + // Verify it was set + p.syncObservationFunc(&models.Observation{}) + assert.True(t, called) +} + +// TestSetSyncSummaryFunc tests the sync summary callback setter. +func TestSetSyncSummaryFunc(t *testing.T) { + p := &Processor{} + + assert.Nil(t, p.syncSummaryFunc) + + var called bool + fn := func(summary *models.SessionSummary) { + called = true + } + + p.SetSyncSummaryFunc(fn) + assert.NotNil(t, p.syncSummaryFunc) + + // Verify it was set + p.syncSummaryFunc(&models.SessionSummary{}) + assert.True(t, called) +} + +// TestBroadcast_NilFunc tests broadcast with nil callback. +func TestBroadcast_NilFunc(t *testing.T) { + p := &Processor{} + + // Should not panic + p.broadcast(map[string]interface{}{"type": "test"}) +} + +// TestIsAvailable_NonexistentPath tests IsAvailable with non-existent path. +func TestIsAvailable_NonexistentPath(t *testing.T) { + p := &Processor{ + claudePath: "/nonexistent/path/to/claude", + } + + assert.False(t, p.IsAvailable()) +} + +// TestIsAvailable_ExistingPath tests IsAvailable with existing path. +func TestIsAvailable_ExistingPath(t *testing.T) { + // Create a temp file to simulate claude binary + tmpFile, err := os.CreateTemp("", "claude-test-*") + if err != nil { + t.Fatal(err) + } + defer os.Remove(tmpFile.Name()) + tmpFile.Close() + + p := &Processor{ + claudePath: tmpFile.Name(), + } + + assert.True(t, p.IsAvailable()) +} + +// TestShouldSkipTrivialOperation_EdgeCases tests edge cases for trivial operation detection. +func TestShouldSkipTrivialOperation_EdgeCases(t *testing.T) { + tests := []struct { + name string + toolName string + inputStr string + outputStr string + expected bool + }{ + { + name: "gitignore_file", + toolName: "Read", + inputStr: `{"file_path": "/project/.gitignore"}`, + outputStr: "This is a gitignore file content that has more than 50 characters long", + expected: true, + }, + { + name: "eslintignore_file", + toolName: "Read", + inputStr: `{"file_path": "/project/.eslintignore"}`, + outputStr: "This is an eslintignore file content that has more than 50 characters long", + expected: true, + }, + { + name: "tsconfig_file", + toolName: "Read", + inputStr: `{"file_path": "/project/tsconfig.json"}`, + outputStr: "This is a tsconfig.json file content that has more than 50 characters long", + expected: true, + }, + { + name: "tailwind_config", + toolName: "Read", + inputStr: `{"file_path": "/project/tailwind.config.js"}`, + outputStr: "This is a tailwind.config file content that has more than 50 characters long", + expected: true, + }, + { + name: "pwd_command", + toolName: "Bash", + inputStr: `{"command": "pwd"}`, + outputStr: "/home/user/project/some/long/path/that/is/more/than/fifty/chars", + expected: true, + }, + { + name: "echo_command", + toolName: "Bash", + inputStr: `{"command": "echo Hello World"}`, + outputStr: "Hello World output that is long enough to pass the length check here", + expected: true, + }, + { + name: "npm_audit_command", + toolName: "Bash", + inputStr: `{"command": "npm audit"}`, + outputStr: "found 0 vulnerabilities in 500 packages which is more than fifty characters", + expected: true, + }, + { + name: "permission_denied", + toolName: "Read", + inputStr: `{"file_path": "/root/secret"}`, + outputStr: "Error: Permission denied accessing the file at specified path", + expected: true, + }, + { + name: "is_a_directory", + toolName: "Read", + inputStr: `{"file_path": "/some/dir"}`, + outputStr: "Error: /some/dir is a directory, not a file that can be read", + expected: true, + }, + { + name: "empty_object", + toolName: "Grep", + inputStr: `{"pattern": "nonexistent"}`, + outputStr: "{}", + expected: true, + }, + { + name: "valid_grep_result", + toolName: "Grep", + inputStr: `{"pattern": "func main"}`, + outputStr: "main.go:10:func main() {\nmain.go:11: fmt.Println(\"Hello\")\n}", + expected: false, + }, + { + name: "valid_bash_build", + toolName: "Bash", + inputStr: `{"command": "go build ./..."}`, + outputStr: "Build completed successfully. Binary output at ./bin/myapp with size 10MB.", + expected: false, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + result := shouldSkipTrivialOperation(tt.toolName, tt.inputStr, tt.outputStr) + assert.Equal(t, tt.expected, result) + }) + } +} + +// TestIsSelfReferentialSummary_MoreCases tests additional self-referential detection cases. +func TestIsSelfReferentialSummary_MoreCases(t *testing.T) { + tests := []struct { + name string + summary *models.ParsedSummary + expected bool + }{ + { + name: "progress_checkpoint", + summary: &models.ParsedSummary{ + Request: "Progress checkpoint for current session", + Completed: "Responding to progress checkpoint request", + Learned: "No technical learnings yet", + }, + expected: true, + }, + { + name: "empty_session", + summary: &models.ParsedSummary{ + Request: "Empty session", + Completed: "Just beginning the session", + Learned: "Nothing has been completed yet", + }, + expected: true, + }, + { + name: "hook_mechanism", + summary: &models.ParsedSummary{ + Request: "Hook execution for session start", + Completed: "Hook mechanism triggered successfully", + Learned: "System hooks are working", + }, + expected: true, + }, + { + name: "api_implementation", + summary: &models.ParsedSummary{ + Request: "Implement REST API endpoints", + Completed: "Created /users and /posts endpoints with CRUD operations", + Learned: "chi router handles middleware chaining elegantly", + NextSteps: "Add authentication middleware", + }, + expected: false, + }, + { + name: "database_migration", + summary: &models.ParsedSummary{ + Request: "Add database migration for new user fields", + Completed: "Created migration 003_add_user_profile.sql with new columns", + Learned: "SQLite ALTER TABLE has limited capabilities, need to recreate table", + NextSteps: "Test migration rollback", + }, + expected: false, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + result := isSelfReferentialSummary(tt.summary) + assert.Equal(t, tt.expected, result) + }) + } +} + +// TestHasMeaningfulContent_MoreCases tests additional meaningful content detection. +func TestHasMeaningfulContent_MoreCases(t *testing.T) { + tests := []struct { + name string + content string + expected bool + }{ + { + name: "code_with_functions", + content: `I've created a new handler function in handlers.go. +The function validateRequest() checks the incoming JSON payload. +Here's the implementation: +` + "```go\nfunc validateRequest(r *http.Request) error {\n\treturn nil\n}\n```", + expected: true, + }, + { + name: "python_code_discussion", + content: `Updated the data processing module in processor.py. +Changed the filter function to use list comprehension. +def process_data(items): + return [item for item in items if item.valid] +This improved performance by 30%.`, + expected: true, + }, + { + name: "typescript_changes", + content: `I've modified the React component in UserProfile.tsx. +Added a new functional component with proper TypeScript type annotations. +Here's the updated implementation: +` + "```tsx\nconst UserProfile: FC = ({ user }) => {\n return
{user.name}
;\n};\n```" + ` +The type annotations ensure type safety across the application. +The component has been updated with proper error handling and loading states.`, + expected: true, + }, + { + name: "yaml_config_update", + content: `I've updated the kubernetes deployment config in deploy.yaml. +Changed replicas from 2 to 4 and added resource limits for memory and CPU. +The deployment.yaml file now includes the following struct configuration: +` + "```yaml\nreplicas: 4\nresources:\n limits:\n memory: 512Mi\n```" + ` +The changes have been implemented and will be applied on next deploy.`, + expected: true, + }, + { + name: "just_system_messages", + content: `SessionStart:Callback hook success +System-reminder about tools +The session is starting +Waiting for user instructions`, + expected: false, + }, + { + name: "borderline_short", + content: "Fixed bug. Updated file. Added test. Committed changes to repository.", + expected: false, // Too short (< 200 chars) + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + result := hasMeaningfulContent(tt.content) + assert.Equal(t, tt.expected, result) + }) + } +} + +// TestToJSONString_ComplexTypes tests JSON conversion for complex types. +func TestToJSONString_ComplexTypes(t *testing.T) { + tests := []struct { + name string + input interface{} + contains string + }{ + { + name: "nested_map", + input: map[string]interface{}{ + "outer": map[string]string{"inner": "value"}, + }, + contains: "inner", + }, + { + name: "bool_true", + input: true, + contains: "true", + }, + { + name: "bool_false", + input: false, + contains: "false", + }, + { + name: "float_value", + input: 3.14, + contains: "3.14", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + result := toJSONString(tt.input) + assert.Contains(t, result, tt.contains) + }) + } +} + +// TestSystemPrompt tests that the system prompt is defined. +func TestSystemPrompt(t *testing.T) { + assert.NotEmpty(t, systemPrompt) + assert.Contains(t, systemPrompt, "memory extraction agent") + assert.Contains(t, systemPrompt, "observation") + assert.Contains(t, systemPrompt, "GUIDELINES") +} + +// TestProcessorSemaphore tests the semaphore behavior. +func TestProcessorSemaphore(t *testing.T) { + p := &Processor{ + sem: make(chan struct{}, 2), + } + + // Acquire 2 slots + p.sem <- struct{}{} + p.sem <- struct{}{} + + // Third should block (we can test with select) + select { + case p.sem <- struct{}{}: + t.Error("Semaphore should be full") + default: + // Expected - semaphore is full + } + + // Release one + <-p.sem + + // Now should be able to acquire + select { + case p.sem <- struct{}{}: + // Expected + default: + t.Error("Should be able to acquire after release") + } +} + +// TestCaptureFileMtimes_DuplicatePaths tests mtime capture with overlapping paths. +func TestCaptureFileMtimes_DuplicatePaths(t *testing.T) { + tmpDir, err := os.MkdirTemp("", "mtime-dup-test-*") + if err != nil { + t.Fatal(err) + } + defer os.RemoveAll(tmpDir) + + testFile := filepath.Join(tmpDir, "shared.txt") + err = os.WriteFile(testFile, []byte("content"), 0644) + if err != nil { + t.Fatal(err) + } + + // Same file in both read and modified lists + mtimes := captureFileMtimes([]string{testFile}, []string{testFile}, "") + + // Should only have one entry (no duplicates) + assert.Len(t, mtimes, 1) + assert.Contains(t, mtimes, testFile) +} + +// TestTruncateForLog_ZeroLength tests truncation with zero length. +func TestTruncateForLog_ZeroLength(t *testing.T) { + result := truncateForLog("hello", 0) + assert.Equal(t, "...", result) +} + +// TestBroadcastFuncType tests the BroadcastFunc type. +func TestBroadcastFuncType(t *testing.T) { + var fn BroadcastFunc = func(event map[string]interface{}) { + // Do nothing + } + assert.NotNil(t, fn) +} + +// TestSyncObservationFuncType tests the SyncObservationFunc type. +func TestSyncObservationFuncType(t *testing.T) { + var fn SyncObservationFunc = func(obs *models.Observation) { + // Do nothing + } + assert.NotNil(t, fn) +} + +// TestSyncSummaryFuncType tests the SyncSummaryFunc type. +func TestSyncSummaryFuncType(t *testing.T) { + var fn SyncSummaryFunc = func(summary *models.SessionSummary) { + // Do nothing + } + assert.NotNil(t, fn) +} diff --git a/internal/worker/session/integration_test.go b/internal/worker/session/integration_test.go new file mode 100644 index 0000000..6f5a2b8 --- /dev/null +++ b/internal/worker/session/integration_test.go @@ -0,0 +1,615 @@ +// Package session provides session lifecycle management for claude-mnemonic. +package session + +import ( + "context" + "os" + "testing" + "time" + + "github.com/lukaszraczylo/claude-mnemonic/internal/db/sqlite" + "github.com/stretchr/testify/require" + "github.com/stretchr/testify/suite" + + // Import sqlite driver + _ "github.com/mattn/go-sqlite3" +) + +// hasFTS5 checks if FTS5 is available in the SQLite build. +func hasFTS5(t *testing.T) bool { + t.Helper() + + tmpDir, err := os.MkdirTemp("", "fts5-check-*") + if err != nil { + return false + } + defer func() { _ = os.RemoveAll(tmpDir) }() + + store, err := sqlite.NewStore(sqlite.StoreConfig{ + Path: tmpDir + "/check.db", + MaxConns: 1, + WALMode: true, + }) + if err != nil { + return false + } + _ = store.Close() + return true +} + +// testStore creates a sqlite.Store with a temporary database for testing. +func testStore(t *testing.T) (*sqlite.Store, func()) { + t.Helper() + + if !hasFTS5(t) { + t.Skip("FTS5 not available in this SQLite build") + } + + tmpDir, err := os.MkdirTemp("", "session-integration-test-*") + require.NoError(t, err) + + dbPath := tmpDir + "/test.db" + + store, err := sqlite.NewStore(sqlite.StoreConfig{ + Path: dbPath, + MaxConns: 1, + WALMode: true, + }) + require.NoError(t, err) + + cleanup := func() { + _ = store.Close() + _ = os.RemoveAll(tmpDir) + } + + return store, cleanup +} + +// SessionIntegrationSuite tests session manager with real SQLite stores. +type SessionIntegrationSuite struct { + suite.Suite + store *sqlite.Store + sessionStore *sqlite.SessionStore + cleanup func() + manager *Manager +} + +func (s *SessionIntegrationSuite) SetupTest() { + if !hasFTS5(s.T()) { + s.T().Skip("FTS5 not available in this SQLite build") + } + + s.store, s.cleanup = testStore(s.T()) + s.sessionStore = sqlite.NewSessionStore(s.store) + s.manager = NewManager(s.sessionStore) +} + +func (s *SessionIntegrationSuite) TearDownTest() { + if s.manager != nil { + s.manager.ShutdownAll(context.Background()) + } + if s.cleanup != nil { + s.cleanup() + } +} + +func TestSessionIntegrationSuite(t *testing.T) { + suite.Run(t, new(SessionIntegrationSuite)) +} + +// TestNewManager_WithRealStore tests manager creation with real store. +func (s *SessionIntegrationSuite) TestNewManager_WithRealStore() { + s.NotNil(s.manager) + s.NotNil(s.manager.sessionStore) + s.NotNil(s.manager.sessions) + s.NotNil(s.manager.ProcessNotify) + s.Equal(0, s.manager.GetActiveSessionCount()) +} + +// TestInitializeSession_WithRealStore tests session initialization. +func (s *SessionIntegrationSuite) TestInitializeSession_WithRealStore() { + ctx := context.Background() + + // Create a session in the database first + sessionID, err := s.sessionStore.CreateSDKSession(ctx, "claude-test-123", "test-project", "initial prompt") + s.Require().NoError(err) + s.Require().Greater(sessionID, int64(0)) + + // Initialize in manager + session, err := s.manager.InitializeSession(ctx, sessionID, "user prompt", 1) + s.Require().NoError(err) + s.Require().NotNil(session) + + // Verify session properties + s.Equal(sessionID, session.SessionDBID) + s.Equal("claude-test-123", session.ClaudeSessionID) + s.Equal("test-project", session.Project) + s.Equal("user prompt", session.UserPrompt) + s.Equal(1, session.LastPromptNumber) + + // Verify manager state + s.Equal(1, s.manager.GetActiveSessionCount()) +} + +// TestInitializeSession_ReuseExisting tests that existing sessions are reused. +func (s *SessionIntegrationSuite) TestInitializeSession_ReuseExisting() { + ctx := context.Background() + + // Create session in database + sessionID, err := s.sessionStore.CreateSDKSession(ctx, "claude-reuse-123", "test-project", "prompt") + s.Require().NoError(err) + + // Initialize first time + session1, err := s.manager.InitializeSession(ctx, sessionID, "prompt 1", 1) + s.Require().NoError(err) + s.Require().NotNil(session1) + + // Initialize second time - should reuse + session2, err := s.manager.InitializeSession(ctx, sessionID, "prompt 2", 2) + s.Require().NoError(err) + s.Require().NotNil(session2) + + // Should be the same session pointer + s.Same(session1, session2) + + // Should have updated user prompt + s.Equal("prompt 2", session2.UserPrompt) + s.Equal(2, session2.LastPromptNumber) + + // Still only 1 active session + s.Equal(1, s.manager.GetActiveSessionCount()) +} + +// TestInitializeSession_NonExistentSession tests initializing non-existent session. +func (s *SessionIntegrationSuite) TestInitializeSession_NonExistentSession() { + ctx := context.Background() + + // Try to initialize non-existent session + session, err := s.manager.InitializeSession(ctx, 999999, "prompt", 1) + s.NoError(err) // No error, just nil session + s.Nil(session) + + s.Equal(0, s.manager.GetActiveSessionCount()) +} + +// TestInitializeSession_EmptyUserPrompt tests initialization with empty user prompt. +func (s *SessionIntegrationSuite) TestInitializeSession_EmptyUserPrompt() { + ctx := context.Background() + + // Create session with initial prompt in database + sessionID, err := s.sessionStore.CreateSDKSession(ctx, "claude-empty-prompt", "test-project", "db prompt") + s.Require().NoError(err) + + // Initialize with empty user prompt - should use database prompt + session, err := s.manager.InitializeSession(ctx, sessionID, "", 0) + s.Require().NoError(err) + s.Require().NotNil(session) + + // Should use database prompt + s.Equal("db prompt", session.UserPrompt) +} + +// TestQueueObservation_WithRealStore tests observation queuing. +func (s *SessionIntegrationSuite) TestQueueObservation_WithRealStore() { + ctx := context.Background() + + // Create and initialize session + sessionID, err := s.sessionStore.CreateSDKSession(ctx, "claude-queue-obs", "test-project", "prompt") + s.Require().NoError(err) + + _, err = s.manager.InitializeSession(ctx, sessionID, "prompt", 1) + s.Require().NoError(err) + + // Queue an observation + err = s.manager.QueueObservation(ctx, sessionID, ObservationData{ + ToolName: "Read", + ToolInput: map[string]string{"path": "/test.go"}, + ToolResponse: "file content", + PromptNumber: 1, + CWD: "/project", + }) + s.Require().NoError(err) + + // Check queue depth + s.Equal(1, s.manager.GetTotalQueueDepth()) + s.True(s.manager.IsAnySessionProcessing()) + + // Drain messages + messages := s.manager.DrainMessages(sessionID) + s.Len(messages, 1) + s.Equal(MessageTypeObservation, messages[0].Type) + s.Equal("Read", messages[0].Observation.ToolName) +} + +// TestQueueObservation_AutoInitialize tests auto-initialization on queue. +func (s *SessionIntegrationSuite) TestQueueObservation_AutoInitialize() { + ctx := context.Background() + + // Create session in database but don't initialize in manager + sessionID, err := s.sessionStore.CreateSDKSession(ctx, "claude-auto-init", "test-project", "prompt") + s.Require().NoError(err) + + // Queue observation without explicit initialization + err = s.manager.QueueObservation(ctx, sessionID, ObservationData{ + ToolName: "Write", + ToolInput: "test input", + ToolResponse: "success", + PromptNumber: 1, + }) + s.Require().NoError(err) + + // Session should be auto-initialized + s.Equal(1, s.manager.GetActiveSessionCount()) + s.Equal(1, s.manager.GetTotalQueueDepth()) +} + +// TestQueueObservation_NonExistentSession tests queuing to non-existent session. +func (s *SessionIntegrationSuite) TestQueueObservation_NonExistentSession() { + ctx := context.Background() + + // Try to queue to non-existent session + err := s.manager.QueueObservation(ctx, 999999, ObservationData{ + ToolName: "Test", + }) + + // Should not error, but session won't be created + s.NoError(err) + s.Equal(0, s.manager.GetActiveSessionCount()) +} + +// TestQueueSummarize_WithRealStore tests summarize queuing. +func (s *SessionIntegrationSuite) TestQueueSummarize_WithRealStore() { + ctx := context.Background() + + // Create and initialize session + sessionID, err := s.sessionStore.CreateSDKSession(ctx, "claude-queue-sum", "test-project", "prompt") + s.Require().NoError(err) + + _, err = s.manager.InitializeSession(ctx, sessionID, "prompt", 1) + s.Require().NoError(err) + + // Queue a summarize request + err = s.manager.QueueSummarize(ctx, sessionID, "What did you do?", "I completed the task.") + s.Require().NoError(err) + + // Check queue depth + s.Equal(1, s.manager.GetTotalQueueDepth()) + s.True(s.manager.IsAnySessionProcessing()) + + // Drain messages + messages := s.manager.DrainMessages(sessionID) + s.Len(messages, 1) + s.Equal(MessageTypeSummarize, messages[0].Type) + s.Equal("What did you do?", messages[0].Summarize.LastUserMessage) + s.Equal("I completed the task.", messages[0].Summarize.LastAssistantMessage) +} + +// TestQueueSummarize_AutoInitialize tests auto-initialization on summarize queue. +func (s *SessionIntegrationSuite) TestQueueSummarize_AutoInitialize() { + ctx := context.Background() + + // Create session in database but don't initialize in manager + sessionID, err := s.sessionStore.CreateSDKSession(ctx, "claude-sum-auto", "test-project", "prompt") + s.Require().NoError(err) + + // Queue summarize without explicit initialization + err = s.manager.QueueSummarize(ctx, sessionID, "user msg", "assistant msg") + s.Require().NoError(err) + + // Session should be auto-initialized + s.Equal(1, s.manager.GetActiveSessionCount()) + s.Equal(1, s.manager.GetTotalQueueDepth()) +} + +// TestQueueSummarize_NonExistentSession tests summarize queuing to non-existent session. +func (s *SessionIntegrationSuite) TestQueueSummarize_NonExistentSession() { + ctx := context.Background() + + // Try to queue to non-existent session + err := s.manager.QueueSummarize(ctx, 999999, "user", "assistant") + + // Should not error, but session won't be created + s.NoError(err) + s.Equal(0, s.manager.GetActiveSessionCount()) +} + +// TestMixedQueueOperations tests mixed observation and summarize queuing. +func (s *SessionIntegrationSuite) TestMixedQueueOperations() { + ctx := context.Background() + + // Create and initialize session + sessionID, err := s.sessionStore.CreateSDKSession(ctx, "claude-mixed", "test-project", "prompt") + s.Require().NoError(err) + + _, err = s.manager.InitializeSession(ctx, sessionID, "prompt", 1) + s.Require().NoError(err) + + // Queue multiple messages of different types + err = s.manager.QueueObservation(ctx, sessionID, ObservationData{ToolName: "Tool1"}) + s.Require().NoError(err) + + err = s.manager.QueueSummarize(ctx, sessionID, "user1", "assistant1") + s.Require().NoError(err) + + err = s.manager.QueueObservation(ctx, sessionID, ObservationData{ToolName: "Tool2"}) + s.Require().NoError(err) + + // Check total queue depth + s.Equal(3, s.manager.GetTotalQueueDepth()) + + // Drain and verify order + messages := s.manager.DrainMessages(sessionID) + s.Len(messages, 3) + s.Equal(MessageTypeObservation, messages[0].Type) + s.Equal("Tool1", messages[0].Observation.ToolName) + s.Equal(MessageTypeSummarize, messages[1].Type) + s.Equal(MessageTypeObservation, messages[2].Type) + s.Equal("Tool2", messages[2].Observation.ToolName) +} + +// TestProcessNotifyChannel tests the process notification channel behavior. +func (s *SessionIntegrationSuite) TestProcessNotifyChannel() { + ctx := context.Background() + + // Create and initialize session + sessionID, err := s.sessionStore.CreateSDKSession(ctx, "claude-notify", "test-project", "prompt") + s.Require().NoError(err) + + _, err = s.manager.InitializeSession(ctx, sessionID, "prompt", 1) + s.Require().NoError(err) + + // Drain any existing notifications + select { + case <-s.manager.ProcessNotify: + default: + } + + // Queue observation - should trigger notification + err = s.manager.QueueObservation(ctx, sessionID, ObservationData{ToolName: "Test"}) + s.Require().NoError(err) + + // Should be able to receive notification + select { + case <-s.manager.ProcessNotify: + // Success + case <-time.After(100 * time.Millisecond): + s.Fail("Should have received process notification") + } +} + +// TestSessionCallbacks tests session lifecycle callbacks. +func (s *SessionIntegrationSuite) TestSessionCallbacks() { + ctx := context.Background() + + var createdID, deletedID int64 + + s.manager.SetOnSessionCreated(func(id int64) { + createdID = id + }) + s.manager.SetOnSessionDeleted(func(id int64) { + deletedID = id + }) + + // Create session + sessionID, err := s.sessionStore.CreateSDKSession(ctx, "claude-callbacks", "test-project", "prompt") + s.Require().NoError(err) + + // Initialize - should trigger created callback + _, err = s.manager.InitializeSession(ctx, sessionID, "prompt", 1) + s.Require().NoError(err) + + s.Equal(sessionID, createdID) + + // Delete - should trigger deleted callback + s.manager.DeleteSession(sessionID) + + s.Equal(sessionID, deletedID) +} + +// TestMultipleSessions tests managing multiple sessions. +func (s *SessionIntegrationSuite) TestMultipleSessions() { + ctx := context.Background() + + // Create multiple sessions + var sessionIDs []int64 + for i := 0; i < 5; i++ { + id, err := s.sessionStore.CreateSDKSession(ctx, "claude-multi-"+string(rune('A'+i)), "project-"+string(rune('a'+i)), "prompt") + s.Require().NoError(err) + sessionIDs = append(sessionIDs, id) + } + + // Initialize all + for _, id := range sessionIDs { + _, err := s.manager.InitializeSession(ctx, id, "prompt", 1) + s.Require().NoError(err) + } + + s.Equal(5, s.manager.GetActiveSessionCount()) + + // Queue observations to each + for i, id := range sessionIDs { + err := s.manager.QueueObservation(ctx, id, ObservationData{ + ToolName: "Tool" + string(rune('A'+i)), + }) + s.Require().NoError(err) + } + + s.Equal(5, s.manager.GetTotalQueueDepth()) + + // Get all sessions + sessions := s.manager.GetAllSessions() + s.Len(sessions, 5) + + // Delete all + for _, id := range sessionIDs { + s.manager.DeleteSession(id) + } + + s.Equal(0, s.manager.GetActiveSessionCount()) +} + +// TestCleanupStaleSessions tests the cleanup of stale sessions. +func (s *SessionIntegrationSuite) TestCleanupStaleSessions() { + ctx := context.Background() + + // Create and initialize a session + sessionID, err := s.sessionStore.CreateSDKSession(ctx, "claude-stale", "test-project", "prompt") + s.Require().NoError(err) + + session, err := s.manager.InitializeSession(ctx, sessionID, "prompt", 1) + s.Require().NoError(err) + + // Manually set start time to past (simulate stale session) + session.StartTime = time.Now().Add(-SessionTimeout - time.Minute) + + // Run cleanup + s.manager.cleanupStaleSessions() + + // Session should be deleted + s.Equal(0, s.manager.GetActiveSessionCount()) +} + +// TestCleanupStaleSessions_WithPendingMessages tests cleanup doesn't delete sessions with pending messages. +func (s *SessionIntegrationSuite) TestCleanupStaleSessions_WithPendingMessages() { + ctx := context.Background() + + // Create and initialize a session + sessionID, err := s.sessionStore.CreateSDKSession(ctx, "claude-stale-pending", "test-project", "prompt") + s.Require().NoError(err) + + session, err := s.manager.InitializeSession(ctx, sessionID, "prompt", 1) + s.Require().NoError(err) + + // Make session stale but add pending messages + session.StartTime = time.Now().Add(-SessionTimeout - time.Minute) + err = s.manager.QueueObservation(ctx, sessionID, ObservationData{ToolName: "Test"}) + s.Require().NoError(err) + + // Run cleanup + s.manager.cleanupStaleSessions() + + // Session should NOT be deleted (has pending messages) + s.Equal(1, s.manager.GetActiveSessionCount()) +} + +// TestCleanupStaleSessions_WithActiveGenerator tests cleanup doesn't delete sessions with active generator. +func (s *SessionIntegrationSuite) TestCleanupStaleSessions_WithActiveGenerator() { + ctx := context.Background() + + // Create and initialize a session + sessionID, err := s.sessionStore.CreateSDKSession(ctx, "claude-stale-gen", "test-project", "prompt") + s.Require().NoError(err) + + session, err := s.manager.InitializeSession(ctx, sessionID, "prompt", 1) + s.Require().NoError(err) + + // Make session stale but mark generator as active + session.StartTime = time.Now().Add(-SessionTimeout - time.Minute) + session.generatorActive.Store(true) + + // Run cleanup + s.manager.cleanupStaleSessions() + + // Session should NOT be deleted (generator is active) + s.Equal(1, s.manager.GetActiveSessionCount()) +} + +// TestConcurrentQueueOperations tests thread-safe queue operations. +func (s *SessionIntegrationSuite) TestConcurrentQueueOperations() { + ctx := context.Background() + + // Create and initialize session + sessionID, err := s.sessionStore.CreateSDKSession(ctx, "claude-concurrent", "test-project", "prompt") + s.Require().NoError(err) + + _, err = s.manager.InitializeSession(ctx, sessionID, "prompt", 1) + s.Require().NoError(err) + + // Concurrent queue operations + done := make(chan bool) + numGoroutines := 50 + + for i := 0; i < numGoroutines; i++ { + go func(idx int) { + if idx%2 == 0 { + _ = s.manager.QueueObservation(ctx, sessionID, ObservationData{ + ToolName: "Tool", + }) + } else { + _ = s.manager.QueueSummarize(ctx, sessionID, "user", "assistant") + } + done <- true + }(i) + } + + // Wait for all goroutines + for i := 0; i < numGoroutines; i++ { + <-done + } + + // All messages should be queued + s.Equal(numGoroutines, s.manager.GetTotalQueueDepth()) +} + +// TestShutdownAll_WithRealSessions tests shutdown of all real sessions. +func (s *SessionIntegrationSuite) TestShutdownAll_WithRealSessions() { + ctx := context.Background() + + // Create and initialize multiple sessions + for i := 0; i < 3; i++ { + sessionID, err := s.sessionStore.CreateSDKSession(ctx, "claude-shutdown-"+string(rune('A'+i)), "project", "prompt") + s.Require().NoError(err) + + _, err = s.manager.InitializeSession(ctx, sessionID, "prompt", 1) + s.Require().NoError(err) + } + + s.Equal(3, s.manager.GetActiveSessionCount()) + + // Shutdown all + s.manager.ShutdownAll(ctx) + + // All sessions should be deleted + s.Equal(0, s.manager.GetActiveSessionCount()) +} + +// TestSessionSDKSessionID tests SDK session ID handling. +func (s *SessionIntegrationSuite) TestSessionSDKSessionID() { + ctx := context.Background() + + // Create session - SDK session ID is generated + sessionID, err := s.sessionStore.CreateSDKSession(ctx, "claude-sdk-test", "test-project", "prompt") + s.Require().NoError(err) + + // Initialize in manager + session, err := s.manager.InitializeSession(ctx, sessionID, "prompt", 1) + s.Require().NoError(err) + s.Require().NotNil(session) + + // SDK session ID should be set + s.NotEmpty(session.SDKSessionID) +} + +// TestPromptNumberTracking tests prompt number tracking across operations. +func (s *SessionIntegrationSuite) TestPromptNumberTracking() { + ctx := context.Background() + + // Create session + sessionID, err := s.sessionStore.CreateSDKSession(ctx, "claude-prompt-num", "test-project", "initial") + s.Require().NoError(err) + + // Initialize with prompt 1 + session, err := s.manager.InitializeSession(ctx, sessionID, "prompt 1", 1) + s.Require().NoError(err) + s.Equal(1, session.LastPromptNumber) + + // Re-initialize with prompt 2 + session, err = s.manager.InitializeSession(ctx, sessionID, "prompt 2", 2) + s.Require().NoError(err) + s.Equal(2, session.LastPromptNumber) + + // Re-initialize with prompt 5 + session, err = s.manager.InitializeSession(ctx, sessionID, "prompt 5", 5) + s.Require().NoError(err) + s.Equal(5, session.LastPromptNumber) +} diff --git a/internal/worker/sse/broadcaster_test.go b/internal/worker/sse/broadcaster_test.go index 9b69459..45776f2 100644 --- a/internal/worker/sse/broadcaster_test.go +++ b/internal/worker/sse/broadcaster_test.go @@ -381,3 +381,165 @@ func TestBroadcasterConcurrentAddRemove(t *testing.T) { count := b.ClientCount() assert.GreaterOrEqual(t, count, 0) } + +// TestRemoveClientByID tests removing a client by ID. +func TestRemoveClientByID(t *testing.T) { + b := NewBroadcaster() + w := newMockResponseWriter() + + client, err := b.AddClient(w) + require.NoError(t, err) + assert.Equal(t, 1, b.ClientCount()) + + // Remove by ID + b.removeClientByID(client.ID) + assert.Equal(t, 0, b.ClientCount()) + + // Done channel should be closed + select { + case <-client.Done: + // Expected + default: + t.Error("Done channel should be closed") + } +} + +// TestRemoveClientByID_NonExistent tests removing a non-existent client by ID. +func TestRemoveClientByID_NonExistent(t *testing.T) { + b := NewBroadcaster() + + // Should not panic + b.removeClientByID("non-existent-id") + assert.Equal(t, 0, b.ClientCount()) +} + +// TestRemoveClientByID_AlreadyClosed tests removing a client with already closed Done channel. +func TestRemoveClientByID_AlreadyClosed(t *testing.T) { + b := NewBroadcaster() + w := newMockResponseWriter() + + client, err := b.AddClient(w) + require.NoError(t, err) + + // Pre-close the Done channel + close(client.Done) + + // Should not panic when trying to close again + b.removeClientByID(client.ID) + assert.Equal(t, 0, b.ClientCount()) +} + +// TestHandleSSE_NonFlusher tests HandleSSE with a non-flusher response writer. +func TestHandleSSE_NonFlusher(t *testing.T) { + b := NewBroadcaster() + + // Create a response writer that doesn't implement Flusher + nonFlusher := &nonFlusherWriter{header: make(http.Header)} + + req := httptest.NewRequest(http.MethodGet, "/events", nil) + + // Should return immediately since writer isn't a Flusher + b.HandleSSE(nonFlusher, req) + + // No clients should be added + assert.Equal(t, 0, b.ClientCount()) +} + +// nonFlusherWriter is a response writer that doesn't implement http.Flusher. +type nonFlusherWriter struct { + header http.Header +} + +func (w *nonFlusherWriter) Header() http.Header { return w.header } +func (w *nonFlusherWriter) Write(data []byte) (int, error) { return len(data), nil } +func (w *nonFlusherWriter) WriteHeader(statusCode int) {} + +// TestWriteToClient_Timeout tests write timeout behavior. +func TestWriteToClient_Timeout(t *testing.T) { + b := NewBroadcaster() + + // Create a slow writer that blocks + slowWriter := &slowMockWriter{ + header: make(http.Header), + blockFor: 5 * time.Second, // Longer than WriteTimeout + } + + client, err := b.AddClient(slowWriter) + require.NoError(t, err) + + // Create dead client channel + deadCh := make(chan string, 1) + + // Try to write - should timeout + msg := "data: test\n\n" + b.writeToClient(client, msg, deadCh) + + // May report dead client due to timeout + // Check if client was reported as dead (with timeout) + select { + case deadID := <-deadCh: + // Client was reported as dead + assert.Equal(t, client.ID, deadID) + case <-time.After(WriteTimeout + 500*time.Millisecond): + // Timed out waiting - also acceptable + } +} + +// slowMockWriter simulates a slow writer for testing timeouts. +type slowMockWriter struct { + header http.Header + blockFor time.Duration + mu sync.Mutex +} + +func (m *slowMockWriter) Header() http.Header { + return m.header +} + +func (m *slowMockWriter) Write(data []byte) (int, error) { + m.mu.Lock() + defer m.mu.Unlock() + time.Sleep(m.blockFor) + return len(data), nil +} + +func (m *slowMockWriter) WriteHeader(statusCode int) {} + +func (m *slowMockWriter) Flush() {} + +// TestBroadcast_InvalidJSON tests broadcasting un-marshalable data. +func TestBroadcast_InvalidJSON(t *testing.T) { + b := NewBroadcaster() + w := newMockResponseWriter() + _, err := b.AddClient(w) + require.NoError(t, err) + + // channels can't be marshaled to JSON + ch := make(chan int) + b.Broadcast(ch) // Should log error but not panic + + // Give time for async processing + time.Sleep(20 * time.Millisecond) + + // Body should be empty or not contain the channel data + body := string(w.GetBody()) + assert.NotContains(t, body, "chan") +} + +// TestBroadcast_ClientDoneChannel tests broadcasting when client Done is closed. +func TestBroadcast_ClientDoneChannel(t *testing.T) { + b := NewBroadcaster() + w := newMockResponseWriter() + + client, err := b.AddClient(w) + require.NoError(t, err) + + // Close the done channel + close(client.Done) + + // Broadcast should skip this client + b.Broadcast(map[string]string{"type": "test"}) + + // Give time for async processing + time.Sleep(20 * time.Millisecond) +}