mirror of
https://github.com/lukaszraczylo/claude-mnemonic.git
synced 2026-06-16 02:51:45 +00:00
Increase tests coverage.
This commit is contained in:
@@ -2,11 +2,13 @@
|
||||
package worker
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"database/sql"
|
||||
"encoding/json"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"strconv"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
@@ -551,3 +553,729 @@ func TestRequireReadyMiddleware_Allows(t *testing.T) {
|
||||
assert.Equal(t, http.StatusOK, rec.Code)
|
||||
assert.Equal(t, "success", rec.Body.String())
|
||||
}
|
||||
|
||||
func TestHandleGetStats(t *testing.T) {
|
||||
svc, cleanup := testService(t)
|
||||
defer cleanup()
|
||||
|
||||
req := httptest.NewRequest(http.MethodGet, "/api/stats", nil)
|
||||
rec := httptest.NewRecorder()
|
||||
|
||||
svc.handleGetStats(rec, req)
|
||||
|
||||
assert.Equal(t, http.StatusOK, rec.Code)
|
||||
|
||||
var response map[string]interface{}
|
||||
err := json.Unmarshal(rec.Body.Bytes(), &response)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Check basic stats fields exist
|
||||
_, hasUptime := response["uptime"]
|
||||
assert.True(t, hasUptime)
|
||||
_, hasReady := response["ready"]
|
||||
assert.True(t, hasReady)
|
||||
}
|
||||
|
||||
func TestHandleGetStats_WithProject(t *testing.T) {
|
||||
svc, cleanup := testService(t)
|
||||
defer cleanup()
|
||||
|
||||
project := "test-project"
|
||||
createTestObservation(t, svc.observationStore, project, "Test", "Test content", []string{"test"})
|
||||
|
||||
req := httptest.NewRequest(http.MethodGet, "/api/stats?project="+project, nil)
|
||||
rec := httptest.NewRecorder()
|
||||
|
||||
svc.handleGetStats(rec, req)
|
||||
|
||||
assert.Equal(t, http.StatusOK, rec.Code)
|
||||
|
||||
var response map[string]interface{}
|
||||
err := json.Unmarshal(rec.Body.Bytes(), &response)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Check project-specific stats
|
||||
assert.Equal(t, project, response["project"])
|
||||
assert.Equal(t, float64(1), response["projectObservations"])
|
||||
}
|
||||
|
||||
func TestHandleGetRetrievalStats(t *testing.T) {
|
||||
svc, cleanup := testService(t)
|
||||
defer cleanup()
|
||||
|
||||
req := httptest.NewRequest(http.MethodGet, "/api/stats/retrieval", nil)
|
||||
rec := httptest.NewRecorder()
|
||||
|
||||
svc.handleGetRetrievalStats(rec, req)
|
||||
|
||||
assert.Equal(t, http.StatusOK, rec.Code)
|
||||
|
||||
var response RetrievalStats
|
||||
err := json.Unmarshal(rec.Body.Bytes(), &response)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Initially all stats should be 0
|
||||
assert.Equal(t, int64(0), response.TotalRequests)
|
||||
}
|
||||
|
||||
func TestHandleContextCount(t *testing.T) {
|
||||
svc, cleanup := testService(t)
|
||||
defer cleanup()
|
||||
|
||||
project := "count-project"
|
||||
|
||||
// Create some observations
|
||||
for i := 0; i < 5; i++ {
|
||||
createTestObservation(t, svc.observationStore, project, "Test "+string(rune('A'+i)), "Content", []string{"test"})
|
||||
}
|
||||
|
||||
req := httptest.NewRequest(http.MethodGet, "/api/context/count?project="+project, nil)
|
||||
rec := httptest.NewRecorder()
|
||||
|
||||
svc.handleContextCount(rec, req)
|
||||
|
||||
assert.Equal(t, http.StatusOK, rec.Code)
|
||||
|
||||
var response map[string]interface{}
|
||||
err := json.Unmarshal(rec.Body.Bytes(), &response)
|
||||
require.NoError(t, err)
|
||||
|
||||
assert.Equal(t, project, response["project"])
|
||||
assert.Equal(t, float64(5), response["count"])
|
||||
}
|
||||
|
||||
func TestHandleContextCount_MissingProject(t *testing.T) {
|
||||
svc, cleanup := testService(t)
|
||||
defer cleanup()
|
||||
|
||||
req := httptest.NewRequest(http.MethodGet, "/api/context/count", nil)
|
||||
rec := httptest.NewRecorder()
|
||||
|
||||
svc.handleContextCount(rec, req)
|
||||
|
||||
assert.Equal(t, http.StatusBadRequest, rec.Code)
|
||||
}
|
||||
|
||||
func TestHandleGetProjects(t *testing.T) {
|
||||
svc, cleanup := testService(t)
|
||||
defer cleanup()
|
||||
|
||||
// Create sessions for different projects
|
||||
ctx := context.Background()
|
||||
svc.sessionStore.CreateSDKSession(ctx, "session-1", "project-alpha", "")
|
||||
svc.sessionStore.CreateSDKSession(ctx, "session-2", "project-beta", "")
|
||||
svc.sessionStore.CreateSDKSession(ctx, "session-3", "project-gamma", "")
|
||||
|
||||
req := httptest.NewRequest(http.MethodGet, "/api/projects", nil)
|
||||
rec := httptest.NewRecorder()
|
||||
|
||||
svc.handleGetProjects(rec, req)
|
||||
|
||||
assert.Equal(t, http.StatusOK, rec.Code)
|
||||
|
||||
var projects []string
|
||||
err := json.Unmarshal(rec.Body.Bytes(), &projects)
|
||||
require.NoError(t, err)
|
||||
|
||||
assert.Len(t, projects, 3)
|
||||
assert.Contains(t, projects, "project-alpha")
|
||||
assert.Contains(t, projects, "project-beta")
|
||||
assert.Contains(t, projects, "project-gamma")
|
||||
}
|
||||
|
||||
func TestHandleGetTypes(t *testing.T) {
|
||||
svc, cleanup := testService(t)
|
||||
defer cleanup()
|
||||
|
||||
req := httptest.NewRequest(http.MethodGet, "/api/types", nil)
|
||||
rec := httptest.NewRecorder()
|
||||
|
||||
svc.handleGetTypes(rec, req)
|
||||
|
||||
assert.Equal(t, http.StatusOK, rec.Code)
|
||||
|
||||
var response map[string]interface{}
|
||||
err := json.Unmarshal(rec.Body.Bytes(), &response)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Check observation types
|
||||
obsTypes, ok := response["observation_types"].([]interface{})
|
||||
require.True(t, ok)
|
||||
assert.Contains(t, toStringSlice(obsTypes), "bugfix")
|
||||
assert.Contains(t, toStringSlice(obsTypes), "feature")
|
||||
|
||||
// Check concept types
|
||||
conceptTypes, ok := response["concept_types"].([]interface{})
|
||||
require.True(t, ok)
|
||||
assert.Contains(t, toStringSlice(conceptTypes), "security")
|
||||
}
|
||||
|
||||
func toStringSlice(arr []interface{}) []string {
|
||||
result := make([]string, len(arr))
|
||||
for i, v := range arr {
|
||||
result[i] = v.(string)
|
||||
}
|
||||
return result
|
||||
}
|
||||
|
||||
func TestHandleGetSummaries(t *testing.T) {
|
||||
svc, cleanup := testService(t)
|
||||
defer cleanup()
|
||||
|
||||
// Create some summaries
|
||||
ctx := context.Background()
|
||||
for i := 0; i < 3; i++ {
|
||||
parsed := &models.ParsedSummary{
|
||||
Request: "Test request " + string(rune('A'+i)),
|
||||
Completed: "Test completed",
|
||||
}
|
||||
sdkSessionID := "sdk-" + string(rune('a'+i))
|
||||
_, _, err := svc.summaryStore.StoreSummary(ctx, sdkSessionID, "project-a", parsed, i+1, 100)
|
||||
require.NoError(t, err)
|
||||
}
|
||||
|
||||
req := httptest.NewRequest(http.MethodGet, "/api/summaries?project=project-a&limit=10", nil)
|
||||
rec := httptest.NewRecorder()
|
||||
|
||||
svc.handleGetSummaries(rec, req)
|
||||
|
||||
assert.Equal(t, http.StatusOK, rec.Code)
|
||||
|
||||
var summaries []map[string]interface{}
|
||||
err := json.Unmarshal(rec.Body.Bytes(), &summaries)
|
||||
require.NoError(t, err)
|
||||
|
||||
assert.Len(t, summaries, 3)
|
||||
}
|
||||
|
||||
func TestHandleGetPrompts(t *testing.T) {
|
||||
svc, cleanup := testService(t)
|
||||
defer cleanup()
|
||||
|
||||
// Create sessions and prompts
|
||||
ctx := context.Background()
|
||||
svc.sessionStore.CreateSDKSession(ctx, "claude-test", "project-x", "")
|
||||
|
||||
// Save prompts
|
||||
for i := 0; i < 5; i++ {
|
||||
_, err := svc.promptStore.SaveUserPromptWithMatches(ctx, "claude-test", i+1, "Test prompt "+string(rune('A'+i)), 0)
|
||||
require.NoError(t, err)
|
||||
}
|
||||
|
||||
req := httptest.NewRequest(http.MethodGet, "/api/prompts?project=project-x&limit=10", nil)
|
||||
rec := httptest.NewRecorder()
|
||||
|
||||
svc.handleGetPrompts(rec, req)
|
||||
|
||||
assert.Equal(t, http.StatusOK, rec.Code)
|
||||
|
||||
var prompts []map[string]interface{}
|
||||
err := json.Unmarshal(rec.Body.Bytes(), &prompts)
|
||||
require.NoError(t, err)
|
||||
|
||||
assert.Len(t, prompts, 5)
|
||||
}
|
||||
|
||||
func TestHandleSelfCheck(t *testing.T) {
|
||||
svc, cleanup := testService(t)
|
||||
defer cleanup()
|
||||
|
||||
svc.ready.Store(true)
|
||||
|
||||
req := httptest.NewRequest(http.MethodGet, "/api/self-check", nil)
|
||||
rec := httptest.NewRecorder()
|
||||
|
||||
svc.handleSelfCheck(rec, req)
|
||||
|
||||
assert.Equal(t, http.StatusOK, rec.Code)
|
||||
|
||||
var response SelfCheckResponse
|
||||
err := json.Unmarshal(rec.Body.Bytes(), &response)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Overall health should be healthy or degraded (not unhealthy for basic tests)
|
||||
assert.NotEqual(t, "unhealthy", response.Overall)
|
||||
assert.NotEmpty(t, response.Version)
|
||||
assert.NotEmpty(t, response.Uptime)
|
||||
assert.NotEmpty(t, response.Components)
|
||||
}
|
||||
|
||||
func TestHandleSelfCheck_NotReady(t *testing.T) {
|
||||
svc, cleanup := testService(t)
|
||||
defer cleanup()
|
||||
|
||||
svc.ready.Store(false)
|
||||
|
||||
req := httptest.NewRequest(http.MethodGet, "/api/self-check", nil)
|
||||
rec := httptest.NewRecorder()
|
||||
|
||||
svc.handleSelfCheck(rec, req)
|
||||
|
||||
assert.Equal(t, http.StatusOK, rec.Code)
|
||||
|
||||
var response SelfCheckResponse
|
||||
err := json.Unmarshal(rec.Body.Bytes(), &response)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Should be degraded when not ready
|
||||
assert.Equal(t, "degraded", response.Overall)
|
||||
}
|
||||
|
||||
func TestObservationTypesAndConcepts(t *testing.T) {
|
||||
// Verify observation types
|
||||
assert.Contains(t, ObservationTypes, "bugfix")
|
||||
assert.Contains(t, ObservationTypes, "feature")
|
||||
assert.Contains(t, ObservationTypes, "refactor")
|
||||
assert.Contains(t, ObservationTypes, "discovery")
|
||||
assert.Contains(t, ObservationTypes, "decision")
|
||||
assert.Contains(t, ObservationTypes, "change")
|
||||
|
||||
// Verify concept types
|
||||
assert.Contains(t, ConceptTypes, "how-it-works")
|
||||
assert.Contains(t, ConceptTypes, "security")
|
||||
assert.Contains(t, ConceptTypes, "best-practice")
|
||||
}
|
||||
|
||||
func TestWriteJSON(t *testing.T) {
|
||||
rec := httptest.NewRecorder()
|
||||
|
||||
data := map[string]string{"test": "value"}
|
||||
writeJSON(rec, data)
|
||||
|
||||
assert.Equal(t, "application/json", rec.Header().Get("Content-Type"))
|
||||
|
||||
var result map[string]string
|
||||
err := json.Unmarshal(rec.Body.Bytes(), &result)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, "value", result["test"])
|
||||
}
|
||||
|
||||
func TestDefaultLimitConstants(t *testing.T) {
|
||||
assert.Equal(t, 100, DefaultObservationsLimit)
|
||||
assert.Equal(t, 50, DefaultSummariesLimit)
|
||||
assert.Equal(t, 100, DefaultPromptsLimit)
|
||||
assert.Equal(t, 50, DefaultSearchLimit)
|
||||
assert.Equal(t, 50, DefaultContextLimit)
|
||||
}
|
||||
|
||||
func TestDuplicatePromptWindowSeconds(t *testing.T) {
|
||||
assert.Equal(t, 10, DuplicatePromptWindowSeconds)
|
||||
}
|
||||
|
||||
func TestHandleSessionInit_Success(t *testing.T) {
|
||||
svc, cleanup := testService(t)
|
||||
defer cleanup()
|
||||
|
||||
reqBody := SessionInitRequest{
|
||||
ClaudeSessionID: "claude-test-123",
|
||||
Project: "test-project",
|
||||
Prompt: "Help me fix this bug",
|
||||
MatchedObservations: 5,
|
||||
}
|
||||
|
||||
body, _ := json.Marshal(reqBody)
|
||||
req := httptest.NewRequest(http.MethodPost, "/api/sessions/init", bytes.NewReader(body))
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
rec := httptest.NewRecorder()
|
||||
|
||||
svc.router.ServeHTTP(rec, req)
|
||||
|
||||
assert.Equal(t, http.StatusOK, rec.Code)
|
||||
|
||||
var response SessionInitResponse
|
||||
err := json.Unmarshal(rec.Body.Bytes(), &response)
|
||||
require.NoError(t, err)
|
||||
|
||||
assert.Greater(t, response.SessionDBID, int64(0))
|
||||
assert.Equal(t, 1, response.PromptNumber)
|
||||
assert.False(t, response.Skipped)
|
||||
}
|
||||
|
||||
func TestHandleSessionInit_InvalidJSON(t *testing.T) {
|
||||
svc, cleanup := testService(t)
|
||||
defer cleanup()
|
||||
|
||||
req := httptest.NewRequest(http.MethodPost, "/api/sessions/init", bytes.NewReader([]byte("invalid json")))
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
rec := httptest.NewRecorder()
|
||||
|
||||
svc.router.ServeHTTP(rec, req)
|
||||
|
||||
assert.Equal(t, http.StatusBadRequest, rec.Code)
|
||||
}
|
||||
|
||||
func TestHandleSessionInit_PrivatePrompt(t *testing.T) {
|
||||
svc, cleanup := testService(t)
|
||||
defer cleanup()
|
||||
|
||||
reqBody := SessionInitRequest{
|
||||
ClaudeSessionID: "claude-private",
|
||||
Project: "test-project",
|
||||
Prompt: "<private>This is a private prompt</private>",
|
||||
}
|
||||
|
||||
body, _ := json.Marshal(reqBody)
|
||||
req := httptest.NewRequest(http.MethodPost, "/api/sessions/init", bytes.NewReader(body))
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
rec := httptest.NewRecorder()
|
||||
|
||||
svc.router.ServeHTTP(rec, req)
|
||||
|
||||
assert.Equal(t, http.StatusOK, rec.Code)
|
||||
|
||||
var response SessionInitResponse
|
||||
err := json.Unmarshal(rec.Body.Bytes(), &response)
|
||||
require.NoError(t, err)
|
||||
|
||||
assert.True(t, response.Skipped)
|
||||
assert.Equal(t, "private", response.Reason)
|
||||
}
|
||||
|
||||
func TestHandleSessionInit_DuplicatePrompt(t *testing.T) {
|
||||
svc, cleanup := testService(t)
|
||||
defer cleanup()
|
||||
|
||||
reqBody := SessionInitRequest{
|
||||
ClaudeSessionID: "claude-dup-test",
|
||||
Project: "test-project",
|
||||
Prompt: "Help me fix this specific bug",
|
||||
}
|
||||
|
||||
body, _ := json.Marshal(reqBody)
|
||||
|
||||
// First request
|
||||
req1 := httptest.NewRequest(http.MethodPost, "/api/sessions/init", bytes.NewReader(body))
|
||||
req1.Header.Set("Content-Type", "application/json")
|
||||
rec1 := httptest.NewRecorder()
|
||||
svc.router.ServeHTTP(rec1, req1)
|
||||
|
||||
assert.Equal(t, http.StatusOK, rec1.Code)
|
||||
var resp1 SessionInitResponse
|
||||
json.Unmarshal(rec1.Body.Bytes(), &resp1)
|
||||
|
||||
// Second request with same prompt (duplicate)
|
||||
body2, _ := json.Marshal(reqBody)
|
||||
req2 := httptest.NewRequest(http.MethodPost, "/api/sessions/init", bytes.NewReader(body2))
|
||||
req2.Header.Set("Content-Type", "application/json")
|
||||
rec2 := httptest.NewRecorder()
|
||||
svc.router.ServeHTTP(rec2, req2)
|
||||
|
||||
assert.Equal(t, http.StatusOK, rec2.Code)
|
||||
var resp2 SessionInitResponse
|
||||
json.Unmarshal(rec2.Body.Bytes(), &resp2)
|
||||
|
||||
// Should return same prompt number (duplicate detected)
|
||||
assert.Equal(t, resp1.PromptNumber, resp2.PromptNumber)
|
||||
}
|
||||
|
||||
func TestHandleSessionStart_Success(t *testing.T) {
|
||||
svc, cleanup := testService(t)
|
||||
defer cleanup()
|
||||
|
||||
// First create a session
|
||||
ctx := context.Background()
|
||||
sessionID, _ := svc.sessionStore.CreateSDKSession(ctx, "claude-start-test", "test-project", "test prompt")
|
||||
|
||||
reqBody := SessionStartRequest{
|
||||
UserPrompt: "Help me with something",
|
||||
PromptNumber: 1,
|
||||
}
|
||||
|
||||
body, _ := json.Marshal(reqBody)
|
||||
req := httptest.NewRequest(http.MethodPost, "/sessions/"+strconv.FormatInt(sessionID, 10)+"/init", bytes.NewReader(body))
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
rec := httptest.NewRecorder()
|
||||
|
||||
svc.router.ServeHTTP(rec, req)
|
||||
|
||||
assert.Equal(t, http.StatusOK, rec.Code)
|
||||
}
|
||||
|
||||
func TestHandleSessionStart_InvalidID(t *testing.T) {
|
||||
svc, cleanup := testService(t)
|
||||
defer cleanup()
|
||||
|
||||
reqBody := SessionStartRequest{
|
||||
UserPrompt: "Help me",
|
||||
PromptNumber: 1,
|
||||
}
|
||||
|
||||
body, _ := json.Marshal(reqBody)
|
||||
req := httptest.NewRequest(http.MethodPost, "/sessions/invalid/init", bytes.NewReader(body))
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
rec := httptest.NewRecorder()
|
||||
|
||||
svc.router.ServeHTTP(rec, req)
|
||||
|
||||
assert.Equal(t, http.StatusBadRequest, rec.Code)
|
||||
}
|
||||
|
||||
func TestHandleSessionStart_NotFound(t *testing.T) {
|
||||
svc, cleanup := testService(t)
|
||||
defer cleanup()
|
||||
|
||||
reqBody := SessionStartRequest{
|
||||
UserPrompt: "Help me",
|
||||
PromptNumber: 1,
|
||||
}
|
||||
|
||||
body, _ := json.Marshal(reqBody)
|
||||
req := httptest.NewRequest(http.MethodPost, "/sessions/999999/init", bytes.NewReader(body))
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
rec := httptest.NewRecorder()
|
||||
|
||||
svc.router.ServeHTTP(rec, req)
|
||||
|
||||
assert.Equal(t, http.StatusNotFound, rec.Code)
|
||||
}
|
||||
|
||||
func TestHandleSessionStart_InvalidJSON(t *testing.T) {
|
||||
svc, cleanup := testService(t)
|
||||
defer cleanup()
|
||||
|
||||
ctx := context.Background()
|
||||
sessionID, _ := svc.sessionStore.CreateSDKSession(ctx, "claude-json-test", "test-project", "")
|
||||
|
||||
req := httptest.NewRequest(http.MethodPost, "/sessions/"+strconv.FormatInt(sessionID, 10)+"/init", bytes.NewReader([]byte("not json")))
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
rec := httptest.NewRecorder()
|
||||
|
||||
svc.router.ServeHTTP(rec, req)
|
||||
|
||||
assert.Equal(t, http.StatusBadRequest, rec.Code)
|
||||
}
|
||||
|
||||
func TestHandleObservation_SessionNotFound(t *testing.T) {
|
||||
svc, cleanup := testService(t)
|
||||
defer cleanup()
|
||||
|
||||
reqBody := ObservationRequest{
|
||||
ClaudeSessionID: "non-existent-session",
|
||||
Project: "test-project",
|
||||
ToolName: "Read",
|
||||
ToolInput: map[string]string{"path": "/test.go"},
|
||||
ToolResponse: "file content",
|
||||
CWD: "/test",
|
||||
}
|
||||
|
||||
body, _ := json.Marshal(reqBody)
|
||||
req := httptest.NewRequest(http.MethodPost, "/api/sessions/observations", bytes.NewReader(body))
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
rec := httptest.NewRecorder()
|
||||
|
||||
svc.router.ServeHTTP(rec, req)
|
||||
|
||||
// Should return 200 (queues observation) or 404 (session not found)
|
||||
assert.Contains(t, []int{http.StatusOK, http.StatusNotFound}, rec.Code)
|
||||
}
|
||||
|
||||
func TestHandleObservation_InvalidJSON(t *testing.T) {
|
||||
svc, cleanup := testService(t)
|
||||
defer cleanup()
|
||||
|
||||
req := httptest.NewRequest(http.MethodPost, "/api/sessions/observations", bytes.NewReader([]byte("invalid")))
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
rec := httptest.NewRecorder()
|
||||
|
||||
svc.router.ServeHTTP(rec, req)
|
||||
|
||||
assert.Equal(t, http.StatusBadRequest, rec.Code)
|
||||
}
|
||||
|
||||
func TestHandleObservation_WithExistingSession(t *testing.T) {
|
||||
svc, cleanup := testService(t)
|
||||
defer cleanup()
|
||||
|
||||
// Create a session first
|
||||
ctx := context.Background()
|
||||
svc.sessionStore.CreateSDKSession(ctx, "claude-obs-test", "test-project", "test prompt")
|
||||
|
||||
reqBody := ObservationRequest{
|
||||
ClaudeSessionID: "claude-obs-test",
|
||||
Project: "test-project",
|
||||
ToolName: "Write",
|
||||
ToolInput: map[string]string{"path": "/test.go"},
|
||||
ToolResponse: "success",
|
||||
CWD: "/project",
|
||||
}
|
||||
|
||||
body, _ := json.Marshal(reqBody)
|
||||
req := httptest.NewRequest(http.MethodPost, "/api/sessions/observations", bytes.NewReader(body))
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
rec := httptest.NewRecorder()
|
||||
|
||||
svc.router.ServeHTTP(rec, req)
|
||||
|
||||
assert.Equal(t, http.StatusOK, rec.Code)
|
||||
}
|
||||
|
||||
func TestHandleGetObservations_DefaultLimit(t *testing.T) {
|
||||
svc, cleanup := testService(t)
|
||||
defer cleanup()
|
||||
|
||||
// Create more than default limit
|
||||
for i := 0; i < 120; i++ {
|
||||
createTestObservation(t, svc.observationStore, "project-limit",
|
||||
"Test "+strconv.Itoa(i),
|
||||
"Content "+strconv.Itoa(i),
|
||||
[]string{"test"})
|
||||
}
|
||||
|
||||
req := httptest.NewRequest(http.MethodGet, "/api/observations", nil)
|
||||
rec := httptest.NewRecorder()
|
||||
|
||||
svc.router.ServeHTTP(rec, req)
|
||||
|
||||
assert.Equal(t, http.StatusOK, rec.Code)
|
||||
|
||||
var observations []map[string]interface{}
|
||||
err := json.Unmarshal(rec.Body.Bytes(), &observations)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Should return default limit (100)
|
||||
assert.LessOrEqual(t, len(observations), DefaultObservationsLimit)
|
||||
}
|
||||
|
||||
func TestHandleGetObservations_FilterByProject(t *testing.T) {
|
||||
svc, cleanup := testService(t)
|
||||
defer cleanup()
|
||||
|
||||
// Create observations in different projects
|
||||
createTestObservation(t, svc.observationStore, "alpha", "Alpha 1", "Content", []string{"test"})
|
||||
createTestObservation(t, svc.observationStore, "alpha", "Alpha 2", "Content", []string{"test"})
|
||||
createTestObservation(t, svc.observationStore, "beta", "Beta 1", "Content", []string{"test"})
|
||||
|
||||
req := httptest.NewRequest(http.MethodGet, "/api/observations?project=alpha", nil)
|
||||
rec := httptest.NewRecorder()
|
||||
|
||||
svc.router.ServeHTTP(rec, req)
|
||||
|
||||
assert.Equal(t, http.StatusOK, rec.Code)
|
||||
|
||||
var observations []map[string]interface{}
|
||||
err := json.Unmarshal(rec.Body.Bytes(), &observations)
|
||||
require.NoError(t, err)
|
||||
|
||||
assert.Len(t, observations, 2)
|
||||
}
|
||||
|
||||
func TestHandleGetObservations_FilterByType(t *testing.T) {
|
||||
svc, cleanup := testService(t)
|
||||
defer cleanup()
|
||||
|
||||
// Create observations - createTestObservation creates discovery type
|
||||
createTestObservation(t, svc.observationStore, "type-test", "Test 1", "Content", []string{"test"})
|
||||
createTestObservation(t, svc.observationStore, "type-test", "Test 2", "Content", []string{"test"})
|
||||
|
||||
req := httptest.NewRequest(http.MethodGet, "/api/observations?type=discovery", nil)
|
||||
rec := httptest.NewRecorder()
|
||||
|
||||
svc.router.ServeHTTP(rec, req)
|
||||
|
||||
assert.Equal(t, http.StatusOK, rec.Code)
|
||||
}
|
||||
|
||||
func TestHandleGetSummaries_DefaultLimit(t *testing.T) {
|
||||
svc, cleanup := testService(t)
|
||||
defer cleanup()
|
||||
|
||||
ctx := context.Background()
|
||||
// Create more than default limit
|
||||
for i := 0; i < 60; i++ {
|
||||
parsed := &models.ParsedSummary{Request: "Request " + strconv.Itoa(i)}
|
||||
svc.summaryStore.StoreSummary(ctx, "sdk-"+strconv.Itoa(i), "project-sum", parsed, i+1, 100)
|
||||
}
|
||||
|
||||
req := httptest.NewRequest(http.MethodGet, "/api/summaries", nil)
|
||||
rec := httptest.NewRecorder()
|
||||
|
||||
svc.router.ServeHTTP(rec, req)
|
||||
|
||||
assert.Equal(t, http.StatusOK, rec.Code)
|
||||
|
||||
var summaries []map[string]interface{}
|
||||
err := json.Unmarshal(rec.Body.Bytes(), &summaries)
|
||||
require.NoError(t, err)
|
||||
|
||||
assert.LessOrEqual(t, len(summaries), DefaultSummariesLimit)
|
||||
}
|
||||
|
||||
func TestHandleGetPrompts_DefaultLimit(t *testing.T) {
|
||||
svc, cleanup := testService(t)
|
||||
defer cleanup()
|
||||
|
||||
ctx := context.Background()
|
||||
svc.sessionStore.CreateSDKSession(ctx, "claude-prompts", "project-prompts", "")
|
||||
|
||||
// Create more than default limit
|
||||
for i := 0; i < 120; i++ {
|
||||
svc.promptStore.SaveUserPromptWithMatches(ctx, "claude-prompts", i+1, "Prompt "+strconv.Itoa(i), 0)
|
||||
}
|
||||
|
||||
req := httptest.NewRequest(http.MethodGet, "/api/prompts", nil)
|
||||
rec := httptest.NewRecorder()
|
||||
|
||||
svc.router.ServeHTTP(rec, req)
|
||||
|
||||
assert.Equal(t, http.StatusOK, rec.Code)
|
||||
|
||||
var prompts []map[string]interface{}
|
||||
err := json.Unmarshal(rec.Body.Bytes(), &prompts)
|
||||
require.NoError(t, err)
|
||||
|
||||
assert.LessOrEqual(t, len(prompts), DefaultPromptsLimit)
|
||||
}
|
||||
|
||||
func TestSessionInitRequest_Fields(t *testing.T) {
|
||||
req := SessionInitRequest{
|
||||
ClaudeSessionID: "test-123",
|
||||
Project: "my-project",
|
||||
Prompt: "Help me",
|
||||
MatchedObservations: 10,
|
||||
}
|
||||
|
||||
assert.Equal(t, "test-123", req.ClaudeSessionID)
|
||||
assert.Equal(t, "my-project", req.Project)
|
||||
assert.Equal(t, "Help me", req.Prompt)
|
||||
assert.Equal(t, 10, req.MatchedObservations)
|
||||
}
|
||||
|
||||
func TestSessionInitResponse_Fields(t *testing.T) {
|
||||
resp := SessionInitResponse{
|
||||
SessionDBID: 123,
|
||||
PromptNumber: 5,
|
||||
Skipped: true,
|
||||
Reason: "private",
|
||||
}
|
||||
|
||||
assert.Equal(t, int64(123), resp.SessionDBID)
|
||||
assert.Equal(t, 5, resp.PromptNumber)
|
||||
assert.True(t, resp.Skipped)
|
||||
assert.Equal(t, "private", resp.Reason)
|
||||
}
|
||||
|
||||
func TestSessionStartRequest_Fields(t *testing.T) {
|
||||
req := SessionStartRequest{
|
||||
UserPrompt: "Help me with code",
|
||||
PromptNumber: 3,
|
||||
}
|
||||
|
||||
assert.Equal(t, "Help me with code", req.UserPrompt)
|
||||
assert.Equal(t, 3, req.PromptNumber)
|
||||
}
|
||||
|
||||
func TestObservationRequest_Fields(t *testing.T) {
|
||||
req := ObservationRequest{
|
||||
ClaudeSessionID: "session-abc",
|
||||
Project: "my-project",
|
||||
ToolName: "Read",
|
||||
ToolInput: map[string]string{"path": "/file.go"},
|
||||
ToolResponse: "file contents",
|
||||
CWD: "/home/user/project",
|
||||
}
|
||||
|
||||
assert.Equal(t, "session-abc", req.ClaudeSessionID)
|
||||
assert.Equal(t, "my-project", req.Project)
|
||||
assert.Equal(t, "Read", req.ToolName)
|
||||
assert.Equal(t, "/home/user/project", req.CWD)
|
||||
}
|
||||
|
||||
@@ -0,0 +1,537 @@
|
||||
package sdk
|
||||
|
||||
import (
|
||||
"testing"
|
||||
|
||||
"github.com/lukaszraczylo/claude-mnemonic/pkg/models"
|
||||
"github.com/stretchr/testify/assert"
|
||||
)
|
||||
|
||||
func TestParseObservations_SingleObservation(t *testing.T) {
|
||||
text := `Some text before
|
||||
<observation>
|
||||
<type>bugfix</type>
|
||||
<title>Fixed null pointer error</title>
|
||||
<subtitle>In user service</subtitle>
|
||||
<narrative>The service was crashing when user ID was nil</narrative>
|
||||
<facts>
|
||||
<fact>Added nil check</fact>
|
||||
<fact>Added unit test</fact>
|
||||
</facts>
|
||||
<concepts>
|
||||
<concept>error-handling</concept>
|
||||
<concept>debugging</concept>
|
||||
</concepts>
|
||||
<files_read>
|
||||
<file>user_service.go</file>
|
||||
</files_read>
|
||||
<files_modified>
|
||||
<file>user_service.go</file>
|
||||
<file>user_service_test.go</file>
|
||||
</files_modified>
|
||||
</observation>
|
||||
Some text after`
|
||||
|
||||
observations := ParseObservations(text, "test-correlation-id")
|
||||
|
||||
assert.Len(t, observations, 1)
|
||||
obs := observations[0]
|
||||
assert.Equal(t, models.ObservationType("bugfix"), obs.Type)
|
||||
assert.Equal(t, "Fixed null pointer error", obs.Title)
|
||||
assert.Equal(t, "In user service", obs.Subtitle)
|
||||
assert.Equal(t, "The service was crashing when user ID was nil", obs.Narrative)
|
||||
assert.Equal(t, []string{"Added nil check", "Added unit test"}, obs.Facts)
|
||||
assert.Equal(t, []string{"error-handling", "debugging"}, obs.Concepts)
|
||||
assert.Equal(t, []string{"user_service.go"}, obs.FilesRead)
|
||||
assert.Equal(t, []string{"user_service.go", "user_service_test.go"}, obs.FilesModified)
|
||||
}
|
||||
|
||||
func TestParseObservations_MultipleObservations(t *testing.T) {
|
||||
text := `
|
||||
<observation>
|
||||
<type>feature</type>
|
||||
<title>Added caching</title>
|
||||
<narrative>Implemented Redis caching</narrative>
|
||||
<facts><fact>Added cache layer</fact></facts>
|
||||
<concepts><concept>caching</concept></concepts>
|
||||
</observation>
|
||||
<observation>
|
||||
<type>refactor</type>
|
||||
<title>Cleaned up code</title>
|
||||
<narrative>Removed dead code</narrative>
|
||||
<facts><fact>Removed unused functions</fact></facts>
|
||||
<concepts><concept>refactoring</concept></concepts>
|
||||
</observation>
|
||||
`
|
||||
|
||||
observations := ParseObservations(text, "test-id")
|
||||
|
||||
assert.Len(t, observations, 2)
|
||||
assert.Equal(t, models.ObservationType("feature"), observations[0].Type)
|
||||
assert.Equal(t, "Added caching", observations[0].Title)
|
||||
assert.Equal(t, models.ObservationType("refactor"), observations[1].Type)
|
||||
assert.Equal(t, "Cleaned up code", observations[1].Title)
|
||||
}
|
||||
|
||||
func TestParseObservations_TableDriven(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
input string
|
||||
expectedCount int
|
||||
expectedType models.ObservationType
|
||||
expectedTitle string
|
||||
checkConcepts []string
|
||||
}{
|
||||
{
|
||||
name: "valid_bugfix_observation",
|
||||
input: `<observation>
|
||||
<type>bugfix</type>
|
||||
<title>Fixed bug</title>
|
||||
<narrative>Details</narrative>
|
||||
</observation>`,
|
||||
expectedCount: 1,
|
||||
expectedType: models.ObsTypeBugfix,
|
||||
expectedTitle: "Fixed bug",
|
||||
},
|
||||
{
|
||||
name: "valid_feature_observation",
|
||||
input: `<observation>
|
||||
<type>feature</type>
|
||||
<title>New feature</title>
|
||||
<narrative>Added new stuff</narrative>
|
||||
</observation>`,
|
||||
expectedCount: 1,
|
||||
expectedType: models.ObsTypeFeature,
|
||||
expectedTitle: "New feature",
|
||||
},
|
||||
{
|
||||
name: "valid_refactor_observation",
|
||||
input: `<observation>
|
||||
<type>refactor</type>
|
||||
<title>Code cleanup</title>
|
||||
<narrative>Refactored module</narrative>
|
||||
</observation>`,
|
||||
expectedCount: 1,
|
||||
expectedType: models.ObsTypeRefactor,
|
||||
expectedTitle: "Code cleanup",
|
||||
},
|
||||
{
|
||||
name: "valid_change_observation",
|
||||
input: `<observation>
|
||||
<type>change</type>
|
||||
<title>Config update</title>
|
||||
<narrative>Changed settings</narrative>
|
||||
</observation>`,
|
||||
expectedCount: 1,
|
||||
expectedType: models.ObsTypeChange,
|
||||
expectedTitle: "Config update",
|
||||
},
|
||||
{
|
||||
name: "valid_discovery_observation",
|
||||
input: `<observation>
|
||||
<type>discovery</type>
|
||||
<title>Found pattern</title>
|
||||
<narrative>Discovered new pattern</narrative>
|
||||
</observation>`,
|
||||
expectedCount: 1,
|
||||
expectedType: models.ObsTypeDiscovery,
|
||||
expectedTitle: "Found pattern",
|
||||
},
|
||||
{
|
||||
name: "valid_decision_observation",
|
||||
input: `<observation>
|
||||
<type>decision</type>
|
||||
<title>Architecture decision</title>
|
||||
<narrative>Chose microservices</narrative>
|
||||
</observation>`,
|
||||
expectedCount: 1,
|
||||
expectedType: models.ObsTypeDecision,
|
||||
expectedTitle: "Architecture decision",
|
||||
},
|
||||
{
|
||||
name: "invalid_type_defaults_to_change",
|
||||
input: `<observation>
|
||||
<type>invalid_type</type>
|
||||
<title>Some title</title>
|
||||
<narrative>Details</narrative>
|
||||
</observation>`,
|
||||
expectedCount: 1,
|
||||
expectedType: models.ObsTypeChange,
|
||||
expectedTitle: "Some title",
|
||||
},
|
||||
{
|
||||
name: "missing_type_defaults_to_change",
|
||||
input: `<observation>
|
||||
<title>No type specified</title>
|
||||
<narrative>Details</narrative>
|
||||
</observation>`,
|
||||
expectedCount: 1,
|
||||
expectedType: models.ObsTypeChange,
|
||||
expectedTitle: "No type specified",
|
||||
},
|
||||
{
|
||||
name: "empty_input",
|
||||
input: "",
|
||||
expectedCount: 0,
|
||||
},
|
||||
{
|
||||
name: "no_observation_tags",
|
||||
input: "Just regular text without any observation",
|
||||
expectedCount: 0,
|
||||
},
|
||||
{
|
||||
name: "valid_concepts_filtered",
|
||||
input: `<observation>
|
||||
<type>bugfix</type>
|
||||
<title>Test</title>
|
||||
<narrative>Test</narrative>
|
||||
<concepts>
|
||||
<concept>best-practice</concept>
|
||||
<concept>invalid-concept</concept>
|
||||
<concept>security</concept>
|
||||
</concepts>
|
||||
</observation>`,
|
||||
expectedCount: 1,
|
||||
expectedType: models.ObsTypeBugfix,
|
||||
checkConcepts: []string{"best-practice", "security"},
|
||||
},
|
||||
{
|
||||
name: "type_in_concepts_filtered_out",
|
||||
input: `<observation>
|
||||
<type>bugfix</type>
|
||||
<title>Test</title>
|
||||
<narrative>Test</narrative>
|
||||
<concepts>
|
||||
<concept>bugfix</concept>
|
||||
<concept>security</concept>
|
||||
</concepts>
|
||||
</observation>`,
|
||||
expectedCount: 1,
|
||||
expectedType: models.ObsTypeBugfix,
|
||||
checkConcepts: []string{"security"},
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
observations := ParseObservations(tt.input, "test-correlation-id")
|
||||
|
||||
assert.Len(t, observations, tt.expectedCount)
|
||||
if tt.expectedCount > 0 {
|
||||
obs := observations[0]
|
||||
assert.Equal(t, tt.expectedType, obs.Type)
|
||||
if tt.expectedTitle != "" {
|
||||
assert.Equal(t, tt.expectedTitle, obs.Title)
|
||||
}
|
||||
if tt.checkConcepts != nil {
|
||||
assert.Equal(t, tt.checkConcepts, obs.Concepts)
|
||||
}
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestParseObservations_AllValidConcepts(t *testing.T) {
|
||||
// Test all valid concepts are accepted
|
||||
validConcepts := []string{
|
||||
"how-it-works", "why-it-exists", "what-changed", "problem-solution", "gotcha", "pattern", "trade-off",
|
||||
"best-practice", "anti-pattern", "architecture", "security", "performance", "testing", "debugging", "workflow", "tooling",
|
||||
"refactoring", "api", "database", "configuration", "error-handling", "caching", "logging", "auth", "validation",
|
||||
}
|
||||
|
||||
for _, concept := range validConcepts {
|
||||
t.Run("concept_"+concept, func(t *testing.T) {
|
||||
input := `<observation>
|
||||
<type>discovery</type>
|
||||
<title>Test</title>
|
||||
<narrative>Test</narrative>
|
||||
<concepts><concept>` + concept + `</concept></concepts>
|
||||
</observation>`
|
||||
|
||||
observations := ParseObservations(input, "test-id")
|
||||
assert.Len(t, observations, 1)
|
||||
assert.Contains(t, observations[0].Concepts, concept)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestParseObservations_ConceptCaseInsensitive(t *testing.T) {
|
||||
input := `<observation>
|
||||
<type>discovery</type>
|
||||
<title>Test</title>
|
||||
<narrative>Test</narrative>
|
||||
<concepts>
|
||||
<concept>SECURITY</concept>
|
||||
<concept>Best-Practice</concept>
|
||||
<concept> caching </concept>
|
||||
</concepts>
|
||||
</observation>`
|
||||
|
||||
observations := ParseObservations(input, "test-id")
|
||||
|
||||
assert.Len(t, observations, 1)
|
||||
assert.Equal(t, []string{"security", "best-practice", "caching"}, observations[0].Concepts)
|
||||
}
|
||||
|
||||
func TestParseSummary_ValidSummary(t *testing.T) {
|
||||
text := `Some text before
|
||||
<summary>
|
||||
<request>User asked to fix the bug</request>
|
||||
<investigated>Looked at error logs and stack traces</investigated>
|
||||
<learned>The issue was a race condition</learned>
|
||||
<completed>Fixed the race condition with mutex</completed>
|
||||
<next_steps>Add more tests for concurrent access</next_steps>
|
||||
<notes>May need to review similar code elsewhere</notes>
|
||||
</summary>
|
||||
Some text after`
|
||||
|
||||
summary := ParseSummary(text, 123)
|
||||
|
||||
assert.NotNil(t, summary)
|
||||
assert.Equal(t, "User asked to fix the bug", summary.Request)
|
||||
assert.Equal(t, "Looked at error logs and stack traces", summary.Investigated)
|
||||
assert.Equal(t, "The issue was a race condition", summary.Learned)
|
||||
assert.Equal(t, "Fixed the race condition with mutex", summary.Completed)
|
||||
assert.Equal(t, "Add more tests for concurrent access", summary.NextSteps)
|
||||
assert.Equal(t, "May need to review similar code elsewhere", summary.Notes)
|
||||
}
|
||||
|
||||
func TestParseSummary_TableDriven(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
input string
|
||||
sessionID int64
|
||||
expectNil bool
|
||||
expectedRequest string
|
||||
}{
|
||||
{
|
||||
name: "empty_input",
|
||||
input: "",
|
||||
sessionID: 1,
|
||||
expectNil: true,
|
||||
},
|
||||
{
|
||||
name: "no_summary_tag",
|
||||
input: "Just some text without summary",
|
||||
sessionID: 1,
|
||||
expectNil: true,
|
||||
},
|
||||
{
|
||||
name: "skip_summary_tag",
|
||||
input: `<skip_summary reason="No significant changes made"/>`,
|
||||
sessionID: 1,
|
||||
expectNil: true,
|
||||
},
|
||||
{
|
||||
name: "skip_summary_with_different_reason",
|
||||
input: `<skip_summary reason="Only read files"/>`,
|
||||
sessionID: 2,
|
||||
expectNil: true,
|
||||
},
|
||||
{
|
||||
name: "valid_summary_minimal",
|
||||
input: `<summary>
|
||||
<request>Test request</request>
|
||||
</summary>`,
|
||||
sessionID: 3,
|
||||
expectNil: false,
|
||||
expectedRequest: "Test request",
|
||||
},
|
||||
{
|
||||
name: "valid_summary_all_fields",
|
||||
input: `<summary>
|
||||
<request>Full request</request>
|
||||
<investigated>Full investigated</investigated>
|
||||
<learned>Full learned</learned>
|
||||
<completed>Full completed</completed>
|
||||
<next_steps>Full next steps</next_steps>
|
||||
<notes>Full notes</notes>
|
||||
</summary>`,
|
||||
sessionID: 4,
|
||||
expectNil: false,
|
||||
expectedRequest: "Full request",
|
||||
},
|
||||
{
|
||||
name: "summary_with_empty_fields",
|
||||
input: `<summary>
|
||||
<request></request>
|
||||
<investigated></investigated>
|
||||
</summary>`,
|
||||
sessionID: 5,
|
||||
expectNil: false,
|
||||
expectedRequest: "",
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
summary := ParseSummary(tt.input, tt.sessionID)
|
||||
|
||||
if tt.expectNil {
|
||||
assert.Nil(t, summary)
|
||||
} else {
|
||||
assert.NotNil(t, summary)
|
||||
assert.Equal(t, tt.expectedRequest, summary.Request)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestParseSummary_SkipSummaryPriority(t *testing.T) {
|
||||
// skip_summary should take priority over summary block
|
||||
text := `<skip_summary reason="No changes"/>
|
||||
<summary>
|
||||
<request>This should be ignored</request>
|
||||
</summary>`
|
||||
|
||||
summary := ParseSummary(text, 1)
|
||||
assert.Nil(t, summary)
|
||||
}
|
||||
|
||||
func TestExtractField_TableDriven(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
content string
|
||||
fieldName string
|
||||
expected string
|
||||
}{
|
||||
{
|
||||
name: "simple_field",
|
||||
content: "<title>Test Title</title>",
|
||||
fieldName: "title",
|
||||
expected: "Test Title",
|
||||
},
|
||||
{
|
||||
name: "field_with_whitespace",
|
||||
content: "<title> Test Title </title>",
|
||||
fieldName: "title",
|
||||
expected: "Test Title",
|
||||
},
|
||||
{
|
||||
name: "field_not_found",
|
||||
content: "<other>Value</other>",
|
||||
fieldName: "title",
|
||||
expected: "",
|
||||
},
|
||||
{
|
||||
name: "empty_field",
|
||||
content: "<title></title>",
|
||||
fieldName: "title",
|
||||
expected: "",
|
||||
},
|
||||
{
|
||||
name: "nested_content",
|
||||
content: "<wrapper><title>Nested</title></wrapper>",
|
||||
fieldName: "title",
|
||||
expected: "Nested",
|
||||
},
|
||||
{
|
||||
name: "field_among_others",
|
||||
content: "<a>A</a><title>Target</title><b>B</b>",
|
||||
fieldName: "title",
|
||||
expected: "Target",
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
result := extractField(tt.content, tt.fieldName)
|
||||
assert.Equal(t, tt.expected, result)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestExtractArrayElements_TableDriven(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
content string
|
||||
arrayName string
|
||||
elementName string
|
||||
expected []string
|
||||
}{
|
||||
{
|
||||
name: "simple_array",
|
||||
content: "<facts><fact>One</fact><fact>Two</fact></facts>",
|
||||
arrayName: "facts",
|
||||
elementName: "fact",
|
||||
expected: []string{"One", "Two"},
|
||||
},
|
||||
{
|
||||
name: "empty_array",
|
||||
content: "<facts></facts>",
|
||||
arrayName: "facts",
|
||||
elementName: "fact",
|
||||
expected: nil,
|
||||
},
|
||||
{
|
||||
name: "array_not_found",
|
||||
content: "<other><item>Value</item></other>",
|
||||
arrayName: "facts",
|
||||
elementName: "fact",
|
||||
expected: nil,
|
||||
},
|
||||
{
|
||||
name: "single_element",
|
||||
content: "<concepts><concept>security</concept></concepts>",
|
||||
arrayName: "concepts",
|
||||
elementName: "concept",
|
||||
expected: []string{"security"},
|
||||
},
|
||||
{
|
||||
name: "multiline_array",
|
||||
content: `<files>
|
||||
<file>file1.go</file>
|
||||
<file>file2.go</file>
|
||||
<file>file3.go</file>
|
||||
</files>`,
|
||||
arrayName: "files",
|
||||
elementName: "file",
|
||||
expected: []string{"file1.go", "file2.go", "file3.go"},
|
||||
},
|
||||
{
|
||||
name: "whitespace_trimmed",
|
||||
content: "<items><item> trimmed </item></items>",
|
||||
arrayName: "items",
|
||||
elementName: "item",
|
||||
expected: []string{"trimmed"},
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
result := extractArrayElements(tt.content, tt.arrayName, tt.elementName)
|
||||
assert.Equal(t, tt.expected, result)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestValidObsTypes(t *testing.T) {
|
||||
expected := map[string]bool{
|
||||
"bugfix": true,
|
||||
"feature": true,
|
||||
"refactor": true,
|
||||
"change": true,
|
||||
"discovery": true,
|
||||
"decision": true,
|
||||
}
|
||||
assert.Equal(t, expected, validObsTypes)
|
||||
}
|
||||
|
||||
func TestValidConcepts(t *testing.T) {
|
||||
// Verify expected concepts are valid
|
||||
expectedValid := []string{
|
||||
"how-it-works", "why-it-exists", "what-changed", "problem-solution", "gotcha", "pattern", "trade-off",
|
||||
"best-practice", "anti-pattern", "architecture", "security", "performance", "testing", "debugging", "workflow", "tooling",
|
||||
"refactoring", "api", "database", "configuration", "error-handling", "caching", "logging", "auth", "validation",
|
||||
}
|
||||
|
||||
for _, concept := range expectedValid {
|
||||
assert.True(t, validConcepts[concept], "Expected %s to be valid", concept)
|
||||
}
|
||||
|
||||
// Verify invalid concepts
|
||||
invalidConcepts := []string{"random", "invalid", "not-a-concept", "foo", "bar"}
|
||||
for _, concept := range invalidConcepts {
|
||||
assert.False(t, validConcepts[concept], "Expected %s to be invalid", concept)
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,695 @@
|
||||
// Package session provides session lifecycle management for claude-mnemonic.
|
||||
package session
|
||||
|
||||
import (
|
||||
"context"
|
||||
"sync"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/suite"
|
||||
)
|
||||
|
||||
// ManagerSuite is a test suite for Manager operations.
|
||||
type ManagerSuite struct {
|
||||
suite.Suite
|
||||
manager *Manager
|
||||
}
|
||||
|
||||
func (s *ManagerSuite) SetupTest() {
|
||||
// Create manager without real session store (use nil for unit tests)
|
||||
s.manager = &Manager{
|
||||
sessions: make(map[int64]*ActiveSession),
|
||||
ProcessNotify: make(chan struct{}, 1),
|
||||
}
|
||||
// Initialize context for manager
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
s.manager.ctx = ctx
|
||||
s.manager.cancel = cancel
|
||||
}
|
||||
|
||||
func (s *ManagerSuite) TearDownTest() {
|
||||
if s.manager != nil && s.manager.cancel != nil {
|
||||
s.manager.cancel()
|
||||
}
|
||||
}
|
||||
|
||||
func TestManagerSuite(t *testing.T) {
|
||||
suite.Run(t, new(ManagerSuite))
|
||||
}
|
||||
|
||||
// TestActiveSession tests ActiveSession creation and basic operations.
|
||||
func (s *ManagerSuite) TestActiveSession() {
|
||||
session := &ActiveSession{
|
||||
SessionDBID: 1,
|
||||
ClaudeSessionID: "claude-123",
|
||||
SDKSessionID: "sdk-123",
|
||||
Project: "test-project",
|
||||
UserPrompt: "Hello",
|
||||
StartTime: time.Now(),
|
||||
pendingMessages: make([]PendingMessage, 0),
|
||||
notify: make(chan struct{}, 1),
|
||||
}
|
||||
|
||||
s.Equal(int64(1), session.SessionDBID)
|
||||
s.Equal("claude-123", session.ClaudeSessionID)
|
||||
s.Equal("sdk-123", session.SDKSessionID)
|
||||
s.Equal("test-project", session.Project)
|
||||
s.Equal("Hello", session.UserPrompt)
|
||||
}
|
||||
|
||||
// TestGetActiveSessionCount tests session counting.
|
||||
func (s *ManagerSuite) TestGetActiveSessionCount() {
|
||||
// Initially 0
|
||||
s.Equal(0, s.manager.GetActiveSessionCount())
|
||||
|
||||
// Add sessions directly for testing
|
||||
s.manager.sessions[1] = &ActiveSession{SessionDBID: 1}
|
||||
s.manager.sessions[2] = &ActiveSession{SessionDBID: 2}
|
||||
|
||||
s.Equal(2, s.manager.GetActiveSessionCount())
|
||||
}
|
||||
|
||||
// TestGetTotalQueueDepth tests queue depth calculation.
|
||||
func (s *ManagerSuite) TestGetTotalQueueDepth() {
|
||||
// Initially 0
|
||||
s.Equal(0, s.manager.GetTotalQueueDepth())
|
||||
|
||||
// Add sessions with pending messages
|
||||
s.manager.sessions[1] = &ActiveSession{
|
||||
SessionDBID: 1,
|
||||
pendingMessages: make([]PendingMessage, 3),
|
||||
}
|
||||
s.manager.sessions[2] = &ActiveSession{
|
||||
SessionDBID: 2,
|
||||
pendingMessages: make([]PendingMessage, 5),
|
||||
}
|
||||
|
||||
s.Equal(8, s.manager.GetTotalQueueDepth())
|
||||
}
|
||||
|
||||
// TestIsAnySessionProcessing tests processing status detection.
|
||||
func (s *ManagerSuite) TestIsAnySessionProcessing() {
|
||||
// No sessions - not processing
|
||||
s.False(s.manager.IsAnySessionProcessing())
|
||||
|
||||
// Session with no pending - not processing
|
||||
s.manager.sessions[1] = &ActiveSession{
|
||||
SessionDBID: 1,
|
||||
pendingMessages: []PendingMessage{},
|
||||
}
|
||||
s.False(s.manager.IsAnySessionProcessing())
|
||||
|
||||
// Session with pending - processing
|
||||
s.manager.sessions[1].pendingMessages = []PendingMessage{{Type: MessageTypeObservation}}
|
||||
s.True(s.manager.IsAnySessionProcessing())
|
||||
|
||||
// Clear pending but set generator active
|
||||
s.manager.sessions[1].pendingMessages = []PendingMessage{}
|
||||
s.manager.sessions[1].generatorActive.Store(true)
|
||||
s.True(s.manager.IsAnySessionProcessing())
|
||||
}
|
||||
|
||||
// TestGetAllSessions tests retrieving all sessions.
|
||||
func (s *ManagerSuite) TestGetAllSessions() {
|
||||
// Empty
|
||||
sessions := s.manager.GetAllSessions()
|
||||
s.Empty(sessions)
|
||||
|
||||
// Add sessions
|
||||
session1 := &ActiveSession{SessionDBID: 1, Project: "project-a"}
|
||||
session2 := &ActiveSession{SessionDBID: 2, Project: "project-b"}
|
||||
s.manager.sessions[1] = session1
|
||||
s.manager.sessions[2] = session2
|
||||
|
||||
sessions = s.manager.GetAllSessions()
|
||||
s.Len(sessions, 2)
|
||||
}
|
||||
|
||||
// TestDeleteSession tests session deletion.
|
||||
func (s *ManagerSuite) TestDeleteSession() {
|
||||
// Create session with context
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
session := &ActiveSession{
|
||||
SessionDBID: 1,
|
||||
Project: "test-project",
|
||||
StartTime: time.Now(),
|
||||
pendingMessages: []PendingMessage{},
|
||||
ctx: ctx,
|
||||
cancel: cancel,
|
||||
}
|
||||
s.manager.sessions[1] = session
|
||||
|
||||
// Track callback
|
||||
var deletedID int64
|
||||
s.manager.SetOnSessionDeleted(func(id int64) {
|
||||
deletedID = id
|
||||
})
|
||||
|
||||
s.Equal(1, s.manager.GetActiveSessionCount())
|
||||
|
||||
// Delete
|
||||
s.manager.DeleteSession(1)
|
||||
|
||||
s.Equal(0, s.manager.GetActiveSessionCount())
|
||||
s.Equal(int64(1), deletedID)
|
||||
|
||||
// Double delete should be safe
|
||||
s.manager.DeleteSession(1)
|
||||
}
|
||||
|
||||
// TestDrainMessages tests message draining.
|
||||
func (s *ManagerSuite) TestDrainMessages() {
|
||||
// No session - nil
|
||||
messages := s.manager.DrainMessages(999)
|
||||
s.Nil(messages)
|
||||
|
||||
// Session with messages
|
||||
session := &ActiveSession{
|
||||
SessionDBID: 1,
|
||||
pendingMessages: []PendingMessage{
|
||||
{Type: MessageTypeObservation},
|
||||
{Type: MessageTypeSummarize},
|
||||
},
|
||||
}
|
||||
s.manager.sessions[1] = session
|
||||
|
||||
messages = s.manager.DrainMessages(1)
|
||||
s.Len(messages, 2)
|
||||
|
||||
// Queue should be empty now
|
||||
s.Empty(session.pendingMessages)
|
||||
|
||||
// Drain again - empty
|
||||
messages = s.manager.DrainMessages(1)
|
||||
s.Empty(messages)
|
||||
}
|
||||
|
||||
// TestSetOnSessionCreated tests callback setting.
|
||||
func (s *ManagerSuite) TestSetOnSessionCreated() {
|
||||
var calledWith int64
|
||||
callback := func(id int64) {
|
||||
calledWith = id
|
||||
}
|
||||
|
||||
s.manager.SetOnSessionCreated(callback)
|
||||
s.NotNil(s.manager.onCreated)
|
||||
|
||||
// Simulate callback
|
||||
if s.manager.onCreated != nil {
|
||||
s.manager.onCreated(42)
|
||||
}
|
||||
s.Equal(int64(42), calledWith)
|
||||
}
|
||||
|
||||
// TestSetOnSessionDeleted tests callback setting.
|
||||
func (s *ManagerSuite) TestSetOnSessionDeleted() {
|
||||
var calledWith int64
|
||||
callback := func(id int64) {
|
||||
calledWith = id
|
||||
}
|
||||
|
||||
s.manager.SetOnSessionDeleted(callback)
|
||||
s.NotNil(s.manager.onDeleted)
|
||||
|
||||
// Simulate callback
|
||||
if s.manager.onDeleted != nil {
|
||||
s.manager.onDeleted(42)
|
||||
}
|
||||
s.Equal(int64(42), calledWith)
|
||||
}
|
||||
|
||||
// TestMessageTypes tests message type constants.
|
||||
func TestMessageTypes(t *testing.T) {
|
||||
assert.Equal(t, MessageType(0), MessageTypeObservation)
|
||||
assert.Equal(t, MessageType(1), MessageTypeSummarize)
|
||||
}
|
||||
|
||||
// TestTimeoutConstants tests timeout constants.
|
||||
func TestTimeoutConstants(t *testing.T) {
|
||||
assert.Equal(t, 30*time.Minute, SessionTimeout)
|
||||
assert.Equal(t, 5*time.Minute, CleanupInterval)
|
||||
}
|
||||
|
||||
// TestObservationData tests observation data structure.
|
||||
func TestObservationData(t *testing.T) {
|
||||
data := ObservationData{
|
||||
ToolName: "Read",
|
||||
ToolInput: map[string]string{"path": "/test/file.go"},
|
||||
ToolResponse: "file content",
|
||||
PromptNumber: 1,
|
||||
CWD: "/test",
|
||||
}
|
||||
|
||||
assert.Equal(t, "Read", data.ToolName)
|
||||
assert.Equal(t, 1, data.PromptNumber)
|
||||
assert.Equal(t, "/test", data.CWD)
|
||||
}
|
||||
|
||||
// TestSummarizeData tests summarize data structure.
|
||||
func TestSummarizeData(t *testing.T) {
|
||||
data := SummarizeData{
|
||||
LastUserMessage: "What did you do?",
|
||||
LastAssistantMessage: "I completed the task.",
|
||||
}
|
||||
|
||||
assert.Equal(t, "What did you do?", data.LastUserMessage)
|
||||
assert.Equal(t, "I completed the task.", data.LastAssistantMessage)
|
||||
}
|
||||
|
||||
// TestPendingMessage tests pending message structure.
|
||||
func TestPendingMessage(t *testing.T) {
|
||||
obsData := &ObservationData{ToolName: "Read"}
|
||||
msg := PendingMessage{
|
||||
Type: MessageTypeObservation,
|
||||
Observation: obsData,
|
||||
}
|
||||
|
||||
assert.Equal(t, MessageTypeObservation, msg.Type)
|
||||
assert.NotNil(t, msg.Observation)
|
||||
assert.Nil(t, msg.Summarize)
|
||||
|
||||
sumData := &SummarizeData{LastUserMessage: "Test"}
|
||||
msg2 := PendingMessage{
|
||||
Type: MessageTypeSummarize,
|
||||
Summarize: sumData,
|
||||
}
|
||||
|
||||
assert.Equal(t, MessageTypeSummarize, msg2.Type)
|
||||
assert.Nil(t, msg2.Observation)
|
||||
assert.NotNil(t, msg2.Summarize)
|
||||
}
|
||||
|
||||
// TestConcurrentSessionAccess tests thread-safe session operations.
|
||||
func TestConcurrentSessionAccess(t *testing.T) {
|
||||
manager := &Manager{
|
||||
sessions: make(map[int64]*ActiveSession),
|
||||
ProcessNotify: make(chan struct{}, 1),
|
||||
}
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
defer cancel()
|
||||
manager.ctx = ctx
|
||||
manager.cancel = cancel
|
||||
|
||||
var wg sync.WaitGroup
|
||||
numGoroutines := 100
|
||||
|
||||
// Concurrent session operations
|
||||
for i := 0; i < numGoroutines; i++ {
|
||||
wg.Add(1)
|
||||
go func(id int64) {
|
||||
defer wg.Done()
|
||||
|
||||
// Add session
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
manager.mu.Lock()
|
||||
manager.sessions[id] = &ActiveSession{
|
||||
SessionDBID: id,
|
||||
Project: "test",
|
||||
StartTime: time.Now(),
|
||||
ctx: ctx,
|
||||
cancel: cancel,
|
||||
}
|
||||
manager.mu.Unlock()
|
||||
|
||||
// Read operations
|
||||
_ = manager.GetActiveSessionCount()
|
||||
_ = manager.GetTotalQueueDepth()
|
||||
_ = manager.IsAnySessionProcessing()
|
||||
_ = manager.GetAllSessions()
|
||||
|
||||
// Delete session
|
||||
manager.DeleteSession(id)
|
||||
}(int64(i))
|
||||
}
|
||||
|
||||
wg.Wait()
|
||||
|
||||
// All sessions should be deleted
|
||||
assert.Equal(t, 0, manager.GetActiveSessionCount())
|
||||
}
|
||||
|
||||
// TestProcessNotifyChannel tests the process notification channel.
|
||||
func TestProcessNotifyChannel(t *testing.T) {
|
||||
manager := &Manager{
|
||||
sessions: make(map[int64]*ActiveSession),
|
||||
ProcessNotify: make(chan struct{}, 1),
|
||||
}
|
||||
|
||||
// Non-blocking send should work
|
||||
select {
|
||||
case manager.ProcessNotify <- struct{}{}:
|
||||
// Success
|
||||
default:
|
||||
t.Error("ProcessNotify channel should accept first message")
|
||||
}
|
||||
|
||||
// Second send should not block (channel is buffered with size 1)
|
||||
select {
|
||||
case manager.ProcessNotify <- struct{}{}:
|
||||
// Full buffer, this is expected behavior
|
||||
default:
|
||||
// This is fine - channel is full
|
||||
}
|
||||
|
||||
// Drain the channel
|
||||
select {
|
||||
case <-manager.ProcessNotify:
|
||||
// Drained
|
||||
default:
|
||||
t.Error("Should be able to receive from ProcessNotify")
|
||||
}
|
||||
}
|
||||
|
||||
// TestActiveSessionContext tests session context handling.
|
||||
func TestActiveSessionContext(t *testing.T) {
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
|
||||
session := &ActiveSession{
|
||||
SessionDBID: 1,
|
||||
ctx: ctx,
|
||||
cancel: cancel,
|
||||
}
|
||||
|
||||
// Context should not be done
|
||||
select {
|
||||
case <-session.ctx.Done():
|
||||
t.Error("Context should not be done yet")
|
||||
default:
|
||||
// Expected
|
||||
}
|
||||
|
||||
// Cancel context
|
||||
session.cancel()
|
||||
|
||||
// Context should be done
|
||||
select {
|
||||
case <-session.ctx.Done():
|
||||
// Expected
|
||||
default:
|
||||
t.Error("Context should be done after cancel")
|
||||
}
|
||||
}
|
||||
|
||||
// TestGeneratorActive tests the atomic generator active flag.
|
||||
func TestGeneratorActive(t *testing.T) {
|
||||
session := &ActiveSession{}
|
||||
|
||||
// Initially false
|
||||
assert.False(t, session.generatorActive.Load())
|
||||
|
||||
// Set to true
|
||||
session.generatorActive.Store(true)
|
||||
assert.True(t, session.generatorActive.Load())
|
||||
|
||||
// Set back to false
|
||||
session.generatorActive.Store(false)
|
||||
assert.False(t, session.generatorActive.Load())
|
||||
}
|
||||
|
||||
// TestTokenAccumulation tests token accumulation fields.
|
||||
func TestTokenAccumulation(t *testing.T) {
|
||||
session := &ActiveSession{
|
||||
CumulativeInputTokens: 0,
|
||||
CumulativeOutputTokens: 0,
|
||||
}
|
||||
|
||||
// Accumulate tokens
|
||||
session.CumulativeInputTokens += 100
|
||||
session.CumulativeOutputTokens += 50
|
||||
|
||||
assert.Equal(t, int64(100), session.CumulativeInputTokens)
|
||||
assert.Equal(t, int64(50), session.CumulativeOutputTokens)
|
||||
|
||||
// Add more
|
||||
session.CumulativeInputTokens += 200
|
||||
session.CumulativeOutputTokens += 100
|
||||
|
||||
assert.Equal(t, int64(300), session.CumulativeInputTokens)
|
||||
assert.Equal(t, int64(150), session.CumulativeOutputTokens)
|
||||
}
|
||||
|
||||
// TestShutdownAll tests graceful shutdown of all sessions.
|
||||
func (s *ManagerSuite) TestShutdownAll() {
|
||||
// Create multiple sessions
|
||||
for i := int64(1); i <= 3; i++ {
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
s.manager.sessions[i] = &ActiveSession{
|
||||
SessionDBID: i,
|
||||
Project: "test-project",
|
||||
StartTime: time.Now(),
|
||||
pendingMessages: []PendingMessage{},
|
||||
ctx: ctx,
|
||||
cancel: cancel,
|
||||
}
|
||||
}
|
||||
|
||||
s.Equal(3, s.manager.GetActiveSessionCount())
|
||||
|
||||
// Track deleted sessions
|
||||
var deletedIDs []int64
|
||||
s.manager.SetOnSessionDeleted(func(id int64) {
|
||||
deletedIDs = append(deletedIDs, id)
|
||||
})
|
||||
|
||||
// Shutdown all
|
||||
s.manager.ShutdownAll(context.Background())
|
||||
|
||||
// All sessions should be deleted
|
||||
s.Equal(0, s.manager.GetActiveSessionCount())
|
||||
s.Len(deletedIDs, 3)
|
||||
}
|
||||
|
||||
// TestDeleteNonExistentSession tests deleting a session that doesn't exist.
|
||||
func (s *ManagerSuite) TestDeleteNonExistentSession() {
|
||||
// Track callback
|
||||
callbackCalled := false
|
||||
s.manager.SetOnSessionDeleted(func(id int64) {
|
||||
callbackCalled = true
|
||||
})
|
||||
|
||||
// Delete non-existent session
|
||||
s.manager.DeleteSession(999)
|
||||
|
||||
// Callback should not be called
|
||||
s.False(callbackCalled)
|
||||
}
|
||||
|
||||
// TestLastPromptNumber tests prompt number tracking.
|
||||
func TestLastPromptNumber(t *testing.T) {
|
||||
session := &ActiveSession{
|
||||
SessionDBID: 1,
|
||||
LastPromptNumber: 0,
|
||||
}
|
||||
|
||||
assert.Equal(t, 0, session.LastPromptNumber)
|
||||
|
||||
session.LastPromptNumber = 5
|
||||
assert.Equal(t, 5, session.LastPromptNumber)
|
||||
|
||||
session.LastPromptNumber++
|
||||
assert.Equal(t, 6, session.LastPromptNumber)
|
||||
}
|
||||
|
||||
// TestActiveSessionNotifyChannel tests session notification channel.
|
||||
func TestActiveSessionNotifyChannel(t *testing.T) {
|
||||
session := &ActiveSession{
|
||||
notify: make(chan struct{}, 1),
|
||||
}
|
||||
|
||||
// Non-blocking send
|
||||
select {
|
||||
case session.notify <- struct{}{}:
|
||||
// Success
|
||||
default:
|
||||
t.Error("Should accept first notification")
|
||||
}
|
||||
|
||||
// Second send should not block
|
||||
select {
|
||||
case session.notify <- struct{}{}:
|
||||
// Full buffer
|
||||
default:
|
||||
// Expected - buffer is full
|
||||
}
|
||||
|
||||
// Drain
|
||||
select {
|
||||
case <-session.notify:
|
||||
// Drained
|
||||
default:
|
||||
t.Error("Should receive notification")
|
||||
}
|
||||
}
|
||||
|
||||
// TestMessageMutex tests message mutex operations.
|
||||
func TestMessageMutex(t *testing.T) {
|
||||
session := &ActiveSession{
|
||||
pendingMessages: make([]PendingMessage, 0),
|
||||
}
|
||||
|
||||
var wg sync.WaitGroup
|
||||
|
||||
// Concurrent message operations
|
||||
for i := 0; i < 50; i++ {
|
||||
wg.Add(1)
|
||||
go func() {
|
||||
defer wg.Done()
|
||||
|
||||
session.messageMu.Lock()
|
||||
session.pendingMessages = append(session.pendingMessages, PendingMessage{
|
||||
Type: MessageTypeObservation,
|
||||
})
|
||||
session.messageMu.Unlock()
|
||||
}()
|
||||
}
|
||||
|
||||
wg.Wait()
|
||||
|
||||
assert.Len(t, session.pendingMessages, 50)
|
||||
}
|
||||
|
||||
// TestQueueDepthMultipleSessions tests queue depth with multiple sessions.
|
||||
func (s *ManagerSuite) TestQueueDepthMultipleSessions() {
|
||||
// Add sessions with varying queue depths
|
||||
s.manager.sessions[1] = &ActiveSession{
|
||||
SessionDBID: 1,
|
||||
pendingMessages: make([]PendingMessage, 10),
|
||||
}
|
||||
s.manager.sessions[2] = &ActiveSession{
|
||||
SessionDBID: 2,
|
||||
pendingMessages: make([]PendingMessage, 0),
|
||||
}
|
||||
s.manager.sessions[3] = &ActiveSession{
|
||||
SessionDBID: 3,
|
||||
pendingMessages: make([]PendingMessage, 5),
|
||||
}
|
||||
|
||||
s.Equal(15, s.manager.GetTotalQueueDepth())
|
||||
}
|
||||
|
||||
// TestIsAnySessionProcessing_GeneratorOnly tests processing status with only generator active.
|
||||
func (s *ManagerSuite) TestIsAnySessionProcessingGeneratorOnly() {
|
||||
session := &ActiveSession{
|
||||
SessionDBID: 1,
|
||||
pendingMessages: []PendingMessage{},
|
||||
}
|
||||
s.manager.sessions[1] = session
|
||||
|
||||
// No processing initially
|
||||
s.False(s.manager.IsAnySessionProcessing())
|
||||
|
||||
// Set generator active
|
||||
session.generatorActive.Store(true)
|
||||
s.True(s.manager.IsAnySessionProcessing())
|
||||
|
||||
// Clear generator
|
||||
session.generatorActive.Store(false)
|
||||
s.False(s.manager.IsAnySessionProcessing())
|
||||
}
|
||||
|
||||
// TestPendingMessageWithBothTypes tests pending messages with both types.
|
||||
func TestPendingMessageWithBothTypes(t *testing.T) {
|
||||
messages := []PendingMessage{
|
||||
{
|
||||
Type: MessageTypeObservation,
|
||||
Observation: &ObservationData{ToolName: "Read"},
|
||||
},
|
||||
{
|
||||
Type: MessageTypeSummarize,
|
||||
Summarize: &SummarizeData{LastUserMessage: "Test"},
|
||||
},
|
||||
{
|
||||
Type: MessageTypeObservation,
|
||||
Observation: &ObservationData{ToolName: "Write"},
|
||||
},
|
||||
}
|
||||
|
||||
assert.Len(t, messages, 3)
|
||||
|
||||
// Verify types
|
||||
assert.Equal(t, MessageTypeObservation, messages[0].Type)
|
||||
assert.Equal(t, MessageTypeSummarize, messages[1].Type)
|
||||
assert.Equal(t, MessageTypeObservation, messages[2].Type)
|
||||
|
||||
// Verify data
|
||||
assert.Equal(t, "Read", messages[0].Observation.ToolName)
|
||||
assert.Nil(t, messages[0].Summarize)
|
||||
|
||||
assert.Equal(t, "Test", messages[1].Summarize.LastUserMessage)
|
||||
assert.Nil(t, messages[1].Observation)
|
||||
|
||||
assert.Equal(t, "Write", messages[2].Observation.ToolName)
|
||||
}
|
||||
|
||||
// TestDrainMessagesPreservesOrder tests that draining preserves message order.
|
||||
func (s *ManagerSuite) TestDrainMessagesPreservesOrder() {
|
||||
session := &ActiveSession{
|
||||
SessionDBID: 1,
|
||||
pendingMessages: []PendingMessage{
|
||||
{Type: MessageTypeObservation, Observation: &ObservationData{ToolName: "Tool1"}},
|
||||
{Type: MessageTypeSummarize, Summarize: &SummarizeData{LastUserMessage: "Msg1"}},
|
||||
{Type: MessageTypeObservation, Observation: &ObservationData{ToolName: "Tool2"}},
|
||||
},
|
||||
}
|
||||
s.manager.sessions[1] = session
|
||||
|
||||
messages := s.manager.DrainMessages(1)
|
||||
|
||||
s.Len(messages, 3)
|
||||
s.Equal("Tool1", messages[0].Observation.ToolName)
|
||||
s.Equal("Msg1", messages[1].Summarize.LastUserMessage)
|
||||
s.Equal("Tool2", messages[2].Observation.ToolName)
|
||||
}
|
||||
|
||||
// TestActiveSessionCWD tests CWD field in ObservationData.
|
||||
func TestActiveSessionCWD(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
cwd string
|
||||
}{
|
||||
{"empty_cwd", ""},
|
||||
{"absolute_path", "/home/user/project"},
|
||||
{"windows_path", "C:\\Users\\test\\project"},
|
||||
{"path_with_spaces", "/home/user/my project"},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
data := ObservationData{
|
||||
ToolName: "Test",
|
||||
CWD: tt.cwd,
|
||||
}
|
||||
assert.Equal(t, tt.cwd, data.CWD)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// TestToolInputResponse tests various tool input/response types.
|
||||
func TestToolInputResponse(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
input interface{}
|
||||
response interface{}
|
||||
}{
|
||||
{"nil_values", nil, nil},
|
||||
{"string_values", "input string", "response string"},
|
||||
{"map_values", map[string]string{"key": "value"}, map[string]interface{}{"result": true}},
|
||||
{"slice_values", []string{"a", "b"}, []int{1, 2, 3}},
|
||||
{"int_values", 42, 100},
|
||||
{"bool_values", true, false},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
data := ObservationData{
|
||||
ToolName: "TestTool",
|
||||
ToolInput: tt.input,
|
||||
ToolResponse: tt.response,
|
||||
}
|
||||
assert.Equal(t, tt.input, data.ToolInput)
|
||||
assert.Equal(t, tt.response, data.ToolResponse)
|
||||
})
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,383 @@
|
||||
// Package sse provides Server-Sent Events broadcasting for claude-mnemonic.
|
||||
package sse
|
||||
|
||||
import (
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"sync"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
"github.com/stretchr/testify/suite"
|
||||
)
|
||||
|
||||
// BroadcasterSuite is a test suite for Broadcaster operations.
|
||||
type BroadcasterSuite struct {
|
||||
suite.Suite
|
||||
broadcaster *Broadcaster
|
||||
}
|
||||
|
||||
func (s *BroadcasterSuite) SetupTest() {
|
||||
s.broadcaster = NewBroadcaster()
|
||||
}
|
||||
|
||||
func TestBroadcasterSuite(t *testing.T) {
|
||||
suite.Run(t, new(BroadcasterSuite))
|
||||
}
|
||||
|
||||
// TestNewBroadcaster tests broadcaster creation.
|
||||
func (s *BroadcasterSuite) TestNewBroadcaster() {
|
||||
b := NewBroadcaster()
|
||||
s.NotNil(b)
|
||||
s.NotNil(b.clients)
|
||||
s.Equal(0, b.ClientCount())
|
||||
}
|
||||
|
||||
// TestClientCount tests client counting.
|
||||
func (s *BroadcasterSuite) TestClientCount() {
|
||||
s.Equal(0, s.broadcaster.ClientCount())
|
||||
}
|
||||
|
||||
// mockResponseWriter implements http.ResponseWriter and http.Flusher for testing.
|
||||
type mockResponseWriter struct {
|
||||
header http.Header
|
||||
body []byte
|
||||
statusCode int
|
||||
mu sync.Mutex
|
||||
}
|
||||
|
||||
func newMockResponseWriter() *mockResponseWriter {
|
||||
return &mockResponseWriter{
|
||||
header: make(http.Header),
|
||||
statusCode: http.StatusOK,
|
||||
}
|
||||
}
|
||||
|
||||
func (m *mockResponseWriter) Header() http.Header {
|
||||
return m.header
|
||||
}
|
||||
|
||||
func (m *mockResponseWriter) Write(data []byte) (int, error) {
|
||||
m.mu.Lock()
|
||||
defer m.mu.Unlock()
|
||||
m.body = append(m.body, data...)
|
||||
return len(data), nil
|
||||
}
|
||||
|
||||
func (m *mockResponseWriter) WriteHeader(statusCode int) {
|
||||
m.statusCode = statusCode
|
||||
}
|
||||
|
||||
func (m *mockResponseWriter) Flush() {
|
||||
// No-op for testing
|
||||
}
|
||||
|
||||
func (m *mockResponseWriter) GetBody() []byte {
|
||||
m.mu.Lock()
|
||||
defer m.mu.Unlock()
|
||||
return m.body
|
||||
}
|
||||
|
||||
// TestAddClient tests adding clients.
|
||||
func (s *BroadcasterSuite) TestAddClient() {
|
||||
w := newMockResponseWriter()
|
||||
|
||||
client, err := s.broadcaster.AddClient(w)
|
||||
s.NoError(err)
|
||||
s.NotNil(client)
|
||||
s.NotEmpty(client.ID)
|
||||
s.NotNil(client.Done)
|
||||
s.Equal(1, s.broadcaster.ClientCount())
|
||||
}
|
||||
|
||||
// TestAddMultipleClients tests adding multiple clients.
|
||||
func (s *BroadcasterSuite) TestAddMultipleClients() {
|
||||
for i := 0; i < 5; i++ {
|
||||
w := newMockResponseWriter()
|
||||
_, err := s.broadcaster.AddClient(w)
|
||||
s.NoError(err)
|
||||
}
|
||||
|
||||
s.Equal(5, s.broadcaster.ClientCount())
|
||||
}
|
||||
|
||||
// TestRemoveClient tests removing clients.
|
||||
func (s *BroadcasterSuite) TestRemoveClient() {
|
||||
w := newMockResponseWriter()
|
||||
client, err := s.broadcaster.AddClient(w)
|
||||
s.NoError(err)
|
||||
|
||||
s.Equal(1, s.broadcaster.ClientCount())
|
||||
|
||||
s.broadcaster.RemoveClient(client)
|
||||
|
||||
s.Equal(0, s.broadcaster.ClientCount())
|
||||
|
||||
// Check that Done channel is closed
|
||||
select {
|
||||
case <-client.Done:
|
||||
// Expected - channel is closed
|
||||
default:
|
||||
s.Fail("Done channel should be closed")
|
||||
}
|
||||
}
|
||||
|
||||
// TestBroadcast tests broadcasting messages.
|
||||
func (s *BroadcasterSuite) TestBroadcast() {
|
||||
w := newMockResponseWriter()
|
||||
_, err := s.broadcaster.AddClient(w)
|
||||
s.NoError(err)
|
||||
|
||||
// Broadcast a message
|
||||
s.broadcaster.Broadcast(map[string]string{"type": "test", "message": "hello"})
|
||||
|
||||
// Give time for async write
|
||||
time.Sleep(50 * time.Millisecond)
|
||||
|
||||
body := string(w.GetBody())
|
||||
s.Contains(body, "data:")
|
||||
s.Contains(body, "test")
|
||||
s.Contains(body, "hello")
|
||||
}
|
||||
|
||||
// TestBroadcastNoClients tests broadcasting with no clients.
|
||||
func (s *BroadcasterSuite) TestBroadcastNoClients() {
|
||||
// Should not panic
|
||||
s.broadcaster.Broadcast(map[string]string{"type": "test"})
|
||||
}
|
||||
|
||||
// TestBroadcastMultipleClients tests broadcasting to multiple clients.
|
||||
func (s *BroadcasterSuite) TestBroadcastMultipleClients() {
|
||||
writers := make([]*mockResponseWriter, 3)
|
||||
for i := 0; i < 3; i++ {
|
||||
writers[i] = newMockResponseWriter()
|
||||
_, err := s.broadcaster.AddClient(writers[i])
|
||||
s.NoError(err)
|
||||
}
|
||||
|
||||
// Broadcast
|
||||
s.broadcaster.Broadcast(map[string]string{"type": "test"})
|
||||
|
||||
// Give time for async writes
|
||||
time.Sleep(100 * time.Millisecond)
|
||||
|
||||
// All clients should receive the message
|
||||
for i, w := range writers {
|
||||
body := string(w.GetBody())
|
||||
s.Contains(body, "data:", "Client %d should receive data", i)
|
||||
}
|
||||
}
|
||||
|
||||
// TestClient tests Client structure.
|
||||
func TestClient(t *testing.T) {
|
||||
w := newMockResponseWriter()
|
||||
client := &Client{
|
||||
ID: "test-client",
|
||||
Writer: w,
|
||||
Flusher: w,
|
||||
Done: make(chan struct{}),
|
||||
}
|
||||
|
||||
assert.Equal(t, "test-client", client.ID)
|
||||
assert.NotNil(t, client.Writer)
|
||||
assert.NotNil(t, client.Flusher)
|
||||
assert.NotNil(t, client.Done)
|
||||
|
||||
// Close done channel
|
||||
close(client.Done)
|
||||
|
||||
select {
|
||||
case <-client.Done:
|
||||
// Expected
|
||||
default:
|
||||
t.Error("Done channel should be closed")
|
||||
}
|
||||
}
|
||||
|
||||
// TestClientUniqueIDs tests that clients get unique IDs.
|
||||
func TestClientUniqueIDs(t *testing.T) {
|
||||
b := NewBroadcaster()
|
||||
ids := make(map[string]bool)
|
||||
|
||||
for i := 0; i < 100; i++ {
|
||||
w := newMockResponseWriter()
|
||||
client, err := b.AddClient(w)
|
||||
require.NoError(t, err)
|
||||
|
||||
// ID should be unique
|
||||
assert.False(t, ids[client.ID], "ID %s should be unique", client.ID)
|
||||
ids[client.ID] = true
|
||||
}
|
||||
}
|
||||
|
||||
// TestWriteTimeout tests the write timeout constant.
|
||||
func TestWriteTimeout(t *testing.T) {
|
||||
assert.Equal(t, 2*time.Second, WriteTimeout)
|
||||
}
|
||||
|
||||
// TestHandleSSE tests the HandleSSE HTTP handler.
|
||||
func TestHandleSSE(t *testing.T) {
|
||||
b := NewBroadcaster()
|
||||
|
||||
// Create a test server
|
||||
handler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
// Set up context that will be cancelled
|
||||
ctx := r.Context()
|
||||
|
||||
// Start goroutine to cancel context after short delay
|
||||
go func() {
|
||||
time.Sleep(50 * time.Millisecond)
|
||||
// Request will be cancelled by the test client
|
||||
}()
|
||||
|
||||
// This will block until context is cancelled
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
return
|
||||
case <-time.After(100 * time.Millisecond):
|
||||
return
|
||||
}
|
||||
})
|
||||
|
||||
_ = handler
|
||||
_ = b
|
||||
|
||||
// Just verify the handler exists and broadcaster can handle SSE
|
||||
req := httptest.NewRequest(http.MethodGet, "/events", nil)
|
||||
rec := httptest.NewRecorder()
|
||||
|
||||
// Can't easily test HandleSSE since it blocks, but we can verify setup
|
||||
assert.NotNil(t, req)
|
||||
assert.NotNil(t, rec)
|
||||
}
|
||||
|
||||
// TestBroadcastJSON tests broadcasting various JSON types.
|
||||
func TestBroadcastJSON(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
data interface{}
|
||||
wantErr bool
|
||||
}{
|
||||
{
|
||||
name: "string map",
|
||||
data: map[string]string{"key": "value"},
|
||||
wantErr: false,
|
||||
},
|
||||
{
|
||||
name: "int map",
|
||||
data: map[string]int{"count": 42},
|
||||
wantErr: false,
|
||||
},
|
||||
{
|
||||
name: "nested struct",
|
||||
data: struct{ Name string }{Name: "test"},
|
||||
wantErr: false,
|
||||
},
|
||||
{
|
||||
name: "array",
|
||||
data: []string{"a", "b", "c"},
|
||||
wantErr: false,
|
||||
},
|
||||
{
|
||||
name: "interface map",
|
||||
data: map[string]interface{}{"type": "test", "count": 1, "active": true},
|
||||
wantErr: false,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
b := NewBroadcaster()
|
||||
w := newMockResponseWriter()
|
||||
_, err := b.AddClient(w)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Should not panic
|
||||
b.Broadcast(tt.data)
|
||||
|
||||
time.Sleep(50 * time.Millisecond)
|
||||
|
||||
body := string(w.GetBody())
|
||||
assert.Contains(t, body, "data:")
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// TestConcurrentBroadcast tests concurrent broadcasting.
|
||||
func TestConcurrentBroadcast(t *testing.T) {
|
||||
b := NewBroadcaster()
|
||||
|
||||
// Add clients
|
||||
for i := 0; i < 10; i++ {
|
||||
w := newMockResponseWriter()
|
||||
_, err := b.AddClient(w)
|
||||
require.NoError(t, err)
|
||||
}
|
||||
|
||||
// Broadcast concurrently
|
||||
var wg sync.WaitGroup
|
||||
for i := 0; i < 100; i++ {
|
||||
wg.Add(1)
|
||||
go func(i int) {
|
||||
defer wg.Done()
|
||||
b.Broadcast(map[string]int{"index": i})
|
||||
}(i)
|
||||
}
|
||||
|
||||
wg.Wait()
|
||||
|
||||
// Should complete without panics
|
||||
assert.Equal(t, 10, b.ClientCount())
|
||||
}
|
||||
|
||||
// TestRemoveNonExistentClient tests removing a non-existent client.
|
||||
func TestRemoveNonExistentClient(t *testing.T) {
|
||||
b := NewBroadcaster()
|
||||
|
||||
// Create a client but don't add it
|
||||
client := &Client{
|
||||
ID: "fake-client",
|
||||
Done: make(chan struct{}),
|
||||
}
|
||||
|
||||
// Should not panic
|
||||
b.RemoveClient(client)
|
||||
|
||||
// Done channel should be closed
|
||||
select {
|
||||
case <-client.Done:
|
||||
// Expected
|
||||
default:
|
||||
t.Error("Done channel should be closed")
|
||||
}
|
||||
}
|
||||
|
||||
// TestBroadcasterConcurrentAddRemove tests concurrent add/remove operations.
|
||||
func TestBroadcasterConcurrentAddRemove(t *testing.T) {
|
||||
b := NewBroadcaster()
|
||||
var wg sync.WaitGroup
|
||||
|
||||
// Concurrent adds
|
||||
for i := 0; i < 50; i++ {
|
||||
wg.Add(1)
|
||||
go func() {
|
||||
defer wg.Done()
|
||||
w := newMockResponseWriter()
|
||||
client, err := b.AddClient(w)
|
||||
if err == nil {
|
||||
// Random chance to remove
|
||||
if time.Now().UnixNano()%2 == 0 {
|
||||
b.RemoveClient(client)
|
||||
}
|
||||
}
|
||||
}()
|
||||
}
|
||||
|
||||
wg.Wait()
|
||||
|
||||
// Should not panic and have some clients
|
||||
count := b.ClientCount()
|
||||
assert.GreaterOrEqual(t, count, 0)
|
||||
}
|
||||
Reference in New Issue
Block a user