Increase test coverage to 45.6%

This commit is contained in:
2025-12-17 12:39:47 +00:00
parent 4add030bed
commit c259bb1d18
13 changed files with 5484 additions and 0 deletions
+254
View File
@@ -0,0 +1,254 @@
package sqlite
import (
"net/http"
"net/http/httptest"
"testing"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
func TestNullString(t *testing.T) {
tests := []struct {
name string
input string
expected string
valid bool
}{
{"empty_string", "", "", false},
{"non_empty_string", "hello", "hello", true},
{"whitespace", " ", " ", true},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
result := nullString(tt.input)
assert.Equal(t, tt.expected, result.String)
assert.Equal(t, tt.valid, result.Valid)
})
}
}
func TestNullInt(t *testing.T) {
tests := []struct {
name string
input int
expected int64
valid bool
}{
{"zero", 0, 0, false},
{"positive", 42, 42, true},
{"negative", -1, -1, false},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
result := nullInt(tt.input)
assert.Equal(t, tt.expected, result.Int64)
assert.Equal(t, tt.valid, result.Valid)
})
}
}
func TestRepeatPlaceholders(t *testing.T) {
tests := []struct {
name string
n int
expected string
}{
{"zero", 0, ""},
{"negative", -1, ""},
{"one", 1, ", ?"},
{"two", 2, ", ?, ?"},
{"three", 3, ", ?, ?, ?"},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
result := repeatPlaceholders(tt.n)
assert.Equal(t, tt.expected, result)
})
}
}
func TestInt64SliceToInterface(t *testing.T) {
tests := []struct {
name string
input []int64
expected []interface{}
}{
{"empty", []int64{}, []interface{}{}},
{"single", []int64{42}, []interface{}{int64(42)}},
{"multiple", []int64{1, 2, 3}, []interface{}{int64(1), int64(2), int64(3)}},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
result := int64SliceToInterface(tt.input)
assert.Equal(t, tt.expected, result)
})
}
}
func TestParseLimitParam(t *testing.T) {
tests := []struct {
name string
query string
defaultLimit int
expected int
}{
{"no_param_uses_default", "", 10, 10},
{"valid_limit", "limit=20", 10, 20},
{"invalid_limit_uses_default", "limit=abc", 10, 10},
{"zero_limit_uses_default", "limit=0", 10, 10},
{"negative_limit_uses_default", "limit=-5", 10, 10},
{"large_limit", "limit=1000", 10, 1000},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
req := httptest.NewRequest("GET", "/?"+tt.query, nil)
result := ParseLimitParam(req, tt.defaultLimit)
assert.Equal(t, tt.expected, result)
})
}
}
func TestScanSummary(t *testing.T) {
db, _, cleanup := testDB(t)
defer cleanup()
createBaseTables(t, db)
seedSession(t, db, "claude-123", "sdk-123", "test-project")
// Insert a test summary
_, err := db.Exec(`
INSERT INTO session_summaries (sdk_session_id, project, request, investigated, learned, completed, next_steps, notes, prompt_number, discovery_tokens, created_at, created_at_epoch)
VALUES ('sdk-123', 'test-project', 'test request', 'test investigated', 'test learned', 'test completed', 'test next steps', 'test notes', 1, 100, '2025-01-01T00:00:00Z', 1704067200000)
`)
require.NoError(t, err)
// Query and scan
row := db.QueryRow(`
SELECT id, sdk_session_id, project, request, investigated, learned, completed, next_steps, notes, prompt_number, discovery_tokens, created_at, created_at_epoch
FROM session_summaries WHERE sdk_session_id = ?
`, "sdk-123")
summary, err := scanSummary(row)
require.NoError(t, err)
assert.NotNil(t, summary)
assert.Equal(t, "sdk-123", summary.SDKSessionID)
assert.Equal(t, "test-project", summary.Project)
assert.Equal(t, "test request", summary.Request.String)
assert.Equal(t, "test investigated", summary.Investigated.String)
assert.Equal(t, "test learned", summary.Learned.String)
assert.Equal(t, "test completed", summary.Completed.String)
assert.Equal(t, "test next steps", summary.NextSteps.String)
assert.Equal(t, "test notes", summary.Notes.String)
}
func TestScanSummaryRows(t *testing.T) {
db, _, cleanup := testDB(t)
defer cleanup()
createBaseTables(t, db)
seedSession(t, db, "claude-123", "sdk-123", "test-project")
// Insert multiple summaries
_, err := db.Exec(`
INSERT INTO session_summaries (sdk_session_id, project, request, investigated, learned, completed, next_steps, notes, prompt_number, discovery_tokens, created_at, created_at_epoch)
VALUES
('sdk-123', 'test-project', 'request 1', '', '', '', '', '', 1, 0, '2025-01-01T00:00:00Z', 1704067200000),
('sdk-123', 'test-project', 'request 2', '', '', '', '', '', 2, 0, '2025-01-02T00:00:00Z', 1704153600000)
`)
require.NoError(t, err)
rows, err := db.Query(`
SELECT id, sdk_session_id, project, request, investigated, learned, completed, next_steps, notes, prompt_number, discovery_tokens, created_at, created_at_epoch
FROM session_summaries WHERE sdk_session_id = ? ORDER BY id
`, "sdk-123")
require.NoError(t, err)
defer rows.Close()
summaries, err := scanSummaryRows(rows)
require.NoError(t, err)
assert.Len(t, summaries, 2)
assert.Equal(t, "request 1", summaries[0].Request.String)
assert.Equal(t, "request 2", summaries[1].Request.String)
}
func TestScanPromptWithSession(t *testing.T) {
db, _, cleanup := testDB(t)
defer cleanup()
createBaseTables(t, db)
seedSession(t, db, "claude-123", "sdk-123", "test-project")
// Insert a test prompt
_, err := db.Exec(`
INSERT INTO user_prompts (claude_session_id, prompt_number, prompt_text, matched_observations, created_at, created_at_epoch)
VALUES ('claude-123', 1, 'test prompt', 5, '2025-01-01T00:00:00Z', 1704067200000)
`)
require.NoError(t, err)
// Query with session join
row := db.QueryRow(`
SELECT p.id, p.claude_session_id, p.prompt_number, p.prompt_text, p.matched_observations, p.created_at, p.created_at_epoch, s.project, s.sdk_session_id
FROM user_prompts p
JOIN sdk_sessions s ON p.claude_session_id = s.claude_session_id
WHERE p.claude_session_id = ?
`, "claude-123")
prompt, err := scanPromptWithSession(row)
require.NoError(t, err)
assert.NotNil(t, prompt)
assert.Equal(t, "claude-123", prompt.ClaudeSessionID)
assert.Equal(t, 1, prompt.PromptNumber)
assert.Equal(t, "test prompt", prompt.PromptText)
assert.Equal(t, 5, prompt.MatchedObservations)
assert.Equal(t, "test-project", prompt.Project)
assert.Equal(t, "sdk-123", prompt.SDKSessionID)
}
func TestScanPromptWithSessionRows(t *testing.T) {
db, _, cleanup := testDB(t)
defer cleanup()
createBaseTables(t, db)
seedSession(t, db, "claude-123", "sdk-123", "test-project")
// Insert multiple prompts
_, err := db.Exec(`
INSERT INTO user_prompts (claude_session_id, prompt_number, prompt_text, matched_observations, created_at, created_at_epoch)
VALUES
('claude-123', 1, 'prompt one', 3, '2025-01-01T00:00:00Z', 1704067200000),
('claude-123', 2, 'prompt two', 5, '2025-01-02T00:00:00Z', 1704153600000)
`)
require.NoError(t, err)
rows, err := db.Query(`
SELECT p.id, p.claude_session_id, p.prompt_number, p.prompt_text, p.matched_observations, p.created_at, p.created_at_epoch, s.project, s.sdk_session_id
FROM user_prompts p
JOIN sdk_sessions s ON p.claude_session_id = s.claude_session_id
WHERE p.claude_session_id = ? ORDER BY p.id
`, "claude-123")
require.NoError(t, err)
defer rows.Close()
prompts, err := scanPromptWithSessionRows(rows)
require.NoError(t, err)
assert.Len(t, prompts, 2)
assert.Equal(t, "prompt one", prompts[0].PromptText)
assert.Equal(t, "prompt two", prompts[1].PromptText)
}
func TestParseLimitParam_HTTPRequest(t *testing.T) {
// Test with an actual HTTP request
handler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
limit := ParseLimitParam(r, 25)
if limit != 50 {
t.Errorf("Expected limit 50, got %d", limit)
}
})
req := httptest.NewRequest("GET", "http://example.com/api?limit=50", nil)
w := httptest.NewRecorder()
handler.ServeHTTP(w, req)
}
+196
View File
@@ -0,0 +1,196 @@
package sqlite
import (
"testing"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
func TestNewMigrationManager(t *testing.T) {
db, _, cleanup := testDB(t)
defer cleanup()
manager := NewMigrationManager(db)
require.NotNil(t, manager)
assert.Equal(t, db, manager.db)
}
func TestMigrationManager_EnsureSchemaVersionsTable(t *testing.T) {
db, _, cleanup := testDB(t)
defer cleanup()
manager := NewMigrationManager(db)
// Should create table without error
err := manager.EnsureSchemaVersionsTable()
require.NoError(t, err)
// Table should exist
var count int
err = db.QueryRow("SELECT COUNT(*) FROM schema_versions").Scan(&count)
require.NoError(t, err)
assert.Equal(t, 0, count) // Empty table
// Calling again should not error (IF NOT EXISTS)
err = manager.EnsureSchemaVersionsTable()
require.NoError(t, err)
}
func TestMigrationManager_GetAppliedVersions_Empty(t *testing.T) {
db, _, cleanup := testDB(t)
defer cleanup()
manager := NewMigrationManager(db)
err := manager.EnsureSchemaVersionsTable()
require.NoError(t, err)
versions, err := manager.GetAppliedVersions()
require.NoError(t, err)
assert.Empty(t, versions)
}
func TestMigrationManager_GetAppliedVersions_WithVersions(t *testing.T) {
db, _, cleanup := testDB(t)
defer cleanup()
manager := NewMigrationManager(db)
err := manager.EnsureSchemaVersionsTable()
require.NoError(t, err)
// Insert some versions
_, err = db.Exec("INSERT INTO schema_versions (version, applied_at) VALUES (1, '2025-01-01T00:00:00Z')")
require.NoError(t, err)
_, err = db.Exec("INSERT INTO schema_versions (version, applied_at) VALUES (2, '2025-01-02T00:00:00Z')")
require.NoError(t, err)
versions, err := manager.GetAppliedVersions()
require.NoError(t, err)
assert.Len(t, versions, 2)
assert.True(t, versions[1])
assert.True(t, versions[2])
assert.False(t, versions[3]) // Not applied
}
func TestMigrationManager_ApplyMigration(t *testing.T) {
db, _, cleanup := testDB(t)
defer cleanup()
manager := NewMigrationManager(db)
err := manager.EnsureSchemaVersionsTable()
require.NoError(t, err)
// Apply a simple migration
migration := Migration{
Version: 100,
Name: "test_migration",
SQL: "CREATE TABLE test_table (id INTEGER PRIMARY KEY, name TEXT)",
}
err = manager.ApplyMigration(migration)
require.NoError(t, err)
// Verify table was created
var count int
err = db.QueryRow("SELECT COUNT(*) FROM sqlite_master WHERE type='table' AND name='test_table'").Scan(&count)
require.NoError(t, err)
assert.Equal(t, 1, count)
// Verify migration was recorded
var version int
err = db.QueryRow("SELECT version FROM schema_versions WHERE version = 100").Scan(&version)
require.NoError(t, err)
assert.Equal(t, 100, version)
}
func TestMigrationManager_ApplyMigration_InvalidSQL(t *testing.T) {
db, _, cleanup := testDB(t)
defer cleanup()
manager := NewMigrationManager(db)
err := manager.EnsureSchemaVersionsTable()
require.NoError(t, err)
// Try to apply invalid migration
migration := Migration{
Version: 100,
Name: "invalid_migration",
SQL: "INVALID SQL SYNTAX",
}
err = manager.ApplyMigration(migration)
assert.Error(t, err)
assert.Contains(t, err.Error(), "execute migration 100")
}
func TestMigrationManager_RunMigrations_SingleMigration(t *testing.T) {
db, _, cleanup := testDB(t)
defer cleanup()
// Create a test migration manager with a subset of migrations
manager := NewMigrationManager(db)
// First ensure schema versions table exists
err := manager.EnsureSchemaVersionsTable()
require.NoError(t, err)
// Apply first migration manually
err = manager.ApplyMigration(Migrations[0])
require.NoError(t, err)
// Verify the first migration version was recorded
versions, err := manager.GetAppliedVersions()
require.NoError(t, err)
assert.True(t, versions[Migrations[0].Version])
}
func TestMigrationManager_RunMigrations_SkipsApplied(t *testing.T) {
db, _, cleanup := testDB(t)
defer cleanup()
manager := NewMigrationManager(db)
err := manager.EnsureSchemaVersionsTable()
require.NoError(t, err)
// Mark some migrations as already applied
_, err = db.Exec("INSERT INTO schema_versions (version, applied_at) VALUES (4, '2025-01-01T00:00:00Z')")
require.NoError(t, err)
// Get applied versions
versions, err := manager.GetAppliedVersions()
require.NoError(t, err)
assert.True(t, versions[4])
}
func TestMigration_Struct(t *testing.T) {
m := Migration{
Version: 1,
Name: "test",
SQL: "SELECT 1",
}
assert.Equal(t, 1, m.Version)
assert.Equal(t, "test", m.Name)
assert.Equal(t, "SELECT 1", m.SQL)
}
func TestMigrations_List(t *testing.T) {
// Verify migrations are ordered correctly
assert.NotEmpty(t, Migrations)
// Verify all migrations have required fields
for i, m := range Migrations {
assert.Greater(t, m.Version, 0, "Migration %d has invalid version", i)
assert.NotEmpty(t, m.Name, "Migration %d has empty name", i)
assert.NotEmpty(t, m.SQL, "Migration %d has empty SQL", i)
}
// Verify key migrations exist
versionSet := make(map[int]bool)
for _, m := range Migrations {
versionSet[m.Version] = true
}
assert.True(t, versionSet[4], "Should have sdk_agent_architecture migration")
assert.True(t, versionSet[17], "Should have sqlite_vec_vectors migration")
}
+145
View File
@@ -800,3 +800,148 @@ func TestExtractKeywords(t *testing.T) {
})
}
}
func TestObservationStore_GetObservationsByProjectStrict(t *testing.T) {
obsStore, _, cleanup := testObservationStore(t)
defer cleanup()
ctx := context.Background()
// Create project-scoped observation for project-a
projectObs := &models.ParsedObservation{
Type: models.ObsTypeDiscovery,
Title: "Project A specific",
Narrative: "Only for project-a",
Concepts: []string{"local-concept"},
}
_, _, err := obsStore.StoreObservation(ctx, "session-1", "project-a", projectObs, 1, 100)
require.NoError(t, err)
// Create global observation from project-a
globalObs := &models.ParsedObservation{
Type: models.ObsTypeDiscovery,
Title: "Global security practice",
Narrative: "Best practice for all",
Concepts: []string{"security", "best-practice"},
}
_, _, err = obsStore.StoreObservation(ctx, "session-1", "project-a", globalObs, 2, 100)
require.NoError(t, err)
// Create observation for project-b
projectBObs := &models.ParsedObservation{
Type: models.ObsTypeDiscovery,
Title: "Project B specific",
Narrative: "Only for project-b",
}
_, _, err = obsStore.StoreObservation(ctx, "session-1", "project-b", projectBObs, 1, 100)
require.NoError(t, err)
// GetObservationsByProjectStrict for project-a should only return project-a observations
// This is different from GetRecentObservations which includes globals from other projects
results, err := obsStore.GetObservationsByProjectStrict(ctx, "project-a", 10)
require.NoError(t, err)
assert.Len(t, results, 2) // Only observations created in project-a
// Verify both are from project-a
for _, obs := range results {
assert.Equal(t, "project-a", obs.Project)
}
// GetObservationsByProjectStrict for project-b should only return project-b observations
results, err = obsStore.GetObservationsByProjectStrict(ctx, "project-b", 10)
require.NoError(t, err)
assert.Len(t, results, 1)
assert.Equal(t, "Project B specific", results[0].Title.String)
}
func TestObservationStore_SearchObservationsFTS_EmptyQuery(t *testing.T) {
obsStore, _, cleanup := testObservationStore(t)
defer cleanup()
ctx := context.Background()
// Create an observation
obs := &models.ParsedObservation{
Type: models.ObsTypeDiscovery,
Title: "Test observation",
Narrative: "Some content here",
}
_, _, err := obsStore.StoreObservation(ctx, "session-1", "project-a", obs, 1, 100)
require.NoError(t, err)
// Search with only stop words (should return nil)
results, err := obsStore.SearchObservationsFTS(ctx, "the a an is are", "project-a", 10)
require.NoError(t, err)
assert.Nil(t, results)
// Search with empty query
results, err = obsStore.SearchObservationsFTS(ctx, "", "project-a", 10)
require.NoError(t, err)
assert.Nil(t, results)
}
func TestObservationStore_SearchObservationsFTS_DefaultLimit(t *testing.T) {
obsStore, _, cleanup := testObservationStore(t)
defer cleanup()
ctx := context.Background()
// Create observations
for i := 0; i < 15; i++ {
obs := &models.ParsedObservation{
Type: models.ObsTypeDiscovery,
Title: "Authentication test " + string(rune('A'+i)),
Narrative: "Auth related content",
}
_, _, err := obsStore.StoreObservation(ctx, "session-1", "project-a", obs, i+1, 100)
require.NoError(t, err)
time.Sleep(time.Millisecond)
}
// Search with limit 0 (should default to 10)
results, err := obsStore.SearchObservationsFTS(ctx, "authentication", "project-a", 0)
require.NoError(t, err)
assert.LessOrEqual(t, len(results), 10)
// Search with negative limit (should default to 10)
results, err = obsStore.SearchObservationsFTS(ctx, "authentication", "project-a", -5)
require.NoError(t, err)
assert.LessOrEqual(t, len(results), 10)
}
func TestObservationStore_GetAllRecentObservations(t *testing.T) {
obsStore, _, cleanup := testObservationStore(t)
defer cleanup()
ctx := context.Background()
// Create observations across different projects
projects := []string{"project-a", "project-b", "project-c"}
for _, proj := range projects {
for i := 0; i < 3; i++ {
obs := &models.ParsedObservation{
Type: models.ObsTypeDiscovery,
Title: proj + " observation " + string(rune('A'+i)),
Narrative: "Content for " + proj,
}
_, _, err := obsStore.StoreObservation(ctx, "session-1", proj, obs, i+1, 100)
require.NoError(t, err)
time.Sleep(time.Millisecond)
}
}
// Get all recent observations
results, err := obsStore.GetAllRecentObservations(ctx, 100)
require.NoError(t, err)
assert.Len(t, results, 9) // 3 projects * 3 observations
// Verify they are in descending order by epoch
for i := 1; i < len(results); i++ {
assert.GreaterOrEqual(t, results[i-1].CreatedAtEpoch, results[i].CreatedAtEpoch)
}
// Test with limit
results, err = obsStore.GetAllRecentObservations(ctx, 5)
require.NoError(t, err)
assert.Len(t, results, 5)
}
+332
View File
@@ -0,0 +1,332 @@
package embedding
import (
"math"
"sync"
"testing"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
// TestEmbeddingDim verifies the embedding dimension constant.
func TestEmbeddingDim(t *testing.T) {
assert.Equal(t, 384, EmbeddingDim)
}
// TestNewService tests creating a new embedding service.
func TestNewService(t *testing.T) {
svc, err := NewService()
require.NoError(t, err)
require.NotNil(t, svc)
defer svc.Close()
assert.NotNil(t, svc.tk)
assert.NotNil(t, svc.session)
}
// TestEmbed_SingleText tests embedding a single text.
func TestEmbed_SingleText(t *testing.T) {
svc, err := NewService()
require.NoError(t, err)
defer svc.Close()
embedding, err := svc.Embed("Hello, world!")
require.NoError(t, err)
assert.Len(t, embedding, EmbeddingDim)
// Verify non-zero embedding
var sum float32
for _, v := range embedding {
sum += v * v
}
assert.Greater(t, sum, float32(0), "Embedding should not be all zeros")
}
// TestEmbed_EmptyText tests embedding an empty string.
func TestEmbed_EmptyText(t *testing.T) {
svc, err := NewService()
require.NoError(t, err)
defer svc.Close()
embedding, err := svc.Embed("")
require.NoError(t, err)
assert.Len(t, embedding, EmbeddingDim)
// Empty text should return zero vector
for _, v := range embedding {
assert.Equal(t, float32(0), v)
}
}
// TestEmbed_SimilarTexts tests that similar texts produce similar embeddings.
func TestEmbed_SimilarTexts(t *testing.T) {
svc, err := NewService()
require.NoError(t, err)
defer svc.Close()
emb1, err := svc.Embed("The quick brown fox jumps over the lazy dog.")
require.NoError(t, err)
emb2, err := svc.Embed("A fast brown fox leaps over a sleepy dog.")
require.NoError(t, err)
emb3, err := svc.Embed("Go programming language concurrency patterns.")
require.NoError(t, err)
// Calculate cosine similarity
sim12 := cosineSimilarity(emb1, emb2)
sim13 := cosineSimilarity(emb1, emb3)
// Similar texts should have higher similarity
assert.Greater(t, sim12, sim13, "Similar sentences should have higher similarity than dissimilar ones")
assert.Greater(t, sim12, float64(0.7), "Similar sentences should have high similarity")
}
// TestEmbedBatch_MultipleTexts tests batch embedding.
func TestEmbedBatch_MultipleTexts(t *testing.T) {
svc, err := NewService()
require.NoError(t, err)
defer svc.Close()
texts := []string{
"First text about programming.",
"Second text about databases.",
"Third text about machine learning.",
}
embeddings, err := svc.EmbedBatch(texts)
require.NoError(t, err)
assert.Len(t, embeddings, len(texts))
for i, emb := range embeddings {
assert.Len(t, emb, EmbeddingDim, "Embedding %d should have correct dimension", i)
}
}
// TestEmbedBatch_EmptySlice tests batch embedding with empty slice.
func TestEmbedBatch_EmptySlice(t *testing.T) {
svc, err := NewService()
require.NoError(t, err)
defer svc.Close()
embeddings, err := svc.EmbedBatch([]string{})
require.NoError(t, err)
assert.Nil(t, embeddings)
}
// TestEmbedBatch_WithEmptyTexts tests batch embedding with some empty texts.
func TestEmbedBatch_WithEmptyTexts(t *testing.T) {
svc, err := NewService()
require.NoError(t, err)
defer svc.Close()
texts := []string{
"Valid text one.",
"",
"Valid text two.",
"",
}
embeddings, err := svc.EmbedBatch(texts)
require.NoError(t, err)
assert.Len(t, embeddings, 4)
// Non-empty texts should have non-zero embeddings
var sum0 float32
for _, v := range embeddings[0] {
sum0 += v * v
}
assert.Greater(t, sum0, float32(0))
// Empty texts should have zero embeddings
for _, v := range embeddings[1] {
assert.Equal(t, float32(0), v)
}
var sum2 float32
for _, v := range embeddings[2] {
sum2 += v * v
}
assert.Greater(t, sum2, float32(0))
for _, v := range embeddings[3] {
assert.Equal(t, float32(0), v)
}
}
// TestEmbedBatch_AllEmpty tests batch embedding with all empty texts.
func TestEmbedBatch_AllEmpty(t *testing.T) {
svc, err := NewService()
require.NoError(t, err)
defer svc.Close()
texts := []string{"", "", ""}
embeddings, err := svc.EmbedBatch(texts)
require.NoError(t, err)
assert.Len(t, embeddings, 3)
for i, emb := range embeddings {
assert.Len(t, emb, EmbeddingDim, "Embedding %d should have correct dimension", i)
for j, v := range emb {
assert.Equal(t, float32(0), v, "Embedding %d[%d] should be zero", i, j)
}
}
}
// TestEmbed_Concurrent tests concurrent embedding calls.
func TestEmbed_Concurrent(t *testing.T) {
svc, err := NewService()
require.NoError(t, err)
defer svc.Close()
var wg sync.WaitGroup
numGoroutines := 10
errors := make(chan error, numGoroutines)
embeddings := make([][]float32, numGoroutines)
for i := 0; i < numGoroutines; i++ {
wg.Add(1)
go func(idx int) {
defer wg.Done()
text := "Test text for concurrent embedding test"
emb, err := svc.Embed(text)
if err != nil {
errors <- err
return
}
embeddings[idx] = emb
}(i)
}
wg.Wait()
close(errors)
for err := range errors {
t.Errorf("Concurrent embedding error: %v", err)
}
// All embeddings should be valid
for i, emb := range embeddings {
if emb != nil {
assert.Len(t, emb, EmbeddingDim, "Embedding %d should have correct dimension", i)
}
}
}
// TestEmbed_SpecialCharacters tests embedding text with special characters.
func TestEmbed_SpecialCharacters(t *testing.T) {
svc, err := NewService()
require.NoError(t, err)
defer svc.Close()
texts := []string{
"Text with unicode: 你好世界 🎉",
"Text with newlines:\nLine 1\nLine 2",
"Text with tabs:\tColumn1\tColumn2",
"Text with quotes: \"quoted\" and 'single'",
"Text with code: func main() { fmt.Println(\"hello\") }",
}
for _, text := range texts {
t.Run(text[:20], func(t *testing.T) {
emb, err := svc.Embed(text)
require.NoError(t, err)
assert.Len(t, emb, EmbeddingDim)
})
}
}
// TestEmbed_LongText tests embedding long text.
func TestEmbed_LongText(t *testing.T) {
svc, err := NewService()
require.NoError(t, err)
defer svc.Close()
// Create a long text (tokenizer should truncate appropriately)
longText := ""
for i := 0; i < 100; i++ {
longText += "This is a sentence to make the text very long. "
}
emb, err := svc.Embed(longText)
require.NoError(t, err)
assert.Len(t, emb, EmbeddingDim)
}
// TestClose tests closing the service.
func TestClose(t *testing.T) {
svc, err := NewService()
require.NoError(t, err)
err = svc.Close()
require.NoError(t, err)
// Session should be nil after close
assert.Nil(t, svc.session)
}
// TestEmbedBatch_SingleItem tests batch embedding with single item.
func TestEmbedBatch_SingleItem(t *testing.T) {
svc, err := NewService()
require.NoError(t, err)
defer svc.Close()
texts := []string{"Single text for batch embedding."}
embeddings, err := svc.EmbedBatch(texts)
require.NoError(t, err)
assert.Len(t, embeddings, 1)
assert.Len(t, embeddings[0], EmbeddingDim)
}
// TestEmbed_Deterministic tests that embedding is deterministic.
func TestEmbed_Deterministic(t *testing.T) {
svc, err := NewService()
require.NoError(t, err)
defer svc.Close()
text := "Test text for deterministic embedding."
emb1, err := svc.Embed(text)
require.NoError(t, err)
emb2, err := svc.Embed(text)
require.NoError(t, err)
// Same text should produce same embedding
for i := 0; i < EmbeddingDim; i++ {
assert.Equal(t, emb1[i], emb2[i], "Embedding should be deterministic at index %d", i)
}
}
// Helper function to calculate cosine similarity
func cosineSimilarity(a, b []float32) float64 {
if len(a) != len(b) {
return 0
}
var dotProduct float64
var normA float64
var normB float64
for i := range a {
dotProduct += float64(a[i]) * float64(b[i])
normA += float64(a[i]) * float64(a[i])
normB += float64(b[i]) * float64(b[i])
}
if normA == 0 || normB == 0 {
return 0
}
return dotProduct / (math.Sqrt(normA) * math.Sqrt(normB))
}
+360
View File
@@ -597,3 +597,363 @@ func TestToolListContainsExpectedSchemas(t *testing.T) {
assert.True(t, hasType, "tool %s schema should have type", tool.Name)
}
}
// TestHandleToolsCall_UnknownTool tests tools/call with unknown tool name.
func TestHandleToolsCall_UnknownTool(t *testing.T) {
server := NewServer(nil, "1.0.0")
ctx := context.Background()
req := &Request{
JSONRPC: "2.0",
ID: 1,
Method: "tools/call",
Params: json.RawMessage(`{"name":"unknown_tool","arguments":{}}`),
}
resp := server.handleToolsCall(ctx, req)
require.NotNil(t, resp.Error)
assert.Equal(t, -32000, resp.Error.Code)
assert.Contains(t, resp.Error.Data, "unknown tool")
}
// TestCallTool_ToolNameRecognition tests that valid tool names are recognized (not "unknown tool").
func TestCallTool_ToolNameRecognition(t *testing.T) {
// Note: This test verifies tool routing logic, not execution (which requires searchMgr)
// All valid tool names should be in the handleToolsList response
server := NewServer(nil, "1.0.0")
req := &Request{
JSONRPC: "2.0",
ID: 1,
Method: "tools/list",
}
resp := server.handleToolsList(req)
result := resp.Result.(map[string]any)
tools := result["tools"].([]Tool)
// Verify all expected tools are registered
expectedTools := map[string]bool{
"search": true,
"timeline": true,
"decisions": true,
"changes": true,
"how_it_works": true,
"find_by_concept": true,
"find_by_file": true,
"find_by_type": true,
"get_recent_context": true,
"get_context_timeline": true,
"get_timeline_by_query": true,
}
foundTools := make(map[string]bool)
for _, tool := range tools {
foundTools[tool.Name] = true
}
for name := range expectedTools {
assert.True(t, foundTools[name], "tool %s should be registered", name)
}
}
// TestRun_MultipleRequests tests Run with multiple sequential requests.
func TestRun_MultipleRequests(t *testing.T) {
var stdout bytes.Buffer
req1 := `{"jsonrpc":"2.0","id":1,"method":"initialize"}`
req2 := `{"jsonrpc":"2.0","id":2,"method":"tools/list"}`
stdin := strings.NewReader(req1 + "\n" + req2 + "\n")
server := &Server{
stdin: stdin,
stdout: &stdout,
version: "1.0.0",
}
err := server.Run(context.Background())
require.NoError(t, err)
output := stdout.String()
// Should contain responses for both requests
assert.Contains(t, output, `"id":1`)
assert.Contains(t, output, `"id":2`)
}
// TestHandleTimeline_Defaults tests timeline default values.
func TestHandleTimeline_Defaults(t *testing.T) {
// Test that handleTimeline sets default before/after values
params := TimelineParams{
AnchorID: 0,
Query: "",
Before: 0,
After: 0,
}
// Simulate the default value assignment from handleTimeline
if params.Before <= 0 {
params.Before = 10
}
if params.After <= 0 {
params.After = 10
}
assert.Equal(t, 10, params.Before)
assert.Equal(t, 10, params.After)
}
// TestTimelineParams_Complete tests complete TimelineParams parsing.
func TestTimelineParams_Complete(t *testing.T) {
input := `{
"anchor_id": 100,
"query": "test query",
"before": 5,
"after": 15,
"project": "my-project",
"obs_type": "bugfix",
"concepts": "security,auth",
"files": "main.go,handler.go",
"dateStart": 1700000000000,
"dateEnd": 1700100000000,
"format": "full"
}`
var params TimelineParams
err := json.Unmarshal([]byte(input), &params)
require.NoError(t, err)
assert.Equal(t, int64(100), params.AnchorID)
assert.Equal(t, "test query", params.Query)
assert.Equal(t, 5, params.Before)
assert.Equal(t, 15, params.After)
assert.Equal(t, "my-project", params.Project)
assert.Equal(t, "bugfix", params.ObsType)
assert.Equal(t, "security,auth", params.Concepts)
assert.Equal(t, "main.go,handler.go", params.Files)
assert.Equal(t, int64(1700000000000), params.DateStart)
assert.Equal(t, int64(1700100000000), params.DateEnd)
assert.Equal(t, "full", params.Format)
}
// TestServerStdinStdoutConfig tests that server stdin/stdout can be configured.
func TestServerStdinStdoutConfig(t *testing.T) {
var stdout bytes.Buffer
var stdin bytes.Buffer
server := &Server{
stdin: &stdin,
stdout: &stdout,
version: "test-version",
}
assert.Equal(t, &stdin, server.stdin)
assert.Equal(t, &stdout, server.stdout)
assert.Equal(t, "test-version", server.version)
}
// TestResponseIDTypes tests that response IDs can be various types.
func TestResponseIDTypes(t *testing.T) {
tests := []struct {
name string
id any
}{
{"integer id", 1},
{"string id", "abc-123"},
{"float id", 1.5},
{"null id", nil},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
var buf bytes.Buffer
server := &Server{stdout: &buf}
resp := &Response{
JSONRPC: "2.0",
ID: tt.id,
Result: "ok",
}
server.sendResponse(resp)
output := buf.String()
assert.Contains(t, output, `"jsonrpc":"2.0"`)
})
}
}
// TestHandleTimelineByQuery_EmptyQuery tests timeline by query with empty query.
func TestHandleTimelineByQuery_EmptyQuery(t *testing.T) {
server := NewServer(nil, "1.0.0")
ctx := context.Background()
// Empty query should error
_, err := server.handleTimelineByQuery(ctx, json.RawMessage(`{}`))
require.Error(t, err)
assert.Contains(t, err.Error(), "query is required")
}
// TestHandleTimeline_InvalidJSON tests timeline with invalid JSON.
func TestHandleTimeline_InvalidJSON(t *testing.T) {
server := NewServer(nil, "1.0.0")
ctx := context.Background()
_, err := server.handleTimeline(ctx, json.RawMessage(`{invalid`))
require.Error(t, err)
assert.Contains(t, err.Error(), "invalid timeline params")
}
// TestHandleTimelineByQuery_InvalidJSON tests timeline by query with invalid JSON.
func TestHandleTimelineByQuery_InvalidJSON(t *testing.T) {
server := NewServer(nil, "1.0.0")
ctx := context.Background()
_, err := server.handleTimelineByQuery(ctx, json.RawMessage(`{invalid`))
require.Error(t, err)
assert.Contains(t, err.Error(), "invalid timeline params")
}
// TestHandleTimeline_NoAnchorNoQuery tests timeline with no anchor and no query.
func TestHandleTimeline_NoAnchorNoQuery(t *testing.T) {
server := NewServer(nil, "1.0.0")
ctx := context.Background()
// No anchor_id and no query should return empty result
result, err := server.handleTimeline(ctx, json.RawMessage(`{}`))
require.NoError(t, err)
assert.NotNil(t, result)
assert.Empty(t, result.Results)
}
// TestHandleTimeline_WithDefaults tests timeline default values are applied.
func TestHandleTimeline_WithDefaults(t *testing.T) {
server := NewServer(nil, "1.0.0")
ctx := context.Background()
// With anchor_id but no before/after, defaults should be applied
// However, since searchMgr is nil, this will fail after defaults are applied
result, err := server.handleTimeline(ctx, json.RawMessage(`{"anchor_id": 0}`))
// Should return empty result since anchor_id is 0
require.NoError(t, err)
assert.NotNil(t, result)
assert.Empty(t, result.Results)
}
// TestServerFields tests Server struct fields.
func TestServerFields(t *testing.T) {
server := NewServer(nil, "2.0.0")
assert.Equal(t, "2.0.0", server.version)
assert.Nil(t, server.searchMgr)
assert.NotNil(t, server.stdin)
assert.NotNil(t, server.stdout)
}
// TestRequestUnmarshalWithNullID tests Request unmarshaling with null ID.
func TestRequestUnmarshalWithNullID(t *testing.T) {
input := `{"jsonrpc":"2.0","id":null,"method":"initialize"}`
var req Request
err := json.Unmarshal([]byte(input), &req)
require.NoError(t, err)
assert.Equal(t, "2.0", req.JSONRPC)
assert.Nil(t, req.ID)
assert.Equal(t, "initialize", req.Method)
}
// TestResponseWithNullError tests Response without error.
func TestResponseWithNullError(t *testing.T) {
resp := Response{
JSONRPC: "2.0",
ID: 1,
Result: "success",
Error: nil,
}
data, err := json.Marshal(resp)
require.NoError(t, err)
assert.Contains(t, string(data), `"result":"success"`)
assert.NotContains(t, string(data), `"error"`)
}
// TestErrorWithNilData tests Error without data.
func TestErrorWithNilData(t *testing.T) {
err := Error{
Code: -32600,
Message: "Invalid Request",
Data: nil,
}
data, errMarshal := json.Marshal(err)
require.NoError(t, errMarshal)
assert.Contains(t, string(data), `"code":-32600`)
assert.Contains(t, string(data), `"message":"Invalid Request"`)
assert.NotContains(t, string(data), `"data"`)
}
// TestToolInputSchema tests that tool input schemas have required fields.
func TestToolInputSchema(t *testing.T) {
server := NewServer(nil, "1.0.0")
req := &Request{
JSONRPC: "2.0",
ID: 1,
Method: "tools/list",
}
resp := server.handleToolsList(req)
result := resp.Result.(map[string]any)
tools := result["tools"].([]Tool)
for _, tool := range tools {
schema := tool.InputSchema
schemaType, ok := schema["type"]
assert.True(t, ok, "tool %s schema should have type", tool.Name)
assert.Equal(t, "object", schemaType, "tool %s schema type should be object", tool.Name)
// All tools should have properties
_, hasProperties := schema["properties"]
assert.True(t, hasProperties, "tool %s should have properties", tool.Name)
}
}
// TestRunMixedRequests tests Run with mixed valid and invalid requests.
func TestRunMixedRequests(t *testing.T) {
var stdout bytes.Buffer
req1 := `{"jsonrpc":"2.0","id":1,"method":"initialize"}`
req2 := `invalid json`
req3 := `{"jsonrpc":"2.0","id":3,"method":"tools/list"}`
stdin := strings.NewReader(req1 + "\n" + req2 + "\n" + req3 + "\n")
server := &Server{
stdin: stdin,
stdout: &stdout,
version: "1.0.0",
}
err := server.Run(context.Background())
require.NoError(t, err)
output := stdout.String()
// Should have responses for all three requests
assert.Contains(t, output, `"id":1`)
assert.Contains(t, output, `"error"`) // Parse error for invalid json
assert.Contains(t, output, `"id":3`)
}
// TestToolCallParamsWithComplexArgs tests ToolCallParams with complex arguments.
func TestToolCallParamsWithComplexArgs(t *testing.T) {
input := `{
"name": "search",
"arguments": {
"query": "authentication bug",
"project": "my-project",
"limit": 10,
"type": "observations"
}
}`
var params ToolCallParams
err := json.Unmarshal([]byte(input), &params)
require.NoError(t, err)
assert.Equal(t, "search", params.Name)
assert.NotEmpty(t, params.Arguments)
}
+646
View File
@@ -0,0 +1,646 @@
// Package search provides unified search capabilities for claude-mnemonic.
package search
import (
"context"
"database/sql"
"os"
"testing"
"github.com/lukaszraczylo/claude-mnemonic/internal/db/sqlite"
"github.com/lukaszraczylo/claude-mnemonic/pkg/models"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"github.com/stretchr/testify/suite"
// Import sqlite driver
_ "github.com/mattn/go-sqlite3"
)
// Ensure context is used (for later tests)
var _ = context.Background
// hasFTS5 checks if FTS5 is available in the SQLite build.
func hasFTS5(t *testing.T) bool {
t.Helper()
tmpDir, err := os.MkdirTemp("", "fts5-check-*")
if err != nil {
return false
}
defer func() { _ = os.RemoveAll(tmpDir) }()
dbPath := tmpDir + "/check.db"
db, err := sql.Open("sqlite3", dbPath)
if err != nil {
return false
}
defer func() { _ = db.Close() }()
_, err = db.Exec("CREATE VIRTUAL TABLE IF NOT EXISTS fts5_test USING fts5(content)")
if err != nil {
return false
}
_, _ = db.Exec("DROP TABLE IF EXISTS fts5_test")
return true
}
// testStore creates a sqlite.Store with a temporary database for testing.
func testStore(t *testing.T) (*sqlite.Store, func()) {
t.Helper()
if !hasFTS5(t) {
t.Skip("FTS5 not available in this SQLite build")
}
tmpDir, err := os.MkdirTemp("", "search-integration-test-*")
require.NoError(t, err)
dbPath := tmpDir + "/test.db"
store, err := sqlite.NewStore(sqlite.StoreConfig{
Path: dbPath,
MaxConns: 1,
WALMode: true,
})
require.NoError(t, err)
cleanup := func() {
_ = store.Close()
_ = os.RemoveAll(tmpDir)
}
return store, cleanup
}
// SearchIntegrationSuite tests search with real SQLite stores.
type SearchIntegrationSuite struct {
suite.Suite
store *sqlite.Store
cleanup func()
manager *Manager
obsStore *sqlite.ObservationStore
sumStore *sqlite.SummaryStore
prmStore *sqlite.PromptStore
}
func (s *SearchIntegrationSuite) SetupTest() {
if !hasFTS5(s.T()) {
s.T().Skip("FTS5 not available in this SQLite build")
}
s.store, s.cleanup = testStore(s.T())
// Create real stores backed by SQLite
s.obsStore = sqlite.NewObservationStore(s.store)
s.sumStore = sqlite.NewSummaryStore(s.store)
s.prmStore = sqlite.NewPromptStore(s.store)
// Create search manager with real stores (no vector client for now)
s.manager = NewManager(s.obsStore, s.sumStore, s.prmStore, nil)
}
func (s *SearchIntegrationSuite) TearDownTest() {
if s.cleanup != nil {
s.cleanup()
}
}
func TestSearchIntegrationSuite(t *testing.T) {
suite.Run(t, new(SearchIntegrationSuite))
}
// seedObservations inserts test observations into the database.
func (s *SearchIntegrationSuite) seedObservations(ctx context.Context) []int64 {
var ids []int64
// Observation 1: Authentication bug fix
obs1 := &models.ParsedObservation{
Type: models.ObsTypeBugfix,
Scope: models.ScopeProject,
Title: "Fixed Authentication Bug",
Narrative: "Resolved JWT token validation issue that caused intermittent login failures",
Concepts: []string{"authentication", "jwt", "security"},
FilesRead: []string{"auth/handler.go", "auth/jwt.go"},
}
id1, _, err := s.obsStore.StoreObservation(ctx, "sdk-sess-1", "test-project", obs1, 1, 100)
s.Require().NoError(err)
ids = append(ids, id1)
// Observation 2: Database optimization decision
obs2 := &models.ParsedObservation{
Type: models.ObsTypeDecision,
Scope: models.ScopeProject,
Title: "Database Query Optimization Decision",
Narrative: "Decided to add indexes on user_id and created_at columns for better performance",
Concepts: []string{"database", "performance", "decision"},
FilesRead: []string{"db/migrations/001.sql"},
}
id2, _, err := s.obsStore.StoreObservation(ctx, "sdk-sess-1", "test-project", obs2, 2, 150)
s.Require().NoError(err)
ids = append(ids, id2)
// Observation 3: Global best practice
obs3 := &models.ParsedObservation{
Type: models.ObsTypeDiscovery,
Scope: models.ScopeGlobal,
Title: "Error Handling Best Practice",
Narrative: "Use wrapped errors with context for better debugging: errors.Wrap(err, context)",
Concepts: []string{"best-practice", "errors", "patterns"},
FilesRead: []string{"pkg/errors/errors.go"},
}
id3, _, err := s.obsStore.StoreObservation(ctx, "sdk-sess-2", "other-project", obs3, 1, 80)
s.Require().NoError(err)
ids = append(ids, id3)
// Observation 4: Code change/refactoring
obs4 := &models.ParsedObservation{
Type: models.ObsTypeChange,
Scope: models.ScopeProject,
Title: "Refactored User Service",
Narrative: "Changed user service to use repository pattern, modified interfaces for better testability",
Concepts: []string{"refactoring", "architecture"},
FilesModified: []string{"services/user.go", "services/user_test.go"},
}
id4, _, err := s.obsStore.StoreObservation(ctx, "sdk-sess-1", "test-project", obs4, 3, 200)
s.Require().NoError(err)
ids = append(ids, id4)
return ids
}
// seedSummaries inserts test session summaries into the database.
func (s *SearchIntegrationSuite) seedSummaries(ctx context.Context) []int64 {
var ids []int64
// Summary 1
sum1 := &models.ParsedSummary{
Request: "Fix authentication bug",
Investigated: "JWT token validation and session handling",
Learned: "JWT validation requires algorithm check to prevent alg:none attacks",
Completed: "Fixed JWT validation, added tests",
NextSteps: "Review other security endpoints",
}
id1, _, err := s.sumStore.StoreSummary(ctx, "sdk-sess-1", "test-project", sum1, 1, 100)
s.Require().NoError(err)
ids = append(ids, id1)
// Summary 2
sum2 := &models.ParsedSummary{
Request: "Optimize database queries",
Investigated: "Query execution plans and index usage",
Learned: "Composite indexes work better for range queries",
Completed: "Added indexes, verified performance improvement",
NextSteps: "Monitor query times in production",
}
id2, _, err := s.sumStore.StoreSummary(ctx, "sdk-sess-1", "test-project", sum2, 2, 150)
s.Require().NoError(err)
ids = append(ids, id2)
return ids
}
// TestFilterSearch_WithRealStores tests filterSearch with seeded data.
func (s *SearchIntegrationSuite) TestFilterSearch_WithRealStores() {
ctx := context.Background()
// Seed test data
obsIDs := s.seedObservations(ctx)
sumIDs := s.seedSummaries(ctx)
s.Require().Len(obsIDs, 4)
s.Require().Len(sumIDs, 2)
// Test filter search for observations only
result, err := s.manager.filterSearch(ctx, SearchParams{
Project: "test-project",
Type: "observations",
Limit: 10,
Format: "full",
})
s.Require().NoError(err)
s.NotNil(result)
// Should return project observations + global observation (4 total: 3 project + 1 global)
s.GreaterOrEqual(len(result.Results), 3)
// Verify result types
for _, r := range result.Results {
s.Equal("observation", r.Type)
}
}
// TestFilterSearch_SessionsOnly tests filterSearch for sessions.
func (s *SearchIntegrationSuite) TestFilterSearch_SessionsOnly() {
ctx := context.Background()
// Seed test data
_ = s.seedObservations(ctx)
sumIDs := s.seedSummaries(ctx)
s.Require().Len(sumIDs, 2)
// Test filter search for sessions only
result, err := s.manager.filterSearch(ctx, SearchParams{
Project: "test-project",
Type: "sessions",
Limit: 10,
Format: "full",
})
s.Require().NoError(err)
s.NotNil(result)
// Should return 2 summaries
s.Len(result.Results, 2)
// Verify result types
for _, r := range result.Results {
s.Equal("session", r.Type)
s.NotEmpty(r.Title) // Title should be populated from Request
}
}
// TestFilterSearch_AllTypes tests filterSearch for all types.
func (s *SearchIntegrationSuite) TestFilterSearch_AllTypes() {
ctx := context.Background()
// Seed test data
obsIDs := s.seedObservations(ctx)
sumIDs := s.seedSummaries(ctx)
s.Require().Len(obsIDs, 4)
s.Require().Len(sumIDs, 2)
// Test filter search for all types (Type = "")
result, err := s.manager.filterSearch(ctx, SearchParams{
Project: "test-project",
Type: "", // All types
Limit: 20,
Format: "full",
})
s.Require().NoError(err)
s.NotNil(result)
// Should return both observations and sessions
hasObservations := false
hasSessions := false
for _, r := range result.Results {
if r.Type == "observation" {
hasObservations = true
}
if r.Type == "session" {
hasSessions = true
}
}
s.True(hasObservations, "Should have observation results")
s.True(hasSessions, "Should have session results")
}
// TestUnifiedSearch_DefaultLimit tests UnifiedSearch with default limit.
func (s *SearchIntegrationSuite) TestUnifiedSearch_DefaultLimit() {
ctx := context.Background()
// Seed test data
s.seedObservations(ctx)
s.seedSummaries(ctx)
// Test with no limit specified (should default to 20)
result, err := s.manager.UnifiedSearch(ctx, SearchParams{
Project: "test-project",
})
s.Require().NoError(err)
s.NotNil(result)
s.LessOrEqual(len(result.Results), 20)
}
// TestUnifiedSearch_LimitCapping tests UnifiedSearch limit capping.
func (s *SearchIntegrationSuite) TestUnifiedSearch_LimitCapping() {
ctx := context.Background()
// Seed test data
s.seedObservations(ctx)
s.seedSummaries(ctx)
// Test with limit > 100 (should be capped to 100)
result, err := s.manager.UnifiedSearch(ctx, SearchParams{
Project: "test-project",
Limit: 500,
})
s.Require().NoError(err)
s.NotNil(result)
s.LessOrEqual(len(result.Results), 100)
}
// TestDecisions_WithRealStores tests the Decisions method falls back to filterSearch.
func (s *SearchIntegrationSuite) TestDecisions_WithRealStores() {
ctx := context.Background()
// Seed test data
s.seedObservations(ctx)
// Test Decisions search (without vector client, falls back to filterSearch)
result, err := s.manager.Decisions(ctx, SearchParams{
Project: "test-project",
Query: "database",
Limit: 10,
})
s.Require().NoError(err)
s.NotNil(result)
// Without vector client, falls back to filterSearch which returns observations
// All results should be observations (type is forced to "observations" in Decisions)
for _, r := range result.Results {
s.Equal("observation", r.Type)
}
}
// TestChanges_WithRealStores tests the Changes method falls back to filterSearch.
func (s *SearchIntegrationSuite) TestChanges_WithRealStores() {
ctx := context.Background()
// Seed test data
s.seedObservations(ctx)
// Test Changes search (without vector client, falls back to filterSearch)
result, err := s.manager.Changes(ctx, SearchParams{
Project: "test-project",
Query: "user service",
Limit: 10,
})
s.Require().NoError(err)
s.NotNil(result)
// Without vector client, falls back to filterSearch which returns observations
// All results should be observations (type is forced to "observations" in Changes)
for _, r := range result.Results {
s.Equal("observation", r.Type)
}
}
// TestHowItWorks_WithRealStores tests the HowItWorks method falls back to filterSearch.
func (s *SearchIntegrationSuite) TestHowItWorks_WithRealStores() {
ctx := context.Background()
// Seed test data
s.seedObservations(ctx)
// Test HowItWorks search (without vector client, falls back to filterSearch)
result, err := s.manager.HowItWorks(ctx, SearchParams{
Project: "test-project",
Query: "authentication",
Limit: 10,
})
s.Require().NoError(err)
s.NotNil(result)
// Without vector client, falls back to filterSearch which returns observations
// All results should be observations (type is forced to "observations" in HowItWorks)
for _, r := range result.Results {
s.Equal("observation", r.Type)
}
}
// TestObservationToResult tests observation to result conversion with full format.
func (s *SearchIntegrationSuite) TestObservationToResult_FullFormat() {
ctx := context.Background()
// Insert single observation
obs := &models.ParsedObservation{
Type: models.ObsTypeDiscovery,
Scope: models.ScopeProject,
Title: "Test Title",
Narrative: "Detailed narrative content for testing",
Concepts: []string{"testing", "content"},
}
id, _, err := s.obsStore.StoreObservation(ctx, "sdk-test", "test-project", obs, 1, 50)
s.Require().NoError(err)
// Retrieve and convert
retrieved, err := s.obsStore.GetObservationByID(ctx, id)
s.Require().NoError(err)
s.Require().NotNil(retrieved)
result := s.manager.observationToResult(retrieved, "full")
s.Equal("observation", result.Type)
s.Equal(id, result.ID)
s.Equal("Test Title", result.Title)
s.Equal("Detailed narrative content for testing", result.Content)
s.Equal("test-project", result.Project)
s.Equal("project", result.Scope)
s.NotNil(result.Metadata)
s.Equal("discovery", result.Metadata["obs_type"])
}
// TestObservationToResult_IndexFormat tests index format (no content).
func (s *SearchIntegrationSuite) TestObservationToResult_IndexFormat() {
ctx := context.Background()
obs := &models.ParsedObservation{
Type: models.ObsTypeBugfix,
Scope: models.ScopeGlobal,
Title: "Bug Fix Title",
Narrative: "This should not appear in index format",
}
id, _, err := s.obsStore.StoreObservation(ctx, "sdk-test", "test-project", obs, 1, 50)
s.Require().NoError(err)
retrieved, err := s.obsStore.GetObservationByID(ctx, id)
s.Require().NoError(err)
result := s.manager.observationToResult(retrieved, "index")
s.Equal("observation", result.Type)
s.Equal("Bug Fix Title", result.Title)
s.Empty(result.Content, "Index format should not include content")
s.Equal("global", result.Scope)
}
// TestSummaryToResult_FullFormat tests summary to result conversion.
func (s *SearchIntegrationSuite) TestSummaryToResult_FullFormat() {
ctx := context.Background()
sum := &models.ParsedSummary{
Request: "Implement new feature",
Learned: "Learned important lessons about testing",
}
id, _, err := s.sumStore.StoreSummary(ctx, "sdk-test", "test-project", sum, 1, 50)
s.Require().NoError(err)
// Retrieve via GetRecentSummaries since there's no GetByID
summaries, err := s.sumStore.GetRecentSummaries(ctx, "test-project", 10)
s.Require().NoError(err)
s.Require().NotEmpty(summaries)
var retrieved *models.SessionSummary
for _, s := range summaries {
if s.ID == id {
retrieved = s
break
}
}
s.Require().NotNil(retrieved)
result := s.manager.summaryToResult(retrieved, "full")
s.Equal("session", result.Type)
s.Equal(id, result.ID)
s.Contains(result.Title, "Implement new feature")
s.Equal("Learned important lessons about testing", result.Content)
s.Equal("test-project", result.Project)
}
// TestPromptToResult_FullFormat tests prompt to result conversion.
func (s *SearchIntegrationSuite) TestPromptToResult_FullFormat() {
// First create a session
ctx := context.Background()
sessionStore := sqlite.NewSessionStore(s.store)
_, err := sessionStore.CreateSDKSession(ctx, "sdk-prompt-test", "test-project", "initial prompt")
s.Require().NoError(err)
// Save a user prompt
promptID, err := s.prmStore.SaveUserPromptWithMatches(ctx, "sdk-prompt-test", 1, "Help me fix this authentication bug", 3)
s.Require().NoError(err)
// Retrieve prompts
prompts, err := s.prmStore.GetPromptsByIDs(ctx, []int64{promptID}, "date_desc", 10)
s.Require().NoError(err)
s.Require().NotEmpty(prompts)
result := s.manager.promptToResult(prompts[0], "full")
s.Equal("prompt", result.Type)
s.Equal(promptID, result.ID)
s.Contains(result.Title, "Help me fix")
s.Equal("Help me fix this authentication bug", result.Content)
}
// TestTruncate_TableDriven tests truncation with various inputs.
func TestTruncate_TableDriven(t *testing.T) {
tests := []struct {
name string
input string
maxLen int
expected string
}{
{"short_string", "hello", 10, "hello"},
{"exact_length", "hello", 5, "hello"},
{"long_string", "hello world", 5, "hello..."},
{"empty_string", "", 10, ""},
{"whitespace_only", " ", 10, ""},
{"with_leading_space", " hello ", 10, "hello"},
{"very_long", "this is a very long string that should be truncated", 20, "this is a very long ..."},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
result := truncate(tt.input, tt.maxLen)
assert.Equal(t, tt.expected, result)
})
}
}
// TestManagerWithNilStores tests that Manager handles nil stores gracefully.
func TestManagerWithNilStores(t *testing.T) {
m := NewManager(nil, nil, nil, nil)
assert.NotNil(t, m)
assert.Nil(t, m.observationStore)
assert.Nil(t, m.summaryStore)
assert.Nil(t, m.promptStore)
assert.Nil(t, m.vectorClient)
}
// TestSearchResultMetadataFields tests all metadata fields with real data.
func (s *SearchIntegrationSuite) TestSearchResultMetadataFields() {
ctx := context.Background()
obs := &models.ParsedObservation{
Type: models.ObsTypeDecision,
Scope: models.ScopeGlobal,
Title: "Architecture Decision",
Concepts: []string{"auth", "security"},
FilesRead: []string{"handler.go", "auth.go"},
}
id, _, err := s.obsStore.StoreObservation(ctx, "sdk-meta-test", "test-project", obs, 1, 50)
s.Require().NoError(err)
retrieved, err := s.obsStore.GetObservationByID(ctx, id)
s.Require().NoError(err)
result := s.manager.observationToResult(retrieved, "full")
// Check metadata fields
s.NotNil(result.Metadata)
s.Equal("decision", result.Metadata["obs_type"])
s.Equal("global", result.Metadata["scope"])
s.Equal("global", result.Scope)
}
// TestObservationToResult_AllFormats tests different format options.
func TestObservationToResult_AllFormats(t *testing.T) {
m := NewManager(nil, nil, nil, nil)
obs := &models.Observation{
ID: 1,
Project: "test",
Type: models.ObsTypeBugfix,
Scope: models.ScopeProject,
Title: sql.NullString{String: "Bug Fix Title", Valid: true},
Narrative: sql.NullString{String: "Detailed bug fix narrative", Valid: true},
CreatedAtEpoch: 1704067200000,
}
tests := []struct {
name string
format string
expectContent bool
}{
{"full_format", "full", true},
{"index_format", "index", false},
{"empty_format", "", false},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
result := m.observationToResult(obs, tt.format)
assert.Equal(t, "observation", result.Type)
assert.Equal(t, int64(1), result.ID)
if tt.expectContent {
assert.NotEmpty(t, result.Content)
}
})
}
}
// TestSummaryToResult_AllFormats tests different format options for summaries.
func TestSummaryToResult_AllFormats(t *testing.T) {
m := NewManager(nil, nil, nil, nil)
summary := &models.SessionSummary{
ID: 1,
Project: "test",
Request: sql.NullString{String: "Test request", Valid: true},
Learned: sql.NullString{String: "Test learned", Valid: true},
Completed: sql.NullString{String: "Test completed", Valid: true},
NextSteps: sql.NullString{String: "Test next steps", Valid: true},
CreatedAtEpoch: 1704067200000,
}
tests := []struct {
name string
format string
expectContent bool
}{
{"full_format", "full", true},
{"index_format", "index", false},
{"empty_format", "", false},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
result := m.summaryToResult(summary, tt.format)
assert.Equal(t, "session", result.Type)
assert.Equal(t, int64(1), result.ID)
if tt.expectContent {
assert.NotEmpty(t, result.Content)
}
})
}
}
+480
View File
@@ -4,6 +4,7 @@ package search
import (
"database/sql"
"testing"
"time"
"github.com/lukaszraczylo/claude-mnemonic/pkg/models"
"github.com/stretchr/testify/assert"
@@ -594,3 +595,482 @@ func TestSearchTypeMapping(t *testing.T) {
})
}
}
// TestFilterSearchWithObservations tests filter search when observations exist.
func TestFilterSearchWithObservations(t *testing.T) {
// Create mock observation
obs := &models.Observation{
ID: 1,
Project: "test-project",
Type: models.ObsTypeDiscovery,
Scope: models.ScopeProject,
Title: sql.NullString{String: "Test Title", Valid: true},
Narrative: sql.NullString{String: "Test narrative content", Valid: true},
CreatedAtEpoch: 1704067200000,
}
m := NewManager(nil, nil, nil, nil)
result := m.observationToResult(obs, "full")
assert.Equal(t, "observation", result.Type)
assert.Equal(t, int64(1), result.ID)
assert.Equal(t, "Test Title", result.Title)
assert.Equal(t, "Test narrative content", result.Content)
assert.Equal(t, "test-project", result.Project)
assert.Equal(t, "project", result.Scope)
}
// TestManagerStoreReferences tests that Manager stores references correctly.
func TestManagerStoreReferences(t *testing.T) {
m := NewManager(nil, nil, nil, nil)
assert.Nil(t, m.observationStore)
assert.Nil(t, m.summaryStore)
assert.Nil(t, m.promptStore)
assert.Nil(t, m.vectorClient)
}
// TestObservationToResultWithMetadata tests metadata inclusion in results.
func TestObservationToResultWithMetadata(t *testing.T) {
m := NewManager(nil, nil, nil, nil)
tests := []struct {
name string
obsType models.ObservationType
scope models.ObservationScope
}{
{"bugfix_project", models.ObsTypeBugfix, models.ScopeProject},
{"feature_global", models.ObsTypeFeature, models.ScopeGlobal},
{"discovery_project", models.ObsTypeDiscovery, models.ScopeProject},
{"decision_global", models.ObsTypeDecision, models.ScopeGlobal},
{"refactor_project", models.ObsTypeRefactor, models.ScopeProject},
{"change_global", models.ObsTypeChange, models.ScopeGlobal},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
obs := &models.Observation{
ID: 1,
Project: "test-project",
Type: tt.obsType,
Scope: tt.scope,
Title: sql.NullString{String: "Title", Valid: true},
CreatedAtEpoch: 1704067200000,
}
result := m.observationToResult(obs, "full")
assert.Equal(t, string(tt.obsType), result.Metadata["obs_type"])
assert.Equal(t, string(tt.scope), result.Metadata["scope"])
})
}
}
// TestSummaryToResultTruncation tests title truncation in summary results.
func TestSummaryToResultTruncation(t *testing.T) {
m := NewManager(nil, nil, nil, nil)
tests := []struct {
name string
request string
expectedLen int
shouldTrunc bool
}{
{"short_title", "Short request", 13, false},
{"exact_100", string(make([]byte, 100)), 103, true}, // 100 + "..."
{"over_100", string(make([]byte, 150)), 103, true},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
summary := &models.SessionSummary{
ID: 1,
Project: "test-project",
Request: sql.NullString{String: tt.request, Valid: true},
CreatedAtEpoch: 1704067200000,
}
result := m.summaryToResult(summary, "full")
if tt.shouldTrunc {
assert.LessOrEqual(t, len(result.Title), tt.expectedLen)
assert.True(t, len(result.Title) <= 103) // max 100 + "..."
} else {
assert.Equal(t, tt.request, result.Title)
}
})
}
}
// TestPromptToResultFormats tests prompt to result conversion with different formats.
func TestPromptToResultFormats(t *testing.T) {
m := NewManager(nil, nil, nil, nil)
prompt := &models.UserPromptWithSession{
UserPrompt: models.UserPrompt{
ID: 123,
PromptText: "What is the meaning of life?",
CreatedAtEpoch: 1704067200000,
},
Project: "my-project",
}
// Full format - includes content
fullResult := m.promptToResult(prompt, "full")
assert.Equal(t, "What is the meaning of life?", fullResult.Content)
// Index format - no content
indexResult := m.promptToResult(prompt, "index")
assert.Equal(t, "", indexResult.Content)
// Both should have same title
assert.Equal(t, fullResult.Title, indexResult.Title)
}
// TestSearchParamsDefaults tests that search params have proper defaults.
func TestSearchParamsDefaults(t *testing.T) {
tests := []struct {
name string
initialLimit int
initialOrder string
expectedLimit int
expectedOrder string
}{
{"zero_limit", 0, "", 20, "date_desc"},
{"negative_limit", -5, "", 20, "date_desc"},
{"over_100_limit", 150, "", 100, "date_desc"},
{"valid_limit_50", 50, "relevance", 50, "relevance"},
{"custom_order", 30, "date_asc", 30, "date_asc"},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
params := SearchParams{
Query: "test",
Project: "project",
Limit: tt.initialLimit,
OrderBy: tt.initialOrder,
}
// Simulate the normalization that happens in UnifiedSearch
if params.Limit <= 0 {
params.Limit = 20
}
if params.Limit > 100 {
params.Limit = 100
}
if params.OrderBy == "" {
params.OrderBy = "date_desc"
}
assert.Equal(t, tt.expectedLimit, params.Limit)
assert.Equal(t, tt.expectedOrder, params.OrderBy)
})
}
}
// TestTruncateEdgeCases tests edge cases for truncate function.
func TestTruncateEdgeCases(t *testing.T) {
tests := []struct {
name string
input string
maxLen int
expected string
}{
// Unicode strings - uses byte length so ensure maxLen accommodates full string
{"unicode_string_no_truncate", "日本語テスト", 20, "日本語テスト"},
{"mixed_unicode_no_truncate", "Hello世界", 15, "Hello世界"},
// ASCII truncation
{"ascii_truncate", "Hello World", 5, "Hello..."},
{"only_whitespace", " ", 10, ""},
{"tabs_and_newlines", "\t\n \t", 10, ""},
{"newlines_with_content", "\n\nhello\n\n", 10, "hello"},
{"zero_max_len", "hello", 0, "..."},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
result := truncate(tt.input, tt.maxLen)
assert.Equal(t, tt.expected, result)
})
}
}
// TestUnifiedSearchResultEmpty tests empty UnifiedSearchResult.
func TestUnifiedSearchResultEmpty(t *testing.T) {
result := UnifiedSearchResult{
Results: []SearchResult{},
TotalCount: 0,
Query: "",
}
assert.Empty(t, result.Results)
assert.Equal(t, 0, result.TotalCount)
assert.Equal(t, "", result.Query)
}
// TestSearchResultMetadata tests SearchResult metadata handling.
func TestSearchResultMetadata(t *testing.T) {
result := SearchResult{
Type: "observation",
ID: 1,
Metadata: map[string]interface{}{
"obs_type": "discovery",
"scope": "project",
"count": 42,
"enabled": true,
},
}
assert.Equal(t, "discovery", result.Metadata["obs_type"])
assert.Equal(t, "project", result.Metadata["scope"])
assert.Equal(t, 42, result.Metadata["count"])
assert.Equal(t, true, result.Metadata["enabled"])
}
// TestSearchResultTypes tests all search result types.
func TestSearchResultTypes(t *testing.T) {
types := []string{"observation", "session", "prompt"}
for _, typ := range types {
t.Run(typ, func(t *testing.T) {
result := SearchResult{
Type: typ,
ID: 1,
Project: "test",
CreatedAt: time.Now().UnixMilli(),
}
assert.Equal(t, typ, result.Type)
})
}
}
// TestSearchParamsAllFields tests SearchParams with all fields populated.
func TestSearchParamsAllFields(t *testing.T) {
params := SearchParams{
Query: "authentication bug",
Type: "observations",
Project: "my-project",
ObsType: "bugfix",
Concepts: "security,auth",
Files: "handler.go,auth.go",
DateStart: 1700000000000,
DateEnd: 1700100000000,
OrderBy: "relevance",
Limit: 25,
Offset: 10,
Format: "full",
Scope: "project",
IncludeGlobal: true,
}
assert.Equal(t, "authentication bug", params.Query)
assert.Equal(t, "observations", params.Type)
assert.Equal(t, "my-project", params.Project)
assert.Equal(t, "bugfix", params.ObsType)
assert.Equal(t, "security,auth", params.Concepts)
assert.Equal(t, "handler.go,auth.go", params.Files)
assert.Equal(t, int64(1700000000000), params.DateStart)
assert.Equal(t, int64(1700100000000), params.DateEnd)
assert.Equal(t, "relevance", params.OrderBy)
assert.Equal(t, 25, params.Limit)
assert.Equal(t, 10, params.Offset)
assert.Equal(t, "full", params.Format)
assert.Equal(t, "project", params.Scope)
assert.True(t, params.IncludeGlobal)
}
// TestObservationToResultWithNullFields tests handling of null fields.
func TestObservationToResultWithNullFields(t *testing.T) {
m := NewManager(nil, nil, nil, nil)
obs := &models.Observation{
ID: 1,
Project: "test-project",
Type: models.ObsTypeDiscovery,
Scope: models.ScopeProject,
Title: sql.NullString{Valid: false},
Narrative: sql.NullString{Valid: false},
CreatedAtEpoch: 1704067200000,
}
result := m.observationToResult(obs, "full")
assert.Equal(t, "", result.Title)
assert.Equal(t, "", result.Content)
}
// TestSummaryToResultWithNullFields tests handling of null fields in summary.
func TestSummaryToResultWithNullFields(t *testing.T) {
m := NewManager(nil, nil, nil, nil)
summary := &models.SessionSummary{
ID: 1,
Project: "test-project",
Request: sql.NullString{Valid: false},
Learned: sql.NullString{Valid: false},
CreatedAtEpoch: 1704067200000,
}
result := m.summaryToResult(summary, "full")
assert.Equal(t, "", result.Title)
assert.Equal(t, "", result.Content)
}
// TestSearchParams_LimitValues tests limit parameter handling values.
func TestSearchParams_LimitValues(t *testing.T) {
tests := []struct {
name string
inputLimit int
expectedValid bool
}{
{"zero_limit", 0, true},
{"negative_limit", -5, true},
{"normal_limit", 20, true},
{"max_limit", 100, true},
{"over_limit", 200, true},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
params := SearchParams{
Query: "test",
Project: "test",
Limit: tt.inputLimit,
}
assert.NotNil(t, params)
assert.Equal(t, tt.inputLimit, params.Limit)
})
}
}
// TestSearchParams_OrderByValues tests order by parameter values.
func TestSearchParams_OrderByValues(t *testing.T) {
validOrders := []string{"relevance", "date_desc", "date_asc", ""}
for _, order := range validOrders {
t.Run("order_"+order, func(t *testing.T) {
params := SearchParams{
Query: "test",
Project: "test",
OrderBy: order,
}
assert.Equal(t, order, params.OrderBy)
})
}
}
// TestSearchParams_TypeValues tests type parameter values.
func TestSearchParams_TypeValues(t *testing.T) {
validTypes := []string{"observations", "sessions", "prompts", ""}
for _, typ := range validTypes {
t.Run("type_"+typ, func(t *testing.T) {
params := SearchParams{
Query: "test",
Project: "test",
Type: typ,
}
assert.Equal(t, typ, params.Type)
})
}
}
// TestSearchParams_ScopeValues tests scope parameter values.
func TestSearchParams_ScopeValues(t *testing.T) {
validScopes := []string{"project", "global", ""}
for _, scope := range validScopes {
t.Run("scope_"+scope, func(t *testing.T) {
params := SearchParams{
Query: "test",
Project: "test",
Scope: scope,
}
assert.Equal(t, scope, params.Scope)
})
}
}
// TestSearchParams_FormatValues tests format parameter values.
func TestSearchParams_FormatValues(t *testing.T) {
validFormats := []string{"index", "full", ""}
for _, format := range validFormats {
t.Run("format_"+format, func(t *testing.T) {
params := SearchParams{
Query: "test",
Project: "test",
Format: format,
}
assert.Equal(t, format, params.Format)
})
}
}
// TestUnifiedSearchResult_MultipleResults tests result with multiple items.
func TestUnifiedSearchResult_MultipleResults(t *testing.T) {
results := []SearchResult{
{Type: "observation", ID: 1, Title: "First", Project: "test"},
{Type: "session", ID: 2, Title: "Second", Project: "test"},
{Type: "prompt", ID: 3, Title: "Third", Project: "test"},
}
result := UnifiedSearchResult{
Results: results,
TotalCount: 3,
Query: "test query",
}
assert.Len(t, result.Results, 3)
assert.Equal(t, 3, result.TotalCount)
assert.Equal(t, "observation", result.Results[0].Type)
assert.Equal(t, "session", result.Results[1].Type)
assert.Equal(t, "prompt", result.Results[2].Type)
}
// TestSearchResult_Metadata tests metadata handling in SearchResult.
func TestSearchResult_Metadata(t *testing.T) {
metadata := map[string]interface{}{
"obs_type": "discovery",
"concepts": []string{"auth", "security"},
"files_count": 5,
"is_important": true,
}
result := SearchResult{
Type: "observation",
ID: 1,
Metadata: metadata,
}
assert.Equal(t, "discovery", result.Metadata["obs_type"])
assert.Equal(t, 5, result.Metadata["files_count"])
assert.Equal(t, true, result.Metadata["is_important"])
}
// TestSearchResult_Scores tests score handling in SearchResult.
func TestSearchResult_Scores(t *testing.T) {
tests := []struct {
name string
score float64
}{
{"perfect_score", 1.0},
{"high_score", 0.95},
{"medium_score", 0.5},
{"low_score", 0.1},
{"zero_score", 0.0},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
result := SearchResult{
Type: "observation",
ID: 1,
Score: tt.score,
}
assert.Equal(t, tt.score, result.Score)
})
}
}
+502
View File
@@ -0,0 +1,502 @@
package sqlitevec
import (
"context"
"database/sql"
"os"
"path/filepath"
"testing"
sqlite_vec "github.com/asg017/sqlite-vec-go-bindings/cgo"
"github.com/lukaszraczylo/claude-mnemonic/internal/embedding"
_ "github.com/mattn/go-sqlite3"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
// testDB creates a test SQLite database with the vectors table.
func testDB(t *testing.T) (*sql.DB, func()) {
t.Helper()
// Create temp directory
tmpDir, err := os.MkdirTemp("", "sqlitevec-test-*")
require.NoError(t, err)
dbPath := filepath.Join(tmpDir, "test.db")
db, err := sql.Open("sqlite3", dbPath)
require.NoError(t, err)
// Enable sqlite-vec
sqlite_vec.Auto()
// Create vectors table (matches production schema)
_, err = db.Exec(`
CREATE VIRTUAL TABLE IF NOT EXISTS vectors USING vec0(
doc_id TEXT PRIMARY KEY,
embedding float[384],
sqlite_id INTEGER,
doc_type TEXT,
field_type TEXT,
project TEXT,
scope TEXT
)
`)
require.NoError(t, err)
cleanup := func() {
db.Close()
os.RemoveAll(tmpDir)
}
return db, cleanup
}
// testEmbeddingService creates a test embedding service.
func testEmbeddingService(t *testing.T) (*embedding.Service, func()) {
t.Helper()
svc, err := embedding.NewService()
require.NoError(t, err)
cleanup := func() {
svc.Close()
}
return svc, cleanup
}
func TestNewClient_Success(t *testing.T) {
db, dbCleanup := testDB(t)
defer dbCleanup()
embedSvc, embedCleanup := testEmbeddingService(t)
defer embedCleanup()
client, err := NewClient(Config{DB: db}, embedSvc)
require.NoError(t, err)
assert.NotNil(t, client)
}
func TestNewClient_NilDB(t *testing.T) {
embedSvc, embedCleanup := testEmbeddingService(t)
defer embedCleanup()
client, err := NewClient(Config{DB: nil}, embedSvc)
assert.Error(t, err)
assert.Nil(t, client)
assert.Contains(t, err.Error(), "database connection required")
}
func TestNewClient_NilEmbedding(t *testing.T) {
db, dbCleanup := testDB(t)
defer dbCleanup()
client, err := NewClient(Config{DB: db}, nil)
assert.Error(t, err)
assert.Nil(t, client)
assert.Contains(t, err.Error(), "embedding service required")
}
func TestClient_AddDocuments_Empty(t *testing.T) {
db, dbCleanup := testDB(t)
defer dbCleanup()
embedSvc, embedCleanup := testEmbeddingService(t)
defer embedCleanup()
client, err := NewClient(Config{DB: db}, embedSvc)
require.NoError(t, err)
err = client.AddDocuments(context.Background(), []Document{})
require.NoError(t, err)
}
func TestClient_AddDocuments_Single(t *testing.T) {
db, dbCleanup := testDB(t)
defer dbCleanup()
embedSvc, embedCleanup := testEmbeddingService(t)
defer embedCleanup()
client, err := NewClient(Config{DB: db}, embedSvc)
require.NoError(t, err)
docs := []Document{
{
ID: "obs-1-title",
Content: "This is a test observation about authentication.",
Metadata: map[string]any{
"sqlite_id": int64(1),
"doc_type": "observation",
"field_type": "title",
"project": "test-project",
"scope": "project",
},
},
}
err = client.AddDocuments(context.Background(), docs)
require.NoError(t, err)
// Verify document was inserted
var count int
err = db.QueryRow("SELECT COUNT(*) FROM vectors WHERE doc_id = ?", "obs-1-title").Scan(&count)
require.NoError(t, err)
assert.Equal(t, 1, count)
}
func TestClient_AddDocuments_Multiple(t *testing.T) {
db, dbCleanup := testDB(t)
defer dbCleanup()
embedSvc, embedCleanup := testEmbeddingService(t)
defer embedCleanup()
client, err := NewClient(Config{DB: db}, embedSvc)
require.NoError(t, err)
docs := []Document{
{
ID: "obs-1-title",
Content: "Authentication flow implementation.",
Metadata: map[string]any{
"sqlite_id": int64(1),
"doc_type": "observation",
"field_type": "title",
"project": "test-project",
"scope": "project",
},
},
{
ID: "obs-1-narrative",
Content: "We implemented JWT-based authentication.",
Metadata: map[string]any{
"sqlite_id": int64(1),
"doc_type": "observation",
"field_type": "narrative",
"project": "test-project",
"scope": "project",
},
},
{
ID: "obs-2-title",
Content: "Database optimization.",
Metadata: map[string]any{
"sqlite_id": int64(2),
"doc_type": "observation",
"field_type": "title",
"project": "test-project",
"scope": "global",
},
},
}
err = client.AddDocuments(context.Background(), docs)
require.NoError(t, err)
// Verify all documents were inserted
var count int
err = db.QueryRow("SELECT COUNT(*) FROM vectors").Scan(&count)
require.NoError(t, err)
assert.Equal(t, 3, count)
}
func TestClient_DeleteDocuments_Empty(t *testing.T) {
db, dbCleanup := testDB(t)
defer dbCleanup()
embedSvc, embedCleanup := testEmbeddingService(t)
defer embedCleanup()
client, err := NewClient(Config{DB: db}, embedSvc)
require.NoError(t, err)
err = client.DeleteDocuments(context.Background(), []string{})
require.NoError(t, err)
}
func TestClient_DeleteDocuments_Existing(t *testing.T) {
db, dbCleanup := testDB(t)
defer dbCleanup()
embedSvc, embedCleanup := testEmbeddingService(t)
defer embedCleanup()
client, err := NewClient(Config{DB: db}, embedSvc)
require.NoError(t, err)
// Add documents first
docs := []Document{
{
ID: "doc-1",
Content: "First document.",
Metadata: map[string]any{
"sqlite_id": int64(1),
"doc_type": "observation",
},
},
{
ID: "doc-2",
Content: "Second document.",
Metadata: map[string]any{
"sqlite_id": int64(2),
"doc_type": "observation",
},
},
}
err = client.AddDocuments(context.Background(), docs)
require.NoError(t, err)
// Delete one document
err = client.DeleteDocuments(context.Background(), []string{"doc-1"})
require.NoError(t, err)
// Verify only one remains
var count int
err = db.QueryRow("SELECT COUNT(*) FROM vectors").Scan(&count)
require.NoError(t, err)
assert.Equal(t, 1, count)
}
func TestClient_Query_Basic(t *testing.T) {
db, dbCleanup := testDB(t)
defer dbCleanup()
embedSvc, embedCleanup := testEmbeddingService(t)
defer embedCleanup()
client, err := NewClient(Config{DB: db}, embedSvc)
require.NoError(t, err)
// Add some test documents
docs := []Document{
{
ID: "obs-1",
Content: "Authentication and login security implementation.",
Metadata: map[string]any{
"sqlite_id": int64(1),
"doc_type": "observation",
"project": "test-project",
"scope": "project",
},
},
{
ID: "obs-2",
Content: "Database query optimization techniques.",
Metadata: map[string]any{
"sqlite_id": int64(2),
"doc_type": "observation",
"project": "test-project",
"scope": "project",
},
},
}
err = client.AddDocuments(context.Background(), docs)
require.NoError(t, err)
// Query for authentication-related content
results, err := client.Query(context.Background(), "login authentication", 10, nil)
require.NoError(t, err)
assert.NotEmpty(t, results)
assert.LessOrEqual(t, len(results), 10)
// First result should be the authentication document (higher similarity)
assert.Equal(t, "obs-1", results[0].ID)
}
func TestClient_Query_WithDocTypeFilter(t *testing.T) {
db, dbCleanup := testDB(t)
defer dbCleanup()
embedSvc, embedCleanup := testEmbeddingService(t)
defer embedCleanup()
client, err := NewClient(Config{DB: db}, embedSvc)
require.NoError(t, err)
// Add documents of different types
docs := []Document{
{
ID: "obs-1",
Content: "Test content for observation.",
Metadata: map[string]any{
"sqlite_id": int64(1),
"doc_type": "observation",
"project": "test-project",
},
},
{
ID: "summary-1",
Content: "Test content for summary.",
Metadata: map[string]any{
"sqlite_id": int64(10),
"doc_type": "session_summary",
"project": "test-project",
},
},
}
err = client.AddDocuments(context.Background(), docs)
require.NoError(t, err)
// Query with doc_type filter
where := map[string]any{"doc_type": "observation"}
results, err := client.Query(context.Background(), "test content", 10, where)
require.NoError(t, err)
// Should only return observation documents
for _, r := range results {
docType, _ := r.Metadata["doc_type"].(string)
assert.Equal(t, "observation", docType)
}
}
func TestClient_Query_WithProjectFilter(t *testing.T) {
db, dbCleanup := testDB(t)
defer dbCleanup()
embedSvc, embedCleanup := testEmbeddingService(t)
defer embedCleanup()
client, err := NewClient(Config{DB: db}, embedSvc)
require.NoError(t, err)
// Add documents from different projects
docs := []Document{
{
ID: "obs-1",
Content: "Project A authentication content.",
Metadata: map[string]any{
"sqlite_id": int64(1),
"doc_type": "observation",
"project": "project-a",
"scope": "project",
},
},
{
ID: "obs-2",
Content: "Project B database content.",
Metadata: map[string]any{
"sqlite_id": int64(2),
"doc_type": "observation",
"project": "project-b",
"scope": "project",
},
},
{
ID: "obs-3",
Content: "Global security best practices.",
Metadata: map[string]any{
"sqlite_id": int64(3),
"doc_type": "observation",
"project": "project-b",
"scope": "global",
},
},
}
err = client.AddDocuments(context.Background(), docs)
require.NoError(t, err)
// Query without project filter to verify all docs are there
results, err := client.Query(context.Background(), "authentication security", 10, nil)
require.NoError(t, err)
assert.NotEmpty(t, results, "Should find some results")
}
func TestClient_IsConnected(t *testing.T) {
db, dbCleanup := testDB(t)
defer dbCleanup()
embedSvc, embedCleanup := testEmbeddingService(t)
defer embedCleanup()
client, err := NewClient(Config{DB: db}, embedSvc)
require.NoError(t, err)
assert.True(t, client.IsConnected())
}
func TestClient_Close(t *testing.T) {
db, dbCleanup := testDB(t)
defer dbCleanup()
embedSvc, embedCleanup := testEmbeddingService(t)
defer embedCleanup()
client, err := NewClient(Config{DB: db}, embedSvc)
require.NoError(t, err)
err = client.Close()
require.NoError(t, err)
}
func TestConfig_Fields(t *testing.T) {
db, dbCleanup := testDB(t)
defer dbCleanup()
cfg := Config{DB: db}
assert.Equal(t, db, cfg.DB)
}
func TestClient_UpdateDocument_DeleteThenAdd(t *testing.T) {
db, dbCleanup := testDB(t)
defer dbCleanup()
embedSvc, embedCleanup := testEmbeddingService(t)
defer embedCleanup()
client, err := NewClient(Config{DB: db}, embedSvc)
require.NoError(t, err)
// Add document
docs1 := []Document{
{
ID: "doc-1",
Content: "Original content.",
Metadata: map[string]any{
"sqlite_id": int64(1),
"doc_type": "observation",
},
},
}
err = client.AddDocuments(context.Background(), docs1)
require.NoError(t, err)
// Delete then add with new content (proper update pattern)
err = client.DeleteDocuments(context.Background(), []string{"doc-1"})
require.NoError(t, err)
docs2 := []Document{
{
ID: "doc-1",
Content: "Updated content.",
Metadata: map[string]any{
"sqlite_id": int64(1),
"doc_type": "observation",
},
},
}
err = client.AddDocuments(context.Background(), docs2)
require.NoError(t, err)
// Should have exactly 1 document
var count int
err = db.QueryRow("SELECT COUNT(*) FROM vectors WHERE doc_id = ?", "doc-1").Scan(&count)
require.NoError(t, err)
assert.Equal(t, 1, count)
}
func TestClient_DeleteDocuments_NonExistent(t *testing.T) {
db, dbCleanup := testDB(t)
defer dbCleanup()
embedSvc, embedCleanup := testEmbeddingService(t)
defer embedCleanup()
client, err := NewClient(Config{DB: db}, embedSvc)
require.NoError(t, err)
// Deleting non-existent document should not error
err = client.DeleteDocuments(context.Background(), []string{"non-existent-id"})
require.NoError(t, err)
}
+574
View File
@@ -0,0 +1,574 @@
package sqlitevec
import (
"testing"
"github.com/stretchr/testify/assert"
)
func TestDocTypes(t *testing.T) {
assert.Equal(t, DocType("observation"), DocTypeObservation)
assert.Equal(t, DocType("session_summary"), DocTypeSessionSummary)
assert.Equal(t, DocType("user_prompt"), DocTypeUserPrompt)
}
func TestDocument_Fields(t *testing.T) {
doc := Document{
ID: "doc-123",
Content: "test content",
Metadata: map[string]any{
"key": "value",
},
}
assert.Equal(t, "doc-123", doc.ID)
assert.Equal(t, "test content", doc.Content)
assert.Equal(t, "value", doc.Metadata["key"])
}
func TestQueryResult_Fields(t *testing.T) {
result := QueryResult{
ID: "result-123",
Distance: 0.5,
Metadata: map[string]any{
"sqlite_id": float64(42),
},
}
assert.Equal(t, "result-123", result.ID)
assert.Equal(t, 0.5, result.Distance)
assert.Equal(t, float64(42), result.Metadata["sqlite_id"])
}
func TestBuildWhereFilter(t *testing.T) {
tests := []struct {
name string
docType DocType
project string
expected map[string]interface{}
}{
{
name: "empty_filters",
docType: "",
project: "",
expected: map[string]interface{}{},
},
{
name: "doc_type_only",
docType: DocTypeObservation,
project: "",
expected: map[string]interface{}{
"doc_type": "observation",
},
},
{
name: "project_only",
docType: "",
project: "my-project",
expected: map[string]interface{}{
"project": "my-project",
},
},
{
name: "both_filters",
docType: DocTypeSessionSummary,
project: "test-project",
expected: map[string]interface{}{
"doc_type": "session_summary",
"project": "test-project",
},
},
{
name: "user_prompt_type",
docType: DocTypeUserPrompt,
project: "prompt-project",
expected: map[string]interface{}{
"doc_type": "user_prompt",
"project": "prompt-project",
},
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
result := BuildWhereFilter(tt.docType, tt.project)
assert.Equal(t, tt.expected, result)
})
}
}
func TestExtractIDsByDocType_Empty(t *testing.T) {
results := []QueryResult{}
ids := ExtractIDsByDocType(results)
assert.Empty(t, ids.ObservationIDs)
assert.Empty(t, ids.SummaryIDs)
assert.Empty(t, ids.PromptIDs)
}
func TestExtractIDsByDocType_AllTypes(t *testing.T) {
results := []QueryResult{
{
ID: "obs-1",
Distance: 0.1,
Metadata: map[string]any{
"sqlite_id": float64(1),
"doc_type": "observation",
},
},
{
ID: "obs-2",
Distance: 0.2,
Metadata: map[string]any{
"sqlite_id": float64(2),
"doc_type": "observation",
},
},
{
ID: "summary-1",
Distance: 0.3,
Metadata: map[string]any{
"sqlite_id": float64(10),
"doc_type": "session_summary",
},
},
{
ID: "prompt-1",
Distance: 0.4,
Metadata: map[string]any{
"sqlite_id": float64(20),
"doc_type": "user_prompt",
},
},
}
ids := ExtractIDsByDocType(results)
assert.Equal(t, []int64{1, 2}, ids.ObservationIDs)
assert.Equal(t, []int64{10}, ids.SummaryIDs)
assert.Equal(t, []int64{20}, ids.PromptIDs)
}
func TestExtractIDsByDocType_Deduplication(t *testing.T) {
results := []QueryResult{
{
ID: "obs-1-field1",
Distance: 0.1,
Metadata: map[string]any{
"sqlite_id": float64(1),
"doc_type": "observation",
},
},
{
ID: "obs-1-field2",
Distance: 0.2,
Metadata: map[string]any{
"sqlite_id": float64(1), // Same ID, different field
"doc_type": "observation",
},
},
{
ID: "obs-2",
Distance: 0.3,
Metadata: map[string]any{
"sqlite_id": float64(2),
"doc_type": "observation",
},
},
}
ids := ExtractIDsByDocType(results)
assert.Equal(t, []int64{1, 2}, ids.ObservationIDs) // Should be deduplicated
}
func TestExtractIDsByDocType_Int64Fallback(t *testing.T) {
results := []QueryResult{
{
ID: "obs-1",
Distance: 0.1,
Metadata: map[string]any{
"sqlite_id": int64(42), // int64 instead of float64
"doc_type": "observation",
},
},
}
ids := ExtractIDsByDocType(results)
assert.Equal(t, []int64{42}, ids.ObservationIDs)
}
func TestExtractIDsByDocType_MissingSqliteID(t *testing.T) {
results := []QueryResult{
{
ID: "obs-1",
Distance: 0.1,
Metadata: map[string]any{
"doc_type": "observation",
// Missing sqlite_id
},
},
}
ids := ExtractIDsByDocType(results)
assert.Empty(t, ids.ObservationIDs)
}
func TestExtractIDsByDocType_UnknownType(t *testing.T) {
results := []QueryResult{
{
ID: "unknown-1",
Distance: 0.1,
Metadata: map[string]any{
"sqlite_id": float64(1),
"doc_type": "unknown_type",
},
},
}
ids := ExtractIDsByDocType(results)
// Should not be added to any category
assert.Empty(t, ids.ObservationIDs)
assert.Empty(t, ids.SummaryIDs)
assert.Empty(t, ids.PromptIDs)
}
func TestExtractObservationIDs_NoFilter(t *testing.T) {
results := []QueryResult{
{
ID: "obs-1",
Metadata: map[string]any{
"sqlite_id": float64(1),
"doc_type": "observation",
"project": "proj-a",
},
},
{
ID: "obs-2",
Metadata: map[string]any{
"sqlite_id": float64(2),
"doc_type": "observation",
"project": "proj-b",
},
},
{
ID: "summary-1",
Metadata: map[string]any{
"sqlite_id": float64(10),
"doc_type": "session_summary",
"project": "proj-a",
},
},
}
ids := ExtractObservationIDs(results, "")
assert.Equal(t, []int64{1, 2}, ids)
}
func TestExtractObservationIDs_WithProjectFilter(t *testing.T) {
results := []QueryResult{
{
ID: "obs-1",
Metadata: map[string]any{
"sqlite_id": float64(1),
"doc_type": "observation",
"project": "proj-a",
"scope": "project",
},
},
{
ID: "obs-2",
Metadata: map[string]any{
"sqlite_id": float64(2),
"doc_type": "observation",
"project": "proj-b",
"scope": "project",
},
},
{
ID: "obs-global",
Metadata: map[string]any{
"sqlite_id": float64(3),
"doc_type": "observation",
"project": "proj-b",
"scope": "global",
},
},
}
ids := ExtractObservationIDs(results, "proj-a")
// Should include proj-a and global scope observations
assert.Equal(t, []int64{1, 3}, ids)
}
func TestExtractObservationIDs_Deduplication(t *testing.T) {
results := []QueryResult{
{
ID: "obs-1-field1",
Metadata: map[string]any{
"sqlite_id": float64(1),
"doc_type": "observation",
},
},
{
ID: "obs-1-field2",
Metadata: map[string]any{
"sqlite_id": float64(1), // Same ID
"doc_type": "observation",
},
},
}
ids := ExtractObservationIDs(results, "")
assert.Equal(t, []int64{1}, ids)
}
func TestExtractSummaryIDs_NoFilter(t *testing.T) {
results := []QueryResult{
{
ID: "summary-1",
Metadata: map[string]any{
"sqlite_id": float64(10),
"doc_type": "session_summary",
"project": "proj-a",
},
},
{
ID: "summary-2",
Metadata: map[string]any{
"sqlite_id": float64(20),
"doc_type": "session_summary",
"project": "proj-b",
},
},
{
ID: "obs-1",
Metadata: map[string]any{
"sqlite_id": float64(1),
"doc_type": "observation",
},
},
}
ids := ExtractSummaryIDs(results, "")
assert.Equal(t, []int64{10, 20}, ids)
}
func TestExtractSummaryIDs_WithProjectFilter(t *testing.T) {
results := []QueryResult{
{
ID: "summary-1",
Metadata: map[string]any{
"sqlite_id": float64(10),
"doc_type": "session_summary",
"project": "proj-a",
},
},
{
ID: "summary-2",
Metadata: map[string]any{
"sqlite_id": float64(20),
"doc_type": "session_summary",
"project": "proj-b",
},
},
}
ids := ExtractSummaryIDs(results, "proj-a")
assert.Equal(t, []int64{10}, ids)
}
func TestExtractPromptIDs_NoFilter(t *testing.T) {
results := []QueryResult{
{
ID: "prompt-1",
Metadata: map[string]any{
"sqlite_id": float64(100),
"doc_type": "user_prompt",
"project": "proj-a",
},
},
{
ID: "prompt-2",
Metadata: map[string]any{
"sqlite_id": float64(200),
"doc_type": "user_prompt",
"project": "proj-b",
},
},
}
ids := ExtractPromptIDs(results, "")
assert.Equal(t, []int64{100, 200}, ids)
}
func TestExtractPromptIDs_WithProjectFilter(t *testing.T) {
results := []QueryResult{
{
ID: "prompt-1",
Metadata: map[string]any{
"sqlite_id": float64(100),
"doc_type": "user_prompt",
"project": "proj-a",
},
},
{
ID: "prompt-2",
Metadata: map[string]any{
"sqlite_id": float64(200),
"doc_type": "user_prompt",
"project": "proj-b",
},
},
}
ids := ExtractPromptIDs(results, "proj-b")
assert.Equal(t, []int64{200}, ids)
}
func TestCopyMetadata(t *testing.T) {
base := map[string]any{
"key1": "value1",
"key2": 42,
}
result := copyMetadata(base, "key3", "value3")
// Original should be unchanged
assert.Len(t, base, 2)
// Result should have new key
assert.Len(t, result, 3)
assert.Equal(t, "value1", result["key1"])
assert.Equal(t, 42, result["key2"])
assert.Equal(t, "value3", result["key3"])
}
func TestCopyMetadataMulti(t *testing.T) {
base := map[string]any{
"key1": "value1",
}
extra := map[string]any{
"key2": "value2",
"key3": "value3",
}
result := copyMetadataMulti(base, extra)
assert.Len(t, result, 3)
assert.Equal(t, "value1", result["key1"])
assert.Equal(t, "value2", result["key2"])
assert.Equal(t, "value3", result["key3"])
}
func TestJoinStrings(t *testing.T) {
tests := []struct {
name string
strs []string
sep string
expected string
}{
{
name: "empty_slice",
strs: []string{},
sep: ", ",
expected: "",
},
{
name: "single_element",
strs: []string{"one"},
sep: ", ",
expected: "one",
},
{
name: "multiple_elements",
strs: []string{"one", "two", "three"},
sep: ", ",
expected: "one, two, three",
},
{
name: "different_separator",
strs: []string{"a", "b", "c"},
sep: "-",
expected: "a-b-c",
},
{
name: "empty_separator",
strs: []string{"a", "b", "c"},
sep: "",
expected: "abc",
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
result := joinStrings(tt.strs, tt.sep)
assert.Equal(t, tt.expected, result)
})
}
}
func TestTruncateString(t *testing.T) {
tests := []struct {
name string
input string
maxLen int
expected string
}{
{
name: "shorter_than_max",
input: "hello",
maxLen: 10,
expected: "hello",
},
{
name: "equal_to_max",
input: "hello",
maxLen: 5,
expected: "hello",
},
{
name: "longer_than_max",
input: "hello world",
maxLen: 5,
expected: "hello...",
},
{
name: "empty_string",
input: "",
maxLen: 5,
expected: "",
},
{
name: "zero_max_length",
input: "hello",
maxLen: 0,
expected: "...",
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
result := truncateString(tt.input, tt.maxLen)
assert.Equal(t, tt.expected, result)
})
}
}
func TestExtractedIDs_Empty(t *testing.T) {
ids := &ExtractedIDs{}
assert.Nil(t, ids.ObservationIDs)
assert.Nil(t, ids.SummaryIDs)
assert.Nil(t, ids.PromptIDs)
}
+846
View File
@@ -1279,3 +1279,849 @@ func TestObservationRequest_Fields(t *testing.T) {
assert.Equal(t, "Read", req.ToolName)
assert.Equal(t, "/home/user/project", req.CWD)
}
// TestRetrievalStats_Fields tests RetrievalStats struct fields.
func TestRetrievalStats_Fields(t *testing.T) {
stats := RetrievalStats{
TotalRequests: 100,
ObservationsServed: 500,
VerifiedStale: 10,
DeletedInvalid: 5,
SearchRequests: 80,
ContextInjections: 20,
}
assert.Equal(t, int64(100), stats.TotalRequests)
assert.Equal(t, int64(500), stats.ObservationsServed)
assert.Equal(t, int64(10), stats.VerifiedStale)
assert.Equal(t, int64(5), stats.DeletedInvalid)
assert.Equal(t, int64(80), stats.SearchRequests)
assert.Equal(t, int64(20), stats.ContextInjections)
}
// TestServiceConstants tests service configuration constants.
func TestServiceConstants(t *testing.T) {
assert.Equal(t, 30*time.Second, DefaultHTTPTimeout)
assert.Equal(t, 50*time.Millisecond, ReadyPollInterval)
assert.Equal(t, 100, StaleQueueSize)
assert.Equal(t, 2*time.Second, QueueProcessInterval)
}
// TestClusterObservations_Empty tests clustering with empty slice.
func TestClusterObservations_Empty(t *testing.T) {
observations := []*models.Observation{}
clustered := clusterObservations(observations, 0.4)
assert.Empty(t, clustered)
}
// TestClusterObservations_Single tests clustering with single observation.
func TestClusterObservations_Single(t *testing.T) {
observations := []*models.Observation{
{
ID: 1,
Title: sql.NullString{String: "Test observation", Valid: true},
Narrative: sql.NullString{String: "Test content", Valid: true},
},
}
clustered := clusterObservations(observations, 0.4)
assert.Len(t, clustered, 1)
}
// TestClusterObservations_VeryDifferent tests clustering with very different observations.
func TestClusterObservations_VeryDifferent(t *testing.T) {
observations := []*models.Observation{
{
ID: 1,
Title: sql.NullString{String: "Database optimization", Valid: true},
Narrative: sql.NullString{String: "PostgreSQL index tuning", Valid: true},
},
{
ID: 2,
Title: sql.NullString{String: "Authentication flow", Valid: true},
Narrative: sql.NullString{String: "JWT token validation", Valid: true},
},
{
ID: 3,
Title: sql.NullString{String: "Logging setup", Valid: true},
Narrative: sql.NullString{String: "Zerolog configuration", Valid: true},
},
}
clustered := clusterObservations(observations, 0.4)
// Very different observations should not be clustered together
assert.GreaterOrEqual(t, len(clustered), 1)
}
// TestHandleContextInject_WithLimit tests context inject with custom limit.
func TestHandleContextInject_WithLimit(t *testing.T) {
svc, cleanup := testService(t)
defer cleanup()
project := "inject-limit-test"
// Create some observations
for i := 0; i < 10; i++ {
createTestObservation(t, svc.observationStore, project,
"Observation "+strconv.Itoa(i),
"Content "+strconv.Itoa(i),
[]string{"test-" + strconv.Itoa(i)})
time.Sleep(time.Millisecond)
}
req := httptest.NewRequest(http.MethodGet, "/api/context/inject?project="+project+"&limit=5", nil)
rec := httptest.NewRecorder()
svc.router.ServeHTTP(rec, req)
assert.Equal(t, http.StatusOK, rec.Code)
var response map[string]interface{}
err := json.Unmarshal(rec.Body.Bytes(), &response)
require.NoError(t, err)
observations, ok := response["observations"].([]interface{})
require.True(t, ok)
assert.LessOrEqual(t, len(observations), 5)
}
// TestHandleGetObservations tests getting observations list.
func TestHandleGetObservations(t *testing.T) {
svc, cleanup := testService(t)
defer cleanup()
// Create some observations
createTestObservation(t, svc.observationStore, "test-project",
"Test Observation 1",
"Test content 1",
[]string{"test"})
createTestObservation(t, svc.observationStore, "test-project",
"Test Observation 2",
"Test content 2",
[]string{"test"})
req := httptest.NewRequest(http.MethodGet, "/api/observations?project=test-project", nil)
rec := httptest.NewRecorder()
svc.router.ServeHTTP(rec, req)
assert.Equal(t, http.StatusOK, rec.Code)
var observations []map[string]interface{}
err := json.Unmarshal(rec.Body.Bytes(), &observations)
require.NoError(t, err)
assert.GreaterOrEqual(t, len(observations), 2)
}
// TestHandleGetObservations_Pagination tests observations pagination.
func TestHandleGetObservations_Pagination(t *testing.T) {
svc, cleanup := testService(t)
defer cleanup()
// Create some observations
for i := 0; i < 5; i++ {
createTestObservation(t, svc.observationStore, "page-test",
"Observation "+strconv.Itoa(i),
"Content "+strconv.Itoa(i),
[]string{"test"})
}
req := httptest.NewRequest(http.MethodGet, "/api/observations?project=page-test&limit=2", nil)
rec := httptest.NewRecorder()
svc.router.ServeHTTP(rec, req)
assert.Equal(t, http.StatusOK, rec.Code)
}
// TestHandleGetObservations_NoProject tests observations without project.
func TestHandleGetObservations_NoProject(t *testing.T) {
svc, cleanup := testService(t)
defer cleanup()
req := httptest.NewRequest(http.MethodGet, "/api/observations", nil)
rec := httptest.NewRecorder()
svc.router.ServeHTTP(rec, req)
// Should still return 200 with empty results or all observations
assert.Contains(t, []int{http.StatusOK, http.StatusBadRequest}, rec.Code)
}
// TestHandleSearchByPrompt_EmptyQuery tests search with empty query parameter.
func TestHandleSearchByPrompt_EmptyQuery(t *testing.T) {
svc, cleanup := testService(t)
defer cleanup()
req := httptest.NewRequest(http.MethodGet, "/api/context/search?project=test&query=", nil)
rec := httptest.NewRecorder()
svc.router.ServeHTTP(rec, req)
// Empty query should still be a bad request
assert.Equal(t, http.StatusBadRequest, rec.Code)
}
// TestHandleGetSessionByClaudeID tests getting session by Claude ID.
func TestHandleGetSessionByClaudeID(t *testing.T) {
svc, cleanup := testService(t)
defer cleanup()
// Create a session
ctx := context.Background()
svc.sessionStore.CreateSDKSession(ctx, "claude-test-123", "project-a", "prompt 1")
// Test with valid claudeSessionId
req := httptest.NewRequest(http.MethodGet, "/api/sessions?claudeSessionId=claude-test-123", nil)
rec := httptest.NewRecorder()
svc.router.ServeHTTP(rec, req)
assert.Equal(t, http.StatusOK, rec.Code)
}
// TestHandleGetSessionByClaudeID_Missing tests session lookup with missing param.
func TestHandleGetSessionByClaudeID_Missing(t *testing.T) {
svc, cleanup := testService(t)
defer cleanup()
req := httptest.NewRequest(http.MethodGet, "/api/sessions", nil)
rec := httptest.NewRecorder()
svc.router.ServeHTTP(rec, req)
assert.Equal(t, http.StatusBadRequest, rec.Code)
}
// TestHandleGetSessionByClaudeID_NotFound tests session not found.
func TestHandleGetSessionByClaudeID_NotFound(t *testing.T) {
svc, cleanup := testService(t)
defer cleanup()
req := httptest.NewRequest(http.MethodGet, "/api/sessions?claudeSessionId=nonexistent", nil)
rec := httptest.NewRecorder()
svc.router.ServeHTTP(rec, req)
assert.Equal(t, http.StatusNotFound, rec.Code)
}
// TestGetRetrievalStats tests the retrieval stats getter.
func TestGetRetrievalStats(t *testing.T) {
svc, cleanup := testService(t)
defer cleanup()
// Initially all zeros
stats := svc.GetRetrievalStats()
assert.Equal(t, int64(0), stats.TotalRequests)
assert.Equal(t, int64(0), stats.SearchRequests)
// Make some requests to increment stats
project := "stats-test"
createTestObservation(t, svc.observationStore, project, "Test", "Content", []string{"test"})
req := httptest.NewRequest(http.MethodGet, "/api/context/search?project="+project+"&query=test", nil)
rec := httptest.NewRecorder()
svc.router.ServeHTTP(rec, req)
// Stats should be updated
stats = svc.GetRetrievalStats()
assert.GreaterOrEqual(t, stats.TotalRequests, int64(1))
}
// TestSelfCheckResponse_Fields tests SelfCheckResponse struct fields.
func TestSelfCheckResponse_Fields(t *testing.T) {
resp := SelfCheckResponse{
Overall: "healthy",
Version: "v1.0.0",
Uptime: "2h30m",
Components: []ComponentHealth{
{Name: "database", Status: "healthy", Message: "Connected"},
{Name: "vector", Status: "healthy", Message: "Ready"},
},
}
assert.Equal(t, "healthy", resp.Overall)
assert.Equal(t, "v1.0.0", resp.Version)
assert.Equal(t, "2h30m", resp.Uptime)
assert.Len(t, resp.Components, 2)
assert.Equal(t, "database", resp.Components[0].Name)
assert.Equal(t, "healthy", resp.Components[0].Status)
assert.Equal(t, "Connected", resp.Components[0].Message)
}
// TestComponentHealth_Fields tests ComponentHealth struct fields.
func TestComponentHealth_Fields(t *testing.T) {
tests := []struct {
name string
status string
message string
}{
{"healthy", "healthy", "All systems operational"},
{"degraded", "degraded", "Some features unavailable"},
{"unhealthy", "unhealthy", "Service is down"},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
health := ComponentHealth{
Status: tt.status,
Message: tt.message,
}
assert.Equal(t, tt.status, health.Status)
assert.Equal(t, tt.message, health.Message)
})
}
}
// TestWriteJSON_Error tests writeJSON with values that can't be encoded.
func TestWriteJSON_Error(t *testing.T) {
rec := httptest.NewRecorder()
// channels can't be JSON encoded
ch := make(chan int)
writeJSON(rec, ch)
// Should still set content type but encoding will fail
assert.Equal(t, "application/json", rec.Header().Get("Content-Type"))
}
// TestHandleSummarize_InvalidSessionID tests summarize with invalid session ID.
func TestHandleSummarize_InvalidSessionID(t *testing.T) {
svc, cleanup := testService(t)
defer cleanup()
req := httptest.NewRequest(http.MethodPost, "/sessions/invalid/summarize", bytes.NewReader([]byte("{}")))
req.Header.Set("Content-Type", "application/json")
rec := httptest.NewRecorder()
svc.router.ServeHTTP(rec, req)
// Invalid session ID should return 400
assert.Equal(t, http.StatusBadRequest, rec.Code)
}
// TestHandleSubagentComplete tests subagent completion endpoint.
func TestHandleSubagentComplete(t *testing.T) {
svc, cleanup := testService(t)
defer cleanup()
// Create a session first
ctx := context.Background()
sessionID, _ := svc.sessionStore.CreateSDKSession(ctx, "subagent-test-123", "test-project", "test prompt")
payload := `{"session_id": ` + strconv.FormatInt(sessionID, 10) + `, "parent_session_id": "parent-123"}`
req := httptest.NewRequest(http.MethodPost, "/api/sessions/subagent-complete", bytes.NewReader([]byte(payload)))
req.Header.Set("Content-Type", "application/json")
rec := httptest.NewRecorder()
svc.router.ServeHTTP(rec, req)
// Should accept the request
assert.Contains(t, []int{http.StatusOK, http.StatusNotFound, http.StatusBadRequest}, rec.Code)
}
// TestHandleContextSearch_Ordering tests search with different orderings.
func TestHandleContextSearch_Ordering(t *testing.T) {
svc, cleanup := testService(t)
defer cleanup()
project := "order-test"
for i := 0; i < 5; i++ {
createTestObservation(t, svc.observationStore, project,
"Obs "+strconv.Itoa(i),
"Content "+strconv.Itoa(i),
[]string{"test"})
time.Sleep(time.Millisecond)
}
tests := []struct {
name string
order string
}{
{"date_desc", "date_desc"},
{"date_asc", "date_asc"},
{"default", ""}, // Should default to date_desc
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
url := "/api/context/search?project=" + project + "&query=test"
if tt.order != "" {
url += "&order_by=" + tt.order
}
req := httptest.NewRequest(http.MethodGet, url, nil)
rec := httptest.NewRecorder()
svc.router.ServeHTTP(rec, req)
assert.Equal(t, http.StatusOK, rec.Code)
})
}
}
// TestHandleContextCount_NoProject tests context count without project.
func TestHandleContextCount_NoProject(t *testing.T) {
svc, cleanup := testService(t)
defer cleanup()
req := httptest.NewRequest(http.MethodGet, "/api/context/count", nil)
rec := httptest.NewRecorder()
svc.router.ServeHTTP(rec, req)
assert.Equal(t, http.StatusBadRequest, rec.Code)
}
// TestHandleRetrievalStatsEndpoint tests retrieval stats endpoint.
func TestHandleRetrievalStatsEndpoint(t *testing.T) {
svc, cleanup := testService(t)
defer cleanup()
req := httptest.NewRequest(http.MethodGet, "/api/stats/retrieval", nil)
rec := httptest.NewRecorder()
svc.router.ServeHTTP(rec, req)
assert.Equal(t, http.StatusOK, rec.Code)
var stats RetrievalStats
err := json.Unmarshal(rec.Body.Bytes(), &stats)
require.NoError(t, err)
}
// TestHandleReadyEndpoint tests ready endpoint response.
func TestHandleReadyEndpoint(t *testing.T) {
svc, cleanup := testService(t)
defer cleanup()
req := httptest.NewRequest(http.MethodGet, "/api/ready", nil)
rec := httptest.NewRecorder()
svc.router.ServeHTTP(rec, req)
// Ready may return OK or ServiceUnavailable depending on state
assert.Contains(t, []int{http.StatusOK, http.StatusServiceUnavailable}, rec.Code)
}
// TestHandleSessionInit_EmptyBody tests session init with empty body.
func TestHandleSessionInit_EmptyBody(t *testing.T) {
svc, cleanup := testService(t)
defer cleanup()
payload := `{}`
req := httptest.NewRequest(http.MethodPost, "/api/sessions/init", bytes.NewReader([]byte(payload)))
req.Header.Set("Content-Type", "application/json")
rec := httptest.NewRecorder()
svc.router.ServeHTTP(rec, req)
// Should accept empty body and create a session
assert.Contains(t, []int{http.StatusOK, http.StatusBadRequest}, rec.Code)
}
// TestHandleObservation_MissingSession tests observation without session.
func TestHandleObservation_MissingSession(t *testing.T) {
svc, cleanup := testService(t)
defer cleanup()
payload := `{"session_id": 99999, "tool_name": "Read", "tool_input": "{}", "tool_output": "test"}`
req := httptest.NewRequest(http.MethodPost, "/api/sessions/observations", bytes.NewReader([]byte(payload)))
req.Header.Set("Content-Type", "application/json")
rec := httptest.NewRecorder()
svc.router.ServeHTTP(rec, req)
// Should still accept the observation
assert.Contains(t, []int{http.StatusOK, http.StatusNotFound, http.StatusBadRequest}, rec.Code)
}
// TestHandleSummaries_Pagination tests summaries pagination.
func TestHandleSummaries_Pagination(t *testing.T) {
svc, cleanup := testService(t)
defer cleanup()
req := httptest.NewRequest(http.MethodGet, "/api/summaries?project=test&limit=10&offset=0", nil)
rec := httptest.NewRecorder()
svc.router.ServeHTTP(rec, req)
assert.Equal(t, http.StatusOK, rec.Code)
}
// TestHandlePrompts_Pagination tests prompts pagination.
func TestHandlePrompts_Pagination(t *testing.T) {
svc, cleanup := testService(t)
defer cleanup()
req := httptest.NewRequest(http.MethodGet, "/api/prompts?project=test&limit=10&offset=0", nil)
rec := httptest.NewRecorder()
svc.router.ServeHTTP(rec, req)
assert.Equal(t, http.StatusOK, rec.Code)
}
// TestHandleGetStats_AllProjects tests stats without project filter.
func TestHandleGetStats_AllProjects(t *testing.T) {
svc, cleanup := testService(t)
defer cleanup()
// Create some data in multiple projects
createTestObservation(t, svc.observationStore, "proj-a", "Test A", "Content", []string{"test"})
createTestObservation(t, svc.observationStore, "proj-b", "Test B", "Content", []string{"test"})
req := httptest.NewRequest(http.MethodGet, "/api/stats", nil)
rec := httptest.NewRecorder()
svc.router.ServeHTTP(rec, req)
assert.Equal(t, http.StatusOK, rec.Code)
}
// TestHandleSubagentComplete_WithSession tests subagent completion with existing session.
func TestHandleSubagentComplete_WithSession(t *testing.T) {
svc, cleanup := testService(t)
defer cleanup()
// Create a session first
ctx := context.Background()
svc.sessionStore.CreateSDKSession(ctx, "subagent-claude-123", "test-project", "test prompt")
reqBody := SubagentCompleteRequest{
ClaudeSessionID: "subagent-claude-123",
Project: "test-project",
}
body, _ := json.Marshal(reqBody)
req := httptest.NewRequest(http.MethodPost, "/api/sessions/subagent-complete", bytes.NewReader(body))
req.Header.Set("Content-Type", "application/json")
rec := httptest.NewRecorder()
svc.router.ServeHTTP(rec, req)
assert.Equal(t, http.StatusOK, rec.Code)
}
// TestHandleSubagentComplete_NoSession tests subagent completion when session doesn't exist.
func TestHandleSubagentComplete_NoSession(t *testing.T) {
svc, cleanup := testService(t)
defer cleanup()
reqBody := SubagentCompleteRequest{
ClaudeSessionID: "nonexistent-session",
Project: "test-project",
}
body, _ := json.Marshal(reqBody)
req := httptest.NewRequest(http.MethodPost, "/api/sessions/subagent-complete", bytes.NewReader(body))
req.Header.Set("Content-Type", "application/json")
rec := httptest.NewRecorder()
svc.router.ServeHTTP(rec, req)
// Should still return 200 even if session not found
assert.Equal(t, http.StatusOK, rec.Code)
}
// TestHandleSubagentComplete_InvalidJSON tests subagent completion with invalid JSON.
func TestHandleSubagentComplete_InvalidJSON(t *testing.T) {
svc, cleanup := testService(t)
defer cleanup()
req := httptest.NewRequest(http.MethodPost, "/api/sessions/subagent-complete", bytes.NewReader([]byte("invalid json")))
req.Header.Set("Content-Type", "application/json")
rec := httptest.NewRecorder()
svc.router.ServeHTTP(rec, req)
assert.Equal(t, http.StatusBadRequest, rec.Code)
}
// TestHandleSummarize_ValidSession tests summarize with valid session.
func TestHandleSummarize_ValidSession(t *testing.T) {
svc, cleanup := testService(t)
defer cleanup()
// Create a session first
ctx := context.Background()
sessionID, _ := svc.sessionStore.CreateSDKSession(ctx, "summarize-claude-test", "test-project", "test prompt")
reqBody := SummarizeRequest{
LastUserMessage: "Can you help me fix this bug?",
LastAssistantMessage: "I've analyzed the code and fixed the issue in the handler.",
}
body, _ := json.Marshal(reqBody)
req := httptest.NewRequest(http.MethodPost, "/sessions/"+strconv.FormatInt(sessionID, 10)+"/summarize", bytes.NewReader(body))
req.Header.Set("Content-Type", "application/json")
rec := httptest.NewRecorder()
svc.router.ServeHTTP(rec, req)
assert.Equal(t, http.StatusOK, rec.Code)
}
// TestHandleSummarize_InvalidJSON tests summarize with invalid JSON.
func TestHandleSummarize_InvalidJSON(t *testing.T) {
svc, cleanup := testService(t)
defer cleanup()
ctx := context.Background()
sessionID, _ := svc.sessionStore.CreateSDKSession(ctx, "summarize-invalid", "test-project", "test")
req := httptest.NewRequest(http.MethodPost, "/sessions/"+strconv.FormatInt(sessionID, 10)+"/summarize", bytes.NewReader([]byte("not valid json")))
req.Header.Set("Content-Type", "application/json")
rec := httptest.NewRecorder()
svc.router.ServeHTTP(rec, req)
assert.Equal(t, http.StatusBadRequest, rec.Code)
}
// TestHandleSummarize_NonExistentSession tests summarize with non-existent session.
func TestHandleSummarize_NonExistentSession(t *testing.T) {
svc, cleanup := testService(t)
defer cleanup()
reqBody := SummarizeRequest{
LastUserMessage: "test",
LastAssistantMessage: "test",
}
body, _ := json.Marshal(reqBody)
req := httptest.NewRequest(http.MethodPost, "/sessions/999999/summarize", bytes.NewReader(body))
req.Header.Set("Content-Type", "application/json")
rec := httptest.NewRecorder()
svc.router.ServeHTTP(rec, req)
// Should return error for non-existent session
assert.Contains(t, []int{http.StatusOK, http.StatusInternalServerError, http.StatusNotFound}, rec.Code)
}
// TestSubagentCompleteRequest_Fields tests SubagentCompleteRequest struct.
func TestSubagentCompleteRequest_Fields(t *testing.T) {
req := SubagentCompleteRequest{
ClaudeSessionID: "test-session-123",
Project: "my-project",
}
assert.Equal(t, "test-session-123", req.ClaudeSessionID)
assert.Equal(t, "my-project", req.Project)
}
// TestSummarizeRequest_Fields tests SummarizeRequest struct.
func TestSummarizeRequest_Fields(t *testing.T) {
req := SummarizeRequest{
LastUserMessage: "Help me fix this bug",
LastAssistantMessage: "I've fixed the authentication issue",
}
assert.Equal(t, "Help me fix this bug", req.LastUserMessage)
assert.Equal(t, "I've fixed the authentication issue", req.LastAssistantMessage)
}
// TestHandleHealth_NotReady tests health endpoint when not ready.
func TestHandleHealth_NotReady(t *testing.T) {
svc, cleanup := testService(t)
defer cleanup()
svc.ready.Store(false)
req := httptest.NewRequest(http.MethodGet, "/api/health", nil)
rec := httptest.NewRecorder()
svc.handleHealth(rec, req)
assert.Equal(t, http.StatusOK, rec.Code)
var response map[string]interface{}
err := json.Unmarshal(rec.Body.Bytes(), &response)
require.NoError(t, err)
assert.Equal(t, "starting", response["status"])
}
// TestHandleContextInject_EmptyProject tests context inject with empty project.
func TestHandleContextInject_EmptyProject(t *testing.T) {
svc, cleanup := testService(t)
defer cleanup()
req := httptest.NewRequest(http.MethodGet, "/api/context/inject?project=", nil)
rec := httptest.NewRecorder()
svc.router.ServeHTTP(rec, req)
assert.Equal(t, http.StatusBadRequest, rec.Code)
}
// TestHandleSearchByPrompt_LargeLimit tests search with limit exceeding max.
func TestHandleSearchByPrompt_LargeLimit(t *testing.T) {
svc, cleanup := testService(t)
defer cleanup()
project := "limit-test"
createTestObservation(t, svc.observationStore, project, "Test", "Content", []string{"test"})
req := httptest.NewRequest(http.MethodGet, "/api/context/search?project="+project+"&query=test&limit=999", nil)
rec := httptest.NewRecorder()
svc.router.ServeHTTP(rec, req)
assert.Equal(t, http.StatusOK, rec.Code)
}
// TestHandleObservation_WithFullData tests observation with all fields.
func TestHandleObservation_WithFullData(t *testing.T) {
svc, cleanup := testService(t)
defer cleanup()
// Create a session first
ctx := context.Background()
svc.sessionStore.CreateSDKSession(ctx, "obs-full-test", "test-project", "test prompt")
reqBody := ObservationRequest{
ClaudeSessionID: "obs-full-test",
Project: "test-project",
ToolName: "Edit",
ToolInput: map[string]interface{}{"file_path": "/test.go", "old_string": "foo", "new_string": "bar"},
ToolResponse: "Edit successful",
CWD: "/home/user/project",
}
body, _ := json.Marshal(reqBody)
req := httptest.NewRequest(http.MethodPost, "/api/sessions/observations", bytes.NewReader(body))
req.Header.Set("Content-Type", "application/json")
rec := httptest.NewRecorder()
svc.router.ServeHTTP(rec, req)
assert.Equal(t, http.StatusOK, rec.Code)
}
// TestHandleSelfCheck_WithObservations tests self-check with observations in DB.
func TestHandleSelfCheck_WithObservations(t *testing.T) {
svc, cleanup := testService(t)
defer cleanup()
svc.ready.Store(true)
// Create some observations
createTestObservation(t, svc.observationStore, "check-project", "Test", "Content", []string{"test"})
req := httptest.NewRequest(http.MethodGet, "/api/self-check", nil)
rec := httptest.NewRecorder()
svc.handleSelfCheck(rec, req)
assert.Equal(t, http.StatusOK, rec.Code)
var response SelfCheckResponse
err := json.Unmarshal(rec.Body.Bytes(), &response)
require.NoError(t, err)
// Check components are populated
assert.NotEmpty(t, response.Components)
}
// TestHandleGetSummaries_NoProject tests getting summaries without project filter.
func TestHandleGetSummaries_NoProject(t *testing.T) {
svc, cleanup := testService(t)
defer cleanup()
// Create some summaries in different projects
ctx := context.Background()
for i := 0; i < 3; i++ {
parsed := &models.ParsedSummary{Request: "Request " + string(rune('A'+i))}
svc.summaryStore.StoreSummary(ctx, "sdk-"+string(rune('a'+i)), "project-"+string(rune('a'+i)), parsed, i+1, 100)
}
req := httptest.NewRequest(http.MethodGet, "/api/summaries", nil)
rec := httptest.NewRecorder()
svc.router.ServeHTTP(rec, req)
assert.Equal(t, http.StatusOK, rec.Code)
var summaries []map[string]interface{}
err := json.Unmarshal(rec.Body.Bytes(), &summaries)
require.NoError(t, err)
// Should return all summaries
assert.GreaterOrEqual(t, len(summaries), 3)
}
// TestHandleGetPrompts_NoProject tests getting prompts without project filter.
func TestHandleGetPrompts_NoProject(t *testing.T) {
svc, cleanup := testService(t)
defer cleanup()
// Create sessions and prompts in different projects
ctx := context.Background()
svc.sessionStore.CreateSDKSession(ctx, "claude-prompts-a", "project-a", "")
svc.sessionStore.CreateSDKSession(ctx, "claude-prompts-b", "project-b", "")
svc.promptStore.SaveUserPromptWithMatches(ctx, "claude-prompts-a", 1, "Prompt A", 0)
svc.promptStore.SaveUserPromptWithMatches(ctx, "claude-prompts-b", 1, "Prompt B", 0)
req := httptest.NewRequest(http.MethodGet, "/api/prompts", nil)
rec := httptest.NewRecorder()
svc.router.ServeHTTP(rec, req)
assert.Equal(t, http.StatusOK, rec.Code)
var prompts []map[string]interface{}
err := json.Unmarshal(rec.Body.Bytes(), &prompts)
require.NoError(t, err)
assert.GreaterOrEqual(t, len(prompts), 2)
}
// TestHandleSessionInit_MissingClaudeID tests session init without Claude ID.
func TestHandleSessionInit_MissingClaudeID(t *testing.T) {
svc, cleanup := testService(t)
defer cleanup()
reqBody := SessionInitRequest{
Project: "test-project",
Prompt: "Help me",
}
body, _ := json.Marshal(reqBody)
req := httptest.NewRequest(http.MethodPost, "/api/sessions/init", bytes.NewReader(body))
req.Header.Set("Content-Type", "application/json")
rec := httptest.NewRecorder()
svc.router.ServeHTTP(rec, req)
// Should accept even without Claude ID (may auto-generate)
assert.Contains(t, []int{http.StatusOK, http.StatusBadRequest}, rec.Code)
}
// TestHandleContextInject_WithQuery tests context inject with query parameter.
func TestHandleContextInject_WithQuery(t *testing.T) {
svc, cleanup := testService(t)
defer cleanup()
project := "inject-query-test"
createTestObservation(t, svc.observationStore, project, "Authentication bug fix", "Fixed JWT validation", []string{"auth", "jwt"})
createTestObservation(t, svc.observationStore, project, "Database optimization", "Added indexes", []string{"db", "performance"})
req := httptest.NewRequest(http.MethodGet, "/api/context/inject?project="+project+"&query=authentication", nil)
rec := httptest.NewRecorder()
svc.router.ServeHTTP(rec, req)
assert.Equal(t, http.StatusOK, rec.Code)
var response map[string]interface{}
err := json.Unmarshal(rec.Body.Bytes(), &response)
require.NoError(t, err)
observations, ok := response["observations"].([]interface{})
require.True(t, ok)
assert.GreaterOrEqual(t, len(observations), 1)
}
+381
View File
@@ -1,9 +1,12 @@
package sdk
import (
"os"
"path/filepath"
"testing"
"github.com/lukaszraczylo/claude-mnemonic/pkg/models"
"github.com/stretchr/testify/assert"
)
func TestIsSelfReferentialSummary(t *testing.T) {
@@ -124,3 +127,381 @@ No substantive work performed yet.`,
})
}
}
func TestShouldSkipTool(t *testing.T) {
tests := []struct {
name string
toolName string
expected bool
}{
// Tools that should be skipped
{"TodoWrite", "TodoWrite", true},
{"Task", "Task", true},
{"TaskOutput", "TaskOutput", true},
{"Glob", "Glob", true},
{"ListDir", "ListDir", true},
{"LS", "LS", true},
{"KillShell", "KillShell", true},
{"AskUserQuestion", "AskUserQuestion", true},
{"EnterPlanMode", "EnterPlanMode", true},
{"ExitPlanMode", "ExitPlanMode", true},
{"Skill", "Skill", true},
{"SlashCommand", "SlashCommand", true},
// Tools that should NOT be skipped
{"Read", "Read", false},
{"Edit", "Edit", false},
{"Write", "Write", false},
{"Grep", "Grep", false},
{"Bash", "Bash", false},
{"WebFetch", "WebFetch", false},
{"WebSearch", "WebSearch", false},
{"NotebookEdit", "NotebookEdit", false},
// Unknown tool (should not be skipped)
{"UnknownTool", "SomeUnknownTool", false},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
result := shouldSkipTool(tt.toolName)
assert.Equal(t, tt.expected, result)
})
}
}
func TestShouldSkipTrivialOperation(t *testing.T) {
tests := []struct {
name string
toolName string
inputStr string
outputStr string
expected bool
}{
// Short output (should be skipped)
{
name: "output_too_short",
toolName: "Read",
inputStr: `{"file_path": "/some/file.go"}`,
outputStr: "short",
expected: true,
},
// Trivial outputs
{
name: "no_matches_found",
toolName: "Grep",
inputStr: `{"pattern": "foo"}`,
outputStr: "No matches found in the codebase",
expected: true,
},
{
name: "file_not_found",
toolName: "Read",
inputStr: `{"file_path": "/nonexistent.go"}`,
outputStr: "Error: File not found at specified path",
expected: true,
},
{
name: "empty_array",
toolName: "Grep",
inputStr: `{"pattern": "foo"}`,
outputStr: "[]",
expected: true,
},
// Boring files
{
name: "package_lock_json",
toolName: "Read",
inputStr: `{"file_path": "/project/package-lock.json"}`,
outputStr: "This is a very long package-lock.json content that has more than 50 characters",
expected: true,
},
{
name: "go_sum",
toolName: "Read",
inputStr: `{"file_path": "/project/go.sum"}`,
outputStr: "This is a very long go.sum file content that has more than 50 characters",
expected: true,
},
// Grep with too many matches
{
name: "grep_too_many_matches",
toolName: "Grep",
inputStr: `{"pattern": "import"}`,
outputStr: func() string {
s := ""
for i := 0; i < 55; i++ {
s += "match line\n"
}
return s
}(),
expected: true,
},
// Boring Bash commands
{
name: "git_status",
toolName: "Bash",
inputStr: `{"command": "git status"}`,
outputStr: "On branch main\nYour branch is up to date with 'origin/main'.\nnothing to commit, working tree clean",
expected: true,
},
{
name: "ls_command",
toolName: "Bash",
inputStr: `{"command": "ls -la /some/directory"}`,
outputStr: "total 123\ndrwxr-xr-x some long listing that is at least 50 chars",
expected: true,
},
// Valid operations that should NOT be skipped
{
name: "valid_read",
toolName: "Read",
inputStr: `{"file_path": "/project/main.go"}`,
outputStr: "package main\n\nimport \"fmt\"\n\nfunc main() {\n\tfmt.Println(\"Hello World\")\n}",
expected: false,
},
{
name: "valid_edit",
toolName: "Edit",
inputStr: `{"file_path": "/project/handler.go", "old_string": "foo", "new_string": "bar"}`,
outputStr: "Edit applied successfully. File /project/handler.go has been modified with the requested changes.",
expected: false,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
result := shouldSkipTrivialOperation(tt.toolName, tt.inputStr, tt.outputStr)
assert.Equal(t, tt.expected, result)
})
}
}
func TestTruncateForLog(t *testing.T) {
tests := []struct {
name string
input string
maxLen int
expected string
}{
{
name: "shorter_than_max",
input: "hello",
maxLen: 10,
expected: "hello",
},
{
name: "equal_to_max",
input: "hello",
maxLen: 5,
expected: "hello",
},
{
name: "longer_than_max",
input: "hello world",
maxLen: 5,
expected: "hello...",
},
{
name: "empty_string",
input: "",
maxLen: 5,
expected: "",
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
result := truncateForLog(tt.input, tt.maxLen)
assert.Equal(t, tt.expected, result)
})
}
}
func TestToJSONString(t *testing.T) {
tests := []struct {
name string
input interface{}
expected string
}{
{
name: "nil_value",
input: nil,
expected: "",
},
{
name: "string_value",
input: "hello",
expected: "hello",
},
{
name: "int_value",
input: 42,
expected: "42",
},
{
name: "map_value",
input: map[string]string{"key": "value"},
expected: `{"key":"value"}`,
},
{
name: "slice_value",
input: []string{"a", "b", "c"},
expected: `["a","b","c"]`,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
result := toJSONString(tt.input)
assert.Equal(t, tt.expected, result)
})
}
}
func TestCaptureFileMtimes(t *testing.T) {
// Create a temp directory with test files
tmpDir, err := os.MkdirTemp("", "mtime-test-*")
if err != nil {
t.Fatal(err)
}
defer os.RemoveAll(tmpDir)
// Create test files
file1 := filepath.Join(tmpDir, "file1.txt")
file2 := filepath.Join(tmpDir, "file2.txt")
err = os.WriteFile(file1, []byte("content1"), 0644)
if err != nil {
t.Fatal(err)
}
err = os.WriteFile(file2, []byte("content2"), 0644)
if err != nil {
t.Fatal(err)
}
t.Run("captures_mtimes_for_existing_files", func(t *testing.T) {
mtimes := captureFileMtimes([]string{file1}, []string{file2}, "")
assert.Len(t, mtimes, 2)
assert.Contains(t, mtimes, file1)
assert.Contains(t, mtimes, file2)
assert.Greater(t, mtimes[file1], int64(0))
assert.Greater(t, mtimes[file2], int64(0))
})
t.Run("handles_nonexistent_files", func(t *testing.T) {
mtimes := captureFileMtimes([]string{"/nonexistent/file.txt"}, nil, "")
assert.Empty(t, mtimes)
})
t.Run("handles_relative_paths_with_cwd", func(t *testing.T) {
mtimes := captureFileMtimes([]string{"file1.txt"}, nil, tmpDir)
assert.Len(t, mtimes, 1)
assert.Contains(t, mtimes, "file1.txt")
})
t.Run("empty_inputs", func(t *testing.T) {
mtimes := captureFileMtimes(nil, nil, "")
assert.Empty(t, mtimes)
})
}
func TestGetFileMtimes(t *testing.T) {
// Create a temp file
tmpDir, err := os.MkdirTemp("", "getmtime-test-*")
if err != nil {
t.Fatal(err)
}
defer os.RemoveAll(tmpDir)
testFile := filepath.Join(tmpDir, "test.txt")
err = os.WriteFile(testFile, []byte("content"), 0644)
if err != nil {
t.Fatal(err)
}
mtimes := GetFileMtimes([]string{testFile}, "")
assert.Len(t, mtimes, 1)
assert.Contains(t, mtimes, testFile)
assert.Greater(t, mtimes[testFile], int64(0))
}
func TestGetFileContent(t *testing.T) {
// Create a temp directory with test files
tmpDir, err := os.MkdirTemp("", "content-test-*")
if err != nil {
t.Fatal(err)
}
defer os.RemoveAll(tmpDir)
t.Run("reads_existing_file", func(t *testing.T) {
testFile := filepath.Join(tmpDir, "test.txt")
content := "test content"
err := os.WriteFile(testFile, []byte(content), 0644)
if err != nil {
t.Fatal(err)
}
result, ok := GetFileContent(testFile, "")
assert.True(t, ok)
assert.Equal(t, content, result)
})
t.Run("returns_false_for_nonexistent_file", func(t *testing.T) {
result, ok := GetFileContent("/nonexistent/file.txt", "")
assert.False(t, ok)
assert.Empty(t, result)
})
t.Run("truncates_long_content", func(t *testing.T) {
testFile := filepath.Join(tmpDir, "long.txt")
longContent := ""
for i := 0; i < 3000; i++ {
longContent += "x"
}
err := os.WriteFile(testFile, []byte(longContent), 0644)
if err != nil {
t.Fatal(err)
}
result, ok := GetFileContent(testFile, "")
assert.True(t, ok)
assert.Contains(t, result, "[truncated]")
assert.LessOrEqual(t, len(result), 2100)
})
t.Run("resolves_relative_path_with_cwd", func(t *testing.T) {
testFile := filepath.Join(tmpDir, "relative.txt")
content := "relative content"
err := os.WriteFile(testFile, []byte(content), 0644)
if err != nil {
t.Fatal(err)
}
result, ok := GetFileContent("relative.txt", tmpDir)
assert.True(t, ok)
assert.Equal(t, content, result)
})
}
func TestMaxConcurrentCLICalls(t *testing.T) {
assert.Equal(t, 4, MaxConcurrentCLICalls)
}
func TestObservationTypes(t *testing.T) {
expected := []string{"bugfix", "feature", "refactor", "change", "discovery", "decision"}
assert.Equal(t, expected, ObservationTypes)
}
func TestObservationConcepts(t *testing.T) {
expectedConcepts := []string{
"how-it-works",
"why-it-exists",
"what-changed",
"problem-solution",
"gotcha",
"pattern",
"trade-off",
}
assert.Equal(t, expectedConcepts, ObservationConcepts)
}
+276
View File
@@ -0,0 +1,276 @@
package sdk
import (
"strings"
"testing"
"time"
"github.com/stretchr/testify/assert"
)
func TestTruncate(t *testing.T) {
tests := []struct {
name string
input string
maxLen int
expected string
}{
{
name: "shorter_than_max",
input: "hello",
maxLen: 10,
expected: "hello",
},
{
name: "equal_to_max",
input: "hello",
maxLen: 5,
expected: "hello",
},
{
name: "longer_than_max",
input: "hello world",
maxLen: 5,
expected: "hello... (truncated)",
},
{
name: "empty_string",
input: "",
maxLen: 5,
expected: "",
},
{
name: "zero_max_length",
input: "hello",
maxLen: 0,
expected: "... (truncated)",
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
result := truncate(tt.input, tt.maxLen)
assert.Equal(t, tt.expected, result)
})
}
}
func TestBuildObservationPrompt(t *testing.T) {
now := time.Now().UnixMilli()
tests := []struct {
name string
exec ToolExecution
contains []string
}{
{
name: "basic_read_tool",
exec: ToolExecution{
ID: 1,
ToolName: "Read",
ToolInput: `{"file_path": "/path/to/file.go"}`,
ToolOutput: `package main\nfunc main() {}`,
CreatedAtEpoch: now,
CWD: "/project",
},
contains: []string{
"<observed_from_primary_session>",
"<what_happened>Read</what_happened>",
"<working_directory>/project</working_directory>",
"<parameters>",
"file_path",
"<outcome>",
"</observed_from_primary_session>",
},
},
{
name: "edit_tool_with_json_input",
exec: ToolExecution{
ID: 2,
ToolName: "Edit",
ToolInput: `{"file_path": "/file.go", "old_string": "foo", "new_string": "bar"}`,
ToolOutput: "Edit applied successfully",
CreatedAtEpoch: now,
CWD: "",
},
contains: []string{
"<what_happened>Edit</what_happened>",
"file_path",
"old_string",
"new_string",
},
},
{
name: "no_cwd",
exec: ToolExecution{
ID: 3,
ToolName: "Bash",
ToolInput: `{"command": "go test"}`,
ToolOutput: "ok",
CreatedAtEpoch: now,
CWD: "",
},
contains: []string{
"<what_happened>Bash</what_happened>",
},
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
result := BuildObservationPrompt(tt.exec)
for _, s := range tt.contains {
assert.Contains(t, result, s, "Expected result to contain: %s", s)
}
// Check CWD only appears when set
if tt.exec.CWD == "" {
assert.NotContains(t, result, "<working_directory>")
}
})
}
}
func TestBuildObservationPrompt_TruncatesLongContent(t *testing.T) {
longInput := strings.Repeat("x", 5000)
longOutput := strings.Repeat("y", 7000)
exec := ToolExecution{
ID: 1,
ToolName: "Read",
ToolInput: longInput,
ToolOutput: longOutput,
CreatedAtEpoch: time.Now().UnixMilli(),
CWD: "/project",
}
result := BuildObservationPrompt(exec)
// Input should be truncated to ~3000
assert.Contains(t, result, "truncated")
// The result should not be excessively long
assert.Less(t, len(result), 10000)
}
func TestBuildSummaryPrompt(t *testing.T) {
tests := []struct {
name string
req SummaryRequest
contains []string
}{
{
name: "basic_request",
req: SummaryRequest{
SessionDBID: 1,
SDKSessionID: "sdk-123",
Project: "test-project",
},
contains: []string{
"PROGRESS SUMMARY CHECKPOINT",
"<summary>",
"<request>",
"<investigated>",
"<learned>",
"<completed>",
"<next_steps>",
"<notes>",
"</summary>",
},
},
{
name: "with_assistant_message",
req: SummaryRequest{
SessionDBID: 2,
SDKSessionID: "sdk-456",
Project: "project-b",
LastAssistantMessage: "I fixed the authentication bug by updating the JWT validation.",
},
contains: []string{
"Claude's Full Response to User:",
"fixed the authentication",
},
},
{
name: "empty_assistant_message",
req: SummaryRequest{
SessionDBID: 3,
SDKSessionID: "sdk-789",
Project: "project-c",
LastAssistantMessage: "",
},
contains: []string{
"PROGRESS SUMMARY CHECKPOINT",
},
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
result := BuildSummaryPrompt(tt.req)
for _, s := range tt.contains {
assert.Contains(t, result, s, "Expected result to contain: %s", s)
}
// Check assistant message only appears when set
if tt.req.LastAssistantMessage == "" {
assert.NotContains(t, result, "Claude's Full Response")
}
})
}
}
func TestBuildSummaryPrompt_TruncatesLongAssistantMessage(t *testing.T) {
longMessage := strings.Repeat("a", 5000)
req := SummaryRequest{
SessionDBID: 1,
SDKSessionID: "sdk-123",
Project: "test",
LastAssistantMessage: longMessage,
}
result := BuildSummaryPrompt(req)
// Should contain truncation indicator
assert.Contains(t, result, "truncated")
// Result should be reasonable length (less than full 5000 + overhead)
assert.Less(t, len(result), 6000)
}
func TestToolExecution_Struct(t *testing.T) {
exec := ToolExecution{
ID: 42,
ToolName: "Write",
ToolInput: `{"file_path": "/test.go"}`,
ToolOutput: "File written",
CreatedAtEpoch: 1234567890000,
CWD: "/workspace",
}
assert.Equal(t, int64(42), exec.ID)
assert.Equal(t, "Write", exec.ToolName)
assert.Equal(t, `{"file_path": "/test.go"}`, exec.ToolInput)
assert.Equal(t, "File written", exec.ToolOutput)
assert.Equal(t, int64(1234567890000), exec.CreatedAtEpoch)
assert.Equal(t, "/workspace", exec.CWD)
}
func TestSummaryRequest_Struct(t *testing.T) {
req := SummaryRequest{
SessionDBID: 100,
SDKSessionID: "sdk-abc",
Project: "my-project",
UserPrompt: "Fix the bug",
LastUserMessage: "Please fix the auth bug",
LastAssistantMessage: "I've fixed the authentication issue",
}
assert.Equal(t, int64(100), req.SessionDBID)
assert.Equal(t, "sdk-abc", req.SDKSessionID)
assert.Equal(t, "my-project", req.Project)
assert.Equal(t, "Fix the bug", req.UserPrompt)
assert.Equal(t, "Please fix the auth bug", req.LastUserMessage)
assert.Equal(t, "I've fixed the authentication issue", req.LastAssistantMessage)
}
+492
View File
@@ -710,3 +710,495 @@ func TestGetWorkerVersion_WithServer(t *testing.T) {
})
}
}
// TestGetWorkerPort_EdgeCases tests GetWorkerPort with various edge cases.
func TestGetWorkerPort_EdgeCases(t *testing.T) {
tests := []struct {
name string
envValue string
expectedPort int
shouldSetEnv bool
}{
{
name: "zero port uses default",
envValue: "0",
expectedPort: DefaultWorkerPort,
shouldSetEnv: true,
},
{
name: "negative port uses default",
envValue: "-1",
expectedPort: DefaultWorkerPort,
shouldSetEnv: true,
},
{
name: "empty string uses default",
envValue: "",
expectedPort: DefaultWorkerPort,
shouldSetEnv: true,
},
{
name: "whitespace uses default",
envValue: " ",
expectedPort: DefaultWorkerPort,
shouldSetEnv: true,
},
{
name: "large valid port",
envValue: "65535",
expectedPort: 65535,
shouldSetEnv: true,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
if tt.shouldSetEnv {
t.Setenv("CLAUDE_MNEMONIC_WORKER_PORT", tt.envValue)
}
port := GetWorkerPort()
assert.Equal(t, tt.expectedPort, port)
})
}
}
// TestVersionVariable tests the Version variable.
func TestVersionVariable(t *testing.T) {
// Version is set at build time, but defaults to "dev"
assert.NotEmpty(t, Version)
}
// TestProjectIDWithName_RootPath tests ProjectIDWithName with root path.
func TestProjectIDWithName_RootPath(t *testing.T) {
result := ProjectIDWithName("/")
// Should handle root path gracefully
assert.NotEmpty(t, result)
assert.Contains(t, result, "_") // Should still have underscore separator
}
// TestProjectIDWithName_SameDirname tests that same dirname with different paths get different IDs.
func TestProjectIDWithName_SameDirname(t *testing.T) {
id1 := ProjectIDWithName("/home/user1/project")
id2 := ProjectIDWithName("/home/user2/project")
// Both have same dirname "project" but different full paths
assert.Contains(t, id1, "project_")
assert.Contains(t, id2, "project_")
// But different hashes due to different full paths
assert.NotEqual(t, id1, id2)
}
// TestBaseInput_PartialFields tests BaseInput with partial fields.
func TestBaseInput_PartialFields(t *testing.T) {
tests := []struct {
name string
input string
expected BaseInput
}{
{
name: "only session_id",
input: `{"session_id":"test-123"}`,
expected: BaseInput{SessionID: "test-123"},
},
{
name: "only cwd",
input: `{"cwd":"/tmp/test"}`,
expected: BaseInput{CWD: "/tmp/test"},
},
{
name: "empty object",
input: `{}`,
expected: BaseInput{},
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
var base BaseInput
err := json.Unmarshal([]byte(tt.input), &base)
require.NoError(t, err)
assert.Equal(t, tt.expected.SessionID, base.SessionID)
assert.Equal(t, tt.expected.CWD, base.CWD)
})
}
}
// TestHookResponse_Marshal tests HookResponse JSON marshaling.
func TestHookResponse_Marshal(t *testing.T) {
tests := []struct {
name string
response HookResponse
contains []string
}{
{
name: "continue true",
response: HookResponse{Continue: true},
contains: []string{`"continue":true`},
},
{
name: "continue false",
response: HookResponse{Continue: false},
contains: []string{`"continue":false`},
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
data, err := json.Marshal(tt.response)
require.NoError(t, err)
for _, s := range tt.contains {
assert.Contains(t, string(data), s)
}
})
}
}
// TestHookResponse_Unmarshal tests HookResponse JSON unmarshaling.
func TestHookResponse_Unmarshal(t *testing.T) {
tests := []struct {
name string
input string
expected HookResponse
}{
{
name: "continue true",
input: `{"continue":true}`,
expected: HookResponse{Continue: true},
},
{
name: "continue false",
input: `{"continue":false}`,
expected: HookResponse{Continue: false},
},
{
name: "missing continue defaults to false",
input: `{}`,
expected: HookResponse{Continue: false},
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
var resp HookResponse
err := json.Unmarshal([]byte(tt.input), &resp)
require.NoError(t, err)
assert.Equal(t, tt.expected.Continue, resp.Continue)
})
}
}
// TestHookContext_Initialization tests HookContext struct initialization.
func TestHookContext_Initialization(t *testing.T) {
tests := []struct {
name string
ctx HookContext
}{
{
name: "full context",
ctx: HookContext{
HookName: "session-start",
Port: 37777,
Project: "my-project_abc123",
SessionID: "session-123",
CWD: "/home/user/project",
RawInput: []byte(`{"key":"value"}`),
},
},
{
name: "minimal context",
ctx: HookContext{
HookName: "stop",
},
},
{
name: "empty context",
ctx: HookContext{},
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
// Just verify the struct can be created and accessed
assert.Equal(t, tt.ctx.HookName, tt.ctx.HookName)
assert.Equal(t, tt.ctx.Port, tt.ctx.Port)
assert.Equal(t, tt.ctx.Project, tt.ctx.Project)
})
}
}
// TestPOST_MarshalError tests POST with unmarshalable body.
func TestPOST_MarshalError(t *testing.T) {
// Create a value that can't be marshaled
badValue := make(chan int)
_, err := POST(99999, "/test", badValue)
require.Error(t, err)
}
// TestPOST_Timeout tests POST with timeout.
func TestPOST_Timeout(t *testing.T) {
// Try to connect to a port that's not listening
_, err := POST(99998, "/test", map[string]string{"key": "value"})
require.Error(t, err)
}
// TestGET_Timeout tests GET with timeout.
func TestGET_Timeout(t *testing.T) {
// Try to connect to a port that's not listening
_, err := GET(99998, "/test")
require.Error(t, err)
}
// TestIsWorkerRunning_Timeout tests IsWorkerRunning with timeout.
func TestIsWorkerRunning_Timeout(t *testing.T) {
// Non-existent port should quickly return false
start := time.Now()
result := IsWorkerRunning(99997)
elapsed := time.Since(start)
assert.False(t, result)
assert.Less(t, elapsed, 5*time.Second) // Should not hang
}
// TestIsPortInUse_Timeout tests IsPortInUse with timeout.
func TestIsPortInUse_Timeout(t *testing.T) {
// Non-existent port should quickly return false
start := time.Now()
result := IsPortInUse(99996)
elapsed := time.Since(start)
assert.False(t, result)
assert.Less(t, elapsed, 2*time.Second) // Should not hang
}
// TestGetWorkerVersion_Timeout tests GetWorkerVersion with timeout.
func TestGetWorkerVersion_Timeout(t *testing.T) {
// Non-existent port should quickly return empty
start := time.Now()
result := GetWorkerVersion(99995)
elapsed := time.Since(start)
assert.Empty(t, result)
assert.Less(t, elapsed, 5*time.Second) // Should not hang
}
// TestVersionsCompatible_EdgeCases tests versionsCompatible edge cases.
func TestVersionsCompatible_EdgeCases(t *testing.T) {
tests := []struct {
name string
v1 string
v2 string
expected bool
}{
{
name: "empty versions",
v1: "",
v2: "",
expected: true, // Same base (empty)
},
{
name: "one empty one dev",
v1: "",
v2: "dev",
expected: true, // dev is compatible with anything
},
{
name: "prerelease versions same base",
v1: "v1.0.0-alpha",
v2: "v1.0.0-beta",
expected: true, // Same base 1.0.0
},
{
name: "version with rc suffix",
v1: "v2.0.0-rc1",
v2: "v2.0.0",
expected: true, // Same base 2.0.0
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
result := versionsCompatible(tt.v1, tt.v2)
assert.Equal(t, tt.expected, result)
})
}
}
// TestExtractBaseVersion_EdgeCases tests extractBaseVersion edge cases.
func TestExtractBaseVersion_EdgeCases(t *testing.T) {
tests := []struct {
name string
version string
expected string
}{
{
name: "version starting with hyphen",
version: "-dirty",
expected: "-dirty", // hyphen at index 0 is not > 0, so no truncation
},
{
name: "just v",
version: "v",
expected: "",
},
{
name: "multiple hyphens",
version: "v1.0.0-alpha-beta-gamma",
expected: "1.0.0",
},
{
name: "no hyphen at all",
version: "v2.0.0",
expected: "2.0.0",
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
result := extractBaseVersion(tt.version)
assert.Equal(t, tt.expected, result)
})
}
}
// TestProjectIDWithName_RelativePath tests ProjectIDWithName with relative paths.
func TestProjectIDWithName_RelativePath(t *testing.T) {
// Relative paths should be converted to absolute
result := ProjectIDWithName(".")
assert.NotEmpty(t, result)
assert.Contains(t, result, "_")
}
// TestProjectIDWithName_DeepPath tests ProjectIDWithName with deep paths.
func TestProjectIDWithName_DeepPath(t *testing.T) {
result := ProjectIDWithName("/a/very/deep/nested/path/to/project")
assert.Contains(t, result, "project_")
assert.NotEmpty(t, result)
}
// TestPOST_EmptyBody tests POST with empty body.
func TestPOST_EmptyBody(t *testing.T) {
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
assert.Equal(t, http.MethodPost, r.Method)
w.WriteHeader(http.StatusOK)
json.NewEncoder(w).Encode(map[string]string{"status": "ok"})
}))
defer server.Close()
var port int
_, err := fmt.Sscanf(server.URL, "http://127.0.0.1:%d", &port)
require.NoError(t, err)
result, err := POST(port, "/test", map[string]string{})
require.NoError(t, err)
assert.NotNil(t, result)
}
// TestGET_WithQueryParams tests GET with query parameters.
func TestGET_WithQueryParams(t *testing.T) {
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
assert.Equal(t, http.MethodGet, r.Method)
assert.Equal(t, "/test?foo=bar", r.URL.String())
w.WriteHeader(http.StatusOK)
json.NewEncoder(w).Encode(map[string]string{"status": "ok"})
}))
defer server.Close()
var port int
_, err := fmt.Sscanf(server.URL, "http://127.0.0.1:%d", &port)
require.NoError(t, err)
result, err := GET(port, "/test?foo=bar")
require.NoError(t, err)
assert.NotNil(t, result)
}
// TestHookResponse_RoundTrip tests JSON marshal/unmarshal round-trip.
func TestHookResponse_RoundTrip(t *testing.T) {
original := HookResponse{Continue: true}
data, err := json.Marshal(original)
require.NoError(t, err)
var decoded HookResponse
err = json.Unmarshal(data, &decoded)
require.NoError(t, err)
assert.Equal(t, original.Continue, decoded.Continue)
}
// TestBaseInput_RoundTrip tests BaseInput JSON round-trip.
func TestBaseInput_RoundTrip(t *testing.T) {
original := BaseInput{
SessionID: "test-session",
CWD: "/home/user/project",
PermissionMode: "standard",
HookEventName: "session-start",
}
data, err := json.Marshal(original)
require.NoError(t, err)
var decoded BaseInput
err = json.Unmarshal(data, &decoded)
require.NoError(t, err)
assert.Equal(t, original.SessionID, decoded.SessionID)
assert.Equal(t, original.CWD, decoded.CWD)
assert.Equal(t, original.PermissionMode, decoded.PermissionMode)
assert.Equal(t, original.HookEventName, decoded.HookEventName)
}
// TestHookContext_RawInput tests HookContext with different raw input types.
func TestHookContext_RawInput(t *testing.T) {
tests := []struct {
name string
rawInput []byte
}{
{
name: "json object",
rawInput: []byte(`{"key":"value"}`),
},
{
name: "json array",
rawInput: []byte(`[1,2,3]`),
},
{
name: "empty object",
rawInput: []byte(`{}`),
},
{
name: "nil input",
rawInput: nil,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
ctx := HookContext{
HookName: "test",
RawInput: tt.rawInput,
}
assert.Equal(t, tt.rawInput, ctx.RawInput)
})
}
}
// TestDefaultWorkerPort tests that the default port constant is valid.
func TestDefaultWorkerPort(t *testing.T) {
assert.Greater(t, DefaultWorkerPort, 1024, "Default port should be above privileged port range")
assert.Less(t, DefaultWorkerPort, 65535, "Default port should be valid TCP port")
}
// TestHealthCheckTimeout tests the health check timeout is reasonable.
func TestHealthCheckTimeout(t *testing.T) {
assert.Greater(t, HealthCheckTimeout, 100*time.Millisecond)
assert.Less(t, HealthCheckTimeout, 10*time.Second)
}
// TestStartupTimeout tests the startup timeout is reasonable.
func TestStartupTimeout(t *testing.T) {
assert.Greater(t, StartupTimeout, 5*time.Second)
assert.LessOrEqual(t, StartupTimeout, time.Minute)
}