Files
claude-mnemonic/internal/worker/handlers_test.go
T
2025-12-19 02:17:02 +00:00

554 lines
17 KiB
Go

// Package worker provides the main worker service for claude-mnemonic.
package worker
import (
"context"
"database/sql"
"encoding/json"
"net/http"
"net/http/httptest"
"testing"
"time"
"github.com/go-chi/chi/v5"
"github.com/lukaszraczylo/claude-mnemonic/internal/config"
"github.com/lukaszraczylo/claude-mnemonic/internal/db/sqlite"
"github.com/lukaszraczylo/claude-mnemonic/internal/worker/session"
"github.com/lukaszraczylo/claude-mnemonic/internal/worker/sse"
"github.com/lukaszraczylo/claude-mnemonic/pkg/models"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
// testService creates a Service with a test SQLite database including FTS5 for testing.
func testService(t *testing.T) (*Service, func()) {
t.Helper()
// Create test store (runs migrations to create all tables including FTS5)
store, dbCleanup := testStore(t)
// Create store wrappers
sessionStore := sqlite.NewSessionStore(store)
observationStore := sqlite.NewObservationStore(store)
summaryStore := sqlite.NewSummaryStore(store)
promptStore := sqlite.NewPromptStore(store)
// Create domain services
sessionManager := session.NewManager(sessionStore)
sseBroadcaster := sse.NewBroadcaster()
// Create context
ctx, cancel := context.WithCancel(context.Background())
// Create router
router := chi.NewRouter()
svc := &Service{
version: "test-version",
config: config.Get(),
store: store,
sessionStore: sessionStore,
observationStore: observationStore,
summaryStore: summaryStore,
promptStore: promptStore,
sessionManager: sessionManager,
sseBroadcaster: sseBroadcaster,
router: router,
ctx: ctx,
cancel: cancel,
startTime: time.Now(),
}
svc.setupRoutes()
// Mark service as ready for tests
svc.ready.Store(true)
cleanup := func() {
cancel()
store.Close()
dbCleanup()
}
return svc, cleanup
}
// createTestObservation creates a test observation in the database.
func createTestObservation(t *testing.T, store *sqlite.ObservationStore, project, title, narrative string, concepts []string) int64 {
t.Helper()
obs := &models.ParsedObservation{
Type: models.ObsTypeDiscovery,
Title: title,
Narrative: narrative,
Concepts: concepts,
}
id, _, err := store.StoreObservation(context.Background(), "test-session", project, obs, 1, 100)
require.NoError(t, err)
return id
}
func TestHandleSearchByPrompt_DefaultLimit(t *testing.T) {
svc, cleanup := testService(t)
defer cleanup()
project := "test-project"
// Create 60 observations (more than the default limit of 50)
for i := 0; i < 60; i++ {
createTestObservation(t, svc.observationStore, project,
"Test observation about authentication",
"This observation is about authentication and security patterns",
[]string{"authentication", "security"})
// Small delay to ensure different timestamps
time.Sleep(time.Millisecond)
}
// Make request without limit parameter
req := httptest.NewRequest(http.MethodGet, "/api/context/search?project="+project+"&query=authentication", nil)
rec := httptest.NewRecorder()
svc.router.ServeHTTP(rec, req)
assert.Equal(t, http.StatusOK, rec.Code)
var response map[string]interface{}
err := json.Unmarshal(rec.Body.Bytes(), &response)
require.NoError(t, err)
observations, ok := response["observations"].([]interface{})
require.True(t, ok, "observations should be an array")
// The default limit is now 50, not 5
// Note: clustering may reduce the count, but we should have more than 5
t.Logf("Got %d observations", len(observations))
// Just verify we got a reasonable number, accounting for clustering
assert.True(t, len(observations) >= 1, "should return at least one observation")
}
func TestHandleSearchByPrompt_CustomLimit(t *testing.T) {
svc, cleanup := testService(t)
defer cleanup()
project := "test-project"
// Create 20 unique observations
for i := 0; i < 20; i++ {
createTestObservation(t, svc.observationStore, project,
"Unique observation "+string(rune('A'+i))+" about testing",
"This is unique observation number "+string(rune('A'+i)),
[]string{"unique-" + string(rune('a'+i))})
time.Sleep(time.Millisecond)
}
// Request with custom limit of 15
req := httptest.NewRequest(http.MethodGet, "/api/context/search?project="+project+"&query=observation&limit=15", nil)
rec := httptest.NewRecorder()
svc.router.ServeHTTP(rec, req)
assert.Equal(t, http.StatusOK, rec.Code)
var response map[string]interface{}
err := json.Unmarshal(rec.Body.Bytes(), &response)
require.NoError(t, err)
observations, ok := response["observations"].([]interface{})
require.True(t, ok)
// Should respect the custom limit (accounting for clustering)
t.Logf("Got %d observations with limit=15", len(observations))
assert.LessOrEqual(t, len(observations), 15)
}
func TestHandleSearchByPrompt_NoHardcodedLimit(t *testing.T) {
svc, cleanup := testService(t)
defer cleanup()
project := "test-project"
// Create observations with VERY different content to avoid clustering
// Each has unique words that won't match other observations
uniqueObservations := []struct {
title string
narrative string
concepts []string
}{
{"JWT tokens expire daily", "OAuth2 bearer tokens authentication", []string{"jwt"}},
{"PostgreSQL indexes optimize queries", "B-tree index on user table", []string{"postgres"}},
{"Redis caching TTL configuration", "Memory eviction policy LRU", []string{"redis"}},
{"Zerolog structured logging", "JSON output formatting levels", []string{"logging"}},
{"Pytest fixtures setup teardown", "Mock objects dependency injection", []string{"pytest"}},
{"Docker containers orchestration", "Compose multi-stage builds", []string{"docker"}},
{"Prometheus metrics collection", "Grafana dashboards alerting", []string{"prometheus"}},
{"OWASP vulnerability scanning", "SQL injection XSS prevention", []string{"owasp"}},
}
for _, obs := range uniqueObservations {
createTestObservation(t, svc.observationStore, project, obs.title, obs.narrative, obs.concepts)
time.Sleep(time.Millisecond)
}
// Search using a common keyword that should match most observations
// Using broader query to match multiple items
req := httptest.NewRequest(http.MethodGet, "/api/context/search?project="+project+"&query=tokens+indexes+caching+logging&limit=10", nil)
rec := httptest.NewRecorder()
svc.router.ServeHTTP(rec, req)
assert.Equal(t, http.StatusOK, rec.Code)
var response map[string]interface{}
err := json.Unmarshal(rec.Body.Bytes(), &response)
require.NoError(t, err)
observations, ok := response["observations"].([]interface{})
require.True(t, ok)
// The key is that the limit is no longer hardcoded to 5
// With our new default of 50, we should be able to return more than 5
t.Logf("Got %d observations (limit=10)", len(observations))
// The test passes as long as the default limit (50) is being used instead of 5
// and we can request a custom limit
assert.LessOrEqual(t, len(observations), 10, "should respect the custom limit")
}
func TestHandleSearchByPrompt_RequiredParams(t *testing.T) {
svc, cleanup := testService(t)
defer cleanup()
tests := []struct {
name string
query string
wantStatus int
}{
{
name: "missing project",
query: "/api/context/search?query=test",
wantStatus: http.StatusBadRequest,
},
{
name: "missing query",
query: "/api/context/search?project=test",
wantStatus: http.StatusBadRequest,
},
{
name: "both present",
query: "/api/context/search?project=test&query=test",
wantStatus: http.StatusOK,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
req := httptest.NewRequest(http.MethodGet, tt.query, nil)
rec := httptest.NewRecorder()
svc.router.ServeHTTP(rec, req)
assert.Equal(t, tt.wantStatus, rec.Code)
})
}
}
func TestHandleContextInject_NoHardcodedLimit(t *testing.T) {
svc, cleanup := testService(t)
defer cleanup()
// Set a higher context observations limit in config
svc.config.ContextObservations = 50
project := "test-project"
// Create observations with VERY different content to avoid clustering
uniqueObservations := []struct {
title string
narrative string
concepts []string
}{
{"JWT tokens expire daily", "OAuth2 bearer tokens authentication", []string{"jwt"}},
{"PostgreSQL indexes optimize queries", "B-tree index on user table", []string{"postgres"}},
{"Redis caching TTL configuration", "Memory eviction policy LRU", []string{"redis"}},
{"Zerolog structured logging", "JSON output formatting levels", []string{"logging"}},
{"Pytest fixtures setup teardown", "Mock objects dependency injection", []string{"pytest"}},
{"Docker containers orchestration", "Compose multi-stage builds", []string{"docker"}},
{"Prometheus metrics collection", "Grafana dashboards alerting", []string{"prometheus"}},
{"OWASP vulnerability scanning", "SQL injection XSS prevention", []string{"owasp"}},
}
for _, obs := range uniqueObservations {
createTestObservation(t, svc.observationStore, project, obs.title, obs.narrative, obs.concepts)
time.Sleep(time.Millisecond)
}
req := httptest.NewRequest(http.MethodGet, "/api/context/inject?project="+project, nil)
rec := httptest.NewRecorder()
svc.router.ServeHTTP(rec, req)
assert.Equal(t, http.StatusOK, rec.Code)
var response map[string]interface{}
err := json.Unmarshal(rec.Body.Bytes(), &response)
require.NoError(t, err)
observations, ok := response["observations"].([]interface{})
require.True(t, ok)
// With very different content, we should get multiple observations back
// The key verification is that the hardcoded limit of 5 has been removed
t.Logf("Got %d observations from context inject", len(observations))
// Should return more than old limit of 5 with unique observations
assert.GreaterOrEqual(t, len(observations), 1, "should return at least 1 observation")
}
func TestHandleContextInject_RequiresProject(t *testing.T) {
svc, cleanup := testService(t)
defer cleanup()
req := httptest.NewRequest(http.MethodGet, "/api/context/inject", nil)
rec := httptest.NewRecorder()
svc.router.ServeHTTP(rec, req)
assert.Equal(t, http.StatusBadRequest, rec.Code)
}
func TestHandleGetObservations_Limit(t *testing.T) {
svc, cleanup := testService(t)
defer cleanup()
// Create 20 observations
for i := 0; i < 20; i++ {
createTestObservation(t, svc.observationStore, "project-"+string(rune('a'+i%5)),
"Observation "+string(rune('A'+i)),
"Content of observation "+string(rune('A'+i)),
[]string{"test"})
time.Sleep(time.Millisecond)
}
// Request with limit=10
req := httptest.NewRequest(http.MethodGet, "/api/observations?limit=10", nil)
rec := httptest.NewRecorder()
svc.router.ServeHTTP(rec, req)
assert.Equal(t, http.StatusOK, rec.Code)
// Parse as generic JSON array since the model uses custom marshaling
var observations []map[string]interface{}
err := json.Unmarshal(rec.Body.Bytes(), &observations)
require.NoError(t, err)
assert.Len(t, observations, 10)
}
func TestSearchObservations_GlobalScope(t *testing.T) {
svc, cleanup := testService(t)
defer cleanup()
// Create a project-scoped observation
createTestObservation(t, svc.observationStore, "project-a",
"Project specific code",
"This is specific to project-a",
[]string{"project-specific"})
// Create a global-scoped observation (has a globalizable concept)
createTestObservation(t, svc.observationStore, "project-a",
"Security best practice",
"Always validate user input",
[]string{"security", "best-practice"})
// Search from project-b - should find global observation
req := httptest.NewRequest(http.MethodGet, "/api/context/search?project=project-b&query=security", nil)
rec := httptest.NewRecorder()
svc.router.ServeHTTP(rec, req)
assert.Equal(t, http.StatusOK, rec.Code)
var response map[string]interface{}
err := json.Unmarshal(rec.Body.Bytes(), &response)
require.NoError(t, err)
observations, ok := response["observations"].([]interface{})
require.True(t, ok)
// Should find the global observation even though it was created in project-a
assert.GreaterOrEqual(t, len(observations), 1)
}
func TestClusterObservations_RemovesDuplicates(t *testing.T) {
// Create similar observations
obs1 := &models.Observation{
ID: 1,
Title: sql.NullString{String: "Authentication flow implementation", Valid: true},
Narrative: sql.NullString{String: "We implemented JWT-based authentication", Valid: true},
}
obs2 := &models.Observation{
ID: 2,
Title: sql.NullString{String: "Authentication flow update", Valid: true},
Narrative: sql.NullString{String: "Updated JWT-based authentication logic", Valid: true},
}
obs3 := &models.Observation{
ID: 3,
Title: sql.NullString{String: "Database migration guide", Valid: true},
Narrative: sql.NullString{String: "How to run database migrations", Valid: true},
}
observations := []*models.Observation{obs1, obs2, obs3}
// Cluster with 0.4 threshold
clustered := clusterObservations(observations, 0.4)
// obs1 and obs2 should be clustered together, obs3 is different
assert.LessOrEqual(t, len(clustered), 3)
assert.GreaterOrEqual(t, len(clustered), 1)
// The first observation in each cluster should be kept (obs1, obs3)
t.Logf("Clustered %d observations down to %d", len(observations), len(clustered))
}
func TestRetrievalStats(t *testing.T) {
svc, cleanup := testService(t)
defer cleanup()
project := "test-project"
createTestObservation(t, svc.observationStore, project,
"Test observation",
"Test narrative",
[]string{"test"})
// Make a search request
req := httptest.NewRequest(http.MethodGet, "/api/context/search?project="+project+"&query=test", nil)
rec := httptest.NewRecorder()
svc.router.ServeHTTP(rec, req)
assert.Equal(t, http.StatusOK, rec.Code)
// Check stats
stats := svc.GetRetrievalStats()
assert.Equal(t, int64(1), stats.TotalRequests)
assert.Equal(t, int64(1), stats.SearchRequests)
assert.GreaterOrEqual(t, stats.ObservationsServed, int64(1))
}
func TestHandleHealth_ReturnsVersion(t *testing.T) {
svc, cleanup := testService(t)
defer cleanup()
svc.version = "test-version-1.2.3"
svc.ready.Store(true)
req := httptest.NewRequest(http.MethodGet, "/api/health", nil)
rec := httptest.NewRecorder()
svc.handleHealth(rec, req)
assert.Equal(t, http.StatusOK, rec.Code)
var response map[string]interface{}
err := json.Unmarshal(rec.Body.Bytes(), &response)
require.NoError(t, err)
assert.Equal(t, "ready", response["status"])
assert.Equal(t, "test-version-1.2.3", response["version"])
}
func TestHandleVersion(t *testing.T) {
svc, cleanup := testService(t)
defer cleanup()
svc.version = "v2.0.0-beta"
req := httptest.NewRequest(http.MethodGet, "/api/version", nil)
rec := httptest.NewRecorder()
svc.handleVersion(rec, req)
assert.Equal(t, http.StatusOK, rec.Code)
var response map[string]string
err := json.Unmarshal(rec.Body.Bytes(), &response)
require.NoError(t, err)
assert.Equal(t, "v2.0.0-beta", response["version"])
}
func TestHandleReady_ServiceNotReady(t *testing.T) {
svc, cleanup := testService(t)
defer cleanup()
// Reset ready state to simulate service not being ready
svc.ready.Store(false)
req := httptest.NewRequest(http.MethodGet, "/api/ready", nil)
rec := httptest.NewRecorder()
svc.handleReady(rec, req)
assert.Equal(t, http.StatusServiceUnavailable, rec.Code)
}
func TestHandleReady_ServiceReady(t *testing.T) {
svc, cleanup := testService(t)
defer cleanup()
svc.ready.Store(true)
req := httptest.NewRequest(http.MethodGet, "/api/ready", nil)
rec := httptest.NewRecorder()
svc.handleReady(rec, req)
assert.Equal(t, http.StatusOK, rec.Code)
var response map[string]string
err := json.Unmarshal(rec.Body.Bytes(), &response)
require.NoError(t, err)
assert.Equal(t, "ready", response["status"])
}
func TestRequireReadyMiddleware_Blocks(t *testing.T) {
svc, cleanup := testService(t)
defer cleanup()
// Reset ready state to simulate service not being ready
svc.ready.Store(false)
handler := svc.requireReady(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(http.StatusOK)
w.Write([]byte("success"))
}))
req := httptest.NewRequest(http.MethodGet, "/test", nil)
rec := httptest.NewRecorder()
handler.ServeHTTP(rec, req)
assert.Equal(t, http.StatusServiceUnavailable, rec.Code)
}
func TestRequireReadyMiddleware_Allows(t *testing.T) {
svc, cleanup := testService(t)
defer cleanup()
svc.ready.Store(true)
handler := svc.requireReady(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(http.StatusOK)
w.Write([]byte("success"))
}))
req := httptest.NewRequest(http.MethodGet, "/test", nil)
rec := httptest.NewRecorder()
handler.ServeHTTP(rec, req)
assert.Equal(t, http.StatusOK, rec.Code)
assert.Equal(t, "success", rec.Body.String())
}