mirror of
https://github.com/lukaszraczylo/claude-mnemonic.git
synced 2026-06-05 23:03:55 +00:00
Further improvements to the coverage.
This commit is contained in:
@@ -6,6 +6,7 @@ import (
|
||||
"context"
|
||||
"database/sql"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"strconv"
|
||||
@@ -15,6 +16,7 @@ import (
|
||||
"github.com/go-chi/chi/v5"
|
||||
"github.com/lukaszraczylo/claude-mnemonic/internal/config"
|
||||
"github.com/lukaszraczylo/claude-mnemonic/internal/db/sqlite"
|
||||
"github.com/lukaszraczylo/claude-mnemonic/internal/update"
|
||||
"github.com/lukaszraczylo/claude-mnemonic/internal/worker/session"
|
||||
"github.com/lukaszraczylo/claude-mnemonic/internal/worker/sse"
|
||||
"github.com/lukaszraczylo/claude-mnemonic/pkg/models"
|
||||
@@ -45,6 +47,9 @@ func testService(t *testing.T) (*Service, func()) {
|
||||
// Create router
|
||||
router := chi.NewRouter()
|
||||
|
||||
// Create test updater
|
||||
testUpdater := update.New("test-version", t.TempDir())
|
||||
|
||||
svc := &Service{
|
||||
version: "test-version",
|
||||
config: config.Get(),
|
||||
@@ -55,6 +60,7 @@ func testService(t *testing.T) (*Service, func()) {
|
||||
promptStore: promptStore,
|
||||
sessionManager: sessionManager,
|
||||
sseBroadcaster: sseBroadcaster,
|
||||
updater: testUpdater,
|
||||
router: router,
|
||||
ctx: ctx,
|
||||
cancel: cancel,
|
||||
@@ -2125,3 +2131,781 @@ func TestHandleContextInject_WithQuery(t *testing.T) {
|
||||
require.True(t, ok)
|
||||
assert.GreaterOrEqual(t, len(observations), 1)
|
||||
}
|
||||
|
||||
// TestHandleUpdateCheck tests the update check endpoint.
|
||||
func TestHandleUpdateCheck(t *testing.T) {
|
||||
svc, cleanup := testService(t)
|
||||
defer cleanup()
|
||||
|
||||
req := httptest.NewRequest(http.MethodGet, "/api/update/check", nil)
|
||||
rec := httptest.NewRecorder()
|
||||
|
||||
svc.router.ServeHTTP(rec, req)
|
||||
|
||||
// May return 200 or 500 depending on network availability
|
||||
assert.Contains(t, []int{http.StatusOK, http.StatusInternalServerError}, rec.Code)
|
||||
}
|
||||
|
||||
// TestHandleUpdateStatus tests the update status endpoint.
|
||||
func TestHandleUpdateStatus(t *testing.T) {
|
||||
svc, cleanup := testService(t)
|
||||
defer cleanup()
|
||||
|
||||
req := httptest.NewRequest(http.MethodGet, "/api/update/status", nil)
|
||||
rec := httptest.NewRecorder()
|
||||
|
||||
svc.router.ServeHTTP(rec, req)
|
||||
|
||||
assert.Equal(t, http.StatusOK, rec.Code)
|
||||
|
||||
var status update.UpdateStatus
|
||||
err := json.Unmarshal(rec.Body.Bytes(), &status)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Default state should be "idle"
|
||||
assert.Equal(t, "idle", status.State)
|
||||
}
|
||||
|
||||
// TestHandleUpdateApply tests the update apply endpoint.
|
||||
func TestHandleUpdateApply(t *testing.T) {
|
||||
svc, cleanup := testService(t)
|
||||
defer cleanup()
|
||||
|
||||
req := httptest.NewRequest(http.MethodPost, "/api/update/apply", nil)
|
||||
rec := httptest.NewRecorder()
|
||||
|
||||
svc.router.ServeHTTP(rec, req)
|
||||
|
||||
// Will return 200 with message or 500 if network check fails
|
||||
assert.Contains(t, []int{http.StatusOK, http.StatusInternalServerError}, rec.Code)
|
||||
}
|
||||
|
||||
// TestHandleUpdateRestart_NoUpdate tests restart without applied update.
|
||||
func TestHandleUpdateRestart_NoUpdate(t *testing.T) {
|
||||
svc, cleanup := testService(t)
|
||||
defer cleanup()
|
||||
|
||||
req := httptest.NewRequest(http.MethodPost, "/api/update/restart", nil)
|
||||
rec := httptest.NewRecorder()
|
||||
|
||||
svc.router.ServeHTTP(rec, req)
|
||||
|
||||
// Should return error since no update has been applied
|
||||
assert.Equal(t, http.StatusBadRequest, rec.Code)
|
||||
}
|
||||
|
||||
// TestHandleRestart tests the general restart endpoint.
|
||||
func TestHandleRestart(t *testing.T) {
|
||||
svc, cleanup := testService(t)
|
||||
defer cleanup()
|
||||
|
||||
req := httptest.NewRequest(http.MethodPost, "/api/restart", nil)
|
||||
rec := httptest.NewRecorder()
|
||||
|
||||
svc.router.ServeHTTP(rec, req)
|
||||
|
||||
// Should return OK with restart message
|
||||
assert.Equal(t, http.StatusOK, rec.Code)
|
||||
|
||||
var response map[string]interface{}
|
||||
err := json.Unmarshal(rec.Body.Bytes(), &response)
|
||||
require.NoError(t, err)
|
||||
|
||||
assert.Equal(t, true, response["success"])
|
||||
assert.Equal(t, "Restarting worker...", response["message"])
|
||||
}
|
||||
|
||||
// TestSetInitError tests setting init error.
|
||||
func TestSetInitError(t *testing.T) {
|
||||
svc, cleanup := testService(t)
|
||||
defer cleanup()
|
||||
|
||||
// Initially no error
|
||||
assert.Nil(t, svc.GetInitError())
|
||||
|
||||
// Set an error
|
||||
testErr := assert.AnError
|
||||
svc.setInitError(testErr)
|
||||
|
||||
// Should be set
|
||||
assert.Equal(t, testErr, svc.GetInitError())
|
||||
}
|
||||
|
||||
// TestQueueStaleVerification tests the stale verification queue.
|
||||
func TestQueueStaleVerification(t *testing.T) {
|
||||
svc, cleanup := testService(t)
|
||||
defer cleanup()
|
||||
|
||||
// Queue some verifications - should not panic
|
||||
svc.queueStaleVerification(1, "/test/path")
|
||||
svc.queueStaleVerification(2, "/test/path2")
|
||||
|
||||
// Give the goroutine a moment to start
|
||||
time.Sleep(10 * time.Millisecond)
|
||||
}
|
||||
|
||||
// TestRecordRetrievalStats tests retrieval stats recording.
|
||||
func TestRecordRetrievalStats(t *testing.T) {
|
||||
svc, cleanup := testService(t)
|
||||
defer cleanup()
|
||||
|
||||
// Initially all zeros
|
||||
stats := svc.GetRetrievalStats()
|
||||
assert.Equal(t, int64(0), stats.TotalRequests)
|
||||
|
||||
// Record a search request
|
||||
svc.recordRetrievalStats(10, 1, 0, true)
|
||||
stats = svc.GetRetrievalStats()
|
||||
assert.Equal(t, int64(1), stats.TotalRequests)
|
||||
assert.Equal(t, int64(10), stats.ObservationsServed)
|
||||
assert.Equal(t, int64(1), stats.VerifiedStale)
|
||||
assert.Equal(t, int64(1), stats.SearchRequests)
|
||||
|
||||
// Record a context injection
|
||||
svc.recordRetrievalStats(5, 0, 1, false)
|
||||
stats = svc.GetRetrievalStats()
|
||||
assert.Equal(t, int64(2), stats.TotalRequests)
|
||||
assert.Equal(t, int64(15), stats.ObservationsServed)
|
||||
assert.Equal(t, int64(1), stats.DeletedInvalid)
|
||||
assert.Equal(t, int64(1), stats.ContextInjections)
|
||||
}
|
||||
|
||||
// TestHandleSelfCheck_WithInitError tests self-check with init error.
|
||||
func TestHandleSelfCheck_WithInitError(t *testing.T) {
|
||||
svc, cleanup := testService(t)
|
||||
defer cleanup()
|
||||
|
||||
svc.ready.Store(false)
|
||||
svc.setInitError(assert.AnError)
|
||||
|
||||
req := httptest.NewRequest(http.MethodGet, "/api/self-check", nil)
|
||||
rec := httptest.NewRecorder()
|
||||
|
||||
svc.handleSelfCheck(rec, req)
|
||||
|
||||
assert.Equal(t, http.StatusOK, rec.Code)
|
||||
|
||||
var response SelfCheckResponse
|
||||
err := json.Unmarshal(rec.Body.Bytes(), &response)
|
||||
require.NoError(t, err)
|
||||
|
||||
assert.Equal(t, "unhealthy", response.Overall)
|
||||
}
|
||||
|
||||
// TestHandleHealthEndpoint tests health endpoint via router.
|
||||
func TestHandleHealthEndpoint(t *testing.T) {
|
||||
svc, cleanup := testService(t)
|
||||
defer cleanup()
|
||||
|
||||
req := httptest.NewRequest(http.MethodGet, "/health", nil)
|
||||
rec := httptest.NewRecorder()
|
||||
|
||||
svc.router.ServeHTTP(rec, req)
|
||||
|
||||
assert.Equal(t, http.StatusOK, rec.Code)
|
||||
}
|
||||
|
||||
// TestHandleSelfCheckEndpoint tests self-check endpoint via router.
|
||||
func TestHandleSelfCheckEndpoint(t *testing.T) {
|
||||
svc, cleanup := testService(t)
|
||||
defer cleanup()
|
||||
|
||||
req := httptest.NewRequest(http.MethodGet, "/api/selfcheck", nil)
|
||||
rec := httptest.NewRecorder()
|
||||
|
||||
svc.router.ServeHTTP(rec, req)
|
||||
|
||||
assert.Equal(t, http.StatusOK, rec.Code)
|
||||
}
|
||||
|
||||
// TestUpdateInfo_Fields tests UpdateInfo struct.
|
||||
func TestUpdateInfo_Fields(t *testing.T) {
|
||||
info := update.UpdateInfo{
|
||||
Available: true,
|
||||
CurrentVersion: "v1.0.0",
|
||||
LatestVersion: "v2.0.0",
|
||||
ReleaseNotes: "Bug fixes",
|
||||
}
|
||||
|
||||
assert.True(t, info.Available)
|
||||
assert.Equal(t, "v1.0.0", info.CurrentVersion)
|
||||
assert.Equal(t, "v2.0.0", info.LatestVersion)
|
||||
assert.Equal(t, "Bug fixes", info.ReleaseNotes)
|
||||
}
|
||||
|
||||
// TestUpdateStatus_Fields tests UpdateStatus struct.
|
||||
func TestUpdateStatus_Fields(t *testing.T) {
|
||||
status := update.UpdateStatus{
|
||||
State: "downloading",
|
||||
Progress: 50.0,
|
||||
Message: "Downloading update...",
|
||||
}
|
||||
|
||||
assert.Equal(t, "downloading", status.State)
|
||||
assert.Equal(t, 50.0, status.Progress)
|
||||
assert.Equal(t, "Downloading update...", status.Message)
|
||||
}
|
||||
|
||||
// TestHandleObservation_MissingFields tests observation with partial fields.
|
||||
// The handler accepts requests even with missing fields since it creates sessions on-the-fly.
|
||||
func TestHandleObservation_MissingFields(t *testing.T) {
|
||||
svc, cleanup := testService(t)
|
||||
defer cleanup()
|
||||
|
||||
body := `{"project": "test"}`
|
||||
req := httptest.NewRequest(http.MethodPost, "/api/sessions/observations", bytes.NewBufferString(body))
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
rec := httptest.NewRecorder()
|
||||
|
||||
svc.router.ServeHTTP(rec, req)
|
||||
|
||||
// Handler accepts requests with partial fields and creates sessions on-the-fly
|
||||
assert.Equal(t, http.StatusOK, rec.Code)
|
||||
}
|
||||
|
||||
// TestHandleGetSummaries_WithProject tests getting summaries filtered by project.
|
||||
func TestHandleGetSummaries_WithProject(t *testing.T) {
|
||||
svc, cleanup := testService(t)
|
||||
defer cleanup()
|
||||
|
||||
req := httptest.NewRequest(http.MethodGet, "/api/summaries?project=test-project&limit=10", nil)
|
||||
rec := httptest.NewRecorder()
|
||||
|
||||
svc.router.ServeHTTP(rec, req)
|
||||
|
||||
assert.Equal(t, http.StatusOK, rec.Code)
|
||||
}
|
||||
|
||||
// TestHandleGetPrompts_WithProject tests getting prompts filtered by project.
|
||||
func TestHandleGetPrompts_WithProject(t *testing.T) {
|
||||
svc, cleanup := testService(t)
|
||||
defer cleanup()
|
||||
|
||||
req := httptest.NewRequest(http.MethodGet, "/api/prompts?project=test-project&limit=10", nil)
|
||||
rec := httptest.NewRecorder()
|
||||
|
||||
svc.router.ServeHTTP(rec, req)
|
||||
|
||||
assert.Equal(t, http.StatusOK, rec.Code)
|
||||
}
|
||||
|
||||
// TestHandleSearchByPrompt_WithCWD tests search with cwd parameter.
|
||||
func TestHandleSearchByPrompt_WithCWD(t *testing.T) {
|
||||
svc, cleanup := testService(t)
|
||||
defer cleanup()
|
||||
|
||||
project := "search-cwd-test"
|
||||
createTestObservation(t, svc.observationStore, project, "Test observation", "Content here", []string{"test"})
|
||||
|
||||
req := httptest.NewRequest(http.MethodGet, "/api/context/search?project="+project+"&query=test&cwd=/tmp", nil)
|
||||
rec := httptest.NewRecorder()
|
||||
|
||||
svc.router.ServeHTTP(rec, req)
|
||||
|
||||
assert.Equal(t, http.StatusOK, rec.Code)
|
||||
}
|
||||
|
||||
// TestHandleContextInject_WithLimitAndCWD tests context inject with both limit and cwd.
|
||||
func TestHandleContextInject_WithLimitAndCWD(t *testing.T) {
|
||||
svc, cleanup := testService(t)
|
||||
defer cleanup()
|
||||
|
||||
project := "inject-cwd-test"
|
||||
createTestObservation(t, svc.observationStore, project, "Test observation", "Content here", []string{"test"})
|
||||
|
||||
req := httptest.NewRequest(http.MethodGet, "/api/context/inject?project="+project+"&limit=5&cwd=/tmp", nil)
|
||||
rec := httptest.NewRecorder()
|
||||
|
||||
svc.router.ServeHTTP(rec, req)
|
||||
|
||||
assert.Equal(t, http.StatusOK, rec.Code)
|
||||
}
|
||||
|
||||
// TestHandleSelfCheck_AllComponents tests self-check response structure.
|
||||
func TestHandleSelfCheck_AllComponents(t *testing.T) {
|
||||
svc, cleanup := testService(t)
|
||||
defer cleanup()
|
||||
|
||||
req := httptest.NewRequest(http.MethodGet, "/api/selfcheck", nil)
|
||||
rec := httptest.NewRecorder()
|
||||
|
||||
svc.router.ServeHTTP(rec, req)
|
||||
|
||||
assert.Equal(t, http.StatusOK, rec.Code)
|
||||
|
||||
var response SelfCheckResponse
|
||||
err := json.Unmarshal(rec.Body.Bytes(), &response)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Verify all expected components are present
|
||||
componentNames := make(map[string]bool)
|
||||
for _, comp := range response.Components {
|
||||
componentNames[comp.Name] = true
|
||||
}
|
||||
|
||||
expectedComponents := []string{"Worker Service", "SQLite Database"}
|
||||
for _, name := range expectedComponents {
|
||||
assert.True(t, componentNames[name], "missing component: %s", name)
|
||||
}
|
||||
}
|
||||
|
||||
// TestHandleReady_NotReady tests ready endpoint when not ready.
|
||||
func TestHandleReady_NotReady(t *testing.T) {
|
||||
svc, cleanup := testService(t)
|
||||
defer cleanup()
|
||||
|
||||
svc.ready.Store(false)
|
||||
|
||||
req := httptest.NewRequest(http.MethodGet, "/api/ready", nil)
|
||||
rec := httptest.NewRecorder()
|
||||
|
||||
svc.router.ServeHTTP(rec, req)
|
||||
|
||||
assert.Equal(t, http.StatusServiceUnavailable, rec.Code)
|
||||
}
|
||||
|
||||
// TestHandleReady_Ready tests ready endpoint when ready.
|
||||
func TestHandleReady_Ready(t *testing.T) {
|
||||
svc, cleanup := testService(t)
|
||||
defer cleanup()
|
||||
|
||||
req := httptest.NewRequest(http.MethodGet, "/api/ready", nil)
|
||||
rec := httptest.NewRecorder()
|
||||
|
||||
svc.router.ServeHTTP(rec, req)
|
||||
|
||||
assert.Equal(t, http.StatusOK, rec.Code)
|
||||
}
|
||||
|
||||
// TestRequireReady_Middleware tests the requireReady middleware.
|
||||
func TestRequireReady_Middleware(t *testing.T) {
|
||||
svc, cleanup := testService(t)
|
||||
defer cleanup()
|
||||
|
||||
// Service is ready by default in tests
|
||||
req := httptest.NewRequest(http.MethodGet, "/api/observations", nil)
|
||||
rec := httptest.NewRecorder()
|
||||
|
||||
svc.router.ServeHTTP(rec, req)
|
||||
|
||||
// Should get OK because service is ready
|
||||
assert.Equal(t, http.StatusOK, rec.Code)
|
||||
}
|
||||
|
||||
// TestHandleObservation_ValidRequest tests observation with valid request.
|
||||
func TestHandleObservation_ValidRequest(t *testing.T) {
|
||||
svc, cleanup := testService(t)
|
||||
defer cleanup()
|
||||
|
||||
// First create a session
|
||||
sessionBody := `{"claude_session_id": "obs-test-session", "project": "obs-test-project"}`
|
||||
sessionReq := httptest.NewRequest(http.MethodPost, "/api/sessions/init", bytes.NewBufferString(sessionBody))
|
||||
sessionReq.Header.Set("Content-Type", "application/json")
|
||||
sessionRec := httptest.NewRecorder()
|
||||
svc.router.ServeHTTP(sessionRec, sessionReq)
|
||||
|
||||
// Now create an observation
|
||||
body := `{
|
||||
"claude_session_id": "obs-test-session",
|
||||
"project": "obs-test-project",
|
||||
"tool_name": "Write",
|
||||
"tool_input": {"file_path": "/test/file.go"},
|
||||
"tool_response": {"success": true},
|
||||
"prompt_number": 1,
|
||||
"cwd": "/test"
|
||||
}`
|
||||
req := httptest.NewRequest(http.MethodPost, "/api/sessions/observations", bytes.NewBufferString(body))
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
rec := httptest.NewRecorder()
|
||||
|
||||
svc.router.ServeHTTP(rec, req)
|
||||
|
||||
// Should accept the observation (may queue for processing)
|
||||
assert.Contains(t, []int{http.StatusOK, http.StatusAccepted}, rec.Code)
|
||||
}
|
||||
|
||||
// TestHandleSubagentComplete_ValidRequest tests subagent complete with valid request.
|
||||
func TestHandleSubagentComplete_ValidRequest(t *testing.T) {
|
||||
svc, cleanup := testService(t)
|
||||
defer cleanup()
|
||||
|
||||
// First create a session
|
||||
sessionBody := `{"claude_session_id": "subagent-test-session", "project": "subagent-test-project"}`
|
||||
sessionReq := httptest.NewRequest(http.MethodPost, "/api/sessions/init", bytes.NewBufferString(sessionBody))
|
||||
sessionReq.Header.Set("Content-Type", "application/json")
|
||||
sessionRec := httptest.NewRecorder()
|
||||
svc.router.ServeHTTP(sessionRec, sessionReq)
|
||||
|
||||
// Now complete a subagent
|
||||
body := `{
|
||||
"claude_session_id": "subagent-test-session",
|
||||
"project": "subagent-test-project",
|
||||
"tool_name": "Task",
|
||||
"tool_input": {"description": "test task"},
|
||||
"tool_response": {"result": "completed"},
|
||||
"prompt_number": 1,
|
||||
"cwd": "/test"
|
||||
}`
|
||||
req := httptest.NewRequest(http.MethodPost, "/api/sessions/subagent-complete", bytes.NewBufferString(body))
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
rec := httptest.NewRecorder()
|
||||
|
||||
svc.router.ServeHTTP(rec, req)
|
||||
|
||||
// Should accept the request
|
||||
assert.Contains(t, []int{http.StatusOK, http.StatusAccepted}, rec.Code)
|
||||
}
|
||||
|
||||
// TestClusterObservations tests the observation clustering function.
|
||||
func TestClusterObservations(t *testing.T) {
|
||||
// Create some observations
|
||||
obs := []*models.Observation{
|
||||
{ID: 1, Title: sql.NullString{String: "Test 1", Valid: true}},
|
||||
{ID: 2, Title: sql.NullString{String: "Test 2", Valid: true}},
|
||||
{ID: 3, Title: sql.NullString{String: "Test 3", Valid: true}},
|
||||
}
|
||||
|
||||
// Cluster with default threshold
|
||||
clustered := clusterObservations(obs, 0.4)
|
||||
|
||||
// Should return at least one observation
|
||||
assert.NotEmpty(t, clustered)
|
||||
assert.LessOrEqual(t, len(clustered), len(obs))
|
||||
}
|
||||
|
||||
// TestClusterObservations_EmptyInput tests clustering with empty input.
|
||||
func TestClusterObservations_EmptyInput(t *testing.T) {
|
||||
clustered := clusterObservations([]*models.Observation{}, 0.4)
|
||||
assert.Empty(t, clustered)
|
||||
}
|
||||
|
||||
// TestClusterObservations_NilInput tests clustering with nil input.
|
||||
func TestClusterObservations_NilInput(t *testing.T) {
|
||||
clustered := clusterObservations(nil, 0.4)
|
||||
assert.Empty(t, clustered)
|
||||
}
|
||||
|
||||
// TestComponentHealth tests ComponentHealth struct.
|
||||
func TestComponentHealth(t *testing.T) {
|
||||
comp := ComponentHealth{
|
||||
Name: "Test Component",
|
||||
Status: "healthy",
|
||||
Message: "All good",
|
||||
}
|
||||
|
||||
assert.Equal(t, "Test Component", comp.Name)
|
||||
assert.Equal(t, "healthy", comp.Status)
|
||||
assert.Equal(t, "All good", comp.Message)
|
||||
}
|
||||
|
||||
// TestHandleGetObservations_EmptyResult tests get observations with no data.
|
||||
func TestHandleGetObservations_EmptyResult(t *testing.T) {
|
||||
svc, cleanup := testService(t)
|
||||
defer cleanup()
|
||||
|
||||
req := httptest.NewRequest(http.MethodGet, "/api/observations?project=nonexistent-project", nil)
|
||||
rec := httptest.NewRecorder()
|
||||
|
||||
svc.router.ServeHTTP(rec, req)
|
||||
|
||||
assert.Equal(t, http.StatusOK, rec.Code)
|
||||
|
||||
// Should return empty array, not null
|
||||
var obs []interface{}
|
||||
err := json.Unmarshal(rec.Body.Bytes(), &obs)
|
||||
require.NoError(t, err)
|
||||
assert.NotNil(t, obs)
|
||||
}
|
||||
|
||||
// TestHandleGetSummaries_EmptyResult tests get summaries with no data.
|
||||
func TestHandleGetSummaries_EmptyResult(t *testing.T) {
|
||||
svc, cleanup := testService(t)
|
||||
defer cleanup()
|
||||
|
||||
req := httptest.NewRequest(http.MethodGet, "/api/summaries?project=nonexistent-project", nil)
|
||||
rec := httptest.NewRecorder()
|
||||
|
||||
svc.router.ServeHTTP(rec, req)
|
||||
|
||||
assert.Equal(t, http.StatusOK, rec.Code)
|
||||
}
|
||||
|
||||
// TestHandleGetPrompts_EmptyResult tests get prompts with no data.
|
||||
func TestHandleGetPrompts_EmptyResult(t *testing.T) {
|
||||
svc, cleanup := testService(t)
|
||||
defer cleanup()
|
||||
|
||||
req := httptest.NewRequest(http.MethodGet, "/api/prompts?project=nonexistent-project", nil)
|
||||
rec := httptest.NewRecorder()
|
||||
|
||||
svc.router.ServeHTTP(rec, req)
|
||||
|
||||
assert.Equal(t, http.StatusOK, rec.Code)
|
||||
}
|
||||
|
||||
// TestHandleContextInject_MissingProject tests context inject without project parameter.
|
||||
func TestHandleContextInject_MissingProject(t *testing.T) {
|
||||
svc, cleanup := testService(t)
|
||||
defer cleanup()
|
||||
|
||||
req := httptest.NewRequest(http.MethodGet, "/api/context/inject", nil)
|
||||
rec := httptest.NewRecorder()
|
||||
|
||||
svc.router.ServeHTTP(rec, req)
|
||||
|
||||
assert.Equal(t, http.StatusBadRequest, rec.Code)
|
||||
assert.Contains(t, rec.Body.String(), "project required")
|
||||
}
|
||||
|
||||
// TestHandleContextInject_DefaultCwd tests context inject with default cwd.
|
||||
func TestHandleContextInject_DefaultCwd(t *testing.T) {
|
||||
svc, cleanup := testService(t)
|
||||
defer cleanup()
|
||||
|
||||
project := "inject-default-cwd-test"
|
||||
createTestObservation(t, svc.observationStore, project, "Test", "Content", []string{"test"})
|
||||
|
||||
// Request without cwd parameter - should use default "/"
|
||||
req := httptest.NewRequest(http.MethodGet, "/api/context/inject?project="+project, nil)
|
||||
rec := httptest.NewRecorder()
|
||||
|
||||
svc.router.ServeHTTP(rec, req)
|
||||
|
||||
assert.Equal(t, http.StatusOK, rec.Code)
|
||||
}
|
||||
|
||||
// TestHandleContextInject_ConfigLimits tests context inject respects config limits.
|
||||
func TestHandleContextInject_ConfigLimits(t *testing.T) {
|
||||
svc, cleanup := testService(t)
|
||||
defer cleanup()
|
||||
|
||||
// Set custom config values
|
||||
svc.config.ContextObservations = 5
|
||||
svc.config.ContextFullCount = 2
|
||||
|
||||
project := "inject-config-test"
|
||||
for i := 0; i < 10; i++ {
|
||||
createTestObservation(t, svc.observationStore, project, fmt.Sprintf("Test %d", i), "Content", []string{"test"})
|
||||
}
|
||||
|
||||
req := httptest.NewRequest(http.MethodGet, "/api/context/inject?project="+project, nil)
|
||||
rec := httptest.NewRecorder()
|
||||
|
||||
svc.router.ServeHTTP(rec, req)
|
||||
|
||||
assert.Equal(t, http.StatusOK, rec.Code)
|
||||
|
||||
var response map[string]interface{}
|
||||
err := json.Unmarshal(rec.Body.Bytes(), &response)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, float64(2), response["full_count"])
|
||||
}
|
||||
|
||||
// TestHandleUpdateApply_NoUpdateAvailable tests update apply when no update is available.
|
||||
func TestHandleUpdateApply_NoUpdateAvailable(t *testing.T) {
|
||||
svc, cleanup := testService(t)
|
||||
defer cleanup()
|
||||
|
||||
req := httptest.NewRequest(http.MethodPost, "/api/update/apply", nil)
|
||||
rec := httptest.NewRecorder()
|
||||
|
||||
svc.router.ServeHTTP(rec, req)
|
||||
|
||||
assert.Equal(t, http.StatusOK, rec.Code)
|
||||
|
||||
var response map[string]interface{}
|
||||
err := json.Unmarshal(rec.Body.Bytes(), &response)
|
||||
require.NoError(t, err)
|
||||
// Update check may succeed or fail - both are valid behaviors
|
||||
assert.NotNil(t, response)
|
||||
}
|
||||
|
||||
// TestHandleGetObservations_WithQuery tests observations with query parameter.
|
||||
func TestHandleGetObservations_WithQuery(t *testing.T) {
|
||||
svc, cleanup := testService(t)
|
||||
defer cleanup()
|
||||
|
||||
project := "obs-query-test"
|
||||
createTestObservation(t, svc.observationStore, project, "Unique Test Title", "Content about testing", []string{"test"})
|
||||
|
||||
req := httptest.NewRequest(http.MethodGet, "/api/observations?project="+project+"&query=unique", nil)
|
||||
rec := httptest.NewRecorder()
|
||||
|
||||
svc.router.ServeHTTP(rec, req)
|
||||
|
||||
assert.Equal(t, http.StatusOK, rec.Code)
|
||||
}
|
||||
|
||||
// TestHandleGetSummaries_WithQuery tests summaries with query parameter.
|
||||
func TestHandleGetSummaries_WithQuery(t *testing.T) {
|
||||
svc, cleanup := testService(t)
|
||||
defer cleanup()
|
||||
|
||||
req := httptest.NewRequest(http.MethodGet, "/api/summaries?project=test-project&query=test", nil)
|
||||
rec := httptest.NewRecorder()
|
||||
|
||||
svc.router.ServeHTTP(rec, req)
|
||||
|
||||
assert.Equal(t, http.StatusOK, rec.Code)
|
||||
}
|
||||
|
||||
// TestHandleGetPrompts_WithQuery tests prompts with query parameter.
|
||||
func TestHandleGetPrompts_WithQuery(t *testing.T) {
|
||||
svc, cleanup := testService(t)
|
||||
defer cleanup()
|
||||
|
||||
req := httptest.NewRequest(http.MethodGet, "/api/prompts?project=test-project&query=test", nil)
|
||||
rec := httptest.NewRecorder()
|
||||
|
||||
svc.router.ServeHTTP(rec, req)
|
||||
|
||||
assert.Equal(t, http.StatusOK, rec.Code)
|
||||
}
|
||||
|
||||
// TestHandleSearchByPrompt_SpecialCharsQuery tests search with special characters to trigger FTS fallback.
|
||||
func TestHandleSearchByPrompt_SpecialCharsQuery(t *testing.T) {
|
||||
svc, cleanup := testService(t)
|
||||
defer cleanup()
|
||||
|
||||
project := "special-chars-test"
|
||||
createTestObservation(t, svc.observationStore, project, "Test", "Content", []string{"test"})
|
||||
|
||||
// Use special characters that might cause FTS issues
|
||||
req := httptest.NewRequest(http.MethodGet, "/api/context/search?project="+project+"&query=test()+special*", nil)
|
||||
rec := httptest.NewRecorder()
|
||||
|
||||
svc.router.ServeHTTP(rec, req)
|
||||
|
||||
// Should still return OK (falls back to recent observations)
|
||||
assert.Equal(t, http.StatusOK, rec.Code)
|
||||
}
|
||||
|
||||
// TestHandleHealth tests the health endpoint.
|
||||
func TestHandleHealth(t *testing.T) {
|
||||
svc, cleanup := testService(t)
|
||||
defer cleanup()
|
||||
|
||||
req := httptest.NewRequest(http.MethodGet, "/api/health", nil)
|
||||
rec := httptest.NewRecorder()
|
||||
|
||||
svc.router.ServeHTTP(rec, req)
|
||||
|
||||
assert.Equal(t, http.StatusOK, rec.Code)
|
||||
|
||||
var response map[string]interface{}
|
||||
err := json.Unmarshal(rec.Body.Bytes(), &response)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, "ready", response["status"])
|
||||
}
|
||||
|
||||
// TestHandleSessionInit_ValidRequest tests session init with valid request.
|
||||
func TestHandleSessionInit_ValidRequest(t *testing.T) {
|
||||
svc, cleanup := testService(t)
|
||||
defer cleanup()
|
||||
|
||||
body := `{"claude_session_id": "init-test-session", "project": "init-test-project", "prompt": "test prompt"}`
|
||||
req := httptest.NewRequest(http.MethodPost, "/api/sessions/init", bytes.NewBufferString(body))
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
rec := httptest.NewRecorder()
|
||||
|
||||
svc.router.ServeHTTP(rec, req)
|
||||
|
||||
assert.Equal(t, http.StatusOK, rec.Code)
|
||||
|
||||
var response map[string]interface{}
|
||||
err := json.Unmarshal(rec.Body.Bytes(), &response)
|
||||
require.NoError(t, err)
|
||||
// Response uses camelCase: sessionDbId
|
||||
assert.Contains(t, response, "sessionDbId")
|
||||
}
|
||||
|
||||
// TestHandleSummarize tests the summarize endpoint.
|
||||
func TestHandleSummarize(t *testing.T) {
|
||||
svc, cleanup := testService(t)
|
||||
defer cleanup()
|
||||
|
||||
// First create a session with observations
|
||||
body := `{"claude_session_id": "summarize-test-session", "project": "summarize-test"}`
|
||||
initReq := httptest.NewRequest(http.MethodPost, "/api/sessions/init", bytes.NewBufferString(body))
|
||||
initReq.Header.Set("Content-Type", "application/json")
|
||||
initRec := httptest.NewRecorder()
|
||||
svc.router.ServeHTTP(initRec, initReq)
|
||||
|
||||
// Now request summarization - using the session ID path
|
||||
req := httptest.NewRequest(http.MethodPost, "/api/sessions/summarize-test-session/summarize", nil)
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
rec := httptest.NewRecorder()
|
||||
|
||||
svc.router.ServeHTTP(rec, req)
|
||||
|
||||
// May return OK, 404 (not found), or error depending on endpoint and session state
|
||||
assert.Contains(t, []int{http.StatusOK, http.StatusNotFound, http.StatusInternalServerError}, rec.Code)
|
||||
}
|
||||
|
||||
// TestHandleGetObservations_AllProjects tests getting all observations without project filter.
|
||||
func TestHandleGetObservations_AllProjects(t *testing.T) {
|
||||
svc, cleanup := testService(t)
|
||||
defer cleanup()
|
||||
|
||||
// Create observations in different projects
|
||||
createTestObservation(t, svc.observationStore, "proj-a", "Test A", "Content", []string{"test"})
|
||||
createTestObservation(t, svc.observationStore, "proj-b", "Test B", "Content", []string{"test"})
|
||||
|
||||
// Request without project filter
|
||||
req := httptest.NewRequest(http.MethodGet, "/api/observations", nil)
|
||||
rec := httptest.NewRecorder()
|
||||
|
||||
svc.router.ServeHTTP(rec, req)
|
||||
|
||||
assert.Equal(t, http.StatusOK, rec.Code)
|
||||
}
|
||||
|
||||
// TestHandleGetSummaries_AllProjects tests getting all summaries without project filter.
|
||||
func TestHandleGetSummaries_AllProjects(t *testing.T) {
|
||||
svc, cleanup := testService(t)
|
||||
defer cleanup()
|
||||
|
||||
req := httptest.NewRequest(http.MethodGet, "/api/summaries", nil)
|
||||
rec := httptest.NewRecorder()
|
||||
|
||||
svc.router.ServeHTTP(rec, req)
|
||||
|
||||
assert.Equal(t, http.StatusOK, rec.Code)
|
||||
}
|
||||
|
||||
// TestHandleGetPrompts_AllProjects tests getting all prompts without project filter.
|
||||
func TestHandleGetPrompts_AllProjects(t *testing.T) {
|
||||
svc, cleanup := testService(t)
|
||||
defer cleanup()
|
||||
|
||||
req := httptest.NewRequest(http.MethodGet, "/api/prompts", nil)
|
||||
rec := httptest.NewRecorder()
|
||||
|
||||
svc.router.ServeHTTP(rec, req)
|
||||
|
||||
assert.Equal(t, http.StatusOK, rec.Code)
|
||||
}
|
||||
|
||||
// TestHandleSessionStart tests the session start endpoint.
|
||||
func TestHandleSessionStart(t *testing.T) {
|
||||
svc, cleanup := testService(t)
|
||||
defer cleanup()
|
||||
|
||||
// First create a session
|
||||
initBody := `{"claude_session_id": "start-test-session", "project": "start-test"}`
|
||||
initReq := httptest.NewRequest(http.MethodPost, "/api/sessions/init", bytes.NewBufferString(initBody))
|
||||
initReq.Header.Set("Content-Type", "application/json")
|
||||
initRec := httptest.NewRecorder()
|
||||
svc.router.ServeHTTP(initRec, initReq)
|
||||
|
||||
// Now start the session
|
||||
body := `{"sessionId": "start-test-session"}`
|
||||
req := httptest.NewRequest(http.MethodPost, "/api/sessions/start", bytes.NewBufferString(body))
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
rec := httptest.NewRecorder()
|
||||
|
||||
svc.router.ServeHTTP(rec, req)
|
||||
|
||||
// May return various status codes depending on session state and endpoint
|
||||
assert.Contains(t, []int{http.StatusOK, http.StatusBadRequest, http.StatusNotFound, http.StatusInternalServerError}, rec.Code)
|
||||
}
|
||||
|
||||
@@ -505,3 +505,472 @@ func TestObservationConcepts(t *testing.T) {
|
||||
}
|
||||
assert.Equal(t, expectedConcepts, ObservationConcepts)
|
||||
}
|
||||
|
||||
// TestProcessorStruct tests processor struct initialization and methods.
|
||||
func TestProcessorStruct(t *testing.T) {
|
||||
p := &Processor{
|
||||
claudePath: "/path/to/claude",
|
||||
model: "haiku",
|
||||
sem: make(chan struct{}, MaxConcurrentCLICalls),
|
||||
}
|
||||
|
||||
assert.Equal(t, "/path/to/claude", p.claudePath)
|
||||
assert.Equal(t, "haiku", p.model)
|
||||
assert.NotNil(t, p.sem)
|
||||
}
|
||||
|
||||
// TestSetBroadcastFunc tests the broadcast callback setter.
|
||||
func TestSetBroadcastFunc(t *testing.T) {
|
||||
p := &Processor{}
|
||||
|
||||
assert.Nil(t, p.broadcastFunc)
|
||||
|
||||
var called bool
|
||||
var receivedEvent map[string]interface{}
|
||||
fn := func(event map[string]interface{}) {
|
||||
called = true
|
||||
receivedEvent = event
|
||||
}
|
||||
|
||||
p.SetBroadcastFunc(fn)
|
||||
assert.NotNil(t, p.broadcastFunc)
|
||||
|
||||
// Test broadcast
|
||||
p.broadcast(map[string]interface{}{"type": "test"})
|
||||
assert.True(t, called)
|
||||
assert.Equal(t, "test", receivedEvent["type"])
|
||||
}
|
||||
|
||||
// TestSetSyncObservationFunc tests the sync observation callback setter.
|
||||
func TestSetSyncObservationFunc(t *testing.T) {
|
||||
p := &Processor{}
|
||||
|
||||
assert.Nil(t, p.syncObservationFunc)
|
||||
|
||||
var called bool
|
||||
fn := func(obs *models.Observation) {
|
||||
called = true
|
||||
}
|
||||
|
||||
p.SetSyncObservationFunc(fn)
|
||||
assert.NotNil(t, p.syncObservationFunc)
|
||||
|
||||
// Verify it was set
|
||||
p.syncObservationFunc(&models.Observation{})
|
||||
assert.True(t, called)
|
||||
}
|
||||
|
||||
// TestSetSyncSummaryFunc tests the sync summary callback setter.
|
||||
func TestSetSyncSummaryFunc(t *testing.T) {
|
||||
p := &Processor{}
|
||||
|
||||
assert.Nil(t, p.syncSummaryFunc)
|
||||
|
||||
var called bool
|
||||
fn := func(summary *models.SessionSummary) {
|
||||
called = true
|
||||
}
|
||||
|
||||
p.SetSyncSummaryFunc(fn)
|
||||
assert.NotNil(t, p.syncSummaryFunc)
|
||||
|
||||
// Verify it was set
|
||||
p.syncSummaryFunc(&models.SessionSummary{})
|
||||
assert.True(t, called)
|
||||
}
|
||||
|
||||
// TestBroadcast_NilFunc tests broadcast with nil callback.
|
||||
func TestBroadcast_NilFunc(t *testing.T) {
|
||||
p := &Processor{}
|
||||
|
||||
// Should not panic
|
||||
p.broadcast(map[string]interface{}{"type": "test"})
|
||||
}
|
||||
|
||||
// TestIsAvailable_NonexistentPath tests IsAvailable with non-existent path.
|
||||
func TestIsAvailable_NonexistentPath(t *testing.T) {
|
||||
p := &Processor{
|
||||
claudePath: "/nonexistent/path/to/claude",
|
||||
}
|
||||
|
||||
assert.False(t, p.IsAvailable())
|
||||
}
|
||||
|
||||
// TestIsAvailable_ExistingPath tests IsAvailable with existing path.
|
||||
func TestIsAvailable_ExistingPath(t *testing.T) {
|
||||
// Create a temp file to simulate claude binary
|
||||
tmpFile, err := os.CreateTemp("", "claude-test-*")
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
defer os.Remove(tmpFile.Name())
|
||||
tmpFile.Close()
|
||||
|
||||
p := &Processor{
|
||||
claudePath: tmpFile.Name(),
|
||||
}
|
||||
|
||||
assert.True(t, p.IsAvailable())
|
||||
}
|
||||
|
||||
// TestShouldSkipTrivialOperation_EdgeCases tests edge cases for trivial operation detection.
|
||||
func TestShouldSkipTrivialOperation_EdgeCases(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
toolName string
|
||||
inputStr string
|
||||
outputStr string
|
||||
expected bool
|
||||
}{
|
||||
{
|
||||
name: "gitignore_file",
|
||||
toolName: "Read",
|
||||
inputStr: `{"file_path": "/project/.gitignore"}`,
|
||||
outputStr: "This is a gitignore file content that has more than 50 characters long",
|
||||
expected: true,
|
||||
},
|
||||
{
|
||||
name: "eslintignore_file",
|
||||
toolName: "Read",
|
||||
inputStr: `{"file_path": "/project/.eslintignore"}`,
|
||||
outputStr: "This is an eslintignore file content that has more than 50 characters long",
|
||||
expected: true,
|
||||
},
|
||||
{
|
||||
name: "tsconfig_file",
|
||||
toolName: "Read",
|
||||
inputStr: `{"file_path": "/project/tsconfig.json"}`,
|
||||
outputStr: "This is a tsconfig.json file content that has more than 50 characters long",
|
||||
expected: true,
|
||||
},
|
||||
{
|
||||
name: "tailwind_config",
|
||||
toolName: "Read",
|
||||
inputStr: `{"file_path": "/project/tailwind.config.js"}`,
|
||||
outputStr: "This is a tailwind.config file content that has more than 50 characters long",
|
||||
expected: true,
|
||||
},
|
||||
{
|
||||
name: "pwd_command",
|
||||
toolName: "Bash",
|
||||
inputStr: `{"command": "pwd"}`,
|
||||
outputStr: "/home/user/project/some/long/path/that/is/more/than/fifty/chars",
|
||||
expected: true,
|
||||
},
|
||||
{
|
||||
name: "echo_command",
|
||||
toolName: "Bash",
|
||||
inputStr: `{"command": "echo Hello World"}`,
|
||||
outputStr: "Hello World output that is long enough to pass the length check here",
|
||||
expected: true,
|
||||
},
|
||||
{
|
||||
name: "npm_audit_command",
|
||||
toolName: "Bash",
|
||||
inputStr: `{"command": "npm audit"}`,
|
||||
outputStr: "found 0 vulnerabilities in 500 packages which is more than fifty characters",
|
||||
expected: true,
|
||||
},
|
||||
{
|
||||
name: "permission_denied",
|
||||
toolName: "Read",
|
||||
inputStr: `{"file_path": "/root/secret"}`,
|
||||
outputStr: "Error: Permission denied accessing the file at specified path",
|
||||
expected: true,
|
||||
},
|
||||
{
|
||||
name: "is_a_directory",
|
||||
toolName: "Read",
|
||||
inputStr: `{"file_path": "/some/dir"}`,
|
||||
outputStr: "Error: /some/dir is a directory, not a file that can be read",
|
||||
expected: true,
|
||||
},
|
||||
{
|
||||
name: "empty_object",
|
||||
toolName: "Grep",
|
||||
inputStr: `{"pattern": "nonexistent"}`,
|
||||
outputStr: "{}",
|
||||
expected: true,
|
||||
},
|
||||
{
|
||||
name: "valid_grep_result",
|
||||
toolName: "Grep",
|
||||
inputStr: `{"pattern": "func main"}`,
|
||||
outputStr: "main.go:10:func main() {\nmain.go:11: fmt.Println(\"Hello\")\n}",
|
||||
expected: false,
|
||||
},
|
||||
{
|
||||
name: "valid_bash_build",
|
||||
toolName: "Bash",
|
||||
inputStr: `{"command": "go build ./..."}`,
|
||||
outputStr: "Build completed successfully. Binary output at ./bin/myapp with size 10MB.",
|
||||
expected: false,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
result := shouldSkipTrivialOperation(tt.toolName, tt.inputStr, tt.outputStr)
|
||||
assert.Equal(t, tt.expected, result)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// TestIsSelfReferentialSummary_MoreCases tests additional self-referential detection cases.
|
||||
func TestIsSelfReferentialSummary_MoreCases(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
summary *models.ParsedSummary
|
||||
expected bool
|
||||
}{
|
||||
{
|
||||
name: "progress_checkpoint",
|
||||
summary: &models.ParsedSummary{
|
||||
Request: "Progress checkpoint for current session",
|
||||
Completed: "Responding to progress checkpoint request",
|
||||
Learned: "No technical learnings yet",
|
||||
},
|
||||
expected: true,
|
||||
},
|
||||
{
|
||||
name: "empty_session",
|
||||
summary: &models.ParsedSummary{
|
||||
Request: "Empty session",
|
||||
Completed: "Just beginning the session",
|
||||
Learned: "Nothing has been completed yet",
|
||||
},
|
||||
expected: true,
|
||||
},
|
||||
{
|
||||
name: "hook_mechanism",
|
||||
summary: &models.ParsedSummary{
|
||||
Request: "Hook execution for session start",
|
||||
Completed: "Hook mechanism triggered successfully",
|
||||
Learned: "System hooks are working",
|
||||
},
|
||||
expected: true,
|
||||
},
|
||||
{
|
||||
name: "api_implementation",
|
||||
summary: &models.ParsedSummary{
|
||||
Request: "Implement REST API endpoints",
|
||||
Completed: "Created /users and /posts endpoints with CRUD operations",
|
||||
Learned: "chi router handles middleware chaining elegantly",
|
||||
NextSteps: "Add authentication middleware",
|
||||
},
|
||||
expected: false,
|
||||
},
|
||||
{
|
||||
name: "database_migration",
|
||||
summary: &models.ParsedSummary{
|
||||
Request: "Add database migration for new user fields",
|
||||
Completed: "Created migration 003_add_user_profile.sql with new columns",
|
||||
Learned: "SQLite ALTER TABLE has limited capabilities, need to recreate table",
|
||||
NextSteps: "Test migration rollback",
|
||||
},
|
||||
expected: false,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
result := isSelfReferentialSummary(tt.summary)
|
||||
assert.Equal(t, tt.expected, result)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// TestHasMeaningfulContent_MoreCases tests additional meaningful content detection.
|
||||
func TestHasMeaningfulContent_MoreCases(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
content string
|
||||
expected bool
|
||||
}{
|
||||
{
|
||||
name: "code_with_functions",
|
||||
content: `I've created a new handler function in handlers.go.
|
||||
The function validateRequest() checks the incoming JSON payload.
|
||||
Here's the implementation:
|
||||
` + "```go\nfunc validateRequest(r *http.Request) error {\n\treturn nil\n}\n```",
|
||||
expected: true,
|
||||
},
|
||||
{
|
||||
name: "python_code_discussion",
|
||||
content: `Updated the data processing module in processor.py.
|
||||
Changed the filter function to use list comprehension.
|
||||
def process_data(items):
|
||||
return [item for item in items if item.valid]
|
||||
This improved performance by 30%.`,
|
||||
expected: true,
|
||||
},
|
||||
{
|
||||
name: "typescript_changes",
|
||||
content: `I've modified the React component in UserProfile.tsx.
|
||||
Added a new functional component with proper TypeScript type annotations.
|
||||
Here's the updated implementation:
|
||||
` + "```tsx\nconst UserProfile: FC<Props> = ({ user }) => {\n return <div>{user.name}</div>;\n};\n```" + `
|
||||
The type annotations ensure type safety across the application.
|
||||
The component has been updated with proper error handling and loading states.`,
|
||||
expected: true,
|
||||
},
|
||||
{
|
||||
name: "yaml_config_update",
|
||||
content: `I've updated the kubernetes deployment config in deploy.yaml.
|
||||
Changed replicas from 2 to 4 and added resource limits for memory and CPU.
|
||||
The deployment.yaml file now includes the following struct configuration:
|
||||
` + "```yaml\nreplicas: 4\nresources:\n limits:\n memory: 512Mi\n```" + `
|
||||
The changes have been implemented and will be applied on next deploy.`,
|
||||
expected: true,
|
||||
},
|
||||
{
|
||||
name: "just_system_messages",
|
||||
content: `SessionStart:Callback hook success
|
||||
System-reminder about tools
|
||||
The session is starting
|
||||
Waiting for user instructions`,
|
||||
expected: false,
|
||||
},
|
||||
{
|
||||
name: "borderline_short",
|
||||
content: "Fixed bug. Updated file. Added test. Committed changes to repository.",
|
||||
expected: false, // Too short (< 200 chars)
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
result := hasMeaningfulContent(tt.content)
|
||||
assert.Equal(t, tt.expected, result)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// TestToJSONString_ComplexTypes tests JSON conversion for complex types.
|
||||
func TestToJSONString_ComplexTypes(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
input interface{}
|
||||
contains string
|
||||
}{
|
||||
{
|
||||
name: "nested_map",
|
||||
input: map[string]interface{}{
|
||||
"outer": map[string]string{"inner": "value"},
|
||||
},
|
||||
contains: "inner",
|
||||
},
|
||||
{
|
||||
name: "bool_true",
|
||||
input: true,
|
||||
contains: "true",
|
||||
},
|
||||
{
|
||||
name: "bool_false",
|
||||
input: false,
|
||||
contains: "false",
|
||||
},
|
||||
{
|
||||
name: "float_value",
|
||||
input: 3.14,
|
||||
contains: "3.14",
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
result := toJSONString(tt.input)
|
||||
assert.Contains(t, result, tt.contains)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// TestSystemPrompt tests that the system prompt is defined.
|
||||
func TestSystemPrompt(t *testing.T) {
|
||||
assert.NotEmpty(t, systemPrompt)
|
||||
assert.Contains(t, systemPrompt, "memory extraction agent")
|
||||
assert.Contains(t, systemPrompt, "observation")
|
||||
assert.Contains(t, systemPrompt, "GUIDELINES")
|
||||
}
|
||||
|
||||
// TestProcessorSemaphore tests the semaphore behavior.
|
||||
func TestProcessorSemaphore(t *testing.T) {
|
||||
p := &Processor{
|
||||
sem: make(chan struct{}, 2),
|
||||
}
|
||||
|
||||
// Acquire 2 slots
|
||||
p.sem <- struct{}{}
|
||||
p.sem <- struct{}{}
|
||||
|
||||
// Third should block (we can test with select)
|
||||
select {
|
||||
case p.sem <- struct{}{}:
|
||||
t.Error("Semaphore should be full")
|
||||
default:
|
||||
// Expected - semaphore is full
|
||||
}
|
||||
|
||||
// Release one
|
||||
<-p.sem
|
||||
|
||||
// Now should be able to acquire
|
||||
select {
|
||||
case p.sem <- struct{}{}:
|
||||
// Expected
|
||||
default:
|
||||
t.Error("Should be able to acquire after release")
|
||||
}
|
||||
}
|
||||
|
||||
// TestCaptureFileMtimes_DuplicatePaths tests mtime capture with overlapping paths.
|
||||
func TestCaptureFileMtimes_DuplicatePaths(t *testing.T) {
|
||||
tmpDir, err := os.MkdirTemp("", "mtime-dup-test-*")
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
defer os.RemoveAll(tmpDir)
|
||||
|
||||
testFile := filepath.Join(tmpDir, "shared.txt")
|
||||
err = os.WriteFile(testFile, []byte("content"), 0644)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
// Same file in both read and modified lists
|
||||
mtimes := captureFileMtimes([]string{testFile}, []string{testFile}, "")
|
||||
|
||||
// Should only have one entry (no duplicates)
|
||||
assert.Len(t, mtimes, 1)
|
||||
assert.Contains(t, mtimes, testFile)
|
||||
}
|
||||
|
||||
// TestTruncateForLog_ZeroLength tests truncation with zero length.
|
||||
func TestTruncateForLog_ZeroLength(t *testing.T) {
|
||||
result := truncateForLog("hello", 0)
|
||||
assert.Equal(t, "...", result)
|
||||
}
|
||||
|
||||
// TestBroadcastFuncType tests the BroadcastFunc type.
|
||||
func TestBroadcastFuncType(t *testing.T) {
|
||||
var fn BroadcastFunc = func(event map[string]interface{}) {
|
||||
// Do nothing
|
||||
}
|
||||
assert.NotNil(t, fn)
|
||||
}
|
||||
|
||||
// TestSyncObservationFuncType tests the SyncObservationFunc type.
|
||||
func TestSyncObservationFuncType(t *testing.T) {
|
||||
var fn SyncObservationFunc = func(obs *models.Observation) {
|
||||
// Do nothing
|
||||
}
|
||||
assert.NotNil(t, fn)
|
||||
}
|
||||
|
||||
// TestSyncSummaryFuncType tests the SyncSummaryFunc type.
|
||||
func TestSyncSummaryFuncType(t *testing.T) {
|
||||
var fn SyncSummaryFunc = func(summary *models.SessionSummary) {
|
||||
// Do nothing
|
||||
}
|
||||
assert.NotNil(t, fn)
|
||||
}
|
||||
|
||||
@@ -0,0 +1,615 @@
|
||||
// Package session provides session lifecycle management for claude-mnemonic.
|
||||
package session
|
||||
|
||||
import (
|
||||
"context"
|
||||
"os"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/lukaszraczylo/claude-mnemonic/internal/db/sqlite"
|
||||
"github.com/stretchr/testify/require"
|
||||
"github.com/stretchr/testify/suite"
|
||||
|
||||
// Import sqlite driver
|
||||
_ "github.com/mattn/go-sqlite3"
|
||||
)
|
||||
|
||||
// hasFTS5 checks if FTS5 is available in the SQLite build.
|
||||
func hasFTS5(t *testing.T) bool {
|
||||
t.Helper()
|
||||
|
||||
tmpDir, err := os.MkdirTemp("", "fts5-check-*")
|
||||
if err != nil {
|
||||
return false
|
||||
}
|
||||
defer func() { _ = os.RemoveAll(tmpDir) }()
|
||||
|
||||
store, err := sqlite.NewStore(sqlite.StoreConfig{
|
||||
Path: tmpDir + "/check.db",
|
||||
MaxConns: 1,
|
||||
WALMode: true,
|
||||
})
|
||||
if err != nil {
|
||||
return false
|
||||
}
|
||||
_ = store.Close()
|
||||
return true
|
||||
}
|
||||
|
||||
// testStore creates a sqlite.Store with a temporary database for testing.
|
||||
func testStore(t *testing.T) (*sqlite.Store, func()) {
|
||||
t.Helper()
|
||||
|
||||
if !hasFTS5(t) {
|
||||
t.Skip("FTS5 not available in this SQLite build")
|
||||
}
|
||||
|
||||
tmpDir, err := os.MkdirTemp("", "session-integration-test-*")
|
||||
require.NoError(t, err)
|
||||
|
||||
dbPath := tmpDir + "/test.db"
|
||||
|
||||
store, err := sqlite.NewStore(sqlite.StoreConfig{
|
||||
Path: dbPath,
|
||||
MaxConns: 1,
|
||||
WALMode: true,
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
cleanup := func() {
|
||||
_ = store.Close()
|
||||
_ = os.RemoveAll(tmpDir)
|
||||
}
|
||||
|
||||
return store, cleanup
|
||||
}
|
||||
|
||||
// SessionIntegrationSuite tests session manager with real SQLite stores.
|
||||
type SessionIntegrationSuite struct {
|
||||
suite.Suite
|
||||
store *sqlite.Store
|
||||
sessionStore *sqlite.SessionStore
|
||||
cleanup func()
|
||||
manager *Manager
|
||||
}
|
||||
|
||||
func (s *SessionIntegrationSuite) SetupTest() {
|
||||
if !hasFTS5(s.T()) {
|
||||
s.T().Skip("FTS5 not available in this SQLite build")
|
||||
}
|
||||
|
||||
s.store, s.cleanup = testStore(s.T())
|
||||
s.sessionStore = sqlite.NewSessionStore(s.store)
|
||||
s.manager = NewManager(s.sessionStore)
|
||||
}
|
||||
|
||||
func (s *SessionIntegrationSuite) TearDownTest() {
|
||||
if s.manager != nil {
|
||||
s.manager.ShutdownAll(context.Background())
|
||||
}
|
||||
if s.cleanup != nil {
|
||||
s.cleanup()
|
||||
}
|
||||
}
|
||||
|
||||
func TestSessionIntegrationSuite(t *testing.T) {
|
||||
suite.Run(t, new(SessionIntegrationSuite))
|
||||
}
|
||||
|
||||
// TestNewManager_WithRealStore tests manager creation with real store.
|
||||
func (s *SessionIntegrationSuite) TestNewManager_WithRealStore() {
|
||||
s.NotNil(s.manager)
|
||||
s.NotNil(s.manager.sessionStore)
|
||||
s.NotNil(s.manager.sessions)
|
||||
s.NotNil(s.manager.ProcessNotify)
|
||||
s.Equal(0, s.manager.GetActiveSessionCount())
|
||||
}
|
||||
|
||||
// TestInitializeSession_WithRealStore tests session initialization.
|
||||
func (s *SessionIntegrationSuite) TestInitializeSession_WithRealStore() {
|
||||
ctx := context.Background()
|
||||
|
||||
// Create a session in the database first
|
||||
sessionID, err := s.sessionStore.CreateSDKSession(ctx, "claude-test-123", "test-project", "initial prompt")
|
||||
s.Require().NoError(err)
|
||||
s.Require().Greater(sessionID, int64(0))
|
||||
|
||||
// Initialize in manager
|
||||
session, err := s.manager.InitializeSession(ctx, sessionID, "user prompt", 1)
|
||||
s.Require().NoError(err)
|
||||
s.Require().NotNil(session)
|
||||
|
||||
// Verify session properties
|
||||
s.Equal(sessionID, session.SessionDBID)
|
||||
s.Equal("claude-test-123", session.ClaudeSessionID)
|
||||
s.Equal("test-project", session.Project)
|
||||
s.Equal("user prompt", session.UserPrompt)
|
||||
s.Equal(1, session.LastPromptNumber)
|
||||
|
||||
// Verify manager state
|
||||
s.Equal(1, s.manager.GetActiveSessionCount())
|
||||
}
|
||||
|
||||
// TestInitializeSession_ReuseExisting tests that existing sessions are reused.
|
||||
func (s *SessionIntegrationSuite) TestInitializeSession_ReuseExisting() {
|
||||
ctx := context.Background()
|
||||
|
||||
// Create session in database
|
||||
sessionID, err := s.sessionStore.CreateSDKSession(ctx, "claude-reuse-123", "test-project", "prompt")
|
||||
s.Require().NoError(err)
|
||||
|
||||
// Initialize first time
|
||||
session1, err := s.manager.InitializeSession(ctx, sessionID, "prompt 1", 1)
|
||||
s.Require().NoError(err)
|
||||
s.Require().NotNil(session1)
|
||||
|
||||
// Initialize second time - should reuse
|
||||
session2, err := s.manager.InitializeSession(ctx, sessionID, "prompt 2", 2)
|
||||
s.Require().NoError(err)
|
||||
s.Require().NotNil(session2)
|
||||
|
||||
// Should be the same session pointer
|
||||
s.Same(session1, session2)
|
||||
|
||||
// Should have updated user prompt
|
||||
s.Equal("prompt 2", session2.UserPrompt)
|
||||
s.Equal(2, session2.LastPromptNumber)
|
||||
|
||||
// Still only 1 active session
|
||||
s.Equal(1, s.manager.GetActiveSessionCount())
|
||||
}
|
||||
|
||||
// TestInitializeSession_NonExistentSession tests initializing non-existent session.
|
||||
func (s *SessionIntegrationSuite) TestInitializeSession_NonExistentSession() {
|
||||
ctx := context.Background()
|
||||
|
||||
// Try to initialize non-existent session
|
||||
session, err := s.manager.InitializeSession(ctx, 999999, "prompt", 1)
|
||||
s.NoError(err) // No error, just nil session
|
||||
s.Nil(session)
|
||||
|
||||
s.Equal(0, s.manager.GetActiveSessionCount())
|
||||
}
|
||||
|
||||
// TestInitializeSession_EmptyUserPrompt tests initialization with empty user prompt.
|
||||
func (s *SessionIntegrationSuite) TestInitializeSession_EmptyUserPrompt() {
|
||||
ctx := context.Background()
|
||||
|
||||
// Create session with initial prompt in database
|
||||
sessionID, err := s.sessionStore.CreateSDKSession(ctx, "claude-empty-prompt", "test-project", "db prompt")
|
||||
s.Require().NoError(err)
|
||||
|
||||
// Initialize with empty user prompt - should use database prompt
|
||||
session, err := s.manager.InitializeSession(ctx, sessionID, "", 0)
|
||||
s.Require().NoError(err)
|
||||
s.Require().NotNil(session)
|
||||
|
||||
// Should use database prompt
|
||||
s.Equal("db prompt", session.UserPrompt)
|
||||
}
|
||||
|
||||
// TestQueueObservation_WithRealStore tests observation queuing.
|
||||
func (s *SessionIntegrationSuite) TestQueueObservation_WithRealStore() {
|
||||
ctx := context.Background()
|
||||
|
||||
// Create and initialize session
|
||||
sessionID, err := s.sessionStore.CreateSDKSession(ctx, "claude-queue-obs", "test-project", "prompt")
|
||||
s.Require().NoError(err)
|
||||
|
||||
_, err = s.manager.InitializeSession(ctx, sessionID, "prompt", 1)
|
||||
s.Require().NoError(err)
|
||||
|
||||
// Queue an observation
|
||||
err = s.manager.QueueObservation(ctx, sessionID, ObservationData{
|
||||
ToolName: "Read",
|
||||
ToolInput: map[string]string{"path": "/test.go"},
|
||||
ToolResponse: "file content",
|
||||
PromptNumber: 1,
|
||||
CWD: "/project",
|
||||
})
|
||||
s.Require().NoError(err)
|
||||
|
||||
// Check queue depth
|
||||
s.Equal(1, s.manager.GetTotalQueueDepth())
|
||||
s.True(s.manager.IsAnySessionProcessing())
|
||||
|
||||
// Drain messages
|
||||
messages := s.manager.DrainMessages(sessionID)
|
||||
s.Len(messages, 1)
|
||||
s.Equal(MessageTypeObservation, messages[0].Type)
|
||||
s.Equal("Read", messages[0].Observation.ToolName)
|
||||
}
|
||||
|
||||
// TestQueueObservation_AutoInitialize tests auto-initialization on queue.
|
||||
func (s *SessionIntegrationSuite) TestQueueObservation_AutoInitialize() {
|
||||
ctx := context.Background()
|
||||
|
||||
// Create session in database but don't initialize in manager
|
||||
sessionID, err := s.sessionStore.CreateSDKSession(ctx, "claude-auto-init", "test-project", "prompt")
|
||||
s.Require().NoError(err)
|
||||
|
||||
// Queue observation without explicit initialization
|
||||
err = s.manager.QueueObservation(ctx, sessionID, ObservationData{
|
||||
ToolName: "Write",
|
||||
ToolInput: "test input",
|
||||
ToolResponse: "success",
|
||||
PromptNumber: 1,
|
||||
})
|
||||
s.Require().NoError(err)
|
||||
|
||||
// Session should be auto-initialized
|
||||
s.Equal(1, s.manager.GetActiveSessionCount())
|
||||
s.Equal(1, s.manager.GetTotalQueueDepth())
|
||||
}
|
||||
|
||||
// TestQueueObservation_NonExistentSession tests queuing to non-existent session.
|
||||
func (s *SessionIntegrationSuite) TestQueueObservation_NonExistentSession() {
|
||||
ctx := context.Background()
|
||||
|
||||
// Try to queue to non-existent session
|
||||
err := s.manager.QueueObservation(ctx, 999999, ObservationData{
|
||||
ToolName: "Test",
|
||||
})
|
||||
|
||||
// Should not error, but session won't be created
|
||||
s.NoError(err)
|
||||
s.Equal(0, s.manager.GetActiveSessionCount())
|
||||
}
|
||||
|
||||
// TestQueueSummarize_WithRealStore tests summarize queuing.
|
||||
func (s *SessionIntegrationSuite) TestQueueSummarize_WithRealStore() {
|
||||
ctx := context.Background()
|
||||
|
||||
// Create and initialize session
|
||||
sessionID, err := s.sessionStore.CreateSDKSession(ctx, "claude-queue-sum", "test-project", "prompt")
|
||||
s.Require().NoError(err)
|
||||
|
||||
_, err = s.manager.InitializeSession(ctx, sessionID, "prompt", 1)
|
||||
s.Require().NoError(err)
|
||||
|
||||
// Queue a summarize request
|
||||
err = s.manager.QueueSummarize(ctx, sessionID, "What did you do?", "I completed the task.")
|
||||
s.Require().NoError(err)
|
||||
|
||||
// Check queue depth
|
||||
s.Equal(1, s.manager.GetTotalQueueDepth())
|
||||
s.True(s.manager.IsAnySessionProcessing())
|
||||
|
||||
// Drain messages
|
||||
messages := s.manager.DrainMessages(sessionID)
|
||||
s.Len(messages, 1)
|
||||
s.Equal(MessageTypeSummarize, messages[0].Type)
|
||||
s.Equal("What did you do?", messages[0].Summarize.LastUserMessage)
|
||||
s.Equal("I completed the task.", messages[0].Summarize.LastAssistantMessage)
|
||||
}
|
||||
|
||||
// TestQueueSummarize_AutoInitialize tests auto-initialization on summarize queue.
|
||||
func (s *SessionIntegrationSuite) TestQueueSummarize_AutoInitialize() {
|
||||
ctx := context.Background()
|
||||
|
||||
// Create session in database but don't initialize in manager
|
||||
sessionID, err := s.sessionStore.CreateSDKSession(ctx, "claude-sum-auto", "test-project", "prompt")
|
||||
s.Require().NoError(err)
|
||||
|
||||
// Queue summarize without explicit initialization
|
||||
err = s.manager.QueueSummarize(ctx, sessionID, "user msg", "assistant msg")
|
||||
s.Require().NoError(err)
|
||||
|
||||
// Session should be auto-initialized
|
||||
s.Equal(1, s.manager.GetActiveSessionCount())
|
||||
s.Equal(1, s.manager.GetTotalQueueDepth())
|
||||
}
|
||||
|
||||
// TestQueueSummarize_NonExistentSession tests summarize queuing to non-existent session.
|
||||
func (s *SessionIntegrationSuite) TestQueueSummarize_NonExistentSession() {
|
||||
ctx := context.Background()
|
||||
|
||||
// Try to queue to non-existent session
|
||||
err := s.manager.QueueSummarize(ctx, 999999, "user", "assistant")
|
||||
|
||||
// Should not error, but session won't be created
|
||||
s.NoError(err)
|
||||
s.Equal(0, s.manager.GetActiveSessionCount())
|
||||
}
|
||||
|
||||
// TestMixedQueueOperations tests mixed observation and summarize queuing.
|
||||
func (s *SessionIntegrationSuite) TestMixedQueueOperations() {
|
||||
ctx := context.Background()
|
||||
|
||||
// Create and initialize session
|
||||
sessionID, err := s.sessionStore.CreateSDKSession(ctx, "claude-mixed", "test-project", "prompt")
|
||||
s.Require().NoError(err)
|
||||
|
||||
_, err = s.manager.InitializeSession(ctx, sessionID, "prompt", 1)
|
||||
s.Require().NoError(err)
|
||||
|
||||
// Queue multiple messages of different types
|
||||
err = s.manager.QueueObservation(ctx, sessionID, ObservationData{ToolName: "Tool1"})
|
||||
s.Require().NoError(err)
|
||||
|
||||
err = s.manager.QueueSummarize(ctx, sessionID, "user1", "assistant1")
|
||||
s.Require().NoError(err)
|
||||
|
||||
err = s.manager.QueueObservation(ctx, sessionID, ObservationData{ToolName: "Tool2"})
|
||||
s.Require().NoError(err)
|
||||
|
||||
// Check total queue depth
|
||||
s.Equal(3, s.manager.GetTotalQueueDepth())
|
||||
|
||||
// Drain and verify order
|
||||
messages := s.manager.DrainMessages(sessionID)
|
||||
s.Len(messages, 3)
|
||||
s.Equal(MessageTypeObservation, messages[0].Type)
|
||||
s.Equal("Tool1", messages[0].Observation.ToolName)
|
||||
s.Equal(MessageTypeSummarize, messages[1].Type)
|
||||
s.Equal(MessageTypeObservation, messages[2].Type)
|
||||
s.Equal("Tool2", messages[2].Observation.ToolName)
|
||||
}
|
||||
|
||||
// TestProcessNotifyChannel tests the process notification channel behavior.
|
||||
func (s *SessionIntegrationSuite) TestProcessNotifyChannel() {
|
||||
ctx := context.Background()
|
||||
|
||||
// Create and initialize session
|
||||
sessionID, err := s.sessionStore.CreateSDKSession(ctx, "claude-notify", "test-project", "prompt")
|
||||
s.Require().NoError(err)
|
||||
|
||||
_, err = s.manager.InitializeSession(ctx, sessionID, "prompt", 1)
|
||||
s.Require().NoError(err)
|
||||
|
||||
// Drain any existing notifications
|
||||
select {
|
||||
case <-s.manager.ProcessNotify:
|
||||
default:
|
||||
}
|
||||
|
||||
// Queue observation - should trigger notification
|
||||
err = s.manager.QueueObservation(ctx, sessionID, ObservationData{ToolName: "Test"})
|
||||
s.Require().NoError(err)
|
||||
|
||||
// Should be able to receive notification
|
||||
select {
|
||||
case <-s.manager.ProcessNotify:
|
||||
// Success
|
||||
case <-time.After(100 * time.Millisecond):
|
||||
s.Fail("Should have received process notification")
|
||||
}
|
||||
}
|
||||
|
||||
// TestSessionCallbacks tests session lifecycle callbacks.
|
||||
func (s *SessionIntegrationSuite) TestSessionCallbacks() {
|
||||
ctx := context.Background()
|
||||
|
||||
var createdID, deletedID int64
|
||||
|
||||
s.manager.SetOnSessionCreated(func(id int64) {
|
||||
createdID = id
|
||||
})
|
||||
s.manager.SetOnSessionDeleted(func(id int64) {
|
||||
deletedID = id
|
||||
})
|
||||
|
||||
// Create session
|
||||
sessionID, err := s.sessionStore.CreateSDKSession(ctx, "claude-callbacks", "test-project", "prompt")
|
||||
s.Require().NoError(err)
|
||||
|
||||
// Initialize - should trigger created callback
|
||||
_, err = s.manager.InitializeSession(ctx, sessionID, "prompt", 1)
|
||||
s.Require().NoError(err)
|
||||
|
||||
s.Equal(sessionID, createdID)
|
||||
|
||||
// Delete - should trigger deleted callback
|
||||
s.manager.DeleteSession(sessionID)
|
||||
|
||||
s.Equal(sessionID, deletedID)
|
||||
}
|
||||
|
||||
// TestMultipleSessions tests managing multiple sessions.
|
||||
func (s *SessionIntegrationSuite) TestMultipleSessions() {
|
||||
ctx := context.Background()
|
||||
|
||||
// Create multiple sessions
|
||||
var sessionIDs []int64
|
||||
for i := 0; i < 5; i++ {
|
||||
id, err := s.sessionStore.CreateSDKSession(ctx, "claude-multi-"+string(rune('A'+i)), "project-"+string(rune('a'+i)), "prompt")
|
||||
s.Require().NoError(err)
|
||||
sessionIDs = append(sessionIDs, id)
|
||||
}
|
||||
|
||||
// Initialize all
|
||||
for _, id := range sessionIDs {
|
||||
_, err := s.manager.InitializeSession(ctx, id, "prompt", 1)
|
||||
s.Require().NoError(err)
|
||||
}
|
||||
|
||||
s.Equal(5, s.manager.GetActiveSessionCount())
|
||||
|
||||
// Queue observations to each
|
||||
for i, id := range sessionIDs {
|
||||
err := s.manager.QueueObservation(ctx, id, ObservationData{
|
||||
ToolName: "Tool" + string(rune('A'+i)),
|
||||
})
|
||||
s.Require().NoError(err)
|
||||
}
|
||||
|
||||
s.Equal(5, s.manager.GetTotalQueueDepth())
|
||||
|
||||
// Get all sessions
|
||||
sessions := s.manager.GetAllSessions()
|
||||
s.Len(sessions, 5)
|
||||
|
||||
// Delete all
|
||||
for _, id := range sessionIDs {
|
||||
s.manager.DeleteSession(id)
|
||||
}
|
||||
|
||||
s.Equal(0, s.manager.GetActiveSessionCount())
|
||||
}
|
||||
|
||||
// TestCleanupStaleSessions tests the cleanup of stale sessions.
|
||||
func (s *SessionIntegrationSuite) TestCleanupStaleSessions() {
|
||||
ctx := context.Background()
|
||||
|
||||
// Create and initialize a session
|
||||
sessionID, err := s.sessionStore.CreateSDKSession(ctx, "claude-stale", "test-project", "prompt")
|
||||
s.Require().NoError(err)
|
||||
|
||||
session, err := s.manager.InitializeSession(ctx, sessionID, "prompt", 1)
|
||||
s.Require().NoError(err)
|
||||
|
||||
// Manually set start time to past (simulate stale session)
|
||||
session.StartTime = time.Now().Add(-SessionTimeout - time.Minute)
|
||||
|
||||
// Run cleanup
|
||||
s.manager.cleanupStaleSessions()
|
||||
|
||||
// Session should be deleted
|
||||
s.Equal(0, s.manager.GetActiveSessionCount())
|
||||
}
|
||||
|
||||
// TestCleanupStaleSessions_WithPendingMessages tests cleanup doesn't delete sessions with pending messages.
|
||||
func (s *SessionIntegrationSuite) TestCleanupStaleSessions_WithPendingMessages() {
|
||||
ctx := context.Background()
|
||||
|
||||
// Create and initialize a session
|
||||
sessionID, err := s.sessionStore.CreateSDKSession(ctx, "claude-stale-pending", "test-project", "prompt")
|
||||
s.Require().NoError(err)
|
||||
|
||||
session, err := s.manager.InitializeSession(ctx, sessionID, "prompt", 1)
|
||||
s.Require().NoError(err)
|
||||
|
||||
// Make session stale but add pending messages
|
||||
session.StartTime = time.Now().Add(-SessionTimeout - time.Minute)
|
||||
err = s.manager.QueueObservation(ctx, sessionID, ObservationData{ToolName: "Test"})
|
||||
s.Require().NoError(err)
|
||||
|
||||
// Run cleanup
|
||||
s.manager.cleanupStaleSessions()
|
||||
|
||||
// Session should NOT be deleted (has pending messages)
|
||||
s.Equal(1, s.manager.GetActiveSessionCount())
|
||||
}
|
||||
|
||||
// TestCleanupStaleSessions_WithActiveGenerator tests cleanup doesn't delete sessions with active generator.
|
||||
func (s *SessionIntegrationSuite) TestCleanupStaleSessions_WithActiveGenerator() {
|
||||
ctx := context.Background()
|
||||
|
||||
// Create and initialize a session
|
||||
sessionID, err := s.sessionStore.CreateSDKSession(ctx, "claude-stale-gen", "test-project", "prompt")
|
||||
s.Require().NoError(err)
|
||||
|
||||
session, err := s.manager.InitializeSession(ctx, sessionID, "prompt", 1)
|
||||
s.Require().NoError(err)
|
||||
|
||||
// Make session stale but mark generator as active
|
||||
session.StartTime = time.Now().Add(-SessionTimeout - time.Minute)
|
||||
session.generatorActive.Store(true)
|
||||
|
||||
// Run cleanup
|
||||
s.manager.cleanupStaleSessions()
|
||||
|
||||
// Session should NOT be deleted (generator is active)
|
||||
s.Equal(1, s.manager.GetActiveSessionCount())
|
||||
}
|
||||
|
||||
// TestConcurrentQueueOperations tests thread-safe queue operations.
|
||||
func (s *SessionIntegrationSuite) TestConcurrentQueueOperations() {
|
||||
ctx := context.Background()
|
||||
|
||||
// Create and initialize session
|
||||
sessionID, err := s.sessionStore.CreateSDKSession(ctx, "claude-concurrent", "test-project", "prompt")
|
||||
s.Require().NoError(err)
|
||||
|
||||
_, err = s.manager.InitializeSession(ctx, sessionID, "prompt", 1)
|
||||
s.Require().NoError(err)
|
||||
|
||||
// Concurrent queue operations
|
||||
done := make(chan bool)
|
||||
numGoroutines := 50
|
||||
|
||||
for i := 0; i < numGoroutines; i++ {
|
||||
go func(idx int) {
|
||||
if idx%2 == 0 {
|
||||
_ = s.manager.QueueObservation(ctx, sessionID, ObservationData{
|
||||
ToolName: "Tool",
|
||||
})
|
||||
} else {
|
||||
_ = s.manager.QueueSummarize(ctx, sessionID, "user", "assistant")
|
||||
}
|
||||
done <- true
|
||||
}(i)
|
||||
}
|
||||
|
||||
// Wait for all goroutines
|
||||
for i := 0; i < numGoroutines; i++ {
|
||||
<-done
|
||||
}
|
||||
|
||||
// All messages should be queued
|
||||
s.Equal(numGoroutines, s.manager.GetTotalQueueDepth())
|
||||
}
|
||||
|
||||
// TestShutdownAll_WithRealSessions tests shutdown of all real sessions.
|
||||
func (s *SessionIntegrationSuite) TestShutdownAll_WithRealSessions() {
|
||||
ctx := context.Background()
|
||||
|
||||
// Create and initialize multiple sessions
|
||||
for i := 0; i < 3; i++ {
|
||||
sessionID, err := s.sessionStore.CreateSDKSession(ctx, "claude-shutdown-"+string(rune('A'+i)), "project", "prompt")
|
||||
s.Require().NoError(err)
|
||||
|
||||
_, err = s.manager.InitializeSession(ctx, sessionID, "prompt", 1)
|
||||
s.Require().NoError(err)
|
||||
}
|
||||
|
||||
s.Equal(3, s.manager.GetActiveSessionCount())
|
||||
|
||||
// Shutdown all
|
||||
s.manager.ShutdownAll(ctx)
|
||||
|
||||
// All sessions should be deleted
|
||||
s.Equal(0, s.manager.GetActiveSessionCount())
|
||||
}
|
||||
|
||||
// TestSessionSDKSessionID tests SDK session ID handling.
|
||||
func (s *SessionIntegrationSuite) TestSessionSDKSessionID() {
|
||||
ctx := context.Background()
|
||||
|
||||
// Create session - SDK session ID is generated
|
||||
sessionID, err := s.sessionStore.CreateSDKSession(ctx, "claude-sdk-test", "test-project", "prompt")
|
||||
s.Require().NoError(err)
|
||||
|
||||
// Initialize in manager
|
||||
session, err := s.manager.InitializeSession(ctx, sessionID, "prompt", 1)
|
||||
s.Require().NoError(err)
|
||||
s.Require().NotNil(session)
|
||||
|
||||
// SDK session ID should be set
|
||||
s.NotEmpty(session.SDKSessionID)
|
||||
}
|
||||
|
||||
// TestPromptNumberTracking tests prompt number tracking across operations.
|
||||
func (s *SessionIntegrationSuite) TestPromptNumberTracking() {
|
||||
ctx := context.Background()
|
||||
|
||||
// Create session
|
||||
sessionID, err := s.sessionStore.CreateSDKSession(ctx, "claude-prompt-num", "test-project", "initial")
|
||||
s.Require().NoError(err)
|
||||
|
||||
// Initialize with prompt 1
|
||||
session, err := s.manager.InitializeSession(ctx, sessionID, "prompt 1", 1)
|
||||
s.Require().NoError(err)
|
||||
s.Equal(1, session.LastPromptNumber)
|
||||
|
||||
// Re-initialize with prompt 2
|
||||
session, err = s.manager.InitializeSession(ctx, sessionID, "prompt 2", 2)
|
||||
s.Require().NoError(err)
|
||||
s.Equal(2, session.LastPromptNumber)
|
||||
|
||||
// Re-initialize with prompt 5
|
||||
session, err = s.manager.InitializeSession(ctx, sessionID, "prompt 5", 5)
|
||||
s.Require().NoError(err)
|
||||
s.Equal(5, session.LastPromptNumber)
|
||||
}
|
||||
@@ -381,3 +381,165 @@ func TestBroadcasterConcurrentAddRemove(t *testing.T) {
|
||||
count := b.ClientCount()
|
||||
assert.GreaterOrEqual(t, count, 0)
|
||||
}
|
||||
|
||||
// TestRemoveClientByID tests removing a client by ID.
|
||||
func TestRemoveClientByID(t *testing.T) {
|
||||
b := NewBroadcaster()
|
||||
w := newMockResponseWriter()
|
||||
|
||||
client, err := b.AddClient(w)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, 1, b.ClientCount())
|
||||
|
||||
// Remove by ID
|
||||
b.removeClientByID(client.ID)
|
||||
assert.Equal(t, 0, b.ClientCount())
|
||||
|
||||
// Done channel should be closed
|
||||
select {
|
||||
case <-client.Done:
|
||||
// Expected
|
||||
default:
|
||||
t.Error("Done channel should be closed")
|
||||
}
|
||||
}
|
||||
|
||||
// TestRemoveClientByID_NonExistent tests removing a non-existent client by ID.
|
||||
func TestRemoveClientByID_NonExistent(t *testing.T) {
|
||||
b := NewBroadcaster()
|
||||
|
||||
// Should not panic
|
||||
b.removeClientByID("non-existent-id")
|
||||
assert.Equal(t, 0, b.ClientCount())
|
||||
}
|
||||
|
||||
// TestRemoveClientByID_AlreadyClosed tests removing a client with already closed Done channel.
|
||||
func TestRemoveClientByID_AlreadyClosed(t *testing.T) {
|
||||
b := NewBroadcaster()
|
||||
w := newMockResponseWriter()
|
||||
|
||||
client, err := b.AddClient(w)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Pre-close the Done channel
|
||||
close(client.Done)
|
||||
|
||||
// Should not panic when trying to close again
|
||||
b.removeClientByID(client.ID)
|
||||
assert.Equal(t, 0, b.ClientCount())
|
||||
}
|
||||
|
||||
// TestHandleSSE_NonFlusher tests HandleSSE with a non-flusher response writer.
|
||||
func TestHandleSSE_NonFlusher(t *testing.T) {
|
||||
b := NewBroadcaster()
|
||||
|
||||
// Create a response writer that doesn't implement Flusher
|
||||
nonFlusher := &nonFlusherWriter{header: make(http.Header)}
|
||||
|
||||
req := httptest.NewRequest(http.MethodGet, "/events", nil)
|
||||
|
||||
// Should return immediately since writer isn't a Flusher
|
||||
b.HandleSSE(nonFlusher, req)
|
||||
|
||||
// No clients should be added
|
||||
assert.Equal(t, 0, b.ClientCount())
|
||||
}
|
||||
|
||||
// nonFlusherWriter is a response writer that doesn't implement http.Flusher.
|
||||
type nonFlusherWriter struct {
|
||||
header http.Header
|
||||
}
|
||||
|
||||
func (w *nonFlusherWriter) Header() http.Header { return w.header }
|
||||
func (w *nonFlusherWriter) Write(data []byte) (int, error) { return len(data), nil }
|
||||
func (w *nonFlusherWriter) WriteHeader(statusCode int) {}
|
||||
|
||||
// TestWriteToClient_Timeout tests write timeout behavior.
|
||||
func TestWriteToClient_Timeout(t *testing.T) {
|
||||
b := NewBroadcaster()
|
||||
|
||||
// Create a slow writer that blocks
|
||||
slowWriter := &slowMockWriter{
|
||||
header: make(http.Header),
|
||||
blockFor: 5 * time.Second, // Longer than WriteTimeout
|
||||
}
|
||||
|
||||
client, err := b.AddClient(slowWriter)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Create dead client channel
|
||||
deadCh := make(chan string, 1)
|
||||
|
||||
// Try to write - should timeout
|
||||
msg := "data: test\n\n"
|
||||
b.writeToClient(client, msg, deadCh)
|
||||
|
||||
// May report dead client due to timeout
|
||||
// Check if client was reported as dead (with timeout)
|
||||
select {
|
||||
case deadID := <-deadCh:
|
||||
// Client was reported as dead
|
||||
assert.Equal(t, client.ID, deadID)
|
||||
case <-time.After(WriteTimeout + 500*time.Millisecond):
|
||||
// Timed out waiting - also acceptable
|
||||
}
|
||||
}
|
||||
|
||||
// slowMockWriter simulates a slow writer for testing timeouts.
|
||||
type slowMockWriter struct {
|
||||
header http.Header
|
||||
blockFor time.Duration
|
||||
mu sync.Mutex
|
||||
}
|
||||
|
||||
func (m *slowMockWriter) Header() http.Header {
|
||||
return m.header
|
||||
}
|
||||
|
||||
func (m *slowMockWriter) Write(data []byte) (int, error) {
|
||||
m.mu.Lock()
|
||||
defer m.mu.Unlock()
|
||||
time.Sleep(m.blockFor)
|
||||
return len(data), nil
|
||||
}
|
||||
|
||||
func (m *slowMockWriter) WriteHeader(statusCode int) {}
|
||||
|
||||
func (m *slowMockWriter) Flush() {}
|
||||
|
||||
// TestBroadcast_InvalidJSON tests broadcasting un-marshalable data.
|
||||
func TestBroadcast_InvalidJSON(t *testing.T) {
|
||||
b := NewBroadcaster()
|
||||
w := newMockResponseWriter()
|
||||
_, err := b.AddClient(w)
|
||||
require.NoError(t, err)
|
||||
|
||||
// channels can't be marshaled to JSON
|
||||
ch := make(chan int)
|
||||
b.Broadcast(ch) // Should log error but not panic
|
||||
|
||||
// Give time for async processing
|
||||
time.Sleep(20 * time.Millisecond)
|
||||
|
||||
// Body should be empty or not contain the channel data
|
||||
body := string(w.GetBody())
|
||||
assert.NotContains(t, body, "chan")
|
||||
}
|
||||
|
||||
// TestBroadcast_ClientDoneChannel tests broadcasting when client Done is closed.
|
||||
func TestBroadcast_ClientDoneChannel(t *testing.T) {
|
||||
b := NewBroadcaster()
|
||||
w := newMockResponseWriter()
|
||||
|
||||
client, err := b.AddClient(w)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Close the done channel
|
||||
close(client.Done)
|
||||
|
||||
// Broadcast should skip this client
|
||||
b.Broadcast(map[string]string{"type": "test"})
|
||||
|
||||
// Give time for async processing
|
||||
time.Sleep(20 * time.Millisecond)
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user