mirror of
https://github.com/lukaszraczylo/claude-mnemonic.git
synced 2026-06-05 23:03:55 +00:00
Increase test coverage to 45.6%
This commit is contained in:
@@ -0,0 +1,254 @@
|
||||
package sqlite
|
||||
|
||||
import (
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func TestNullString(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
input string
|
||||
expected string
|
||||
valid bool
|
||||
}{
|
||||
{"empty_string", "", "", false},
|
||||
{"non_empty_string", "hello", "hello", true},
|
||||
{"whitespace", " ", " ", true},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
result := nullString(tt.input)
|
||||
assert.Equal(t, tt.expected, result.String)
|
||||
assert.Equal(t, tt.valid, result.Valid)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestNullInt(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
input int
|
||||
expected int64
|
||||
valid bool
|
||||
}{
|
||||
{"zero", 0, 0, false},
|
||||
{"positive", 42, 42, true},
|
||||
{"negative", -1, -1, false},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
result := nullInt(tt.input)
|
||||
assert.Equal(t, tt.expected, result.Int64)
|
||||
assert.Equal(t, tt.valid, result.Valid)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestRepeatPlaceholders(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
n int
|
||||
expected string
|
||||
}{
|
||||
{"zero", 0, ""},
|
||||
{"negative", -1, ""},
|
||||
{"one", 1, ", ?"},
|
||||
{"two", 2, ", ?, ?"},
|
||||
{"three", 3, ", ?, ?, ?"},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
result := repeatPlaceholders(tt.n)
|
||||
assert.Equal(t, tt.expected, result)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestInt64SliceToInterface(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
input []int64
|
||||
expected []interface{}
|
||||
}{
|
||||
{"empty", []int64{}, []interface{}{}},
|
||||
{"single", []int64{42}, []interface{}{int64(42)}},
|
||||
{"multiple", []int64{1, 2, 3}, []interface{}{int64(1), int64(2), int64(3)}},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
result := int64SliceToInterface(tt.input)
|
||||
assert.Equal(t, tt.expected, result)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestParseLimitParam(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
query string
|
||||
defaultLimit int
|
||||
expected int
|
||||
}{
|
||||
{"no_param_uses_default", "", 10, 10},
|
||||
{"valid_limit", "limit=20", 10, 20},
|
||||
{"invalid_limit_uses_default", "limit=abc", 10, 10},
|
||||
{"zero_limit_uses_default", "limit=0", 10, 10},
|
||||
{"negative_limit_uses_default", "limit=-5", 10, 10},
|
||||
{"large_limit", "limit=1000", 10, 1000},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
req := httptest.NewRequest("GET", "/?"+tt.query, nil)
|
||||
result := ParseLimitParam(req, tt.defaultLimit)
|
||||
assert.Equal(t, tt.expected, result)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestScanSummary(t *testing.T) {
|
||||
db, _, cleanup := testDB(t)
|
||||
defer cleanup()
|
||||
createBaseTables(t, db)
|
||||
seedSession(t, db, "claude-123", "sdk-123", "test-project")
|
||||
|
||||
// Insert a test summary
|
||||
_, err := db.Exec(`
|
||||
INSERT INTO session_summaries (sdk_session_id, project, request, investigated, learned, completed, next_steps, notes, prompt_number, discovery_tokens, created_at, created_at_epoch)
|
||||
VALUES ('sdk-123', 'test-project', 'test request', 'test investigated', 'test learned', 'test completed', 'test next steps', 'test notes', 1, 100, '2025-01-01T00:00:00Z', 1704067200000)
|
||||
`)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Query and scan
|
||||
row := db.QueryRow(`
|
||||
SELECT id, sdk_session_id, project, request, investigated, learned, completed, next_steps, notes, prompt_number, discovery_tokens, created_at, created_at_epoch
|
||||
FROM session_summaries WHERE sdk_session_id = ?
|
||||
`, "sdk-123")
|
||||
|
||||
summary, err := scanSummary(row)
|
||||
require.NoError(t, err)
|
||||
assert.NotNil(t, summary)
|
||||
assert.Equal(t, "sdk-123", summary.SDKSessionID)
|
||||
assert.Equal(t, "test-project", summary.Project)
|
||||
assert.Equal(t, "test request", summary.Request.String)
|
||||
assert.Equal(t, "test investigated", summary.Investigated.String)
|
||||
assert.Equal(t, "test learned", summary.Learned.String)
|
||||
assert.Equal(t, "test completed", summary.Completed.String)
|
||||
assert.Equal(t, "test next steps", summary.NextSteps.String)
|
||||
assert.Equal(t, "test notes", summary.Notes.String)
|
||||
}
|
||||
|
||||
func TestScanSummaryRows(t *testing.T) {
|
||||
db, _, cleanup := testDB(t)
|
||||
defer cleanup()
|
||||
createBaseTables(t, db)
|
||||
seedSession(t, db, "claude-123", "sdk-123", "test-project")
|
||||
|
||||
// Insert multiple summaries
|
||||
_, err := db.Exec(`
|
||||
INSERT INTO session_summaries (sdk_session_id, project, request, investigated, learned, completed, next_steps, notes, prompt_number, discovery_tokens, created_at, created_at_epoch)
|
||||
VALUES
|
||||
('sdk-123', 'test-project', 'request 1', '', '', '', '', '', 1, 0, '2025-01-01T00:00:00Z', 1704067200000),
|
||||
('sdk-123', 'test-project', 'request 2', '', '', '', '', '', 2, 0, '2025-01-02T00:00:00Z', 1704153600000)
|
||||
`)
|
||||
require.NoError(t, err)
|
||||
|
||||
rows, err := db.Query(`
|
||||
SELECT id, sdk_session_id, project, request, investigated, learned, completed, next_steps, notes, prompt_number, discovery_tokens, created_at, created_at_epoch
|
||||
FROM session_summaries WHERE sdk_session_id = ? ORDER BY id
|
||||
`, "sdk-123")
|
||||
require.NoError(t, err)
|
||||
defer rows.Close()
|
||||
|
||||
summaries, err := scanSummaryRows(rows)
|
||||
require.NoError(t, err)
|
||||
assert.Len(t, summaries, 2)
|
||||
assert.Equal(t, "request 1", summaries[0].Request.String)
|
||||
assert.Equal(t, "request 2", summaries[1].Request.String)
|
||||
}
|
||||
|
||||
func TestScanPromptWithSession(t *testing.T) {
|
||||
db, _, cleanup := testDB(t)
|
||||
defer cleanup()
|
||||
createBaseTables(t, db)
|
||||
seedSession(t, db, "claude-123", "sdk-123", "test-project")
|
||||
|
||||
// Insert a test prompt
|
||||
_, err := db.Exec(`
|
||||
INSERT INTO user_prompts (claude_session_id, prompt_number, prompt_text, matched_observations, created_at, created_at_epoch)
|
||||
VALUES ('claude-123', 1, 'test prompt', 5, '2025-01-01T00:00:00Z', 1704067200000)
|
||||
`)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Query with session join
|
||||
row := db.QueryRow(`
|
||||
SELECT p.id, p.claude_session_id, p.prompt_number, p.prompt_text, p.matched_observations, p.created_at, p.created_at_epoch, s.project, s.sdk_session_id
|
||||
FROM user_prompts p
|
||||
JOIN sdk_sessions s ON p.claude_session_id = s.claude_session_id
|
||||
WHERE p.claude_session_id = ?
|
||||
`, "claude-123")
|
||||
|
||||
prompt, err := scanPromptWithSession(row)
|
||||
require.NoError(t, err)
|
||||
assert.NotNil(t, prompt)
|
||||
assert.Equal(t, "claude-123", prompt.ClaudeSessionID)
|
||||
assert.Equal(t, 1, prompt.PromptNumber)
|
||||
assert.Equal(t, "test prompt", prompt.PromptText)
|
||||
assert.Equal(t, 5, prompt.MatchedObservations)
|
||||
assert.Equal(t, "test-project", prompt.Project)
|
||||
assert.Equal(t, "sdk-123", prompt.SDKSessionID)
|
||||
}
|
||||
|
||||
func TestScanPromptWithSessionRows(t *testing.T) {
|
||||
db, _, cleanup := testDB(t)
|
||||
defer cleanup()
|
||||
createBaseTables(t, db)
|
||||
seedSession(t, db, "claude-123", "sdk-123", "test-project")
|
||||
|
||||
// Insert multiple prompts
|
||||
_, err := db.Exec(`
|
||||
INSERT INTO user_prompts (claude_session_id, prompt_number, prompt_text, matched_observations, created_at, created_at_epoch)
|
||||
VALUES
|
||||
('claude-123', 1, 'prompt one', 3, '2025-01-01T00:00:00Z', 1704067200000),
|
||||
('claude-123', 2, 'prompt two', 5, '2025-01-02T00:00:00Z', 1704153600000)
|
||||
`)
|
||||
require.NoError(t, err)
|
||||
|
||||
rows, err := db.Query(`
|
||||
SELECT p.id, p.claude_session_id, p.prompt_number, p.prompt_text, p.matched_observations, p.created_at, p.created_at_epoch, s.project, s.sdk_session_id
|
||||
FROM user_prompts p
|
||||
JOIN sdk_sessions s ON p.claude_session_id = s.claude_session_id
|
||||
WHERE p.claude_session_id = ? ORDER BY p.id
|
||||
`, "claude-123")
|
||||
require.NoError(t, err)
|
||||
defer rows.Close()
|
||||
|
||||
prompts, err := scanPromptWithSessionRows(rows)
|
||||
require.NoError(t, err)
|
||||
assert.Len(t, prompts, 2)
|
||||
assert.Equal(t, "prompt one", prompts[0].PromptText)
|
||||
assert.Equal(t, "prompt two", prompts[1].PromptText)
|
||||
}
|
||||
|
||||
func TestParseLimitParam_HTTPRequest(t *testing.T) {
|
||||
// Test with an actual HTTP request
|
||||
handler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
limit := ParseLimitParam(r, 25)
|
||||
if limit != 50 {
|
||||
t.Errorf("Expected limit 50, got %d", limit)
|
||||
}
|
||||
})
|
||||
|
||||
req := httptest.NewRequest("GET", "http://example.com/api?limit=50", nil)
|
||||
w := httptest.NewRecorder()
|
||||
handler.ServeHTTP(w, req)
|
||||
}
|
||||
@@ -0,0 +1,196 @@
|
||||
package sqlite
|
||||
|
||||
import (
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func TestNewMigrationManager(t *testing.T) {
|
||||
db, _, cleanup := testDB(t)
|
||||
defer cleanup()
|
||||
|
||||
manager := NewMigrationManager(db)
|
||||
require.NotNil(t, manager)
|
||||
assert.Equal(t, db, manager.db)
|
||||
}
|
||||
|
||||
func TestMigrationManager_EnsureSchemaVersionsTable(t *testing.T) {
|
||||
db, _, cleanup := testDB(t)
|
||||
defer cleanup()
|
||||
|
||||
manager := NewMigrationManager(db)
|
||||
|
||||
// Should create table without error
|
||||
err := manager.EnsureSchemaVersionsTable()
|
||||
require.NoError(t, err)
|
||||
|
||||
// Table should exist
|
||||
var count int
|
||||
err = db.QueryRow("SELECT COUNT(*) FROM schema_versions").Scan(&count)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, 0, count) // Empty table
|
||||
|
||||
// Calling again should not error (IF NOT EXISTS)
|
||||
err = manager.EnsureSchemaVersionsTable()
|
||||
require.NoError(t, err)
|
||||
}
|
||||
|
||||
func TestMigrationManager_GetAppliedVersions_Empty(t *testing.T) {
|
||||
db, _, cleanup := testDB(t)
|
||||
defer cleanup()
|
||||
|
||||
manager := NewMigrationManager(db)
|
||||
err := manager.EnsureSchemaVersionsTable()
|
||||
require.NoError(t, err)
|
||||
|
||||
versions, err := manager.GetAppliedVersions()
|
||||
require.NoError(t, err)
|
||||
assert.Empty(t, versions)
|
||||
}
|
||||
|
||||
func TestMigrationManager_GetAppliedVersions_WithVersions(t *testing.T) {
|
||||
db, _, cleanup := testDB(t)
|
||||
defer cleanup()
|
||||
|
||||
manager := NewMigrationManager(db)
|
||||
err := manager.EnsureSchemaVersionsTable()
|
||||
require.NoError(t, err)
|
||||
|
||||
// Insert some versions
|
||||
_, err = db.Exec("INSERT INTO schema_versions (version, applied_at) VALUES (1, '2025-01-01T00:00:00Z')")
|
||||
require.NoError(t, err)
|
||||
_, err = db.Exec("INSERT INTO schema_versions (version, applied_at) VALUES (2, '2025-01-02T00:00:00Z')")
|
||||
require.NoError(t, err)
|
||||
|
||||
versions, err := manager.GetAppliedVersions()
|
||||
require.NoError(t, err)
|
||||
assert.Len(t, versions, 2)
|
||||
assert.True(t, versions[1])
|
||||
assert.True(t, versions[2])
|
||||
assert.False(t, versions[3]) // Not applied
|
||||
}
|
||||
|
||||
func TestMigrationManager_ApplyMigration(t *testing.T) {
|
||||
db, _, cleanup := testDB(t)
|
||||
defer cleanup()
|
||||
|
||||
manager := NewMigrationManager(db)
|
||||
err := manager.EnsureSchemaVersionsTable()
|
||||
require.NoError(t, err)
|
||||
|
||||
// Apply a simple migration
|
||||
migration := Migration{
|
||||
Version: 100,
|
||||
Name: "test_migration",
|
||||
SQL: "CREATE TABLE test_table (id INTEGER PRIMARY KEY, name TEXT)",
|
||||
}
|
||||
|
||||
err = manager.ApplyMigration(migration)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Verify table was created
|
||||
var count int
|
||||
err = db.QueryRow("SELECT COUNT(*) FROM sqlite_master WHERE type='table' AND name='test_table'").Scan(&count)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, 1, count)
|
||||
|
||||
// Verify migration was recorded
|
||||
var version int
|
||||
err = db.QueryRow("SELECT version FROM schema_versions WHERE version = 100").Scan(&version)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, 100, version)
|
||||
}
|
||||
|
||||
func TestMigrationManager_ApplyMigration_InvalidSQL(t *testing.T) {
|
||||
db, _, cleanup := testDB(t)
|
||||
defer cleanup()
|
||||
|
||||
manager := NewMigrationManager(db)
|
||||
err := manager.EnsureSchemaVersionsTable()
|
||||
require.NoError(t, err)
|
||||
|
||||
// Try to apply invalid migration
|
||||
migration := Migration{
|
||||
Version: 100,
|
||||
Name: "invalid_migration",
|
||||
SQL: "INVALID SQL SYNTAX",
|
||||
}
|
||||
|
||||
err = manager.ApplyMigration(migration)
|
||||
assert.Error(t, err)
|
||||
assert.Contains(t, err.Error(), "execute migration 100")
|
||||
}
|
||||
|
||||
func TestMigrationManager_RunMigrations_SingleMigration(t *testing.T) {
|
||||
db, _, cleanup := testDB(t)
|
||||
defer cleanup()
|
||||
|
||||
// Create a test migration manager with a subset of migrations
|
||||
manager := NewMigrationManager(db)
|
||||
|
||||
// First ensure schema versions table exists
|
||||
err := manager.EnsureSchemaVersionsTable()
|
||||
require.NoError(t, err)
|
||||
|
||||
// Apply first migration manually
|
||||
err = manager.ApplyMigration(Migrations[0])
|
||||
require.NoError(t, err)
|
||||
|
||||
// Verify the first migration version was recorded
|
||||
versions, err := manager.GetAppliedVersions()
|
||||
require.NoError(t, err)
|
||||
assert.True(t, versions[Migrations[0].Version])
|
||||
}
|
||||
|
||||
func TestMigrationManager_RunMigrations_SkipsApplied(t *testing.T) {
|
||||
db, _, cleanup := testDB(t)
|
||||
defer cleanup()
|
||||
|
||||
manager := NewMigrationManager(db)
|
||||
err := manager.EnsureSchemaVersionsTable()
|
||||
require.NoError(t, err)
|
||||
|
||||
// Mark some migrations as already applied
|
||||
_, err = db.Exec("INSERT INTO schema_versions (version, applied_at) VALUES (4, '2025-01-01T00:00:00Z')")
|
||||
require.NoError(t, err)
|
||||
|
||||
// Get applied versions
|
||||
versions, err := manager.GetAppliedVersions()
|
||||
require.NoError(t, err)
|
||||
assert.True(t, versions[4])
|
||||
}
|
||||
|
||||
func TestMigration_Struct(t *testing.T) {
|
||||
m := Migration{
|
||||
Version: 1,
|
||||
Name: "test",
|
||||
SQL: "SELECT 1",
|
||||
}
|
||||
|
||||
assert.Equal(t, 1, m.Version)
|
||||
assert.Equal(t, "test", m.Name)
|
||||
assert.Equal(t, "SELECT 1", m.SQL)
|
||||
}
|
||||
|
||||
func TestMigrations_List(t *testing.T) {
|
||||
// Verify migrations are ordered correctly
|
||||
assert.NotEmpty(t, Migrations)
|
||||
|
||||
// Verify all migrations have required fields
|
||||
for i, m := range Migrations {
|
||||
assert.Greater(t, m.Version, 0, "Migration %d has invalid version", i)
|
||||
assert.NotEmpty(t, m.Name, "Migration %d has empty name", i)
|
||||
assert.NotEmpty(t, m.SQL, "Migration %d has empty SQL", i)
|
||||
}
|
||||
|
||||
// Verify key migrations exist
|
||||
versionSet := make(map[int]bool)
|
||||
for _, m := range Migrations {
|
||||
versionSet[m.Version] = true
|
||||
}
|
||||
|
||||
assert.True(t, versionSet[4], "Should have sdk_agent_architecture migration")
|
||||
assert.True(t, versionSet[17], "Should have sqlite_vec_vectors migration")
|
||||
}
|
||||
@@ -800,3 +800,148 @@ func TestExtractKeywords(t *testing.T) {
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestObservationStore_GetObservationsByProjectStrict(t *testing.T) {
|
||||
obsStore, _, cleanup := testObservationStore(t)
|
||||
defer cleanup()
|
||||
|
||||
ctx := context.Background()
|
||||
|
||||
// Create project-scoped observation for project-a
|
||||
projectObs := &models.ParsedObservation{
|
||||
Type: models.ObsTypeDiscovery,
|
||||
Title: "Project A specific",
|
||||
Narrative: "Only for project-a",
|
||||
Concepts: []string{"local-concept"},
|
||||
}
|
||||
_, _, err := obsStore.StoreObservation(ctx, "session-1", "project-a", projectObs, 1, 100)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Create global observation from project-a
|
||||
globalObs := &models.ParsedObservation{
|
||||
Type: models.ObsTypeDiscovery,
|
||||
Title: "Global security practice",
|
||||
Narrative: "Best practice for all",
|
||||
Concepts: []string{"security", "best-practice"},
|
||||
}
|
||||
_, _, err = obsStore.StoreObservation(ctx, "session-1", "project-a", globalObs, 2, 100)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Create observation for project-b
|
||||
projectBObs := &models.ParsedObservation{
|
||||
Type: models.ObsTypeDiscovery,
|
||||
Title: "Project B specific",
|
||||
Narrative: "Only for project-b",
|
||||
}
|
||||
_, _, err = obsStore.StoreObservation(ctx, "session-1", "project-b", projectBObs, 1, 100)
|
||||
require.NoError(t, err)
|
||||
|
||||
// GetObservationsByProjectStrict for project-a should only return project-a observations
|
||||
// This is different from GetRecentObservations which includes globals from other projects
|
||||
results, err := obsStore.GetObservationsByProjectStrict(ctx, "project-a", 10)
|
||||
require.NoError(t, err)
|
||||
assert.Len(t, results, 2) // Only observations created in project-a
|
||||
|
||||
// Verify both are from project-a
|
||||
for _, obs := range results {
|
||||
assert.Equal(t, "project-a", obs.Project)
|
||||
}
|
||||
|
||||
// GetObservationsByProjectStrict for project-b should only return project-b observations
|
||||
results, err = obsStore.GetObservationsByProjectStrict(ctx, "project-b", 10)
|
||||
require.NoError(t, err)
|
||||
assert.Len(t, results, 1)
|
||||
assert.Equal(t, "Project B specific", results[0].Title.String)
|
||||
}
|
||||
|
||||
func TestObservationStore_SearchObservationsFTS_EmptyQuery(t *testing.T) {
|
||||
obsStore, _, cleanup := testObservationStore(t)
|
||||
defer cleanup()
|
||||
|
||||
ctx := context.Background()
|
||||
|
||||
// Create an observation
|
||||
obs := &models.ParsedObservation{
|
||||
Type: models.ObsTypeDiscovery,
|
||||
Title: "Test observation",
|
||||
Narrative: "Some content here",
|
||||
}
|
||||
_, _, err := obsStore.StoreObservation(ctx, "session-1", "project-a", obs, 1, 100)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Search with only stop words (should return nil)
|
||||
results, err := obsStore.SearchObservationsFTS(ctx, "the a an is are", "project-a", 10)
|
||||
require.NoError(t, err)
|
||||
assert.Nil(t, results)
|
||||
|
||||
// Search with empty query
|
||||
results, err = obsStore.SearchObservationsFTS(ctx, "", "project-a", 10)
|
||||
require.NoError(t, err)
|
||||
assert.Nil(t, results)
|
||||
}
|
||||
|
||||
func TestObservationStore_SearchObservationsFTS_DefaultLimit(t *testing.T) {
|
||||
obsStore, _, cleanup := testObservationStore(t)
|
||||
defer cleanup()
|
||||
|
||||
ctx := context.Background()
|
||||
|
||||
// Create observations
|
||||
for i := 0; i < 15; i++ {
|
||||
obs := &models.ParsedObservation{
|
||||
Type: models.ObsTypeDiscovery,
|
||||
Title: "Authentication test " + string(rune('A'+i)),
|
||||
Narrative: "Auth related content",
|
||||
}
|
||||
_, _, err := obsStore.StoreObservation(ctx, "session-1", "project-a", obs, i+1, 100)
|
||||
require.NoError(t, err)
|
||||
time.Sleep(time.Millisecond)
|
||||
}
|
||||
|
||||
// Search with limit 0 (should default to 10)
|
||||
results, err := obsStore.SearchObservationsFTS(ctx, "authentication", "project-a", 0)
|
||||
require.NoError(t, err)
|
||||
assert.LessOrEqual(t, len(results), 10)
|
||||
|
||||
// Search with negative limit (should default to 10)
|
||||
results, err = obsStore.SearchObservationsFTS(ctx, "authentication", "project-a", -5)
|
||||
require.NoError(t, err)
|
||||
assert.LessOrEqual(t, len(results), 10)
|
||||
}
|
||||
|
||||
func TestObservationStore_GetAllRecentObservations(t *testing.T) {
|
||||
obsStore, _, cleanup := testObservationStore(t)
|
||||
defer cleanup()
|
||||
|
||||
ctx := context.Background()
|
||||
|
||||
// Create observations across different projects
|
||||
projects := []string{"project-a", "project-b", "project-c"}
|
||||
for _, proj := range projects {
|
||||
for i := 0; i < 3; i++ {
|
||||
obs := &models.ParsedObservation{
|
||||
Type: models.ObsTypeDiscovery,
|
||||
Title: proj + " observation " + string(rune('A'+i)),
|
||||
Narrative: "Content for " + proj,
|
||||
}
|
||||
_, _, err := obsStore.StoreObservation(ctx, "session-1", proj, obs, i+1, 100)
|
||||
require.NoError(t, err)
|
||||
time.Sleep(time.Millisecond)
|
||||
}
|
||||
}
|
||||
|
||||
// Get all recent observations
|
||||
results, err := obsStore.GetAllRecentObservations(ctx, 100)
|
||||
require.NoError(t, err)
|
||||
assert.Len(t, results, 9) // 3 projects * 3 observations
|
||||
|
||||
// Verify they are in descending order by epoch
|
||||
for i := 1; i < len(results); i++ {
|
||||
assert.GreaterOrEqual(t, results[i-1].CreatedAtEpoch, results[i].CreatedAtEpoch)
|
||||
}
|
||||
|
||||
// Test with limit
|
||||
results, err = obsStore.GetAllRecentObservations(ctx, 5)
|
||||
require.NoError(t, err)
|
||||
assert.Len(t, results, 5)
|
||||
}
|
||||
|
||||
@@ -0,0 +1,332 @@
|
||||
package embedding
|
||||
|
||||
import (
|
||||
"math"
|
||||
"sync"
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
// TestEmbeddingDim verifies the embedding dimension constant.
|
||||
func TestEmbeddingDim(t *testing.T) {
|
||||
assert.Equal(t, 384, EmbeddingDim)
|
||||
}
|
||||
|
||||
// TestNewService tests creating a new embedding service.
|
||||
func TestNewService(t *testing.T) {
|
||||
svc, err := NewService()
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, svc)
|
||||
|
||||
defer svc.Close()
|
||||
|
||||
assert.NotNil(t, svc.tk)
|
||||
assert.NotNil(t, svc.session)
|
||||
}
|
||||
|
||||
// TestEmbed_SingleText tests embedding a single text.
|
||||
func TestEmbed_SingleText(t *testing.T) {
|
||||
svc, err := NewService()
|
||||
require.NoError(t, err)
|
||||
defer svc.Close()
|
||||
|
||||
embedding, err := svc.Embed("Hello, world!")
|
||||
require.NoError(t, err)
|
||||
|
||||
assert.Len(t, embedding, EmbeddingDim)
|
||||
|
||||
// Verify non-zero embedding
|
||||
var sum float32
|
||||
for _, v := range embedding {
|
||||
sum += v * v
|
||||
}
|
||||
assert.Greater(t, sum, float32(0), "Embedding should not be all zeros")
|
||||
}
|
||||
|
||||
// TestEmbed_EmptyText tests embedding an empty string.
|
||||
func TestEmbed_EmptyText(t *testing.T) {
|
||||
svc, err := NewService()
|
||||
require.NoError(t, err)
|
||||
defer svc.Close()
|
||||
|
||||
embedding, err := svc.Embed("")
|
||||
require.NoError(t, err)
|
||||
|
||||
assert.Len(t, embedding, EmbeddingDim)
|
||||
|
||||
// Empty text should return zero vector
|
||||
for _, v := range embedding {
|
||||
assert.Equal(t, float32(0), v)
|
||||
}
|
||||
}
|
||||
|
||||
// TestEmbed_SimilarTexts tests that similar texts produce similar embeddings.
|
||||
func TestEmbed_SimilarTexts(t *testing.T) {
|
||||
svc, err := NewService()
|
||||
require.NoError(t, err)
|
||||
defer svc.Close()
|
||||
|
||||
emb1, err := svc.Embed("The quick brown fox jumps over the lazy dog.")
|
||||
require.NoError(t, err)
|
||||
|
||||
emb2, err := svc.Embed("A fast brown fox leaps over a sleepy dog.")
|
||||
require.NoError(t, err)
|
||||
|
||||
emb3, err := svc.Embed("Go programming language concurrency patterns.")
|
||||
require.NoError(t, err)
|
||||
|
||||
// Calculate cosine similarity
|
||||
sim12 := cosineSimilarity(emb1, emb2)
|
||||
sim13 := cosineSimilarity(emb1, emb3)
|
||||
|
||||
// Similar texts should have higher similarity
|
||||
assert.Greater(t, sim12, sim13, "Similar sentences should have higher similarity than dissimilar ones")
|
||||
assert.Greater(t, sim12, float64(0.7), "Similar sentences should have high similarity")
|
||||
}
|
||||
|
||||
// TestEmbedBatch_MultipleTexts tests batch embedding.
|
||||
func TestEmbedBatch_MultipleTexts(t *testing.T) {
|
||||
svc, err := NewService()
|
||||
require.NoError(t, err)
|
||||
defer svc.Close()
|
||||
|
||||
texts := []string{
|
||||
"First text about programming.",
|
||||
"Second text about databases.",
|
||||
"Third text about machine learning.",
|
||||
}
|
||||
|
||||
embeddings, err := svc.EmbedBatch(texts)
|
||||
require.NoError(t, err)
|
||||
|
||||
assert.Len(t, embeddings, len(texts))
|
||||
for i, emb := range embeddings {
|
||||
assert.Len(t, emb, EmbeddingDim, "Embedding %d should have correct dimension", i)
|
||||
}
|
||||
}
|
||||
|
||||
// TestEmbedBatch_EmptySlice tests batch embedding with empty slice.
|
||||
func TestEmbedBatch_EmptySlice(t *testing.T) {
|
||||
svc, err := NewService()
|
||||
require.NoError(t, err)
|
||||
defer svc.Close()
|
||||
|
||||
embeddings, err := svc.EmbedBatch([]string{})
|
||||
require.NoError(t, err)
|
||||
assert.Nil(t, embeddings)
|
||||
}
|
||||
|
||||
// TestEmbedBatch_WithEmptyTexts tests batch embedding with some empty texts.
|
||||
func TestEmbedBatch_WithEmptyTexts(t *testing.T) {
|
||||
svc, err := NewService()
|
||||
require.NoError(t, err)
|
||||
defer svc.Close()
|
||||
|
||||
texts := []string{
|
||||
"Valid text one.",
|
||||
"",
|
||||
"Valid text two.",
|
||||
"",
|
||||
}
|
||||
|
||||
embeddings, err := svc.EmbedBatch(texts)
|
||||
require.NoError(t, err)
|
||||
|
||||
assert.Len(t, embeddings, 4)
|
||||
|
||||
// Non-empty texts should have non-zero embeddings
|
||||
var sum0 float32
|
||||
for _, v := range embeddings[0] {
|
||||
sum0 += v * v
|
||||
}
|
||||
assert.Greater(t, sum0, float32(0))
|
||||
|
||||
// Empty texts should have zero embeddings
|
||||
for _, v := range embeddings[1] {
|
||||
assert.Equal(t, float32(0), v)
|
||||
}
|
||||
|
||||
var sum2 float32
|
||||
for _, v := range embeddings[2] {
|
||||
sum2 += v * v
|
||||
}
|
||||
assert.Greater(t, sum2, float32(0))
|
||||
|
||||
for _, v := range embeddings[3] {
|
||||
assert.Equal(t, float32(0), v)
|
||||
}
|
||||
}
|
||||
|
||||
// TestEmbedBatch_AllEmpty tests batch embedding with all empty texts.
|
||||
func TestEmbedBatch_AllEmpty(t *testing.T) {
|
||||
svc, err := NewService()
|
||||
require.NoError(t, err)
|
||||
defer svc.Close()
|
||||
|
||||
texts := []string{"", "", ""}
|
||||
|
||||
embeddings, err := svc.EmbedBatch(texts)
|
||||
require.NoError(t, err)
|
||||
|
||||
assert.Len(t, embeddings, 3)
|
||||
for i, emb := range embeddings {
|
||||
assert.Len(t, emb, EmbeddingDim, "Embedding %d should have correct dimension", i)
|
||||
for j, v := range emb {
|
||||
assert.Equal(t, float32(0), v, "Embedding %d[%d] should be zero", i, j)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// TestEmbed_Concurrent tests concurrent embedding calls.
|
||||
func TestEmbed_Concurrent(t *testing.T) {
|
||||
svc, err := NewService()
|
||||
require.NoError(t, err)
|
||||
defer svc.Close()
|
||||
|
||||
var wg sync.WaitGroup
|
||||
numGoroutines := 10
|
||||
|
||||
errors := make(chan error, numGoroutines)
|
||||
embeddings := make([][]float32, numGoroutines)
|
||||
|
||||
for i := 0; i < numGoroutines; i++ {
|
||||
wg.Add(1)
|
||||
go func(idx int) {
|
||||
defer wg.Done()
|
||||
|
||||
text := "Test text for concurrent embedding test"
|
||||
emb, err := svc.Embed(text)
|
||||
if err != nil {
|
||||
errors <- err
|
||||
return
|
||||
}
|
||||
embeddings[idx] = emb
|
||||
}(i)
|
||||
}
|
||||
|
||||
wg.Wait()
|
||||
close(errors)
|
||||
|
||||
for err := range errors {
|
||||
t.Errorf("Concurrent embedding error: %v", err)
|
||||
}
|
||||
|
||||
// All embeddings should be valid
|
||||
for i, emb := range embeddings {
|
||||
if emb != nil {
|
||||
assert.Len(t, emb, EmbeddingDim, "Embedding %d should have correct dimension", i)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// TestEmbed_SpecialCharacters tests embedding text with special characters.
|
||||
func TestEmbed_SpecialCharacters(t *testing.T) {
|
||||
svc, err := NewService()
|
||||
require.NoError(t, err)
|
||||
defer svc.Close()
|
||||
|
||||
texts := []string{
|
||||
"Text with unicode: 你好世界 🎉",
|
||||
"Text with newlines:\nLine 1\nLine 2",
|
||||
"Text with tabs:\tColumn1\tColumn2",
|
||||
"Text with quotes: \"quoted\" and 'single'",
|
||||
"Text with code: func main() { fmt.Println(\"hello\") }",
|
||||
}
|
||||
|
||||
for _, text := range texts {
|
||||
t.Run(text[:20], func(t *testing.T) {
|
||||
emb, err := svc.Embed(text)
|
||||
require.NoError(t, err)
|
||||
assert.Len(t, emb, EmbeddingDim)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// TestEmbed_LongText tests embedding long text.
|
||||
func TestEmbed_LongText(t *testing.T) {
|
||||
svc, err := NewService()
|
||||
require.NoError(t, err)
|
||||
defer svc.Close()
|
||||
|
||||
// Create a long text (tokenizer should truncate appropriately)
|
||||
longText := ""
|
||||
for i := 0; i < 100; i++ {
|
||||
longText += "This is a sentence to make the text very long. "
|
||||
}
|
||||
|
||||
emb, err := svc.Embed(longText)
|
||||
require.NoError(t, err)
|
||||
assert.Len(t, emb, EmbeddingDim)
|
||||
}
|
||||
|
||||
// TestClose tests closing the service.
|
||||
func TestClose(t *testing.T) {
|
||||
svc, err := NewService()
|
||||
require.NoError(t, err)
|
||||
|
||||
err = svc.Close()
|
||||
require.NoError(t, err)
|
||||
|
||||
// Session should be nil after close
|
||||
assert.Nil(t, svc.session)
|
||||
}
|
||||
|
||||
// TestEmbedBatch_SingleItem tests batch embedding with single item.
|
||||
func TestEmbedBatch_SingleItem(t *testing.T) {
|
||||
svc, err := NewService()
|
||||
require.NoError(t, err)
|
||||
defer svc.Close()
|
||||
|
||||
texts := []string{"Single text for batch embedding."}
|
||||
|
||||
embeddings, err := svc.EmbedBatch(texts)
|
||||
require.NoError(t, err)
|
||||
|
||||
assert.Len(t, embeddings, 1)
|
||||
assert.Len(t, embeddings[0], EmbeddingDim)
|
||||
}
|
||||
|
||||
// TestEmbed_Deterministic tests that embedding is deterministic.
|
||||
func TestEmbed_Deterministic(t *testing.T) {
|
||||
svc, err := NewService()
|
||||
require.NoError(t, err)
|
||||
defer svc.Close()
|
||||
|
||||
text := "Test text for deterministic embedding."
|
||||
|
||||
emb1, err := svc.Embed(text)
|
||||
require.NoError(t, err)
|
||||
|
||||
emb2, err := svc.Embed(text)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Same text should produce same embedding
|
||||
for i := 0; i < EmbeddingDim; i++ {
|
||||
assert.Equal(t, emb1[i], emb2[i], "Embedding should be deterministic at index %d", i)
|
||||
}
|
||||
}
|
||||
|
||||
// Helper function to calculate cosine similarity
|
||||
func cosineSimilarity(a, b []float32) float64 {
|
||||
if len(a) != len(b) {
|
||||
return 0
|
||||
}
|
||||
|
||||
var dotProduct float64
|
||||
var normA float64
|
||||
var normB float64
|
||||
|
||||
for i := range a {
|
||||
dotProduct += float64(a[i]) * float64(b[i])
|
||||
normA += float64(a[i]) * float64(a[i])
|
||||
normB += float64(b[i]) * float64(b[i])
|
||||
}
|
||||
|
||||
if normA == 0 || normB == 0 {
|
||||
return 0
|
||||
}
|
||||
|
||||
return dotProduct / (math.Sqrt(normA) * math.Sqrt(normB))
|
||||
}
|
||||
@@ -597,3 +597,363 @@ func TestToolListContainsExpectedSchemas(t *testing.T) {
|
||||
assert.True(t, hasType, "tool %s schema should have type", tool.Name)
|
||||
}
|
||||
}
|
||||
|
||||
// TestHandleToolsCall_UnknownTool tests tools/call with unknown tool name.
|
||||
func TestHandleToolsCall_UnknownTool(t *testing.T) {
|
||||
server := NewServer(nil, "1.0.0")
|
||||
ctx := context.Background()
|
||||
|
||||
req := &Request{
|
||||
JSONRPC: "2.0",
|
||||
ID: 1,
|
||||
Method: "tools/call",
|
||||
Params: json.RawMessage(`{"name":"unknown_tool","arguments":{}}`),
|
||||
}
|
||||
|
||||
resp := server.handleToolsCall(ctx, req)
|
||||
require.NotNil(t, resp.Error)
|
||||
assert.Equal(t, -32000, resp.Error.Code)
|
||||
assert.Contains(t, resp.Error.Data, "unknown tool")
|
||||
}
|
||||
|
||||
// TestCallTool_ToolNameRecognition tests that valid tool names are recognized (not "unknown tool").
|
||||
func TestCallTool_ToolNameRecognition(t *testing.T) {
|
||||
// Note: This test verifies tool routing logic, not execution (which requires searchMgr)
|
||||
// All valid tool names should be in the handleToolsList response
|
||||
server := NewServer(nil, "1.0.0")
|
||||
|
||||
req := &Request{
|
||||
JSONRPC: "2.0",
|
||||
ID: 1,
|
||||
Method: "tools/list",
|
||||
}
|
||||
|
||||
resp := server.handleToolsList(req)
|
||||
result := resp.Result.(map[string]any)
|
||||
tools := result["tools"].([]Tool)
|
||||
|
||||
// Verify all expected tools are registered
|
||||
expectedTools := map[string]bool{
|
||||
"search": true,
|
||||
"timeline": true,
|
||||
"decisions": true,
|
||||
"changes": true,
|
||||
"how_it_works": true,
|
||||
"find_by_concept": true,
|
||||
"find_by_file": true,
|
||||
"find_by_type": true,
|
||||
"get_recent_context": true,
|
||||
"get_context_timeline": true,
|
||||
"get_timeline_by_query": true,
|
||||
}
|
||||
|
||||
foundTools := make(map[string]bool)
|
||||
for _, tool := range tools {
|
||||
foundTools[tool.Name] = true
|
||||
}
|
||||
|
||||
for name := range expectedTools {
|
||||
assert.True(t, foundTools[name], "tool %s should be registered", name)
|
||||
}
|
||||
}
|
||||
|
||||
// TestRun_MultipleRequests tests Run with multiple sequential requests.
|
||||
func TestRun_MultipleRequests(t *testing.T) {
|
||||
var stdout bytes.Buffer
|
||||
req1 := `{"jsonrpc":"2.0","id":1,"method":"initialize"}`
|
||||
req2 := `{"jsonrpc":"2.0","id":2,"method":"tools/list"}`
|
||||
stdin := strings.NewReader(req1 + "\n" + req2 + "\n")
|
||||
|
||||
server := &Server{
|
||||
stdin: stdin,
|
||||
stdout: &stdout,
|
||||
version: "1.0.0",
|
||||
}
|
||||
|
||||
err := server.Run(context.Background())
|
||||
require.NoError(t, err)
|
||||
|
||||
output := stdout.String()
|
||||
// Should contain responses for both requests
|
||||
assert.Contains(t, output, `"id":1`)
|
||||
assert.Contains(t, output, `"id":2`)
|
||||
}
|
||||
|
||||
// TestHandleTimeline_Defaults tests timeline default values.
|
||||
func TestHandleTimeline_Defaults(t *testing.T) {
|
||||
// Test that handleTimeline sets default before/after values
|
||||
params := TimelineParams{
|
||||
AnchorID: 0,
|
||||
Query: "",
|
||||
Before: 0,
|
||||
After: 0,
|
||||
}
|
||||
|
||||
// Simulate the default value assignment from handleTimeline
|
||||
if params.Before <= 0 {
|
||||
params.Before = 10
|
||||
}
|
||||
if params.After <= 0 {
|
||||
params.After = 10
|
||||
}
|
||||
|
||||
assert.Equal(t, 10, params.Before)
|
||||
assert.Equal(t, 10, params.After)
|
||||
}
|
||||
|
||||
// TestTimelineParams_Complete tests complete TimelineParams parsing.
|
||||
func TestTimelineParams_Complete(t *testing.T) {
|
||||
input := `{
|
||||
"anchor_id": 100,
|
||||
"query": "test query",
|
||||
"before": 5,
|
||||
"after": 15,
|
||||
"project": "my-project",
|
||||
"obs_type": "bugfix",
|
||||
"concepts": "security,auth",
|
||||
"files": "main.go,handler.go",
|
||||
"dateStart": 1700000000000,
|
||||
"dateEnd": 1700100000000,
|
||||
"format": "full"
|
||||
}`
|
||||
|
||||
var params TimelineParams
|
||||
err := json.Unmarshal([]byte(input), ¶ms)
|
||||
require.NoError(t, err)
|
||||
|
||||
assert.Equal(t, int64(100), params.AnchorID)
|
||||
assert.Equal(t, "test query", params.Query)
|
||||
assert.Equal(t, 5, params.Before)
|
||||
assert.Equal(t, 15, params.After)
|
||||
assert.Equal(t, "my-project", params.Project)
|
||||
assert.Equal(t, "bugfix", params.ObsType)
|
||||
assert.Equal(t, "security,auth", params.Concepts)
|
||||
assert.Equal(t, "main.go,handler.go", params.Files)
|
||||
assert.Equal(t, int64(1700000000000), params.DateStart)
|
||||
assert.Equal(t, int64(1700100000000), params.DateEnd)
|
||||
assert.Equal(t, "full", params.Format)
|
||||
}
|
||||
|
||||
// TestServerStdinStdoutConfig tests that server stdin/stdout can be configured.
|
||||
func TestServerStdinStdoutConfig(t *testing.T) {
|
||||
var stdout bytes.Buffer
|
||||
var stdin bytes.Buffer
|
||||
|
||||
server := &Server{
|
||||
stdin: &stdin,
|
||||
stdout: &stdout,
|
||||
version: "test-version",
|
||||
}
|
||||
|
||||
assert.Equal(t, &stdin, server.stdin)
|
||||
assert.Equal(t, &stdout, server.stdout)
|
||||
assert.Equal(t, "test-version", server.version)
|
||||
}
|
||||
|
||||
// TestResponseIDTypes tests that response IDs can be various types.
|
||||
func TestResponseIDTypes(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
id any
|
||||
}{
|
||||
{"integer id", 1},
|
||||
{"string id", "abc-123"},
|
||||
{"float id", 1.5},
|
||||
{"null id", nil},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
var buf bytes.Buffer
|
||||
server := &Server{stdout: &buf}
|
||||
|
||||
resp := &Response{
|
||||
JSONRPC: "2.0",
|
||||
ID: tt.id,
|
||||
Result: "ok",
|
||||
}
|
||||
|
||||
server.sendResponse(resp)
|
||||
output := buf.String()
|
||||
assert.Contains(t, output, `"jsonrpc":"2.0"`)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// TestHandleTimelineByQuery_EmptyQuery tests timeline by query with empty query.
|
||||
func TestHandleTimelineByQuery_EmptyQuery(t *testing.T) {
|
||||
server := NewServer(nil, "1.0.0")
|
||||
ctx := context.Background()
|
||||
|
||||
// Empty query should error
|
||||
_, err := server.handleTimelineByQuery(ctx, json.RawMessage(`{}`))
|
||||
require.Error(t, err)
|
||||
assert.Contains(t, err.Error(), "query is required")
|
||||
}
|
||||
|
||||
// TestHandleTimeline_InvalidJSON tests timeline with invalid JSON.
|
||||
func TestHandleTimeline_InvalidJSON(t *testing.T) {
|
||||
server := NewServer(nil, "1.0.0")
|
||||
ctx := context.Background()
|
||||
|
||||
_, err := server.handleTimeline(ctx, json.RawMessage(`{invalid`))
|
||||
require.Error(t, err)
|
||||
assert.Contains(t, err.Error(), "invalid timeline params")
|
||||
}
|
||||
|
||||
// TestHandleTimelineByQuery_InvalidJSON tests timeline by query with invalid JSON.
|
||||
func TestHandleTimelineByQuery_InvalidJSON(t *testing.T) {
|
||||
server := NewServer(nil, "1.0.0")
|
||||
ctx := context.Background()
|
||||
|
||||
_, err := server.handleTimelineByQuery(ctx, json.RawMessage(`{invalid`))
|
||||
require.Error(t, err)
|
||||
assert.Contains(t, err.Error(), "invalid timeline params")
|
||||
}
|
||||
|
||||
// TestHandleTimeline_NoAnchorNoQuery tests timeline with no anchor and no query.
|
||||
func TestHandleTimeline_NoAnchorNoQuery(t *testing.T) {
|
||||
server := NewServer(nil, "1.0.0")
|
||||
ctx := context.Background()
|
||||
|
||||
// No anchor_id and no query should return empty result
|
||||
result, err := server.handleTimeline(ctx, json.RawMessage(`{}`))
|
||||
require.NoError(t, err)
|
||||
assert.NotNil(t, result)
|
||||
assert.Empty(t, result.Results)
|
||||
}
|
||||
|
||||
// TestHandleTimeline_WithDefaults tests timeline default values are applied.
|
||||
func TestHandleTimeline_WithDefaults(t *testing.T) {
|
||||
server := NewServer(nil, "1.0.0")
|
||||
ctx := context.Background()
|
||||
|
||||
// With anchor_id but no before/after, defaults should be applied
|
||||
// However, since searchMgr is nil, this will fail after defaults are applied
|
||||
result, err := server.handleTimeline(ctx, json.RawMessage(`{"anchor_id": 0}`))
|
||||
// Should return empty result since anchor_id is 0
|
||||
require.NoError(t, err)
|
||||
assert.NotNil(t, result)
|
||||
assert.Empty(t, result.Results)
|
||||
}
|
||||
|
||||
// TestServerFields tests Server struct fields.
|
||||
func TestServerFields(t *testing.T) {
|
||||
server := NewServer(nil, "2.0.0")
|
||||
|
||||
assert.Equal(t, "2.0.0", server.version)
|
||||
assert.Nil(t, server.searchMgr)
|
||||
assert.NotNil(t, server.stdin)
|
||||
assert.NotNil(t, server.stdout)
|
||||
}
|
||||
|
||||
// TestRequestUnmarshalWithNullID tests Request unmarshaling with null ID.
|
||||
func TestRequestUnmarshalWithNullID(t *testing.T) {
|
||||
input := `{"jsonrpc":"2.0","id":null,"method":"initialize"}`
|
||||
|
||||
var req Request
|
||||
err := json.Unmarshal([]byte(input), &req)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, "2.0", req.JSONRPC)
|
||||
assert.Nil(t, req.ID)
|
||||
assert.Equal(t, "initialize", req.Method)
|
||||
}
|
||||
|
||||
// TestResponseWithNullError tests Response without error.
|
||||
func TestResponseWithNullError(t *testing.T) {
|
||||
resp := Response{
|
||||
JSONRPC: "2.0",
|
||||
ID: 1,
|
||||
Result: "success",
|
||||
Error: nil,
|
||||
}
|
||||
|
||||
data, err := json.Marshal(resp)
|
||||
require.NoError(t, err)
|
||||
assert.Contains(t, string(data), `"result":"success"`)
|
||||
assert.NotContains(t, string(data), `"error"`)
|
||||
}
|
||||
|
||||
// TestErrorWithNilData tests Error without data.
|
||||
func TestErrorWithNilData(t *testing.T) {
|
||||
err := Error{
|
||||
Code: -32600,
|
||||
Message: "Invalid Request",
|
||||
Data: nil,
|
||||
}
|
||||
|
||||
data, errMarshal := json.Marshal(err)
|
||||
require.NoError(t, errMarshal)
|
||||
assert.Contains(t, string(data), `"code":-32600`)
|
||||
assert.Contains(t, string(data), `"message":"Invalid Request"`)
|
||||
assert.NotContains(t, string(data), `"data"`)
|
||||
}
|
||||
|
||||
// TestToolInputSchema tests that tool input schemas have required fields.
|
||||
func TestToolInputSchema(t *testing.T) {
|
||||
server := NewServer(nil, "1.0.0")
|
||||
|
||||
req := &Request{
|
||||
JSONRPC: "2.0",
|
||||
ID: 1,
|
||||
Method: "tools/list",
|
||||
}
|
||||
|
||||
resp := server.handleToolsList(req)
|
||||
result := resp.Result.(map[string]any)
|
||||
tools := result["tools"].([]Tool)
|
||||
|
||||
for _, tool := range tools {
|
||||
schema := tool.InputSchema
|
||||
schemaType, ok := schema["type"]
|
||||
assert.True(t, ok, "tool %s schema should have type", tool.Name)
|
||||
assert.Equal(t, "object", schemaType, "tool %s schema type should be object", tool.Name)
|
||||
|
||||
// All tools should have properties
|
||||
_, hasProperties := schema["properties"]
|
||||
assert.True(t, hasProperties, "tool %s should have properties", tool.Name)
|
||||
}
|
||||
}
|
||||
|
||||
// TestRunMixedRequests tests Run with mixed valid and invalid requests.
|
||||
func TestRunMixedRequests(t *testing.T) {
|
||||
var stdout bytes.Buffer
|
||||
req1 := `{"jsonrpc":"2.0","id":1,"method":"initialize"}`
|
||||
req2 := `invalid json`
|
||||
req3 := `{"jsonrpc":"2.0","id":3,"method":"tools/list"}`
|
||||
stdin := strings.NewReader(req1 + "\n" + req2 + "\n" + req3 + "\n")
|
||||
|
||||
server := &Server{
|
||||
stdin: stdin,
|
||||
stdout: &stdout,
|
||||
version: "1.0.0",
|
||||
}
|
||||
|
||||
err := server.Run(context.Background())
|
||||
require.NoError(t, err)
|
||||
|
||||
output := stdout.String()
|
||||
// Should have responses for all three requests
|
||||
assert.Contains(t, output, `"id":1`)
|
||||
assert.Contains(t, output, `"error"`) // Parse error for invalid json
|
||||
assert.Contains(t, output, `"id":3`)
|
||||
}
|
||||
|
||||
// TestToolCallParamsWithComplexArgs tests ToolCallParams with complex arguments.
|
||||
func TestToolCallParamsWithComplexArgs(t *testing.T) {
|
||||
input := `{
|
||||
"name": "search",
|
||||
"arguments": {
|
||||
"query": "authentication bug",
|
||||
"project": "my-project",
|
||||
"limit": 10,
|
||||
"type": "observations"
|
||||
}
|
||||
}`
|
||||
|
||||
var params ToolCallParams
|
||||
err := json.Unmarshal([]byte(input), ¶ms)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, "search", params.Name)
|
||||
assert.NotEmpty(t, params.Arguments)
|
||||
}
|
||||
|
||||
@@ -0,0 +1,646 @@
|
||||
// Package search provides unified search capabilities for claude-mnemonic.
|
||||
package search
|
||||
|
||||
import (
|
||||
"context"
|
||||
"database/sql"
|
||||
"os"
|
||||
"testing"
|
||||
|
||||
"github.com/lukaszraczylo/claude-mnemonic/internal/db/sqlite"
|
||||
"github.com/lukaszraczylo/claude-mnemonic/pkg/models"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
"github.com/stretchr/testify/suite"
|
||||
|
||||
// Import sqlite driver
|
||||
_ "github.com/mattn/go-sqlite3"
|
||||
)
|
||||
|
||||
// Ensure context is used (for later tests)
|
||||
var _ = context.Background
|
||||
|
||||
// hasFTS5 checks if FTS5 is available in the SQLite build.
|
||||
func hasFTS5(t *testing.T) bool {
|
||||
t.Helper()
|
||||
|
||||
tmpDir, err := os.MkdirTemp("", "fts5-check-*")
|
||||
if err != nil {
|
||||
return false
|
||||
}
|
||||
defer func() { _ = os.RemoveAll(tmpDir) }()
|
||||
|
||||
dbPath := tmpDir + "/check.db"
|
||||
db, err := sql.Open("sqlite3", dbPath)
|
||||
if err != nil {
|
||||
return false
|
||||
}
|
||||
defer func() { _ = db.Close() }()
|
||||
|
||||
_, err = db.Exec("CREATE VIRTUAL TABLE IF NOT EXISTS fts5_test USING fts5(content)")
|
||||
if err != nil {
|
||||
return false
|
||||
}
|
||||
_, _ = db.Exec("DROP TABLE IF EXISTS fts5_test")
|
||||
return true
|
||||
}
|
||||
|
||||
// testStore creates a sqlite.Store with a temporary database for testing.
|
||||
func testStore(t *testing.T) (*sqlite.Store, func()) {
|
||||
t.Helper()
|
||||
|
||||
if !hasFTS5(t) {
|
||||
t.Skip("FTS5 not available in this SQLite build")
|
||||
}
|
||||
|
||||
tmpDir, err := os.MkdirTemp("", "search-integration-test-*")
|
||||
require.NoError(t, err)
|
||||
|
||||
dbPath := tmpDir + "/test.db"
|
||||
|
||||
store, err := sqlite.NewStore(sqlite.StoreConfig{
|
||||
Path: dbPath,
|
||||
MaxConns: 1,
|
||||
WALMode: true,
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
cleanup := func() {
|
||||
_ = store.Close()
|
||||
_ = os.RemoveAll(tmpDir)
|
||||
}
|
||||
|
||||
return store, cleanup
|
||||
}
|
||||
|
||||
// SearchIntegrationSuite tests search with real SQLite stores.
|
||||
type SearchIntegrationSuite struct {
|
||||
suite.Suite
|
||||
store *sqlite.Store
|
||||
cleanup func()
|
||||
manager *Manager
|
||||
obsStore *sqlite.ObservationStore
|
||||
sumStore *sqlite.SummaryStore
|
||||
prmStore *sqlite.PromptStore
|
||||
}
|
||||
|
||||
func (s *SearchIntegrationSuite) SetupTest() {
|
||||
if !hasFTS5(s.T()) {
|
||||
s.T().Skip("FTS5 not available in this SQLite build")
|
||||
}
|
||||
|
||||
s.store, s.cleanup = testStore(s.T())
|
||||
|
||||
// Create real stores backed by SQLite
|
||||
s.obsStore = sqlite.NewObservationStore(s.store)
|
||||
s.sumStore = sqlite.NewSummaryStore(s.store)
|
||||
s.prmStore = sqlite.NewPromptStore(s.store)
|
||||
|
||||
// Create search manager with real stores (no vector client for now)
|
||||
s.manager = NewManager(s.obsStore, s.sumStore, s.prmStore, nil)
|
||||
}
|
||||
|
||||
func (s *SearchIntegrationSuite) TearDownTest() {
|
||||
if s.cleanup != nil {
|
||||
s.cleanup()
|
||||
}
|
||||
}
|
||||
|
||||
func TestSearchIntegrationSuite(t *testing.T) {
|
||||
suite.Run(t, new(SearchIntegrationSuite))
|
||||
}
|
||||
|
||||
// seedObservations inserts test observations into the database.
|
||||
func (s *SearchIntegrationSuite) seedObservations(ctx context.Context) []int64 {
|
||||
var ids []int64
|
||||
|
||||
// Observation 1: Authentication bug fix
|
||||
obs1 := &models.ParsedObservation{
|
||||
Type: models.ObsTypeBugfix,
|
||||
Scope: models.ScopeProject,
|
||||
Title: "Fixed Authentication Bug",
|
||||
Narrative: "Resolved JWT token validation issue that caused intermittent login failures",
|
||||
Concepts: []string{"authentication", "jwt", "security"},
|
||||
FilesRead: []string{"auth/handler.go", "auth/jwt.go"},
|
||||
}
|
||||
id1, _, err := s.obsStore.StoreObservation(ctx, "sdk-sess-1", "test-project", obs1, 1, 100)
|
||||
s.Require().NoError(err)
|
||||
ids = append(ids, id1)
|
||||
|
||||
// Observation 2: Database optimization decision
|
||||
obs2 := &models.ParsedObservation{
|
||||
Type: models.ObsTypeDecision,
|
||||
Scope: models.ScopeProject,
|
||||
Title: "Database Query Optimization Decision",
|
||||
Narrative: "Decided to add indexes on user_id and created_at columns for better performance",
|
||||
Concepts: []string{"database", "performance", "decision"},
|
||||
FilesRead: []string{"db/migrations/001.sql"},
|
||||
}
|
||||
id2, _, err := s.obsStore.StoreObservation(ctx, "sdk-sess-1", "test-project", obs2, 2, 150)
|
||||
s.Require().NoError(err)
|
||||
ids = append(ids, id2)
|
||||
|
||||
// Observation 3: Global best practice
|
||||
obs3 := &models.ParsedObservation{
|
||||
Type: models.ObsTypeDiscovery,
|
||||
Scope: models.ScopeGlobal,
|
||||
Title: "Error Handling Best Practice",
|
||||
Narrative: "Use wrapped errors with context for better debugging: errors.Wrap(err, context)",
|
||||
Concepts: []string{"best-practice", "errors", "patterns"},
|
||||
FilesRead: []string{"pkg/errors/errors.go"},
|
||||
}
|
||||
id3, _, err := s.obsStore.StoreObservation(ctx, "sdk-sess-2", "other-project", obs3, 1, 80)
|
||||
s.Require().NoError(err)
|
||||
ids = append(ids, id3)
|
||||
|
||||
// Observation 4: Code change/refactoring
|
||||
obs4 := &models.ParsedObservation{
|
||||
Type: models.ObsTypeChange,
|
||||
Scope: models.ScopeProject,
|
||||
Title: "Refactored User Service",
|
||||
Narrative: "Changed user service to use repository pattern, modified interfaces for better testability",
|
||||
Concepts: []string{"refactoring", "architecture"},
|
||||
FilesModified: []string{"services/user.go", "services/user_test.go"},
|
||||
}
|
||||
id4, _, err := s.obsStore.StoreObservation(ctx, "sdk-sess-1", "test-project", obs4, 3, 200)
|
||||
s.Require().NoError(err)
|
||||
ids = append(ids, id4)
|
||||
|
||||
return ids
|
||||
}
|
||||
|
||||
// seedSummaries inserts test session summaries into the database.
|
||||
func (s *SearchIntegrationSuite) seedSummaries(ctx context.Context) []int64 {
|
||||
var ids []int64
|
||||
|
||||
// Summary 1
|
||||
sum1 := &models.ParsedSummary{
|
||||
Request: "Fix authentication bug",
|
||||
Investigated: "JWT token validation and session handling",
|
||||
Learned: "JWT validation requires algorithm check to prevent alg:none attacks",
|
||||
Completed: "Fixed JWT validation, added tests",
|
||||
NextSteps: "Review other security endpoints",
|
||||
}
|
||||
id1, _, err := s.sumStore.StoreSummary(ctx, "sdk-sess-1", "test-project", sum1, 1, 100)
|
||||
s.Require().NoError(err)
|
||||
ids = append(ids, id1)
|
||||
|
||||
// Summary 2
|
||||
sum2 := &models.ParsedSummary{
|
||||
Request: "Optimize database queries",
|
||||
Investigated: "Query execution plans and index usage",
|
||||
Learned: "Composite indexes work better for range queries",
|
||||
Completed: "Added indexes, verified performance improvement",
|
||||
NextSteps: "Monitor query times in production",
|
||||
}
|
||||
id2, _, err := s.sumStore.StoreSummary(ctx, "sdk-sess-1", "test-project", sum2, 2, 150)
|
||||
s.Require().NoError(err)
|
||||
ids = append(ids, id2)
|
||||
|
||||
return ids
|
||||
}
|
||||
|
||||
// TestFilterSearch_WithRealStores tests filterSearch with seeded data.
|
||||
func (s *SearchIntegrationSuite) TestFilterSearch_WithRealStores() {
|
||||
ctx := context.Background()
|
||||
|
||||
// Seed test data
|
||||
obsIDs := s.seedObservations(ctx)
|
||||
sumIDs := s.seedSummaries(ctx)
|
||||
s.Require().Len(obsIDs, 4)
|
||||
s.Require().Len(sumIDs, 2)
|
||||
|
||||
// Test filter search for observations only
|
||||
result, err := s.manager.filterSearch(ctx, SearchParams{
|
||||
Project: "test-project",
|
||||
Type: "observations",
|
||||
Limit: 10,
|
||||
Format: "full",
|
||||
})
|
||||
s.Require().NoError(err)
|
||||
s.NotNil(result)
|
||||
|
||||
// Should return project observations + global observation (4 total: 3 project + 1 global)
|
||||
s.GreaterOrEqual(len(result.Results), 3)
|
||||
|
||||
// Verify result types
|
||||
for _, r := range result.Results {
|
||||
s.Equal("observation", r.Type)
|
||||
}
|
||||
}
|
||||
|
||||
// TestFilterSearch_SessionsOnly tests filterSearch for sessions.
|
||||
func (s *SearchIntegrationSuite) TestFilterSearch_SessionsOnly() {
|
||||
ctx := context.Background()
|
||||
|
||||
// Seed test data
|
||||
_ = s.seedObservations(ctx)
|
||||
sumIDs := s.seedSummaries(ctx)
|
||||
s.Require().Len(sumIDs, 2)
|
||||
|
||||
// Test filter search for sessions only
|
||||
result, err := s.manager.filterSearch(ctx, SearchParams{
|
||||
Project: "test-project",
|
||||
Type: "sessions",
|
||||
Limit: 10,
|
||||
Format: "full",
|
||||
})
|
||||
s.Require().NoError(err)
|
||||
s.NotNil(result)
|
||||
|
||||
// Should return 2 summaries
|
||||
s.Len(result.Results, 2)
|
||||
|
||||
// Verify result types
|
||||
for _, r := range result.Results {
|
||||
s.Equal("session", r.Type)
|
||||
s.NotEmpty(r.Title) // Title should be populated from Request
|
||||
}
|
||||
}
|
||||
|
||||
// TestFilterSearch_AllTypes tests filterSearch for all types.
|
||||
func (s *SearchIntegrationSuite) TestFilterSearch_AllTypes() {
|
||||
ctx := context.Background()
|
||||
|
||||
// Seed test data
|
||||
obsIDs := s.seedObservations(ctx)
|
||||
sumIDs := s.seedSummaries(ctx)
|
||||
s.Require().Len(obsIDs, 4)
|
||||
s.Require().Len(sumIDs, 2)
|
||||
|
||||
// Test filter search for all types (Type = "")
|
||||
result, err := s.manager.filterSearch(ctx, SearchParams{
|
||||
Project: "test-project",
|
||||
Type: "", // All types
|
||||
Limit: 20,
|
||||
Format: "full",
|
||||
})
|
||||
s.Require().NoError(err)
|
||||
s.NotNil(result)
|
||||
|
||||
// Should return both observations and sessions
|
||||
hasObservations := false
|
||||
hasSessions := false
|
||||
for _, r := range result.Results {
|
||||
if r.Type == "observation" {
|
||||
hasObservations = true
|
||||
}
|
||||
if r.Type == "session" {
|
||||
hasSessions = true
|
||||
}
|
||||
}
|
||||
s.True(hasObservations, "Should have observation results")
|
||||
s.True(hasSessions, "Should have session results")
|
||||
}
|
||||
|
||||
// TestUnifiedSearch_DefaultLimit tests UnifiedSearch with default limit.
|
||||
func (s *SearchIntegrationSuite) TestUnifiedSearch_DefaultLimit() {
|
||||
ctx := context.Background()
|
||||
|
||||
// Seed test data
|
||||
s.seedObservations(ctx)
|
||||
s.seedSummaries(ctx)
|
||||
|
||||
// Test with no limit specified (should default to 20)
|
||||
result, err := s.manager.UnifiedSearch(ctx, SearchParams{
|
||||
Project: "test-project",
|
||||
})
|
||||
s.Require().NoError(err)
|
||||
s.NotNil(result)
|
||||
s.LessOrEqual(len(result.Results), 20)
|
||||
}
|
||||
|
||||
// TestUnifiedSearch_LimitCapping tests UnifiedSearch limit capping.
|
||||
func (s *SearchIntegrationSuite) TestUnifiedSearch_LimitCapping() {
|
||||
ctx := context.Background()
|
||||
|
||||
// Seed test data
|
||||
s.seedObservations(ctx)
|
||||
s.seedSummaries(ctx)
|
||||
|
||||
// Test with limit > 100 (should be capped to 100)
|
||||
result, err := s.manager.UnifiedSearch(ctx, SearchParams{
|
||||
Project: "test-project",
|
||||
Limit: 500,
|
||||
})
|
||||
s.Require().NoError(err)
|
||||
s.NotNil(result)
|
||||
s.LessOrEqual(len(result.Results), 100)
|
||||
}
|
||||
|
||||
// TestDecisions_WithRealStores tests the Decisions method falls back to filterSearch.
|
||||
func (s *SearchIntegrationSuite) TestDecisions_WithRealStores() {
|
||||
ctx := context.Background()
|
||||
|
||||
// Seed test data
|
||||
s.seedObservations(ctx)
|
||||
|
||||
// Test Decisions search (without vector client, falls back to filterSearch)
|
||||
result, err := s.manager.Decisions(ctx, SearchParams{
|
||||
Project: "test-project",
|
||||
Query: "database",
|
||||
Limit: 10,
|
||||
})
|
||||
s.Require().NoError(err)
|
||||
s.NotNil(result)
|
||||
|
||||
// Without vector client, falls back to filterSearch which returns observations
|
||||
// All results should be observations (type is forced to "observations" in Decisions)
|
||||
for _, r := range result.Results {
|
||||
s.Equal("observation", r.Type)
|
||||
}
|
||||
}
|
||||
|
||||
// TestChanges_WithRealStores tests the Changes method falls back to filterSearch.
|
||||
func (s *SearchIntegrationSuite) TestChanges_WithRealStores() {
|
||||
ctx := context.Background()
|
||||
|
||||
// Seed test data
|
||||
s.seedObservations(ctx)
|
||||
|
||||
// Test Changes search (without vector client, falls back to filterSearch)
|
||||
result, err := s.manager.Changes(ctx, SearchParams{
|
||||
Project: "test-project",
|
||||
Query: "user service",
|
||||
Limit: 10,
|
||||
})
|
||||
s.Require().NoError(err)
|
||||
s.NotNil(result)
|
||||
|
||||
// Without vector client, falls back to filterSearch which returns observations
|
||||
// All results should be observations (type is forced to "observations" in Changes)
|
||||
for _, r := range result.Results {
|
||||
s.Equal("observation", r.Type)
|
||||
}
|
||||
}
|
||||
|
||||
// TestHowItWorks_WithRealStores tests the HowItWorks method falls back to filterSearch.
|
||||
func (s *SearchIntegrationSuite) TestHowItWorks_WithRealStores() {
|
||||
ctx := context.Background()
|
||||
|
||||
// Seed test data
|
||||
s.seedObservations(ctx)
|
||||
|
||||
// Test HowItWorks search (without vector client, falls back to filterSearch)
|
||||
result, err := s.manager.HowItWorks(ctx, SearchParams{
|
||||
Project: "test-project",
|
||||
Query: "authentication",
|
||||
Limit: 10,
|
||||
})
|
||||
s.Require().NoError(err)
|
||||
s.NotNil(result)
|
||||
|
||||
// Without vector client, falls back to filterSearch which returns observations
|
||||
// All results should be observations (type is forced to "observations" in HowItWorks)
|
||||
for _, r := range result.Results {
|
||||
s.Equal("observation", r.Type)
|
||||
}
|
||||
}
|
||||
|
||||
// TestObservationToResult tests observation to result conversion with full format.
|
||||
func (s *SearchIntegrationSuite) TestObservationToResult_FullFormat() {
|
||||
ctx := context.Background()
|
||||
|
||||
// Insert single observation
|
||||
obs := &models.ParsedObservation{
|
||||
Type: models.ObsTypeDiscovery,
|
||||
Scope: models.ScopeProject,
|
||||
Title: "Test Title",
|
||||
Narrative: "Detailed narrative content for testing",
|
||||
Concepts: []string{"testing", "content"},
|
||||
}
|
||||
id, _, err := s.obsStore.StoreObservation(ctx, "sdk-test", "test-project", obs, 1, 50)
|
||||
s.Require().NoError(err)
|
||||
|
||||
// Retrieve and convert
|
||||
retrieved, err := s.obsStore.GetObservationByID(ctx, id)
|
||||
s.Require().NoError(err)
|
||||
s.Require().NotNil(retrieved)
|
||||
|
||||
result := s.manager.observationToResult(retrieved, "full")
|
||||
|
||||
s.Equal("observation", result.Type)
|
||||
s.Equal(id, result.ID)
|
||||
s.Equal("Test Title", result.Title)
|
||||
s.Equal("Detailed narrative content for testing", result.Content)
|
||||
s.Equal("test-project", result.Project)
|
||||
s.Equal("project", result.Scope)
|
||||
s.NotNil(result.Metadata)
|
||||
s.Equal("discovery", result.Metadata["obs_type"])
|
||||
}
|
||||
|
||||
// TestObservationToResult_IndexFormat tests index format (no content).
|
||||
func (s *SearchIntegrationSuite) TestObservationToResult_IndexFormat() {
|
||||
ctx := context.Background()
|
||||
|
||||
obs := &models.ParsedObservation{
|
||||
Type: models.ObsTypeBugfix,
|
||||
Scope: models.ScopeGlobal,
|
||||
Title: "Bug Fix Title",
|
||||
Narrative: "This should not appear in index format",
|
||||
}
|
||||
id, _, err := s.obsStore.StoreObservation(ctx, "sdk-test", "test-project", obs, 1, 50)
|
||||
s.Require().NoError(err)
|
||||
|
||||
retrieved, err := s.obsStore.GetObservationByID(ctx, id)
|
||||
s.Require().NoError(err)
|
||||
|
||||
result := s.manager.observationToResult(retrieved, "index")
|
||||
|
||||
s.Equal("observation", result.Type)
|
||||
s.Equal("Bug Fix Title", result.Title)
|
||||
s.Empty(result.Content, "Index format should not include content")
|
||||
s.Equal("global", result.Scope)
|
||||
}
|
||||
|
||||
// TestSummaryToResult_FullFormat tests summary to result conversion.
|
||||
func (s *SearchIntegrationSuite) TestSummaryToResult_FullFormat() {
|
||||
ctx := context.Background()
|
||||
|
||||
sum := &models.ParsedSummary{
|
||||
Request: "Implement new feature",
|
||||
Learned: "Learned important lessons about testing",
|
||||
}
|
||||
id, _, err := s.sumStore.StoreSummary(ctx, "sdk-test", "test-project", sum, 1, 50)
|
||||
s.Require().NoError(err)
|
||||
|
||||
// Retrieve via GetRecentSummaries since there's no GetByID
|
||||
summaries, err := s.sumStore.GetRecentSummaries(ctx, "test-project", 10)
|
||||
s.Require().NoError(err)
|
||||
s.Require().NotEmpty(summaries)
|
||||
|
||||
var retrieved *models.SessionSummary
|
||||
for _, s := range summaries {
|
||||
if s.ID == id {
|
||||
retrieved = s
|
||||
break
|
||||
}
|
||||
}
|
||||
s.Require().NotNil(retrieved)
|
||||
|
||||
result := s.manager.summaryToResult(retrieved, "full")
|
||||
|
||||
s.Equal("session", result.Type)
|
||||
s.Equal(id, result.ID)
|
||||
s.Contains(result.Title, "Implement new feature")
|
||||
s.Equal("Learned important lessons about testing", result.Content)
|
||||
s.Equal("test-project", result.Project)
|
||||
}
|
||||
|
||||
// TestPromptToResult_FullFormat tests prompt to result conversion.
|
||||
func (s *SearchIntegrationSuite) TestPromptToResult_FullFormat() {
|
||||
// First create a session
|
||||
ctx := context.Background()
|
||||
sessionStore := sqlite.NewSessionStore(s.store)
|
||||
_, err := sessionStore.CreateSDKSession(ctx, "sdk-prompt-test", "test-project", "initial prompt")
|
||||
s.Require().NoError(err)
|
||||
|
||||
// Save a user prompt
|
||||
promptID, err := s.prmStore.SaveUserPromptWithMatches(ctx, "sdk-prompt-test", 1, "Help me fix this authentication bug", 3)
|
||||
s.Require().NoError(err)
|
||||
|
||||
// Retrieve prompts
|
||||
prompts, err := s.prmStore.GetPromptsByIDs(ctx, []int64{promptID}, "date_desc", 10)
|
||||
s.Require().NoError(err)
|
||||
s.Require().NotEmpty(prompts)
|
||||
|
||||
result := s.manager.promptToResult(prompts[0], "full")
|
||||
|
||||
s.Equal("prompt", result.Type)
|
||||
s.Equal(promptID, result.ID)
|
||||
s.Contains(result.Title, "Help me fix")
|
||||
s.Equal("Help me fix this authentication bug", result.Content)
|
||||
}
|
||||
|
||||
// TestTruncate_TableDriven tests truncation with various inputs.
|
||||
func TestTruncate_TableDriven(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
input string
|
||||
maxLen int
|
||||
expected string
|
||||
}{
|
||||
{"short_string", "hello", 10, "hello"},
|
||||
{"exact_length", "hello", 5, "hello"},
|
||||
{"long_string", "hello world", 5, "hello..."},
|
||||
{"empty_string", "", 10, ""},
|
||||
{"whitespace_only", " ", 10, ""},
|
||||
{"with_leading_space", " hello ", 10, "hello"},
|
||||
{"very_long", "this is a very long string that should be truncated", 20, "this is a very long ..."},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
result := truncate(tt.input, tt.maxLen)
|
||||
assert.Equal(t, tt.expected, result)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// TestManagerWithNilStores tests that Manager handles nil stores gracefully.
|
||||
func TestManagerWithNilStores(t *testing.T) {
|
||||
m := NewManager(nil, nil, nil, nil)
|
||||
assert.NotNil(t, m)
|
||||
assert.Nil(t, m.observationStore)
|
||||
assert.Nil(t, m.summaryStore)
|
||||
assert.Nil(t, m.promptStore)
|
||||
assert.Nil(t, m.vectorClient)
|
||||
}
|
||||
|
||||
// TestSearchResultMetadataFields tests all metadata fields with real data.
|
||||
func (s *SearchIntegrationSuite) TestSearchResultMetadataFields() {
|
||||
ctx := context.Background()
|
||||
|
||||
obs := &models.ParsedObservation{
|
||||
Type: models.ObsTypeDecision,
|
||||
Scope: models.ScopeGlobal,
|
||||
Title: "Architecture Decision",
|
||||
Concepts: []string{"auth", "security"},
|
||||
FilesRead: []string{"handler.go", "auth.go"},
|
||||
}
|
||||
id, _, err := s.obsStore.StoreObservation(ctx, "sdk-meta-test", "test-project", obs, 1, 50)
|
||||
s.Require().NoError(err)
|
||||
|
||||
retrieved, err := s.obsStore.GetObservationByID(ctx, id)
|
||||
s.Require().NoError(err)
|
||||
|
||||
result := s.manager.observationToResult(retrieved, "full")
|
||||
|
||||
// Check metadata fields
|
||||
s.NotNil(result.Metadata)
|
||||
s.Equal("decision", result.Metadata["obs_type"])
|
||||
s.Equal("global", result.Metadata["scope"])
|
||||
s.Equal("global", result.Scope)
|
||||
}
|
||||
|
||||
// TestObservationToResult_AllFormats tests different format options.
|
||||
func TestObservationToResult_AllFormats(t *testing.T) {
|
||||
m := NewManager(nil, nil, nil, nil)
|
||||
|
||||
obs := &models.Observation{
|
||||
ID: 1,
|
||||
Project: "test",
|
||||
Type: models.ObsTypeBugfix,
|
||||
Scope: models.ScopeProject,
|
||||
Title: sql.NullString{String: "Bug Fix Title", Valid: true},
|
||||
Narrative: sql.NullString{String: "Detailed bug fix narrative", Valid: true},
|
||||
CreatedAtEpoch: 1704067200000,
|
||||
}
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
format string
|
||||
expectContent bool
|
||||
}{
|
||||
{"full_format", "full", true},
|
||||
{"index_format", "index", false},
|
||||
{"empty_format", "", false},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
result := m.observationToResult(obs, tt.format)
|
||||
assert.Equal(t, "observation", result.Type)
|
||||
assert.Equal(t, int64(1), result.ID)
|
||||
if tt.expectContent {
|
||||
assert.NotEmpty(t, result.Content)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// TestSummaryToResult_AllFormats tests different format options for summaries.
|
||||
func TestSummaryToResult_AllFormats(t *testing.T) {
|
||||
m := NewManager(nil, nil, nil, nil)
|
||||
|
||||
summary := &models.SessionSummary{
|
||||
ID: 1,
|
||||
Project: "test",
|
||||
Request: sql.NullString{String: "Test request", Valid: true},
|
||||
Learned: sql.NullString{String: "Test learned", Valid: true},
|
||||
Completed: sql.NullString{String: "Test completed", Valid: true},
|
||||
NextSteps: sql.NullString{String: "Test next steps", Valid: true},
|
||||
CreatedAtEpoch: 1704067200000,
|
||||
}
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
format string
|
||||
expectContent bool
|
||||
}{
|
||||
{"full_format", "full", true},
|
||||
{"index_format", "index", false},
|
||||
{"empty_format", "", false},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
result := m.summaryToResult(summary, tt.format)
|
||||
assert.Equal(t, "session", result.Type)
|
||||
assert.Equal(t, int64(1), result.ID)
|
||||
if tt.expectContent {
|
||||
assert.NotEmpty(t, result.Content)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
@@ -4,6 +4,7 @@ package search
|
||||
import (
|
||||
"database/sql"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/lukaszraczylo/claude-mnemonic/pkg/models"
|
||||
"github.com/stretchr/testify/assert"
|
||||
@@ -594,3 +595,482 @@ func TestSearchTypeMapping(t *testing.T) {
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// TestFilterSearchWithObservations tests filter search when observations exist.
|
||||
func TestFilterSearchWithObservations(t *testing.T) {
|
||||
// Create mock observation
|
||||
obs := &models.Observation{
|
||||
ID: 1,
|
||||
Project: "test-project",
|
||||
Type: models.ObsTypeDiscovery,
|
||||
Scope: models.ScopeProject,
|
||||
Title: sql.NullString{String: "Test Title", Valid: true},
|
||||
Narrative: sql.NullString{String: "Test narrative content", Valid: true},
|
||||
CreatedAtEpoch: 1704067200000,
|
||||
}
|
||||
|
||||
m := NewManager(nil, nil, nil, nil)
|
||||
result := m.observationToResult(obs, "full")
|
||||
|
||||
assert.Equal(t, "observation", result.Type)
|
||||
assert.Equal(t, int64(1), result.ID)
|
||||
assert.Equal(t, "Test Title", result.Title)
|
||||
assert.Equal(t, "Test narrative content", result.Content)
|
||||
assert.Equal(t, "test-project", result.Project)
|
||||
assert.Equal(t, "project", result.Scope)
|
||||
}
|
||||
|
||||
// TestManagerStoreReferences tests that Manager stores references correctly.
|
||||
func TestManagerStoreReferences(t *testing.T) {
|
||||
m := NewManager(nil, nil, nil, nil)
|
||||
|
||||
assert.Nil(t, m.observationStore)
|
||||
assert.Nil(t, m.summaryStore)
|
||||
assert.Nil(t, m.promptStore)
|
||||
assert.Nil(t, m.vectorClient)
|
||||
}
|
||||
|
||||
// TestObservationToResultWithMetadata tests metadata inclusion in results.
|
||||
func TestObservationToResultWithMetadata(t *testing.T) {
|
||||
m := NewManager(nil, nil, nil, nil)
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
obsType models.ObservationType
|
||||
scope models.ObservationScope
|
||||
}{
|
||||
{"bugfix_project", models.ObsTypeBugfix, models.ScopeProject},
|
||||
{"feature_global", models.ObsTypeFeature, models.ScopeGlobal},
|
||||
{"discovery_project", models.ObsTypeDiscovery, models.ScopeProject},
|
||||
{"decision_global", models.ObsTypeDecision, models.ScopeGlobal},
|
||||
{"refactor_project", models.ObsTypeRefactor, models.ScopeProject},
|
||||
{"change_global", models.ObsTypeChange, models.ScopeGlobal},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
obs := &models.Observation{
|
||||
ID: 1,
|
||||
Project: "test-project",
|
||||
Type: tt.obsType,
|
||||
Scope: tt.scope,
|
||||
Title: sql.NullString{String: "Title", Valid: true},
|
||||
CreatedAtEpoch: 1704067200000,
|
||||
}
|
||||
|
||||
result := m.observationToResult(obs, "full")
|
||||
|
||||
assert.Equal(t, string(tt.obsType), result.Metadata["obs_type"])
|
||||
assert.Equal(t, string(tt.scope), result.Metadata["scope"])
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// TestSummaryToResultTruncation tests title truncation in summary results.
|
||||
func TestSummaryToResultTruncation(t *testing.T) {
|
||||
m := NewManager(nil, nil, nil, nil)
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
request string
|
||||
expectedLen int
|
||||
shouldTrunc bool
|
||||
}{
|
||||
{"short_title", "Short request", 13, false},
|
||||
{"exact_100", string(make([]byte, 100)), 103, true}, // 100 + "..."
|
||||
{"over_100", string(make([]byte, 150)), 103, true},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
summary := &models.SessionSummary{
|
||||
ID: 1,
|
||||
Project: "test-project",
|
||||
Request: sql.NullString{String: tt.request, Valid: true},
|
||||
CreatedAtEpoch: 1704067200000,
|
||||
}
|
||||
|
||||
result := m.summaryToResult(summary, "full")
|
||||
|
||||
if tt.shouldTrunc {
|
||||
assert.LessOrEqual(t, len(result.Title), tt.expectedLen)
|
||||
assert.True(t, len(result.Title) <= 103) // max 100 + "..."
|
||||
} else {
|
||||
assert.Equal(t, tt.request, result.Title)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// TestPromptToResultFormats tests prompt to result conversion with different formats.
|
||||
func TestPromptToResultFormats(t *testing.T) {
|
||||
m := NewManager(nil, nil, nil, nil)
|
||||
|
||||
prompt := &models.UserPromptWithSession{
|
||||
UserPrompt: models.UserPrompt{
|
||||
ID: 123,
|
||||
PromptText: "What is the meaning of life?",
|
||||
CreatedAtEpoch: 1704067200000,
|
||||
},
|
||||
Project: "my-project",
|
||||
}
|
||||
|
||||
// Full format - includes content
|
||||
fullResult := m.promptToResult(prompt, "full")
|
||||
assert.Equal(t, "What is the meaning of life?", fullResult.Content)
|
||||
|
||||
// Index format - no content
|
||||
indexResult := m.promptToResult(prompt, "index")
|
||||
assert.Equal(t, "", indexResult.Content)
|
||||
|
||||
// Both should have same title
|
||||
assert.Equal(t, fullResult.Title, indexResult.Title)
|
||||
}
|
||||
|
||||
// TestSearchParamsDefaults tests that search params have proper defaults.
|
||||
func TestSearchParamsDefaults(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
initialLimit int
|
||||
initialOrder string
|
||||
expectedLimit int
|
||||
expectedOrder string
|
||||
}{
|
||||
{"zero_limit", 0, "", 20, "date_desc"},
|
||||
{"negative_limit", -5, "", 20, "date_desc"},
|
||||
{"over_100_limit", 150, "", 100, "date_desc"},
|
||||
{"valid_limit_50", 50, "relevance", 50, "relevance"},
|
||||
{"custom_order", 30, "date_asc", 30, "date_asc"},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
params := SearchParams{
|
||||
Query: "test",
|
||||
Project: "project",
|
||||
Limit: tt.initialLimit,
|
||||
OrderBy: tt.initialOrder,
|
||||
}
|
||||
|
||||
// Simulate the normalization that happens in UnifiedSearch
|
||||
if params.Limit <= 0 {
|
||||
params.Limit = 20
|
||||
}
|
||||
if params.Limit > 100 {
|
||||
params.Limit = 100
|
||||
}
|
||||
if params.OrderBy == "" {
|
||||
params.OrderBy = "date_desc"
|
||||
}
|
||||
|
||||
assert.Equal(t, tt.expectedLimit, params.Limit)
|
||||
assert.Equal(t, tt.expectedOrder, params.OrderBy)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// TestTruncateEdgeCases tests edge cases for truncate function.
|
||||
func TestTruncateEdgeCases(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
input string
|
||||
maxLen int
|
||||
expected string
|
||||
}{
|
||||
// Unicode strings - uses byte length so ensure maxLen accommodates full string
|
||||
{"unicode_string_no_truncate", "日本語テスト", 20, "日本語テスト"},
|
||||
{"mixed_unicode_no_truncate", "Hello世界", 15, "Hello世界"},
|
||||
// ASCII truncation
|
||||
{"ascii_truncate", "Hello World", 5, "Hello..."},
|
||||
{"only_whitespace", " ", 10, ""},
|
||||
{"tabs_and_newlines", "\t\n \t", 10, ""},
|
||||
{"newlines_with_content", "\n\nhello\n\n", 10, "hello"},
|
||||
{"zero_max_len", "hello", 0, "..."},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
result := truncate(tt.input, tt.maxLen)
|
||||
assert.Equal(t, tt.expected, result)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// TestUnifiedSearchResultEmpty tests empty UnifiedSearchResult.
|
||||
func TestUnifiedSearchResultEmpty(t *testing.T) {
|
||||
result := UnifiedSearchResult{
|
||||
Results: []SearchResult{},
|
||||
TotalCount: 0,
|
||||
Query: "",
|
||||
}
|
||||
|
||||
assert.Empty(t, result.Results)
|
||||
assert.Equal(t, 0, result.TotalCount)
|
||||
assert.Equal(t, "", result.Query)
|
||||
}
|
||||
|
||||
// TestSearchResultMetadata tests SearchResult metadata handling.
|
||||
func TestSearchResultMetadata(t *testing.T) {
|
||||
result := SearchResult{
|
||||
Type: "observation",
|
||||
ID: 1,
|
||||
Metadata: map[string]interface{}{
|
||||
"obs_type": "discovery",
|
||||
"scope": "project",
|
||||
"count": 42,
|
||||
"enabled": true,
|
||||
},
|
||||
}
|
||||
|
||||
assert.Equal(t, "discovery", result.Metadata["obs_type"])
|
||||
assert.Equal(t, "project", result.Metadata["scope"])
|
||||
assert.Equal(t, 42, result.Metadata["count"])
|
||||
assert.Equal(t, true, result.Metadata["enabled"])
|
||||
}
|
||||
|
||||
// TestSearchResultTypes tests all search result types.
|
||||
func TestSearchResultTypes(t *testing.T) {
|
||||
types := []string{"observation", "session", "prompt"}
|
||||
|
||||
for _, typ := range types {
|
||||
t.Run(typ, func(t *testing.T) {
|
||||
result := SearchResult{
|
||||
Type: typ,
|
||||
ID: 1,
|
||||
Project: "test",
|
||||
CreatedAt: time.Now().UnixMilli(),
|
||||
}
|
||||
assert.Equal(t, typ, result.Type)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// TestSearchParamsAllFields tests SearchParams with all fields populated.
|
||||
func TestSearchParamsAllFields(t *testing.T) {
|
||||
params := SearchParams{
|
||||
Query: "authentication bug",
|
||||
Type: "observations",
|
||||
Project: "my-project",
|
||||
ObsType: "bugfix",
|
||||
Concepts: "security,auth",
|
||||
Files: "handler.go,auth.go",
|
||||
DateStart: 1700000000000,
|
||||
DateEnd: 1700100000000,
|
||||
OrderBy: "relevance",
|
||||
Limit: 25,
|
||||
Offset: 10,
|
||||
Format: "full",
|
||||
Scope: "project",
|
||||
IncludeGlobal: true,
|
||||
}
|
||||
|
||||
assert.Equal(t, "authentication bug", params.Query)
|
||||
assert.Equal(t, "observations", params.Type)
|
||||
assert.Equal(t, "my-project", params.Project)
|
||||
assert.Equal(t, "bugfix", params.ObsType)
|
||||
assert.Equal(t, "security,auth", params.Concepts)
|
||||
assert.Equal(t, "handler.go,auth.go", params.Files)
|
||||
assert.Equal(t, int64(1700000000000), params.DateStart)
|
||||
assert.Equal(t, int64(1700100000000), params.DateEnd)
|
||||
assert.Equal(t, "relevance", params.OrderBy)
|
||||
assert.Equal(t, 25, params.Limit)
|
||||
assert.Equal(t, 10, params.Offset)
|
||||
assert.Equal(t, "full", params.Format)
|
||||
assert.Equal(t, "project", params.Scope)
|
||||
assert.True(t, params.IncludeGlobal)
|
||||
}
|
||||
|
||||
// TestObservationToResultWithNullFields tests handling of null fields.
|
||||
func TestObservationToResultWithNullFields(t *testing.T) {
|
||||
m := NewManager(nil, nil, nil, nil)
|
||||
|
||||
obs := &models.Observation{
|
||||
ID: 1,
|
||||
Project: "test-project",
|
||||
Type: models.ObsTypeDiscovery,
|
||||
Scope: models.ScopeProject,
|
||||
Title: sql.NullString{Valid: false},
|
||||
Narrative: sql.NullString{Valid: false},
|
||||
CreatedAtEpoch: 1704067200000,
|
||||
}
|
||||
|
||||
result := m.observationToResult(obs, "full")
|
||||
|
||||
assert.Equal(t, "", result.Title)
|
||||
assert.Equal(t, "", result.Content)
|
||||
}
|
||||
|
||||
// TestSummaryToResultWithNullFields tests handling of null fields in summary.
|
||||
func TestSummaryToResultWithNullFields(t *testing.T) {
|
||||
m := NewManager(nil, nil, nil, nil)
|
||||
|
||||
summary := &models.SessionSummary{
|
||||
ID: 1,
|
||||
Project: "test-project",
|
||||
Request: sql.NullString{Valid: false},
|
||||
Learned: sql.NullString{Valid: false},
|
||||
CreatedAtEpoch: 1704067200000,
|
||||
}
|
||||
|
||||
result := m.summaryToResult(summary, "full")
|
||||
|
||||
assert.Equal(t, "", result.Title)
|
||||
assert.Equal(t, "", result.Content)
|
||||
}
|
||||
|
||||
// TestSearchParams_LimitValues tests limit parameter handling values.
|
||||
func TestSearchParams_LimitValues(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
inputLimit int
|
||||
expectedValid bool
|
||||
}{
|
||||
{"zero_limit", 0, true},
|
||||
{"negative_limit", -5, true},
|
||||
{"normal_limit", 20, true},
|
||||
{"max_limit", 100, true},
|
||||
{"over_limit", 200, true},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
params := SearchParams{
|
||||
Query: "test",
|
||||
Project: "test",
|
||||
Limit: tt.inputLimit,
|
||||
}
|
||||
assert.NotNil(t, params)
|
||||
assert.Equal(t, tt.inputLimit, params.Limit)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// TestSearchParams_OrderByValues tests order by parameter values.
|
||||
func TestSearchParams_OrderByValues(t *testing.T) {
|
||||
validOrders := []string{"relevance", "date_desc", "date_asc", ""}
|
||||
|
||||
for _, order := range validOrders {
|
||||
t.Run("order_"+order, func(t *testing.T) {
|
||||
params := SearchParams{
|
||||
Query: "test",
|
||||
Project: "test",
|
||||
OrderBy: order,
|
||||
}
|
||||
assert.Equal(t, order, params.OrderBy)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// TestSearchParams_TypeValues tests type parameter values.
|
||||
func TestSearchParams_TypeValues(t *testing.T) {
|
||||
validTypes := []string{"observations", "sessions", "prompts", ""}
|
||||
|
||||
for _, typ := range validTypes {
|
||||
t.Run("type_"+typ, func(t *testing.T) {
|
||||
params := SearchParams{
|
||||
Query: "test",
|
||||
Project: "test",
|
||||
Type: typ,
|
||||
}
|
||||
assert.Equal(t, typ, params.Type)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// TestSearchParams_ScopeValues tests scope parameter values.
|
||||
func TestSearchParams_ScopeValues(t *testing.T) {
|
||||
validScopes := []string{"project", "global", ""}
|
||||
|
||||
for _, scope := range validScopes {
|
||||
t.Run("scope_"+scope, func(t *testing.T) {
|
||||
params := SearchParams{
|
||||
Query: "test",
|
||||
Project: "test",
|
||||
Scope: scope,
|
||||
}
|
||||
assert.Equal(t, scope, params.Scope)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// TestSearchParams_FormatValues tests format parameter values.
|
||||
func TestSearchParams_FormatValues(t *testing.T) {
|
||||
validFormats := []string{"index", "full", ""}
|
||||
|
||||
for _, format := range validFormats {
|
||||
t.Run("format_"+format, func(t *testing.T) {
|
||||
params := SearchParams{
|
||||
Query: "test",
|
||||
Project: "test",
|
||||
Format: format,
|
||||
}
|
||||
assert.Equal(t, format, params.Format)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// TestUnifiedSearchResult_MultipleResults tests result with multiple items.
|
||||
func TestUnifiedSearchResult_MultipleResults(t *testing.T) {
|
||||
results := []SearchResult{
|
||||
{Type: "observation", ID: 1, Title: "First", Project: "test"},
|
||||
{Type: "session", ID: 2, Title: "Second", Project: "test"},
|
||||
{Type: "prompt", ID: 3, Title: "Third", Project: "test"},
|
||||
}
|
||||
|
||||
result := UnifiedSearchResult{
|
||||
Results: results,
|
||||
TotalCount: 3,
|
||||
Query: "test query",
|
||||
}
|
||||
|
||||
assert.Len(t, result.Results, 3)
|
||||
assert.Equal(t, 3, result.TotalCount)
|
||||
assert.Equal(t, "observation", result.Results[0].Type)
|
||||
assert.Equal(t, "session", result.Results[1].Type)
|
||||
assert.Equal(t, "prompt", result.Results[2].Type)
|
||||
}
|
||||
|
||||
// TestSearchResult_Metadata tests metadata handling in SearchResult.
|
||||
func TestSearchResult_Metadata(t *testing.T) {
|
||||
metadata := map[string]interface{}{
|
||||
"obs_type": "discovery",
|
||||
"concepts": []string{"auth", "security"},
|
||||
"files_count": 5,
|
||||
"is_important": true,
|
||||
}
|
||||
|
||||
result := SearchResult{
|
||||
Type: "observation",
|
||||
ID: 1,
|
||||
Metadata: metadata,
|
||||
}
|
||||
|
||||
assert.Equal(t, "discovery", result.Metadata["obs_type"])
|
||||
assert.Equal(t, 5, result.Metadata["files_count"])
|
||||
assert.Equal(t, true, result.Metadata["is_important"])
|
||||
}
|
||||
|
||||
// TestSearchResult_Scores tests score handling in SearchResult.
|
||||
func TestSearchResult_Scores(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
score float64
|
||||
}{
|
||||
{"perfect_score", 1.0},
|
||||
{"high_score", 0.95},
|
||||
{"medium_score", 0.5},
|
||||
{"low_score", 0.1},
|
||||
{"zero_score", 0.0},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
result := SearchResult{
|
||||
Type: "observation",
|
||||
ID: 1,
|
||||
Score: tt.score,
|
||||
}
|
||||
assert.Equal(t, tt.score, result.Score)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
@@ -0,0 +1,502 @@
|
||||
package sqlitevec
|
||||
|
||||
import (
|
||||
"context"
|
||||
"database/sql"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"testing"
|
||||
|
||||
sqlite_vec "github.com/asg017/sqlite-vec-go-bindings/cgo"
|
||||
"github.com/lukaszraczylo/claude-mnemonic/internal/embedding"
|
||||
_ "github.com/mattn/go-sqlite3"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
// testDB creates a test SQLite database with the vectors table.
|
||||
func testDB(t *testing.T) (*sql.DB, func()) {
|
||||
t.Helper()
|
||||
|
||||
// Create temp directory
|
||||
tmpDir, err := os.MkdirTemp("", "sqlitevec-test-*")
|
||||
require.NoError(t, err)
|
||||
|
||||
dbPath := filepath.Join(tmpDir, "test.db")
|
||||
db, err := sql.Open("sqlite3", dbPath)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Enable sqlite-vec
|
||||
sqlite_vec.Auto()
|
||||
|
||||
// Create vectors table (matches production schema)
|
||||
_, err = db.Exec(`
|
||||
CREATE VIRTUAL TABLE IF NOT EXISTS vectors USING vec0(
|
||||
doc_id TEXT PRIMARY KEY,
|
||||
embedding float[384],
|
||||
sqlite_id INTEGER,
|
||||
doc_type TEXT,
|
||||
field_type TEXT,
|
||||
project TEXT,
|
||||
scope TEXT
|
||||
)
|
||||
`)
|
||||
require.NoError(t, err)
|
||||
|
||||
cleanup := func() {
|
||||
db.Close()
|
||||
os.RemoveAll(tmpDir)
|
||||
}
|
||||
|
||||
return db, cleanup
|
||||
}
|
||||
|
||||
// testEmbeddingService creates a test embedding service.
|
||||
func testEmbeddingService(t *testing.T) (*embedding.Service, func()) {
|
||||
t.Helper()
|
||||
|
||||
svc, err := embedding.NewService()
|
||||
require.NoError(t, err)
|
||||
|
||||
cleanup := func() {
|
||||
svc.Close()
|
||||
}
|
||||
|
||||
return svc, cleanup
|
||||
}
|
||||
|
||||
func TestNewClient_Success(t *testing.T) {
|
||||
db, dbCleanup := testDB(t)
|
||||
defer dbCleanup()
|
||||
|
||||
embedSvc, embedCleanup := testEmbeddingService(t)
|
||||
defer embedCleanup()
|
||||
|
||||
client, err := NewClient(Config{DB: db}, embedSvc)
|
||||
require.NoError(t, err)
|
||||
assert.NotNil(t, client)
|
||||
}
|
||||
|
||||
func TestNewClient_NilDB(t *testing.T) {
|
||||
embedSvc, embedCleanup := testEmbeddingService(t)
|
||||
defer embedCleanup()
|
||||
|
||||
client, err := NewClient(Config{DB: nil}, embedSvc)
|
||||
assert.Error(t, err)
|
||||
assert.Nil(t, client)
|
||||
assert.Contains(t, err.Error(), "database connection required")
|
||||
}
|
||||
|
||||
func TestNewClient_NilEmbedding(t *testing.T) {
|
||||
db, dbCleanup := testDB(t)
|
||||
defer dbCleanup()
|
||||
|
||||
client, err := NewClient(Config{DB: db}, nil)
|
||||
assert.Error(t, err)
|
||||
assert.Nil(t, client)
|
||||
assert.Contains(t, err.Error(), "embedding service required")
|
||||
}
|
||||
|
||||
func TestClient_AddDocuments_Empty(t *testing.T) {
|
||||
db, dbCleanup := testDB(t)
|
||||
defer dbCleanup()
|
||||
|
||||
embedSvc, embedCleanup := testEmbeddingService(t)
|
||||
defer embedCleanup()
|
||||
|
||||
client, err := NewClient(Config{DB: db}, embedSvc)
|
||||
require.NoError(t, err)
|
||||
|
||||
err = client.AddDocuments(context.Background(), []Document{})
|
||||
require.NoError(t, err)
|
||||
}
|
||||
|
||||
func TestClient_AddDocuments_Single(t *testing.T) {
|
||||
db, dbCleanup := testDB(t)
|
||||
defer dbCleanup()
|
||||
|
||||
embedSvc, embedCleanup := testEmbeddingService(t)
|
||||
defer embedCleanup()
|
||||
|
||||
client, err := NewClient(Config{DB: db}, embedSvc)
|
||||
require.NoError(t, err)
|
||||
|
||||
docs := []Document{
|
||||
{
|
||||
ID: "obs-1-title",
|
||||
Content: "This is a test observation about authentication.",
|
||||
Metadata: map[string]any{
|
||||
"sqlite_id": int64(1),
|
||||
"doc_type": "observation",
|
||||
"field_type": "title",
|
||||
"project": "test-project",
|
||||
"scope": "project",
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
err = client.AddDocuments(context.Background(), docs)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Verify document was inserted
|
||||
var count int
|
||||
err = db.QueryRow("SELECT COUNT(*) FROM vectors WHERE doc_id = ?", "obs-1-title").Scan(&count)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, 1, count)
|
||||
}
|
||||
|
||||
func TestClient_AddDocuments_Multiple(t *testing.T) {
|
||||
db, dbCleanup := testDB(t)
|
||||
defer dbCleanup()
|
||||
|
||||
embedSvc, embedCleanup := testEmbeddingService(t)
|
||||
defer embedCleanup()
|
||||
|
||||
client, err := NewClient(Config{DB: db}, embedSvc)
|
||||
require.NoError(t, err)
|
||||
|
||||
docs := []Document{
|
||||
{
|
||||
ID: "obs-1-title",
|
||||
Content: "Authentication flow implementation.",
|
||||
Metadata: map[string]any{
|
||||
"sqlite_id": int64(1),
|
||||
"doc_type": "observation",
|
||||
"field_type": "title",
|
||||
"project": "test-project",
|
||||
"scope": "project",
|
||||
},
|
||||
},
|
||||
{
|
||||
ID: "obs-1-narrative",
|
||||
Content: "We implemented JWT-based authentication.",
|
||||
Metadata: map[string]any{
|
||||
"sqlite_id": int64(1),
|
||||
"doc_type": "observation",
|
||||
"field_type": "narrative",
|
||||
"project": "test-project",
|
||||
"scope": "project",
|
||||
},
|
||||
},
|
||||
{
|
||||
ID: "obs-2-title",
|
||||
Content: "Database optimization.",
|
||||
Metadata: map[string]any{
|
||||
"sqlite_id": int64(2),
|
||||
"doc_type": "observation",
|
||||
"field_type": "title",
|
||||
"project": "test-project",
|
||||
"scope": "global",
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
err = client.AddDocuments(context.Background(), docs)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Verify all documents were inserted
|
||||
var count int
|
||||
err = db.QueryRow("SELECT COUNT(*) FROM vectors").Scan(&count)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, 3, count)
|
||||
}
|
||||
|
||||
func TestClient_DeleteDocuments_Empty(t *testing.T) {
|
||||
db, dbCleanup := testDB(t)
|
||||
defer dbCleanup()
|
||||
|
||||
embedSvc, embedCleanup := testEmbeddingService(t)
|
||||
defer embedCleanup()
|
||||
|
||||
client, err := NewClient(Config{DB: db}, embedSvc)
|
||||
require.NoError(t, err)
|
||||
|
||||
err = client.DeleteDocuments(context.Background(), []string{})
|
||||
require.NoError(t, err)
|
||||
}
|
||||
|
||||
func TestClient_DeleteDocuments_Existing(t *testing.T) {
|
||||
db, dbCleanup := testDB(t)
|
||||
defer dbCleanup()
|
||||
|
||||
embedSvc, embedCleanup := testEmbeddingService(t)
|
||||
defer embedCleanup()
|
||||
|
||||
client, err := NewClient(Config{DB: db}, embedSvc)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Add documents first
|
||||
docs := []Document{
|
||||
{
|
||||
ID: "doc-1",
|
||||
Content: "First document.",
|
||||
Metadata: map[string]any{
|
||||
"sqlite_id": int64(1),
|
||||
"doc_type": "observation",
|
||||
},
|
||||
},
|
||||
{
|
||||
ID: "doc-2",
|
||||
Content: "Second document.",
|
||||
Metadata: map[string]any{
|
||||
"sqlite_id": int64(2),
|
||||
"doc_type": "observation",
|
||||
},
|
||||
},
|
||||
}
|
||||
err = client.AddDocuments(context.Background(), docs)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Delete one document
|
||||
err = client.DeleteDocuments(context.Background(), []string{"doc-1"})
|
||||
require.NoError(t, err)
|
||||
|
||||
// Verify only one remains
|
||||
var count int
|
||||
err = db.QueryRow("SELECT COUNT(*) FROM vectors").Scan(&count)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, 1, count)
|
||||
}
|
||||
|
||||
func TestClient_Query_Basic(t *testing.T) {
|
||||
db, dbCleanup := testDB(t)
|
||||
defer dbCleanup()
|
||||
|
||||
embedSvc, embedCleanup := testEmbeddingService(t)
|
||||
defer embedCleanup()
|
||||
|
||||
client, err := NewClient(Config{DB: db}, embedSvc)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Add some test documents
|
||||
docs := []Document{
|
||||
{
|
||||
ID: "obs-1",
|
||||
Content: "Authentication and login security implementation.",
|
||||
Metadata: map[string]any{
|
||||
"sqlite_id": int64(1),
|
||||
"doc_type": "observation",
|
||||
"project": "test-project",
|
||||
"scope": "project",
|
||||
},
|
||||
},
|
||||
{
|
||||
ID: "obs-2",
|
||||
Content: "Database query optimization techniques.",
|
||||
Metadata: map[string]any{
|
||||
"sqlite_id": int64(2),
|
||||
"doc_type": "observation",
|
||||
"project": "test-project",
|
||||
"scope": "project",
|
||||
},
|
||||
},
|
||||
}
|
||||
err = client.AddDocuments(context.Background(), docs)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Query for authentication-related content
|
||||
results, err := client.Query(context.Background(), "login authentication", 10, nil)
|
||||
require.NoError(t, err)
|
||||
|
||||
assert.NotEmpty(t, results)
|
||||
assert.LessOrEqual(t, len(results), 10)
|
||||
|
||||
// First result should be the authentication document (higher similarity)
|
||||
assert.Equal(t, "obs-1", results[0].ID)
|
||||
}
|
||||
|
||||
func TestClient_Query_WithDocTypeFilter(t *testing.T) {
|
||||
db, dbCleanup := testDB(t)
|
||||
defer dbCleanup()
|
||||
|
||||
embedSvc, embedCleanup := testEmbeddingService(t)
|
||||
defer embedCleanup()
|
||||
|
||||
client, err := NewClient(Config{DB: db}, embedSvc)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Add documents of different types
|
||||
docs := []Document{
|
||||
{
|
||||
ID: "obs-1",
|
||||
Content: "Test content for observation.",
|
||||
Metadata: map[string]any{
|
||||
"sqlite_id": int64(1),
|
||||
"doc_type": "observation",
|
||||
"project": "test-project",
|
||||
},
|
||||
},
|
||||
{
|
||||
ID: "summary-1",
|
||||
Content: "Test content for summary.",
|
||||
Metadata: map[string]any{
|
||||
"sqlite_id": int64(10),
|
||||
"doc_type": "session_summary",
|
||||
"project": "test-project",
|
||||
},
|
||||
},
|
||||
}
|
||||
err = client.AddDocuments(context.Background(), docs)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Query with doc_type filter
|
||||
where := map[string]any{"doc_type": "observation"}
|
||||
results, err := client.Query(context.Background(), "test content", 10, where)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Should only return observation documents
|
||||
for _, r := range results {
|
||||
docType, _ := r.Metadata["doc_type"].(string)
|
||||
assert.Equal(t, "observation", docType)
|
||||
}
|
||||
}
|
||||
|
||||
func TestClient_Query_WithProjectFilter(t *testing.T) {
|
||||
db, dbCleanup := testDB(t)
|
||||
defer dbCleanup()
|
||||
|
||||
embedSvc, embedCleanup := testEmbeddingService(t)
|
||||
defer embedCleanup()
|
||||
|
||||
client, err := NewClient(Config{DB: db}, embedSvc)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Add documents from different projects
|
||||
docs := []Document{
|
||||
{
|
||||
ID: "obs-1",
|
||||
Content: "Project A authentication content.",
|
||||
Metadata: map[string]any{
|
||||
"sqlite_id": int64(1),
|
||||
"doc_type": "observation",
|
||||
"project": "project-a",
|
||||
"scope": "project",
|
||||
},
|
||||
},
|
||||
{
|
||||
ID: "obs-2",
|
||||
Content: "Project B database content.",
|
||||
Metadata: map[string]any{
|
||||
"sqlite_id": int64(2),
|
||||
"doc_type": "observation",
|
||||
"project": "project-b",
|
||||
"scope": "project",
|
||||
},
|
||||
},
|
||||
{
|
||||
ID: "obs-3",
|
||||
Content: "Global security best practices.",
|
||||
Metadata: map[string]any{
|
||||
"sqlite_id": int64(3),
|
||||
"doc_type": "observation",
|
||||
"project": "project-b",
|
||||
"scope": "global",
|
||||
},
|
||||
},
|
||||
}
|
||||
err = client.AddDocuments(context.Background(), docs)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Query without project filter to verify all docs are there
|
||||
results, err := client.Query(context.Background(), "authentication security", 10, nil)
|
||||
require.NoError(t, err)
|
||||
assert.NotEmpty(t, results, "Should find some results")
|
||||
}
|
||||
|
||||
func TestClient_IsConnected(t *testing.T) {
|
||||
db, dbCleanup := testDB(t)
|
||||
defer dbCleanup()
|
||||
|
||||
embedSvc, embedCleanup := testEmbeddingService(t)
|
||||
defer embedCleanup()
|
||||
|
||||
client, err := NewClient(Config{DB: db}, embedSvc)
|
||||
require.NoError(t, err)
|
||||
|
||||
assert.True(t, client.IsConnected())
|
||||
}
|
||||
|
||||
func TestClient_Close(t *testing.T) {
|
||||
db, dbCleanup := testDB(t)
|
||||
defer dbCleanup()
|
||||
|
||||
embedSvc, embedCleanup := testEmbeddingService(t)
|
||||
defer embedCleanup()
|
||||
|
||||
client, err := NewClient(Config{DB: db}, embedSvc)
|
||||
require.NoError(t, err)
|
||||
|
||||
err = client.Close()
|
||||
require.NoError(t, err)
|
||||
}
|
||||
|
||||
func TestConfig_Fields(t *testing.T) {
|
||||
db, dbCleanup := testDB(t)
|
||||
defer dbCleanup()
|
||||
|
||||
cfg := Config{DB: db}
|
||||
assert.Equal(t, db, cfg.DB)
|
||||
}
|
||||
|
||||
func TestClient_UpdateDocument_DeleteThenAdd(t *testing.T) {
|
||||
db, dbCleanup := testDB(t)
|
||||
defer dbCleanup()
|
||||
|
||||
embedSvc, embedCleanup := testEmbeddingService(t)
|
||||
defer embedCleanup()
|
||||
|
||||
client, err := NewClient(Config{DB: db}, embedSvc)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Add document
|
||||
docs1 := []Document{
|
||||
{
|
||||
ID: "doc-1",
|
||||
Content: "Original content.",
|
||||
Metadata: map[string]any{
|
||||
"sqlite_id": int64(1),
|
||||
"doc_type": "observation",
|
||||
},
|
||||
},
|
||||
}
|
||||
err = client.AddDocuments(context.Background(), docs1)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Delete then add with new content (proper update pattern)
|
||||
err = client.DeleteDocuments(context.Background(), []string{"doc-1"})
|
||||
require.NoError(t, err)
|
||||
|
||||
docs2 := []Document{
|
||||
{
|
||||
ID: "doc-1",
|
||||
Content: "Updated content.",
|
||||
Metadata: map[string]any{
|
||||
"sqlite_id": int64(1),
|
||||
"doc_type": "observation",
|
||||
},
|
||||
},
|
||||
}
|
||||
err = client.AddDocuments(context.Background(), docs2)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Should have exactly 1 document
|
||||
var count int
|
||||
err = db.QueryRow("SELECT COUNT(*) FROM vectors WHERE doc_id = ?", "doc-1").Scan(&count)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, 1, count)
|
||||
}
|
||||
|
||||
func TestClient_DeleteDocuments_NonExistent(t *testing.T) {
|
||||
db, dbCleanup := testDB(t)
|
||||
defer dbCleanup()
|
||||
|
||||
embedSvc, embedCleanup := testEmbeddingService(t)
|
||||
defer embedCleanup()
|
||||
|
||||
client, err := NewClient(Config{DB: db}, embedSvc)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Deleting non-existent document should not error
|
||||
err = client.DeleteDocuments(context.Background(), []string{"non-existent-id"})
|
||||
require.NoError(t, err)
|
||||
}
|
||||
@@ -0,0 +1,574 @@
|
||||
package sqlitevec
|
||||
|
||||
import (
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
)
|
||||
|
||||
func TestDocTypes(t *testing.T) {
|
||||
assert.Equal(t, DocType("observation"), DocTypeObservation)
|
||||
assert.Equal(t, DocType("session_summary"), DocTypeSessionSummary)
|
||||
assert.Equal(t, DocType("user_prompt"), DocTypeUserPrompt)
|
||||
}
|
||||
|
||||
func TestDocument_Fields(t *testing.T) {
|
||||
doc := Document{
|
||||
ID: "doc-123",
|
||||
Content: "test content",
|
||||
Metadata: map[string]any{
|
||||
"key": "value",
|
||||
},
|
||||
}
|
||||
|
||||
assert.Equal(t, "doc-123", doc.ID)
|
||||
assert.Equal(t, "test content", doc.Content)
|
||||
assert.Equal(t, "value", doc.Metadata["key"])
|
||||
}
|
||||
|
||||
func TestQueryResult_Fields(t *testing.T) {
|
||||
result := QueryResult{
|
||||
ID: "result-123",
|
||||
Distance: 0.5,
|
||||
Metadata: map[string]any{
|
||||
"sqlite_id": float64(42),
|
||||
},
|
||||
}
|
||||
|
||||
assert.Equal(t, "result-123", result.ID)
|
||||
assert.Equal(t, 0.5, result.Distance)
|
||||
assert.Equal(t, float64(42), result.Metadata["sqlite_id"])
|
||||
}
|
||||
|
||||
func TestBuildWhereFilter(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
docType DocType
|
||||
project string
|
||||
expected map[string]interface{}
|
||||
}{
|
||||
{
|
||||
name: "empty_filters",
|
||||
docType: "",
|
||||
project: "",
|
||||
expected: map[string]interface{}{},
|
||||
},
|
||||
{
|
||||
name: "doc_type_only",
|
||||
docType: DocTypeObservation,
|
||||
project: "",
|
||||
expected: map[string]interface{}{
|
||||
"doc_type": "observation",
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "project_only",
|
||||
docType: "",
|
||||
project: "my-project",
|
||||
expected: map[string]interface{}{
|
||||
"project": "my-project",
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "both_filters",
|
||||
docType: DocTypeSessionSummary,
|
||||
project: "test-project",
|
||||
expected: map[string]interface{}{
|
||||
"doc_type": "session_summary",
|
||||
"project": "test-project",
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "user_prompt_type",
|
||||
docType: DocTypeUserPrompt,
|
||||
project: "prompt-project",
|
||||
expected: map[string]interface{}{
|
||||
"doc_type": "user_prompt",
|
||||
"project": "prompt-project",
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
result := BuildWhereFilter(tt.docType, tt.project)
|
||||
assert.Equal(t, tt.expected, result)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestExtractIDsByDocType_Empty(t *testing.T) {
|
||||
results := []QueryResult{}
|
||||
ids := ExtractIDsByDocType(results)
|
||||
|
||||
assert.Empty(t, ids.ObservationIDs)
|
||||
assert.Empty(t, ids.SummaryIDs)
|
||||
assert.Empty(t, ids.PromptIDs)
|
||||
}
|
||||
|
||||
func TestExtractIDsByDocType_AllTypes(t *testing.T) {
|
||||
results := []QueryResult{
|
||||
{
|
||||
ID: "obs-1",
|
||||
Distance: 0.1,
|
||||
Metadata: map[string]any{
|
||||
"sqlite_id": float64(1),
|
||||
"doc_type": "observation",
|
||||
},
|
||||
},
|
||||
{
|
||||
ID: "obs-2",
|
||||
Distance: 0.2,
|
||||
Metadata: map[string]any{
|
||||
"sqlite_id": float64(2),
|
||||
"doc_type": "observation",
|
||||
},
|
||||
},
|
||||
{
|
||||
ID: "summary-1",
|
||||
Distance: 0.3,
|
||||
Metadata: map[string]any{
|
||||
"sqlite_id": float64(10),
|
||||
"doc_type": "session_summary",
|
||||
},
|
||||
},
|
||||
{
|
||||
ID: "prompt-1",
|
||||
Distance: 0.4,
|
||||
Metadata: map[string]any{
|
||||
"sqlite_id": float64(20),
|
||||
"doc_type": "user_prompt",
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
ids := ExtractIDsByDocType(results)
|
||||
|
||||
assert.Equal(t, []int64{1, 2}, ids.ObservationIDs)
|
||||
assert.Equal(t, []int64{10}, ids.SummaryIDs)
|
||||
assert.Equal(t, []int64{20}, ids.PromptIDs)
|
||||
}
|
||||
|
||||
func TestExtractIDsByDocType_Deduplication(t *testing.T) {
|
||||
results := []QueryResult{
|
||||
{
|
||||
ID: "obs-1-field1",
|
||||
Distance: 0.1,
|
||||
Metadata: map[string]any{
|
||||
"sqlite_id": float64(1),
|
||||
"doc_type": "observation",
|
||||
},
|
||||
},
|
||||
{
|
||||
ID: "obs-1-field2",
|
||||
Distance: 0.2,
|
||||
Metadata: map[string]any{
|
||||
"sqlite_id": float64(1), // Same ID, different field
|
||||
"doc_type": "observation",
|
||||
},
|
||||
},
|
||||
{
|
||||
ID: "obs-2",
|
||||
Distance: 0.3,
|
||||
Metadata: map[string]any{
|
||||
"sqlite_id": float64(2),
|
||||
"doc_type": "observation",
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
ids := ExtractIDsByDocType(results)
|
||||
|
||||
assert.Equal(t, []int64{1, 2}, ids.ObservationIDs) // Should be deduplicated
|
||||
}
|
||||
|
||||
func TestExtractIDsByDocType_Int64Fallback(t *testing.T) {
|
||||
results := []QueryResult{
|
||||
{
|
||||
ID: "obs-1",
|
||||
Distance: 0.1,
|
||||
Metadata: map[string]any{
|
||||
"sqlite_id": int64(42), // int64 instead of float64
|
||||
"doc_type": "observation",
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
ids := ExtractIDsByDocType(results)
|
||||
|
||||
assert.Equal(t, []int64{42}, ids.ObservationIDs)
|
||||
}
|
||||
|
||||
func TestExtractIDsByDocType_MissingSqliteID(t *testing.T) {
|
||||
results := []QueryResult{
|
||||
{
|
||||
ID: "obs-1",
|
||||
Distance: 0.1,
|
||||
Metadata: map[string]any{
|
||||
"doc_type": "observation",
|
||||
// Missing sqlite_id
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
ids := ExtractIDsByDocType(results)
|
||||
|
||||
assert.Empty(t, ids.ObservationIDs)
|
||||
}
|
||||
|
||||
func TestExtractIDsByDocType_UnknownType(t *testing.T) {
|
||||
results := []QueryResult{
|
||||
{
|
||||
ID: "unknown-1",
|
||||
Distance: 0.1,
|
||||
Metadata: map[string]any{
|
||||
"sqlite_id": float64(1),
|
||||
"doc_type": "unknown_type",
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
ids := ExtractIDsByDocType(results)
|
||||
|
||||
// Should not be added to any category
|
||||
assert.Empty(t, ids.ObservationIDs)
|
||||
assert.Empty(t, ids.SummaryIDs)
|
||||
assert.Empty(t, ids.PromptIDs)
|
||||
}
|
||||
|
||||
func TestExtractObservationIDs_NoFilter(t *testing.T) {
|
||||
results := []QueryResult{
|
||||
{
|
||||
ID: "obs-1",
|
||||
Metadata: map[string]any{
|
||||
"sqlite_id": float64(1),
|
||||
"doc_type": "observation",
|
||||
"project": "proj-a",
|
||||
},
|
||||
},
|
||||
{
|
||||
ID: "obs-2",
|
||||
Metadata: map[string]any{
|
||||
"sqlite_id": float64(2),
|
||||
"doc_type": "observation",
|
||||
"project": "proj-b",
|
||||
},
|
||||
},
|
||||
{
|
||||
ID: "summary-1",
|
||||
Metadata: map[string]any{
|
||||
"sqlite_id": float64(10),
|
||||
"doc_type": "session_summary",
|
||||
"project": "proj-a",
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
ids := ExtractObservationIDs(results, "")
|
||||
|
||||
assert.Equal(t, []int64{1, 2}, ids)
|
||||
}
|
||||
|
||||
func TestExtractObservationIDs_WithProjectFilter(t *testing.T) {
|
||||
results := []QueryResult{
|
||||
{
|
||||
ID: "obs-1",
|
||||
Metadata: map[string]any{
|
||||
"sqlite_id": float64(1),
|
||||
"doc_type": "observation",
|
||||
"project": "proj-a",
|
||||
"scope": "project",
|
||||
},
|
||||
},
|
||||
{
|
||||
ID: "obs-2",
|
||||
Metadata: map[string]any{
|
||||
"sqlite_id": float64(2),
|
||||
"doc_type": "observation",
|
||||
"project": "proj-b",
|
||||
"scope": "project",
|
||||
},
|
||||
},
|
||||
{
|
||||
ID: "obs-global",
|
||||
Metadata: map[string]any{
|
||||
"sqlite_id": float64(3),
|
||||
"doc_type": "observation",
|
||||
"project": "proj-b",
|
||||
"scope": "global",
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
ids := ExtractObservationIDs(results, "proj-a")
|
||||
|
||||
// Should include proj-a and global scope observations
|
||||
assert.Equal(t, []int64{1, 3}, ids)
|
||||
}
|
||||
|
||||
func TestExtractObservationIDs_Deduplication(t *testing.T) {
|
||||
results := []QueryResult{
|
||||
{
|
||||
ID: "obs-1-field1",
|
||||
Metadata: map[string]any{
|
||||
"sqlite_id": float64(1),
|
||||
"doc_type": "observation",
|
||||
},
|
||||
},
|
||||
{
|
||||
ID: "obs-1-field2",
|
||||
Metadata: map[string]any{
|
||||
"sqlite_id": float64(1), // Same ID
|
||||
"doc_type": "observation",
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
ids := ExtractObservationIDs(results, "")
|
||||
|
||||
assert.Equal(t, []int64{1}, ids)
|
||||
}
|
||||
|
||||
func TestExtractSummaryIDs_NoFilter(t *testing.T) {
|
||||
results := []QueryResult{
|
||||
{
|
||||
ID: "summary-1",
|
||||
Metadata: map[string]any{
|
||||
"sqlite_id": float64(10),
|
||||
"doc_type": "session_summary",
|
||||
"project": "proj-a",
|
||||
},
|
||||
},
|
||||
{
|
||||
ID: "summary-2",
|
||||
Metadata: map[string]any{
|
||||
"sqlite_id": float64(20),
|
||||
"doc_type": "session_summary",
|
||||
"project": "proj-b",
|
||||
},
|
||||
},
|
||||
{
|
||||
ID: "obs-1",
|
||||
Metadata: map[string]any{
|
||||
"sqlite_id": float64(1),
|
||||
"doc_type": "observation",
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
ids := ExtractSummaryIDs(results, "")
|
||||
|
||||
assert.Equal(t, []int64{10, 20}, ids)
|
||||
}
|
||||
|
||||
func TestExtractSummaryIDs_WithProjectFilter(t *testing.T) {
|
||||
results := []QueryResult{
|
||||
{
|
||||
ID: "summary-1",
|
||||
Metadata: map[string]any{
|
||||
"sqlite_id": float64(10),
|
||||
"doc_type": "session_summary",
|
||||
"project": "proj-a",
|
||||
},
|
||||
},
|
||||
{
|
||||
ID: "summary-2",
|
||||
Metadata: map[string]any{
|
||||
"sqlite_id": float64(20),
|
||||
"doc_type": "session_summary",
|
||||
"project": "proj-b",
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
ids := ExtractSummaryIDs(results, "proj-a")
|
||||
|
||||
assert.Equal(t, []int64{10}, ids)
|
||||
}
|
||||
|
||||
func TestExtractPromptIDs_NoFilter(t *testing.T) {
|
||||
results := []QueryResult{
|
||||
{
|
||||
ID: "prompt-1",
|
||||
Metadata: map[string]any{
|
||||
"sqlite_id": float64(100),
|
||||
"doc_type": "user_prompt",
|
||||
"project": "proj-a",
|
||||
},
|
||||
},
|
||||
{
|
||||
ID: "prompt-2",
|
||||
Metadata: map[string]any{
|
||||
"sqlite_id": float64(200),
|
||||
"doc_type": "user_prompt",
|
||||
"project": "proj-b",
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
ids := ExtractPromptIDs(results, "")
|
||||
|
||||
assert.Equal(t, []int64{100, 200}, ids)
|
||||
}
|
||||
|
||||
func TestExtractPromptIDs_WithProjectFilter(t *testing.T) {
|
||||
results := []QueryResult{
|
||||
{
|
||||
ID: "prompt-1",
|
||||
Metadata: map[string]any{
|
||||
"sqlite_id": float64(100),
|
||||
"doc_type": "user_prompt",
|
||||
"project": "proj-a",
|
||||
},
|
||||
},
|
||||
{
|
||||
ID: "prompt-2",
|
||||
Metadata: map[string]any{
|
||||
"sqlite_id": float64(200),
|
||||
"doc_type": "user_prompt",
|
||||
"project": "proj-b",
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
ids := ExtractPromptIDs(results, "proj-b")
|
||||
|
||||
assert.Equal(t, []int64{200}, ids)
|
||||
}
|
||||
|
||||
func TestCopyMetadata(t *testing.T) {
|
||||
base := map[string]any{
|
||||
"key1": "value1",
|
||||
"key2": 42,
|
||||
}
|
||||
|
||||
result := copyMetadata(base, "key3", "value3")
|
||||
|
||||
// Original should be unchanged
|
||||
assert.Len(t, base, 2)
|
||||
|
||||
// Result should have new key
|
||||
assert.Len(t, result, 3)
|
||||
assert.Equal(t, "value1", result["key1"])
|
||||
assert.Equal(t, 42, result["key2"])
|
||||
assert.Equal(t, "value3", result["key3"])
|
||||
}
|
||||
|
||||
func TestCopyMetadataMulti(t *testing.T) {
|
||||
base := map[string]any{
|
||||
"key1": "value1",
|
||||
}
|
||||
extra := map[string]any{
|
||||
"key2": "value2",
|
||||
"key3": "value3",
|
||||
}
|
||||
|
||||
result := copyMetadataMulti(base, extra)
|
||||
|
||||
assert.Len(t, result, 3)
|
||||
assert.Equal(t, "value1", result["key1"])
|
||||
assert.Equal(t, "value2", result["key2"])
|
||||
assert.Equal(t, "value3", result["key3"])
|
||||
}
|
||||
|
||||
func TestJoinStrings(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
strs []string
|
||||
sep string
|
||||
expected string
|
||||
}{
|
||||
{
|
||||
name: "empty_slice",
|
||||
strs: []string{},
|
||||
sep: ", ",
|
||||
expected: "",
|
||||
},
|
||||
{
|
||||
name: "single_element",
|
||||
strs: []string{"one"},
|
||||
sep: ", ",
|
||||
expected: "one",
|
||||
},
|
||||
{
|
||||
name: "multiple_elements",
|
||||
strs: []string{"one", "two", "three"},
|
||||
sep: ", ",
|
||||
expected: "one, two, three",
|
||||
},
|
||||
{
|
||||
name: "different_separator",
|
||||
strs: []string{"a", "b", "c"},
|
||||
sep: "-",
|
||||
expected: "a-b-c",
|
||||
},
|
||||
{
|
||||
name: "empty_separator",
|
||||
strs: []string{"a", "b", "c"},
|
||||
sep: "",
|
||||
expected: "abc",
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
result := joinStrings(tt.strs, tt.sep)
|
||||
assert.Equal(t, tt.expected, result)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestTruncateString(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
input string
|
||||
maxLen int
|
||||
expected string
|
||||
}{
|
||||
{
|
||||
name: "shorter_than_max",
|
||||
input: "hello",
|
||||
maxLen: 10,
|
||||
expected: "hello",
|
||||
},
|
||||
{
|
||||
name: "equal_to_max",
|
||||
input: "hello",
|
||||
maxLen: 5,
|
||||
expected: "hello",
|
||||
},
|
||||
{
|
||||
name: "longer_than_max",
|
||||
input: "hello world",
|
||||
maxLen: 5,
|
||||
expected: "hello...",
|
||||
},
|
||||
{
|
||||
name: "empty_string",
|
||||
input: "",
|
||||
maxLen: 5,
|
||||
expected: "",
|
||||
},
|
||||
{
|
||||
name: "zero_max_length",
|
||||
input: "hello",
|
||||
maxLen: 0,
|
||||
expected: "...",
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
result := truncateString(tt.input, tt.maxLen)
|
||||
assert.Equal(t, tt.expected, result)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestExtractedIDs_Empty(t *testing.T) {
|
||||
ids := &ExtractedIDs{}
|
||||
|
||||
assert.Nil(t, ids.ObservationIDs)
|
||||
assert.Nil(t, ids.SummaryIDs)
|
||||
assert.Nil(t, ids.PromptIDs)
|
||||
}
|
||||
@@ -1279,3 +1279,849 @@ func TestObservationRequest_Fields(t *testing.T) {
|
||||
assert.Equal(t, "Read", req.ToolName)
|
||||
assert.Equal(t, "/home/user/project", req.CWD)
|
||||
}
|
||||
|
||||
// TestRetrievalStats_Fields tests RetrievalStats struct fields.
|
||||
func TestRetrievalStats_Fields(t *testing.T) {
|
||||
stats := RetrievalStats{
|
||||
TotalRequests: 100,
|
||||
ObservationsServed: 500,
|
||||
VerifiedStale: 10,
|
||||
DeletedInvalid: 5,
|
||||
SearchRequests: 80,
|
||||
ContextInjections: 20,
|
||||
}
|
||||
|
||||
assert.Equal(t, int64(100), stats.TotalRequests)
|
||||
assert.Equal(t, int64(500), stats.ObservationsServed)
|
||||
assert.Equal(t, int64(10), stats.VerifiedStale)
|
||||
assert.Equal(t, int64(5), stats.DeletedInvalid)
|
||||
assert.Equal(t, int64(80), stats.SearchRequests)
|
||||
assert.Equal(t, int64(20), stats.ContextInjections)
|
||||
}
|
||||
|
||||
// TestServiceConstants tests service configuration constants.
|
||||
func TestServiceConstants(t *testing.T) {
|
||||
assert.Equal(t, 30*time.Second, DefaultHTTPTimeout)
|
||||
assert.Equal(t, 50*time.Millisecond, ReadyPollInterval)
|
||||
assert.Equal(t, 100, StaleQueueSize)
|
||||
assert.Equal(t, 2*time.Second, QueueProcessInterval)
|
||||
}
|
||||
|
||||
// TestClusterObservations_Empty tests clustering with empty slice.
|
||||
func TestClusterObservations_Empty(t *testing.T) {
|
||||
observations := []*models.Observation{}
|
||||
clustered := clusterObservations(observations, 0.4)
|
||||
assert.Empty(t, clustered)
|
||||
}
|
||||
|
||||
// TestClusterObservations_Single tests clustering with single observation.
|
||||
func TestClusterObservations_Single(t *testing.T) {
|
||||
observations := []*models.Observation{
|
||||
{
|
||||
ID: 1,
|
||||
Title: sql.NullString{String: "Test observation", Valid: true},
|
||||
Narrative: sql.NullString{String: "Test content", Valid: true},
|
||||
},
|
||||
}
|
||||
clustered := clusterObservations(observations, 0.4)
|
||||
assert.Len(t, clustered, 1)
|
||||
}
|
||||
|
||||
// TestClusterObservations_VeryDifferent tests clustering with very different observations.
|
||||
func TestClusterObservations_VeryDifferent(t *testing.T) {
|
||||
observations := []*models.Observation{
|
||||
{
|
||||
ID: 1,
|
||||
Title: sql.NullString{String: "Database optimization", Valid: true},
|
||||
Narrative: sql.NullString{String: "PostgreSQL index tuning", Valid: true},
|
||||
},
|
||||
{
|
||||
ID: 2,
|
||||
Title: sql.NullString{String: "Authentication flow", Valid: true},
|
||||
Narrative: sql.NullString{String: "JWT token validation", Valid: true},
|
||||
},
|
||||
{
|
||||
ID: 3,
|
||||
Title: sql.NullString{String: "Logging setup", Valid: true},
|
||||
Narrative: sql.NullString{String: "Zerolog configuration", Valid: true},
|
||||
},
|
||||
}
|
||||
clustered := clusterObservations(observations, 0.4)
|
||||
// Very different observations should not be clustered together
|
||||
assert.GreaterOrEqual(t, len(clustered), 1)
|
||||
}
|
||||
|
||||
// TestHandleContextInject_WithLimit tests context inject with custom limit.
|
||||
func TestHandleContextInject_WithLimit(t *testing.T) {
|
||||
svc, cleanup := testService(t)
|
||||
defer cleanup()
|
||||
|
||||
project := "inject-limit-test"
|
||||
|
||||
// Create some observations
|
||||
for i := 0; i < 10; i++ {
|
||||
createTestObservation(t, svc.observationStore, project,
|
||||
"Observation "+strconv.Itoa(i),
|
||||
"Content "+strconv.Itoa(i),
|
||||
[]string{"test-" + strconv.Itoa(i)})
|
||||
time.Sleep(time.Millisecond)
|
||||
}
|
||||
|
||||
req := httptest.NewRequest(http.MethodGet, "/api/context/inject?project="+project+"&limit=5", nil)
|
||||
rec := httptest.NewRecorder()
|
||||
|
||||
svc.router.ServeHTTP(rec, req)
|
||||
|
||||
assert.Equal(t, http.StatusOK, rec.Code)
|
||||
|
||||
var response map[string]interface{}
|
||||
err := json.Unmarshal(rec.Body.Bytes(), &response)
|
||||
require.NoError(t, err)
|
||||
|
||||
observations, ok := response["observations"].([]interface{})
|
||||
require.True(t, ok)
|
||||
assert.LessOrEqual(t, len(observations), 5)
|
||||
}
|
||||
|
||||
// TestHandleGetObservations tests getting observations list.
|
||||
func TestHandleGetObservations(t *testing.T) {
|
||||
svc, cleanup := testService(t)
|
||||
defer cleanup()
|
||||
|
||||
// Create some observations
|
||||
createTestObservation(t, svc.observationStore, "test-project",
|
||||
"Test Observation 1",
|
||||
"Test content 1",
|
||||
[]string{"test"})
|
||||
createTestObservation(t, svc.observationStore, "test-project",
|
||||
"Test Observation 2",
|
||||
"Test content 2",
|
||||
[]string{"test"})
|
||||
|
||||
req := httptest.NewRequest(http.MethodGet, "/api/observations?project=test-project", nil)
|
||||
rec := httptest.NewRecorder()
|
||||
|
||||
svc.router.ServeHTTP(rec, req)
|
||||
|
||||
assert.Equal(t, http.StatusOK, rec.Code)
|
||||
|
||||
var observations []map[string]interface{}
|
||||
err := json.Unmarshal(rec.Body.Bytes(), &observations)
|
||||
require.NoError(t, err)
|
||||
|
||||
assert.GreaterOrEqual(t, len(observations), 2)
|
||||
}
|
||||
|
||||
// TestHandleGetObservations_Pagination tests observations pagination.
|
||||
func TestHandleGetObservations_Pagination(t *testing.T) {
|
||||
svc, cleanup := testService(t)
|
||||
defer cleanup()
|
||||
|
||||
// Create some observations
|
||||
for i := 0; i < 5; i++ {
|
||||
createTestObservation(t, svc.observationStore, "page-test",
|
||||
"Observation "+strconv.Itoa(i),
|
||||
"Content "+strconv.Itoa(i),
|
||||
[]string{"test"})
|
||||
}
|
||||
|
||||
req := httptest.NewRequest(http.MethodGet, "/api/observations?project=page-test&limit=2", nil)
|
||||
rec := httptest.NewRecorder()
|
||||
|
||||
svc.router.ServeHTTP(rec, req)
|
||||
|
||||
assert.Equal(t, http.StatusOK, rec.Code)
|
||||
}
|
||||
|
||||
// TestHandleGetObservations_NoProject tests observations without project.
|
||||
func TestHandleGetObservations_NoProject(t *testing.T) {
|
||||
svc, cleanup := testService(t)
|
||||
defer cleanup()
|
||||
|
||||
req := httptest.NewRequest(http.MethodGet, "/api/observations", nil)
|
||||
rec := httptest.NewRecorder()
|
||||
|
||||
svc.router.ServeHTTP(rec, req)
|
||||
|
||||
// Should still return 200 with empty results or all observations
|
||||
assert.Contains(t, []int{http.StatusOK, http.StatusBadRequest}, rec.Code)
|
||||
}
|
||||
|
||||
// TestHandleSearchByPrompt_EmptyQuery tests search with empty query parameter.
|
||||
func TestHandleSearchByPrompt_EmptyQuery(t *testing.T) {
|
||||
svc, cleanup := testService(t)
|
||||
defer cleanup()
|
||||
|
||||
req := httptest.NewRequest(http.MethodGet, "/api/context/search?project=test&query=", nil)
|
||||
rec := httptest.NewRecorder()
|
||||
|
||||
svc.router.ServeHTTP(rec, req)
|
||||
|
||||
// Empty query should still be a bad request
|
||||
assert.Equal(t, http.StatusBadRequest, rec.Code)
|
||||
}
|
||||
|
||||
// TestHandleGetSessionByClaudeID tests getting session by Claude ID.
|
||||
func TestHandleGetSessionByClaudeID(t *testing.T) {
|
||||
svc, cleanup := testService(t)
|
||||
defer cleanup()
|
||||
|
||||
// Create a session
|
||||
ctx := context.Background()
|
||||
svc.sessionStore.CreateSDKSession(ctx, "claude-test-123", "project-a", "prompt 1")
|
||||
|
||||
// Test with valid claudeSessionId
|
||||
req := httptest.NewRequest(http.MethodGet, "/api/sessions?claudeSessionId=claude-test-123", nil)
|
||||
rec := httptest.NewRecorder()
|
||||
|
||||
svc.router.ServeHTTP(rec, req)
|
||||
|
||||
assert.Equal(t, http.StatusOK, rec.Code)
|
||||
}
|
||||
|
||||
// TestHandleGetSessionByClaudeID_Missing tests session lookup with missing param.
|
||||
func TestHandleGetSessionByClaudeID_Missing(t *testing.T) {
|
||||
svc, cleanup := testService(t)
|
||||
defer cleanup()
|
||||
|
||||
req := httptest.NewRequest(http.MethodGet, "/api/sessions", nil)
|
||||
rec := httptest.NewRecorder()
|
||||
|
||||
svc.router.ServeHTTP(rec, req)
|
||||
|
||||
assert.Equal(t, http.StatusBadRequest, rec.Code)
|
||||
}
|
||||
|
||||
// TestHandleGetSessionByClaudeID_NotFound tests session not found.
|
||||
func TestHandleGetSessionByClaudeID_NotFound(t *testing.T) {
|
||||
svc, cleanup := testService(t)
|
||||
defer cleanup()
|
||||
|
||||
req := httptest.NewRequest(http.MethodGet, "/api/sessions?claudeSessionId=nonexistent", nil)
|
||||
rec := httptest.NewRecorder()
|
||||
|
||||
svc.router.ServeHTTP(rec, req)
|
||||
|
||||
assert.Equal(t, http.StatusNotFound, rec.Code)
|
||||
}
|
||||
|
||||
// TestGetRetrievalStats tests the retrieval stats getter.
|
||||
func TestGetRetrievalStats(t *testing.T) {
|
||||
svc, cleanup := testService(t)
|
||||
defer cleanup()
|
||||
|
||||
// Initially all zeros
|
||||
stats := svc.GetRetrievalStats()
|
||||
assert.Equal(t, int64(0), stats.TotalRequests)
|
||||
assert.Equal(t, int64(0), stats.SearchRequests)
|
||||
|
||||
// Make some requests to increment stats
|
||||
project := "stats-test"
|
||||
createTestObservation(t, svc.observationStore, project, "Test", "Content", []string{"test"})
|
||||
|
||||
req := httptest.NewRequest(http.MethodGet, "/api/context/search?project="+project+"&query=test", nil)
|
||||
rec := httptest.NewRecorder()
|
||||
svc.router.ServeHTTP(rec, req)
|
||||
|
||||
// Stats should be updated
|
||||
stats = svc.GetRetrievalStats()
|
||||
assert.GreaterOrEqual(t, stats.TotalRequests, int64(1))
|
||||
}
|
||||
|
||||
// TestSelfCheckResponse_Fields tests SelfCheckResponse struct fields.
|
||||
func TestSelfCheckResponse_Fields(t *testing.T) {
|
||||
resp := SelfCheckResponse{
|
||||
Overall: "healthy",
|
||||
Version: "v1.0.0",
|
||||
Uptime: "2h30m",
|
||||
Components: []ComponentHealth{
|
||||
{Name: "database", Status: "healthy", Message: "Connected"},
|
||||
{Name: "vector", Status: "healthy", Message: "Ready"},
|
||||
},
|
||||
}
|
||||
|
||||
assert.Equal(t, "healthy", resp.Overall)
|
||||
assert.Equal(t, "v1.0.0", resp.Version)
|
||||
assert.Equal(t, "2h30m", resp.Uptime)
|
||||
assert.Len(t, resp.Components, 2)
|
||||
assert.Equal(t, "database", resp.Components[0].Name)
|
||||
assert.Equal(t, "healthy", resp.Components[0].Status)
|
||||
assert.Equal(t, "Connected", resp.Components[0].Message)
|
||||
}
|
||||
|
||||
// TestComponentHealth_Fields tests ComponentHealth struct fields.
|
||||
func TestComponentHealth_Fields(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
status string
|
||||
message string
|
||||
}{
|
||||
{"healthy", "healthy", "All systems operational"},
|
||||
{"degraded", "degraded", "Some features unavailable"},
|
||||
{"unhealthy", "unhealthy", "Service is down"},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
health := ComponentHealth{
|
||||
Status: tt.status,
|
||||
Message: tt.message,
|
||||
}
|
||||
assert.Equal(t, tt.status, health.Status)
|
||||
assert.Equal(t, tt.message, health.Message)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// TestWriteJSON_Error tests writeJSON with values that can't be encoded.
|
||||
func TestWriteJSON_Error(t *testing.T) {
|
||||
rec := httptest.NewRecorder()
|
||||
|
||||
// channels can't be JSON encoded
|
||||
ch := make(chan int)
|
||||
writeJSON(rec, ch)
|
||||
|
||||
// Should still set content type but encoding will fail
|
||||
assert.Equal(t, "application/json", rec.Header().Get("Content-Type"))
|
||||
}
|
||||
|
||||
// TestHandleSummarize_InvalidSessionID tests summarize with invalid session ID.
|
||||
func TestHandleSummarize_InvalidSessionID(t *testing.T) {
|
||||
svc, cleanup := testService(t)
|
||||
defer cleanup()
|
||||
|
||||
req := httptest.NewRequest(http.MethodPost, "/sessions/invalid/summarize", bytes.NewReader([]byte("{}")))
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
rec := httptest.NewRecorder()
|
||||
|
||||
svc.router.ServeHTTP(rec, req)
|
||||
|
||||
// Invalid session ID should return 400
|
||||
assert.Equal(t, http.StatusBadRequest, rec.Code)
|
||||
}
|
||||
|
||||
// TestHandleSubagentComplete tests subagent completion endpoint.
|
||||
func TestHandleSubagentComplete(t *testing.T) {
|
||||
svc, cleanup := testService(t)
|
||||
defer cleanup()
|
||||
|
||||
// Create a session first
|
||||
ctx := context.Background()
|
||||
sessionID, _ := svc.sessionStore.CreateSDKSession(ctx, "subagent-test-123", "test-project", "test prompt")
|
||||
|
||||
payload := `{"session_id": ` + strconv.FormatInt(sessionID, 10) + `, "parent_session_id": "parent-123"}`
|
||||
req := httptest.NewRequest(http.MethodPost, "/api/sessions/subagent-complete", bytes.NewReader([]byte(payload)))
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
rec := httptest.NewRecorder()
|
||||
|
||||
svc.router.ServeHTTP(rec, req)
|
||||
|
||||
// Should accept the request
|
||||
assert.Contains(t, []int{http.StatusOK, http.StatusNotFound, http.StatusBadRequest}, rec.Code)
|
||||
}
|
||||
|
||||
// TestHandleContextSearch_Ordering tests search with different orderings.
|
||||
func TestHandleContextSearch_Ordering(t *testing.T) {
|
||||
svc, cleanup := testService(t)
|
||||
defer cleanup()
|
||||
|
||||
project := "order-test"
|
||||
|
||||
for i := 0; i < 5; i++ {
|
||||
createTestObservation(t, svc.observationStore, project,
|
||||
"Obs "+strconv.Itoa(i),
|
||||
"Content "+strconv.Itoa(i),
|
||||
[]string{"test"})
|
||||
time.Sleep(time.Millisecond)
|
||||
}
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
order string
|
||||
}{
|
||||
{"date_desc", "date_desc"},
|
||||
{"date_asc", "date_asc"},
|
||||
{"default", ""}, // Should default to date_desc
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
url := "/api/context/search?project=" + project + "&query=test"
|
||||
if tt.order != "" {
|
||||
url += "&order_by=" + tt.order
|
||||
}
|
||||
|
||||
req := httptest.NewRequest(http.MethodGet, url, nil)
|
||||
rec := httptest.NewRecorder()
|
||||
|
||||
svc.router.ServeHTTP(rec, req)
|
||||
|
||||
assert.Equal(t, http.StatusOK, rec.Code)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// TestHandleContextCount_NoProject tests context count without project.
|
||||
func TestHandleContextCount_NoProject(t *testing.T) {
|
||||
svc, cleanup := testService(t)
|
||||
defer cleanup()
|
||||
|
||||
req := httptest.NewRequest(http.MethodGet, "/api/context/count", nil)
|
||||
rec := httptest.NewRecorder()
|
||||
|
||||
svc.router.ServeHTTP(rec, req)
|
||||
|
||||
assert.Equal(t, http.StatusBadRequest, rec.Code)
|
||||
}
|
||||
|
||||
// TestHandleRetrievalStatsEndpoint tests retrieval stats endpoint.
|
||||
func TestHandleRetrievalStatsEndpoint(t *testing.T) {
|
||||
svc, cleanup := testService(t)
|
||||
defer cleanup()
|
||||
|
||||
req := httptest.NewRequest(http.MethodGet, "/api/stats/retrieval", nil)
|
||||
rec := httptest.NewRecorder()
|
||||
|
||||
svc.router.ServeHTTP(rec, req)
|
||||
|
||||
assert.Equal(t, http.StatusOK, rec.Code)
|
||||
|
||||
var stats RetrievalStats
|
||||
err := json.Unmarshal(rec.Body.Bytes(), &stats)
|
||||
require.NoError(t, err)
|
||||
}
|
||||
|
||||
// TestHandleReadyEndpoint tests ready endpoint response.
|
||||
func TestHandleReadyEndpoint(t *testing.T) {
|
||||
svc, cleanup := testService(t)
|
||||
defer cleanup()
|
||||
|
||||
req := httptest.NewRequest(http.MethodGet, "/api/ready", nil)
|
||||
rec := httptest.NewRecorder()
|
||||
|
||||
svc.router.ServeHTTP(rec, req)
|
||||
|
||||
// Ready may return OK or ServiceUnavailable depending on state
|
||||
assert.Contains(t, []int{http.StatusOK, http.StatusServiceUnavailable}, rec.Code)
|
||||
}
|
||||
|
||||
// TestHandleSessionInit_EmptyBody tests session init with empty body.
|
||||
func TestHandleSessionInit_EmptyBody(t *testing.T) {
|
||||
svc, cleanup := testService(t)
|
||||
defer cleanup()
|
||||
|
||||
payload := `{}`
|
||||
req := httptest.NewRequest(http.MethodPost, "/api/sessions/init", bytes.NewReader([]byte(payload)))
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
rec := httptest.NewRecorder()
|
||||
|
||||
svc.router.ServeHTTP(rec, req)
|
||||
|
||||
// Should accept empty body and create a session
|
||||
assert.Contains(t, []int{http.StatusOK, http.StatusBadRequest}, rec.Code)
|
||||
}
|
||||
|
||||
// TestHandleObservation_MissingSession tests observation without session.
|
||||
func TestHandleObservation_MissingSession(t *testing.T) {
|
||||
svc, cleanup := testService(t)
|
||||
defer cleanup()
|
||||
|
||||
payload := `{"session_id": 99999, "tool_name": "Read", "tool_input": "{}", "tool_output": "test"}`
|
||||
req := httptest.NewRequest(http.MethodPost, "/api/sessions/observations", bytes.NewReader([]byte(payload)))
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
rec := httptest.NewRecorder()
|
||||
|
||||
svc.router.ServeHTTP(rec, req)
|
||||
|
||||
// Should still accept the observation
|
||||
assert.Contains(t, []int{http.StatusOK, http.StatusNotFound, http.StatusBadRequest}, rec.Code)
|
||||
}
|
||||
|
||||
// TestHandleSummaries_Pagination tests summaries pagination.
|
||||
func TestHandleSummaries_Pagination(t *testing.T) {
|
||||
svc, cleanup := testService(t)
|
||||
defer cleanup()
|
||||
|
||||
req := httptest.NewRequest(http.MethodGet, "/api/summaries?project=test&limit=10&offset=0", nil)
|
||||
rec := httptest.NewRecorder()
|
||||
|
||||
svc.router.ServeHTTP(rec, req)
|
||||
|
||||
assert.Equal(t, http.StatusOK, rec.Code)
|
||||
}
|
||||
|
||||
// TestHandlePrompts_Pagination tests prompts pagination.
|
||||
func TestHandlePrompts_Pagination(t *testing.T) {
|
||||
svc, cleanup := testService(t)
|
||||
defer cleanup()
|
||||
|
||||
req := httptest.NewRequest(http.MethodGet, "/api/prompts?project=test&limit=10&offset=0", nil)
|
||||
rec := httptest.NewRecorder()
|
||||
|
||||
svc.router.ServeHTTP(rec, req)
|
||||
|
||||
assert.Equal(t, http.StatusOK, rec.Code)
|
||||
}
|
||||
|
||||
// TestHandleGetStats_AllProjects tests stats without project filter.
|
||||
func TestHandleGetStats_AllProjects(t *testing.T) {
|
||||
svc, cleanup := testService(t)
|
||||
defer cleanup()
|
||||
|
||||
// Create some data in multiple projects
|
||||
createTestObservation(t, svc.observationStore, "proj-a", "Test A", "Content", []string{"test"})
|
||||
createTestObservation(t, svc.observationStore, "proj-b", "Test B", "Content", []string{"test"})
|
||||
|
||||
req := httptest.NewRequest(http.MethodGet, "/api/stats", nil)
|
||||
rec := httptest.NewRecorder()
|
||||
|
||||
svc.router.ServeHTTP(rec, req)
|
||||
|
||||
assert.Equal(t, http.StatusOK, rec.Code)
|
||||
}
|
||||
|
||||
// TestHandleSubagentComplete_WithSession tests subagent completion with existing session.
|
||||
func TestHandleSubagentComplete_WithSession(t *testing.T) {
|
||||
svc, cleanup := testService(t)
|
||||
defer cleanup()
|
||||
|
||||
// Create a session first
|
||||
ctx := context.Background()
|
||||
svc.sessionStore.CreateSDKSession(ctx, "subagent-claude-123", "test-project", "test prompt")
|
||||
|
||||
reqBody := SubagentCompleteRequest{
|
||||
ClaudeSessionID: "subagent-claude-123",
|
||||
Project: "test-project",
|
||||
}
|
||||
|
||||
body, _ := json.Marshal(reqBody)
|
||||
req := httptest.NewRequest(http.MethodPost, "/api/sessions/subagent-complete", bytes.NewReader(body))
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
rec := httptest.NewRecorder()
|
||||
|
||||
svc.router.ServeHTTP(rec, req)
|
||||
|
||||
assert.Equal(t, http.StatusOK, rec.Code)
|
||||
}
|
||||
|
||||
// TestHandleSubagentComplete_NoSession tests subagent completion when session doesn't exist.
|
||||
func TestHandleSubagentComplete_NoSession(t *testing.T) {
|
||||
svc, cleanup := testService(t)
|
||||
defer cleanup()
|
||||
|
||||
reqBody := SubagentCompleteRequest{
|
||||
ClaudeSessionID: "nonexistent-session",
|
||||
Project: "test-project",
|
||||
}
|
||||
|
||||
body, _ := json.Marshal(reqBody)
|
||||
req := httptest.NewRequest(http.MethodPost, "/api/sessions/subagent-complete", bytes.NewReader(body))
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
rec := httptest.NewRecorder()
|
||||
|
||||
svc.router.ServeHTTP(rec, req)
|
||||
|
||||
// Should still return 200 even if session not found
|
||||
assert.Equal(t, http.StatusOK, rec.Code)
|
||||
}
|
||||
|
||||
// TestHandleSubagentComplete_InvalidJSON tests subagent completion with invalid JSON.
|
||||
func TestHandleSubagentComplete_InvalidJSON(t *testing.T) {
|
||||
svc, cleanup := testService(t)
|
||||
defer cleanup()
|
||||
|
||||
req := httptest.NewRequest(http.MethodPost, "/api/sessions/subagent-complete", bytes.NewReader([]byte("invalid json")))
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
rec := httptest.NewRecorder()
|
||||
|
||||
svc.router.ServeHTTP(rec, req)
|
||||
|
||||
assert.Equal(t, http.StatusBadRequest, rec.Code)
|
||||
}
|
||||
|
||||
// TestHandleSummarize_ValidSession tests summarize with valid session.
|
||||
func TestHandleSummarize_ValidSession(t *testing.T) {
|
||||
svc, cleanup := testService(t)
|
||||
defer cleanup()
|
||||
|
||||
// Create a session first
|
||||
ctx := context.Background()
|
||||
sessionID, _ := svc.sessionStore.CreateSDKSession(ctx, "summarize-claude-test", "test-project", "test prompt")
|
||||
|
||||
reqBody := SummarizeRequest{
|
||||
LastUserMessage: "Can you help me fix this bug?",
|
||||
LastAssistantMessage: "I've analyzed the code and fixed the issue in the handler.",
|
||||
}
|
||||
|
||||
body, _ := json.Marshal(reqBody)
|
||||
req := httptest.NewRequest(http.MethodPost, "/sessions/"+strconv.FormatInt(sessionID, 10)+"/summarize", bytes.NewReader(body))
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
rec := httptest.NewRecorder()
|
||||
|
||||
svc.router.ServeHTTP(rec, req)
|
||||
|
||||
assert.Equal(t, http.StatusOK, rec.Code)
|
||||
}
|
||||
|
||||
// TestHandleSummarize_InvalidJSON tests summarize with invalid JSON.
|
||||
func TestHandleSummarize_InvalidJSON(t *testing.T) {
|
||||
svc, cleanup := testService(t)
|
||||
defer cleanup()
|
||||
|
||||
ctx := context.Background()
|
||||
sessionID, _ := svc.sessionStore.CreateSDKSession(ctx, "summarize-invalid", "test-project", "test")
|
||||
|
||||
req := httptest.NewRequest(http.MethodPost, "/sessions/"+strconv.FormatInt(sessionID, 10)+"/summarize", bytes.NewReader([]byte("not valid json")))
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
rec := httptest.NewRecorder()
|
||||
|
||||
svc.router.ServeHTTP(rec, req)
|
||||
|
||||
assert.Equal(t, http.StatusBadRequest, rec.Code)
|
||||
}
|
||||
|
||||
// TestHandleSummarize_NonExistentSession tests summarize with non-existent session.
|
||||
func TestHandleSummarize_NonExistentSession(t *testing.T) {
|
||||
svc, cleanup := testService(t)
|
||||
defer cleanup()
|
||||
|
||||
reqBody := SummarizeRequest{
|
||||
LastUserMessage: "test",
|
||||
LastAssistantMessage: "test",
|
||||
}
|
||||
|
||||
body, _ := json.Marshal(reqBody)
|
||||
req := httptest.NewRequest(http.MethodPost, "/sessions/999999/summarize", bytes.NewReader(body))
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
rec := httptest.NewRecorder()
|
||||
|
||||
svc.router.ServeHTTP(rec, req)
|
||||
|
||||
// Should return error for non-existent session
|
||||
assert.Contains(t, []int{http.StatusOK, http.StatusInternalServerError, http.StatusNotFound}, rec.Code)
|
||||
}
|
||||
|
||||
// TestSubagentCompleteRequest_Fields tests SubagentCompleteRequest struct.
|
||||
func TestSubagentCompleteRequest_Fields(t *testing.T) {
|
||||
req := SubagentCompleteRequest{
|
||||
ClaudeSessionID: "test-session-123",
|
||||
Project: "my-project",
|
||||
}
|
||||
|
||||
assert.Equal(t, "test-session-123", req.ClaudeSessionID)
|
||||
assert.Equal(t, "my-project", req.Project)
|
||||
}
|
||||
|
||||
// TestSummarizeRequest_Fields tests SummarizeRequest struct.
|
||||
func TestSummarizeRequest_Fields(t *testing.T) {
|
||||
req := SummarizeRequest{
|
||||
LastUserMessage: "Help me fix this bug",
|
||||
LastAssistantMessage: "I've fixed the authentication issue",
|
||||
}
|
||||
|
||||
assert.Equal(t, "Help me fix this bug", req.LastUserMessage)
|
||||
assert.Equal(t, "I've fixed the authentication issue", req.LastAssistantMessage)
|
||||
}
|
||||
|
||||
// TestHandleHealth_NotReady tests health endpoint when not ready.
|
||||
func TestHandleHealth_NotReady(t *testing.T) {
|
||||
svc, cleanup := testService(t)
|
||||
defer cleanup()
|
||||
|
||||
svc.ready.Store(false)
|
||||
|
||||
req := httptest.NewRequest(http.MethodGet, "/api/health", nil)
|
||||
rec := httptest.NewRecorder()
|
||||
|
||||
svc.handleHealth(rec, req)
|
||||
|
||||
assert.Equal(t, http.StatusOK, rec.Code)
|
||||
|
||||
var response map[string]interface{}
|
||||
err := json.Unmarshal(rec.Body.Bytes(), &response)
|
||||
require.NoError(t, err)
|
||||
|
||||
assert.Equal(t, "starting", response["status"])
|
||||
}
|
||||
|
||||
// TestHandleContextInject_EmptyProject tests context inject with empty project.
|
||||
func TestHandleContextInject_EmptyProject(t *testing.T) {
|
||||
svc, cleanup := testService(t)
|
||||
defer cleanup()
|
||||
|
||||
req := httptest.NewRequest(http.MethodGet, "/api/context/inject?project=", nil)
|
||||
rec := httptest.NewRecorder()
|
||||
|
||||
svc.router.ServeHTTP(rec, req)
|
||||
|
||||
assert.Equal(t, http.StatusBadRequest, rec.Code)
|
||||
}
|
||||
|
||||
// TestHandleSearchByPrompt_LargeLimit tests search with limit exceeding max.
|
||||
func TestHandleSearchByPrompt_LargeLimit(t *testing.T) {
|
||||
svc, cleanup := testService(t)
|
||||
defer cleanup()
|
||||
|
||||
project := "limit-test"
|
||||
createTestObservation(t, svc.observationStore, project, "Test", "Content", []string{"test"})
|
||||
|
||||
req := httptest.NewRequest(http.MethodGet, "/api/context/search?project="+project+"&query=test&limit=999", nil)
|
||||
rec := httptest.NewRecorder()
|
||||
|
||||
svc.router.ServeHTTP(rec, req)
|
||||
|
||||
assert.Equal(t, http.StatusOK, rec.Code)
|
||||
}
|
||||
|
||||
// TestHandleObservation_WithFullData tests observation with all fields.
|
||||
func TestHandleObservation_WithFullData(t *testing.T) {
|
||||
svc, cleanup := testService(t)
|
||||
defer cleanup()
|
||||
|
||||
// Create a session first
|
||||
ctx := context.Background()
|
||||
svc.sessionStore.CreateSDKSession(ctx, "obs-full-test", "test-project", "test prompt")
|
||||
|
||||
reqBody := ObservationRequest{
|
||||
ClaudeSessionID: "obs-full-test",
|
||||
Project: "test-project",
|
||||
ToolName: "Edit",
|
||||
ToolInput: map[string]interface{}{"file_path": "/test.go", "old_string": "foo", "new_string": "bar"},
|
||||
ToolResponse: "Edit successful",
|
||||
CWD: "/home/user/project",
|
||||
}
|
||||
|
||||
body, _ := json.Marshal(reqBody)
|
||||
req := httptest.NewRequest(http.MethodPost, "/api/sessions/observations", bytes.NewReader(body))
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
rec := httptest.NewRecorder()
|
||||
|
||||
svc.router.ServeHTTP(rec, req)
|
||||
|
||||
assert.Equal(t, http.StatusOK, rec.Code)
|
||||
}
|
||||
|
||||
// TestHandleSelfCheck_WithObservations tests self-check with observations in DB.
|
||||
func TestHandleSelfCheck_WithObservations(t *testing.T) {
|
||||
svc, cleanup := testService(t)
|
||||
defer cleanup()
|
||||
|
||||
svc.ready.Store(true)
|
||||
|
||||
// Create some observations
|
||||
createTestObservation(t, svc.observationStore, "check-project", "Test", "Content", []string{"test"})
|
||||
|
||||
req := httptest.NewRequest(http.MethodGet, "/api/self-check", nil)
|
||||
rec := httptest.NewRecorder()
|
||||
|
||||
svc.handleSelfCheck(rec, req)
|
||||
|
||||
assert.Equal(t, http.StatusOK, rec.Code)
|
||||
|
||||
var response SelfCheckResponse
|
||||
err := json.Unmarshal(rec.Body.Bytes(), &response)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Check components are populated
|
||||
assert.NotEmpty(t, response.Components)
|
||||
}
|
||||
|
||||
// TestHandleGetSummaries_NoProject tests getting summaries without project filter.
|
||||
func TestHandleGetSummaries_NoProject(t *testing.T) {
|
||||
svc, cleanup := testService(t)
|
||||
defer cleanup()
|
||||
|
||||
// Create some summaries in different projects
|
||||
ctx := context.Background()
|
||||
for i := 0; i < 3; i++ {
|
||||
parsed := &models.ParsedSummary{Request: "Request " + string(rune('A'+i))}
|
||||
svc.summaryStore.StoreSummary(ctx, "sdk-"+string(rune('a'+i)), "project-"+string(rune('a'+i)), parsed, i+1, 100)
|
||||
}
|
||||
|
||||
req := httptest.NewRequest(http.MethodGet, "/api/summaries", nil)
|
||||
rec := httptest.NewRecorder()
|
||||
|
||||
svc.router.ServeHTTP(rec, req)
|
||||
|
||||
assert.Equal(t, http.StatusOK, rec.Code)
|
||||
|
||||
var summaries []map[string]interface{}
|
||||
err := json.Unmarshal(rec.Body.Bytes(), &summaries)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Should return all summaries
|
||||
assert.GreaterOrEqual(t, len(summaries), 3)
|
||||
}
|
||||
|
||||
// TestHandleGetPrompts_NoProject tests getting prompts without project filter.
|
||||
func TestHandleGetPrompts_NoProject(t *testing.T) {
|
||||
svc, cleanup := testService(t)
|
||||
defer cleanup()
|
||||
|
||||
// Create sessions and prompts in different projects
|
||||
ctx := context.Background()
|
||||
svc.sessionStore.CreateSDKSession(ctx, "claude-prompts-a", "project-a", "")
|
||||
svc.sessionStore.CreateSDKSession(ctx, "claude-prompts-b", "project-b", "")
|
||||
|
||||
svc.promptStore.SaveUserPromptWithMatches(ctx, "claude-prompts-a", 1, "Prompt A", 0)
|
||||
svc.promptStore.SaveUserPromptWithMatches(ctx, "claude-prompts-b", 1, "Prompt B", 0)
|
||||
|
||||
req := httptest.NewRequest(http.MethodGet, "/api/prompts", nil)
|
||||
rec := httptest.NewRecorder()
|
||||
|
||||
svc.router.ServeHTTP(rec, req)
|
||||
|
||||
assert.Equal(t, http.StatusOK, rec.Code)
|
||||
|
||||
var prompts []map[string]interface{}
|
||||
err := json.Unmarshal(rec.Body.Bytes(), &prompts)
|
||||
require.NoError(t, err)
|
||||
|
||||
assert.GreaterOrEqual(t, len(prompts), 2)
|
||||
}
|
||||
|
||||
// TestHandleSessionInit_MissingClaudeID tests session init without Claude ID.
|
||||
func TestHandleSessionInit_MissingClaudeID(t *testing.T) {
|
||||
svc, cleanup := testService(t)
|
||||
defer cleanup()
|
||||
|
||||
reqBody := SessionInitRequest{
|
||||
Project: "test-project",
|
||||
Prompt: "Help me",
|
||||
}
|
||||
|
||||
body, _ := json.Marshal(reqBody)
|
||||
req := httptest.NewRequest(http.MethodPost, "/api/sessions/init", bytes.NewReader(body))
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
rec := httptest.NewRecorder()
|
||||
|
||||
svc.router.ServeHTTP(rec, req)
|
||||
|
||||
// Should accept even without Claude ID (may auto-generate)
|
||||
assert.Contains(t, []int{http.StatusOK, http.StatusBadRequest}, rec.Code)
|
||||
}
|
||||
|
||||
// TestHandleContextInject_WithQuery tests context inject with query parameter.
|
||||
func TestHandleContextInject_WithQuery(t *testing.T) {
|
||||
svc, cleanup := testService(t)
|
||||
defer cleanup()
|
||||
|
||||
project := "inject-query-test"
|
||||
createTestObservation(t, svc.observationStore, project, "Authentication bug fix", "Fixed JWT validation", []string{"auth", "jwt"})
|
||||
createTestObservation(t, svc.observationStore, project, "Database optimization", "Added indexes", []string{"db", "performance"})
|
||||
|
||||
req := httptest.NewRequest(http.MethodGet, "/api/context/inject?project="+project+"&query=authentication", nil)
|
||||
rec := httptest.NewRecorder()
|
||||
|
||||
svc.router.ServeHTTP(rec, req)
|
||||
|
||||
assert.Equal(t, http.StatusOK, rec.Code)
|
||||
|
||||
var response map[string]interface{}
|
||||
err := json.Unmarshal(rec.Body.Bytes(), &response)
|
||||
require.NoError(t, err)
|
||||
|
||||
observations, ok := response["observations"].([]interface{})
|
||||
require.True(t, ok)
|
||||
assert.GreaterOrEqual(t, len(observations), 1)
|
||||
}
|
||||
|
||||
@@ -1,9 +1,12 @@
|
||||
package sdk
|
||||
|
||||
import (
|
||||
"os"
|
||||
"path/filepath"
|
||||
"testing"
|
||||
|
||||
"github.com/lukaszraczylo/claude-mnemonic/pkg/models"
|
||||
"github.com/stretchr/testify/assert"
|
||||
)
|
||||
|
||||
func TestIsSelfReferentialSummary(t *testing.T) {
|
||||
@@ -124,3 +127,381 @@ No substantive work performed yet.`,
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestShouldSkipTool(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
toolName string
|
||||
expected bool
|
||||
}{
|
||||
// Tools that should be skipped
|
||||
{"TodoWrite", "TodoWrite", true},
|
||||
{"Task", "Task", true},
|
||||
{"TaskOutput", "TaskOutput", true},
|
||||
{"Glob", "Glob", true},
|
||||
{"ListDir", "ListDir", true},
|
||||
{"LS", "LS", true},
|
||||
{"KillShell", "KillShell", true},
|
||||
{"AskUserQuestion", "AskUserQuestion", true},
|
||||
{"EnterPlanMode", "EnterPlanMode", true},
|
||||
{"ExitPlanMode", "ExitPlanMode", true},
|
||||
{"Skill", "Skill", true},
|
||||
{"SlashCommand", "SlashCommand", true},
|
||||
|
||||
// Tools that should NOT be skipped
|
||||
{"Read", "Read", false},
|
||||
{"Edit", "Edit", false},
|
||||
{"Write", "Write", false},
|
||||
{"Grep", "Grep", false},
|
||||
{"Bash", "Bash", false},
|
||||
{"WebFetch", "WebFetch", false},
|
||||
{"WebSearch", "WebSearch", false},
|
||||
{"NotebookEdit", "NotebookEdit", false},
|
||||
|
||||
// Unknown tool (should not be skipped)
|
||||
{"UnknownTool", "SomeUnknownTool", false},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
result := shouldSkipTool(tt.toolName)
|
||||
assert.Equal(t, tt.expected, result)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestShouldSkipTrivialOperation(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
toolName string
|
||||
inputStr string
|
||||
outputStr string
|
||||
expected bool
|
||||
}{
|
||||
// Short output (should be skipped)
|
||||
{
|
||||
name: "output_too_short",
|
||||
toolName: "Read",
|
||||
inputStr: `{"file_path": "/some/file.go"}`,
|
||||
outputStr: "short",
|
||||
expected: true,
|
||||
},
|
||||
// Trivial outputs
|
||||
{
|
||||
name: "no_matches_found",
|
||||
toolName: "Grep",
|
||||
inputStr: `{"pattern": "foo"}`,
|
||||
outputStr: "No matches found in the codebase",
|
||||
expected: true,
|
||||
},
|
||||
{
|
||||
name: "file_not_found",
|
||||
toolName: "Read",
|
||||
inputStr: `{"file_path": "/nonexistent.go"}`,
|
||||
outputStr: "Error: File not found at specified path",
|
||||
expected: true,
|
||||
},
|
||||
{
|
||||
name: "empty_array",
|
||||
toolName: "Grep",
|
||||
inputStr: `{"pattern": "foo"}`,
|
||||
outputStr: "[]",
|
||||
expected: true,
|
||||
},
|
||||
// Boring files
|
||||
{
|
||||
name: "package_lock_json",
|
||||
toolName: "Read",
|
||||
inputStr: `{"file_path": "/project/package-lock.json"}`,
|
||||
outputStr: "This is a very long package-lock.json content that has more than 50 characters",
|
||||
expected: true,
|
||||
},
|
||||
{
|
||||
name: "go_sum",
|
||||
toolName: "Read",
|
||||
inputStr: `{"file_path": "/project/go.sum"}`,
|
||||
outputStr: "This is a very long go.sum file content that has more than 50 characters",
|
||||
expected: true,
|
||||
},
|
||||
// Grep with too many matches
|
||||
{
|
||||
name: "grep_too_many_matches",
|
||||
toolName: "Grep",
|
||||
inputStr: `{"pattern": "import"}`,
|
||||
outputStr: func() string {
|
||||
s := ""
|
||||
for i := 0; i < 55; i++ {
|
||||
s += "match line\n"
|
||||
}
|
||||
return s
|
||||
}(),
|
||||
expected: true,
|
||||
},
|
||||
// Boring Bash commands
|
||||
{
|
||||
name: "git_status",
|
||||
toolName: "Bash",
|
||||
inputStr: `{"command": "git status"}`,
|
||||
outputStr: "On branch main\nYour branch is up to date with 'origin/main'.\nnothing to commit, working tree clean",
|
||||
expected: true,
|
||||
},
|
||||
{
|
||||
name: "ls_command",
|
||||
toolName: "Bash",
|
||||
inputStr: `{"command": "ls -la /some/directory"}`,
|
||||
outputStr: "total 123\ndrwxr-xr-x some long listing that is at least 50 chars",
|
||||
expected: true,
|
||||
},
|
||||
// Valid operations that should NOT be skipped
|
||||
{
|
||||
name: "valid_read",
|
||||
toolName: "Read",
|
||||
inputStr: `{"file_path": "/project/main.go"}`,
|
||||
outputStr: "package main\n\nimport \"fmt\"\n\nfunc main() {\n\tfmt.Println(\"Hello World\")\n}",
|
||||
expected: false,
|
||||
},
|
||||
{
|
||||
name: "valid_edit",
|
||||
toolName: "Edit",
|
||||
inputStr: `{"file_path": "/project/handler.go", "old_string": "foo", "new_string": "bar"}`,
|
||||
outputStr: "Edit applied successfully. File /project/handler.go has been modified with the requested changes.",
|
||||
expected: false,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
result := shouldSkipTrivialOperation(tt.toolName, tt.inputStr, tt.outputStr)
|
||||
assert.Equal(t, tt.expected, result)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestTruncateForLog(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
input string
|
||||
maxLen int
|
||||
expected string
|
||||
}{
|
||||
{
|
||||
name: "shorter_than_max",
|
||||
input: "hello",
|
||||
maxLen: 10,
|
||||
expected: "hello",
|
||||
},
|
||||
{
|
||||
name: "equal_to_max",
|
||||
input: "hello",
|
||||
maxLen: 5,
|
||||
expected: "hello",
|
||||
},
|
||||
{
|
||||
name: "longer_than_max",
|
||||
input: "hello world",
|
||||
maxLen: 5,
|
||||
expected: "hello...",
|
||||
},
|
||||
{
|
||||
name: "empty_string",
|
||||
input: "",
|
||||
maxLen: 5,
|
||||
expected: "",
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
result := truncateForLog(tt.input, tt.maxLen)
|
||||
assert.Equal(t, tt.expected, result)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestToJSONString(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
input interface{}
|
||||
expected string
|
||||
}{
|
||||
{
|
||||
name: "nil_value",
|
||||
input: nil,
|
||||
expected: "",
|
||||
},
|
||||
{
|
||||
name: "string_value",
|
||||
input: "hello",
|
||||
expected: "hello",
|
||||
},
|
||||
{
|
||||
name: "int_value",
|
||||
input: 42,
|
||||
expected: "42",
|
||||
},
|
||||
{
|
||||
name: "map_value",
|
||||
input: map[string]string{"key": "value"},
|
||||
expected: `{"key":"value"}`,
|
||||
},
|
||||
{
|
||||
name: "slice_value",
|
||||
input: []string{"a", "b", "c"},
|
||||
expected: `["a","b","c"]`,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
result := toJSONString(tt.input)
|
||||
assert.Equal(t, tt.expected, result)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestCaptureFileMtimes(t *testing.T) {
|
||||
// Create a temp directory with test files
|
||||
tmpDir, err := os.MkdirTemp("", "mtime-test-*")
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
defer os.RemoveAll(tmpDir)
|
||||
|
||||
// Create test files
|
||||
file1 := filepath.Join(tmpDir, "file1.txt")
|
||||
file2 := filepath.Join(tmpDir, "file2.txt")
|
||||
|
||||
err = os.WriteFile(file1, []byte("content1"), 0644)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
err = os.WriteFile(file2, []byte("content2"), 0644)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
t.Run("captures_mtimes_for_existing_files", func(t *testing.T) {
|
||||
mtimes := captureFileMtimes([]string{file1}, []string{file2}, "")
|
||||
assert.Len(t, mtimes, 2)
|
||||
assert.Contains(t, mtimes, file1)
|
||||
assert.Contains(t, mtimes, file2)
|
||||
assert.Greater(t, mtimes[file1], int64(0))
|
||||
assert.Greater(t, mtimes[file2], int64(0))
|
||||
})
|
||||
|
||||
t.Run("handles_nonexistent_files", func(t *testing.T) {
|
||||
mtimes := captureFileMtimes([]string{"/nonexistent/file.txt"}, nil, "")
|
||||
assert.Empty(t, mtimes)
|
||||
})
|
||||
|
||||
t.Run("handles_relative_paths_with_cwd", func(t *testing.T) {
|
||||
mtimes := captureFileMtimes([]string{"file1.txt"}, nil, tmpDir)
|
||||
assert.Len(t, mtimes, 1)
|
||||
assert.Contains(t, mtimes, "file1.txt")
|
||||
})
|
||||
|
||||
t.Run("empty_inputs", func(t *testing.T) {
|
||||
mtimes := captureFileMtimes(nil, nil, "")
|
||||
assert.Empty(t, mtimes)
|
||||
})
|
||||
}
|
||||
|
||||
func TestGetFileMtimes(t *testing.T) {
|
||||
// Create a temp file
|
||||
tmpDir, err := os.MkdirTemp("", "getmtime-test-*")
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
defer os.RemoveAll(tmpDir)
|
||||
|
||||
testFile := filepath.Join(tmpDir, "test.txt")
|
||||
err = os.WriteFile(testFile, []byte("content"), 0644)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
mtimes := GetFileMtimes([]string{testFile}, "")
|
||||
assert.Len(t, mtimes, 1)
|
||||
assert.Contains(t, mtimes, testFile)
|
||||
assert.Greater(t, mtimes[testFile], int64(0))
|
||||
}
|
||||
|
||||
func TestGetFileContent(t *testing.T) {
|
||||
// Create a temp directory with test files
|
||||
tmpDir, err := os.MkdirTemp("", "content-test-*")
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
defer os.RemoveAll(tmpDir)
|
||||
|
||||
t.Run("reads_existing_file", func(t *testing.T) {
|
||||
testFile := filepath.Join(tmpDir, "test.txt")
|
||||
content := "test content"
|
||||
err := os.WriteFile(testFile, []byte(content), 0644)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
result, ok := GetFileContent(testFile, "")
|
||||
assert.True(t, ok)
|
||||
assert.Equal(t, content, result)
|
||||
})
|
||||
|
||||
t.Run("returns_false_for_nonexistent_file", func(t *testing.T) {
|
||||
result, ok := GetFileContent("/nonexistent/file.txt", "")
|
||||
assert.False(t, ok)
|
||||
assert.Empty(t, result)
|
||||
})
|
||||
|
||||
t.Run("truncates_long_content", func(t *testing.T) {
|
||||
testFile := filepath.Join(tmpDir, "long.txt")
|
||||
longContent := ""
|
||||
for i := 0; i < 3000; i++ {
|
||||
longContent += "x"
|
||||
}
|
||||
err := os.WriteFile(testFile, []byte(longContent), 0644)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
result, ok := GetFileContent(testFile, "")
|
||||
assert.True(t, ok)
|
||||
assert.Contains(t, result, "[truncated]")
|
||||
assert.LessOrEqual(t, len(result), 2100)
|
||||
})
|
||||
|
||||
t.Run("resolves_relative_path_with_cwd", func(t *testing.T) {
|
||||
testFile := filepath.Join(tmpDir, "relative.txt")
|
||||
content := "relative content"
|
||||
err := os.WriteFile(testFile, []byte(content), 0644)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
result, ok := GetFileContent("relative.txt", tmpDir)
|
||||
assert.True(t, ok)
|
||||
assert.Equal(t, content, result)
|
||||
})
|
||||
}
|
||||
|
||||
func TestMaxConcurrentCLICalls(t *testing.T) {
|
||||
assert.Equal(t, 4, MaxConcurrentCLICalls)
|
||||
}
|
||||
|
||||
func TestObservationTypes(t *testing.T) {
|
||||
expected := []string{"bugfix", "feature", "refactor", "change", "discovery", "decision"}
|
||||
assert.Equal(t, expected, ObservationTypes)
|
||||
}
|
||||
|
||||
func TestObservationConcepts(t *testing.T) {
|
||||
expectedConcepts := []string{
|
||||
"how-it-works",
|
||||
"why-it-exists",
|
||||
"what-changed",
|
||||
"problem-solution",
|
||||
"gotcha",
|
||||
"pattern",
|
||||
"trade-off",
|
||||
}
|
||||
assert.Equal(t, expectedConcepts, ObservationConcepts)
|
||||
}
|
||||
|
||||
@@ -0,0 +1,276 @@
|
||||
package sdk
|
||||
|
||||
import (
|
||||
"strings"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
)
|
||||
|
||||
func TestTruncate(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
input string
|
||||
maxLen int
|
||||
expected string
|
||||
}{
|
||||
{
|
||||
name: "shorter_than_max",
|
||||
input: "hello",
|
||||
maxLen: 10,
|
||||
expected: "hello",
|
||||
},
|
||||
{
|
||||
name: "equal_to_max",
|
||||
input: "hello",
|
||||
maxLen: 5,
|
||||
expected: "hello",
|
||||
},
|
||||
{
|
||||
name: "longer_than_max",
|
||||
input: "hello world",
|
||||
maxLen: 5,
|
||||
expected: "hello... (truncated)",
|
||||
},
|
||||
{
|
||||
name: "empty_string",
|
||||
input: "",
|
||||
maxLen: 5,
|
||||
expected: "",
|
||||
},
|
||||
{
|
||||
name: "zero_max_length",
|
||||
input: "hello",
|
||||
maxLen: 0,
|
||||
expected: "... (truncated)",
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
result := truncate(tt.input, tt.maxLen)
|
||||
assert.Equal(t, tt.expected, result)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestBuildObservationPrompt(t *testing.T) {
|
||||
now := time.Now().UnixMilli()
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
exec ToolExecution
|
||||
contains []string
|
||||
}{
|
||||
{
|
||||
name: "basic_read_tool",
|
||||
exec: ToolExecution{
|
||||
ID: 1,
|
||||
ToolName: "Read",
|
||||
ToolInput: `{"file_path": "/path/to/file.go"}`,
|
||||
ToolOutput: `package main\nfunc main() {}`,
|
||||
CreatedAtEpoch: now,
|
||||
CWD: "/project",
|
||||
},
|
||||
contains: []string{
|
||||
"<observed_from_primary_session>",
|
||||
"<what_happened>Read</what_happened>",
|
||||
"<working_directory>/project</working_directory>",
|
||||
"<parameters>",
|
||||
"file_path",
|
||||
"<outcome>",
|
||||
"</observed_from_primary_session>",
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "edit_tool_with_json_input",
|
||||
exec: ToolExecution{
|
||||
ID: 2,
|
||||
ToolName: "Edit",
|
||||
ToolInput: `{"file_path": "/file.go", "old_string": "foo", "new_string": "bar"}`,
|
||||
ToolOutput: "Edit applied successfully",
|
||||
CreatedAtEpoch: now,
|
||||
CWD: "",
|
||||
},
|
||||
contains: []string{
|
||||
"<what_happened>Edit</what_happened>",
|
||||
"file_path",
|
||||
"old_string",
|
||||
"new_string",
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "no_cwd",
|
||||
exec: ToolExecution{
|
||||
ID: 3,
|
||||
ToolName: "Bash",
|
||||
ToolInput: `{"command": "go test"}`,
|
||||
ToolOutput: "ok",
|
||||
CreatedAtEpoch: now,
|
||||
CWD: "",
|
||||
},
|
||||
contains: []string{
|
||||
"<what_happened>Bash</what_happened>",
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
result := BuildObservationPrompt(tt.exec)
|
||||
|
||||
for _, s := range tt.contains {
|
||||
assert.Contains(t, result, s, "Expected result to contain: %s", s)
|
||||
}
|
||||
|
||||
// Check CWD only appears when set
|
||||
if tt.exec.CWD == "" {
|
||||
assert.NotContains(t, result, "<working_directory>")
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestBuildObservationPrompt_TruncatesLongContent(t *testing.T) {
|
||||
longInput := strings.Repeat("x", 5000)
|
||||
longOutput := strings.Repeat("y", 7000)
|
||||
|
||||
exec := ToolExecution{
|
||||
ID: 1,
|
||||
ToolName: "Read",
|
||||
ToolInput: longInput,
|
||||
ToolOutput: longOutput,
|
||||
CreatedAtEpoch: time.Now().UnixMilli(),
|
||||
CWD: "/project",
|
||||
}
|
||||
|
||||
result := BuildObservationPrompt(exec)
|
||||
|
||||
// Input should be truncated to ~3000
|
||||
assert.Contains(t, result, "truncated")
|
||||
// The result should not be excessively long
|
||||
assert.Less(t, len(result), 10000)
|
||||
}
|
||||
|
||||
func TestBuildSummaryPrompt(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
req SummaryRequest
|
||||
contains []string
|
||||
}{
|
||||
{
|
||||
name: "basic_request",
|
||||
req: SummaryRequest{
|
||||
SessionDBID: 1,
|
||||
SDKSessionID: "sdk-123",
|
||||
Project: "test-project",
|
||||
},
|
||||
contains: []string{
|
||||
"PROGRESS SUMMARY CHECKPOINT",
|
||||
"<summary>",
|
||||
"<request>",
|
||||
"<investigated>",
|
||||
"<learned>",
|
||||
"<completed>",
|
||||
"<next_steps>",
|
||||
"<notes>",
|
||||
"</summary>",
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "with_assistant_message",
|
||||
req: SummaryRequest{
|
||||
SessionDBID: 2,
|
||||
SDKSessionID: "sdk-456",
|
||||
Project: "project-b",
|
||||
LastAssistantMessage: "I fixed the authentication bug by updating the JWT validation.",
|
||||
},
|
||||
contains: []string{
|
||||
"Claude's Full Response to User:",
|
||||
"fixed the authentication",
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "empty_assistant_message",
|
||||
req: SummaryRequest{
|
||||
SessionDBID: 3,
|
||||
SDKSessionID: "sdk-789",
|
||||
Project: "project-c",
|
||||
LastAssistantMessage: "",
|
||||
},
|
||||
contains: []string{
|
||||
"PROGRESS SUMMARY CHECKPOINT",
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
result := BuildSummaryPrompt(tt.req)
|
||||
|
||||
for _, s := range tt.contains {
|
||||
assert.Contains(t, result, s, "Expected result to contain: %s", s)
|
||||
}
|
||||
|
||||
// Check assistant message only appears when set
|
||||
if tt.req.LastAssistantMessage == "" {
|
||||
assert.NotContains(t, result, "Claude's Full Response")
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestBuildSummaryPrompt_TruncatesLongAssistantMessage(t *testing.T) {
|
||||
longMessage := strings.Repeat("a", 5000)
|
||||
|
||||
req := SummaryRequest{
|
||||
SessionDBID: 1,
|
||||
SDKSessionID: "sdk-123",
|
||||
Project: "test",
|
||||
LastAssistantMessage: longMessage,
|
||||
}
|
||||
|
||||
result := BuildSummaryPrompt(req)
|
||||
|
||||
// Should contain truncation indicator
|
||||
assert.Contains(t, result, "truncated")
|
||||
// Result should be reasonable length (less than full 5000 + overhead)
|
||||
assert.Less(t, len(result), 6000)
|
||||
}
|
||||
|
||||
func TestToolExecution_Struct(t *testing.T) {
|
||||
exec := ToolExecution{
|
||||
ID: 42,
|
||||
ToolName: "Write",
|
||||
ToolInput: `{"file_path": "/test.go"}`,
|
||||
ToolOutput: "File written",
|
||||
CreatedAtEpoch: 1234567890000,
|
||||
CWD: "/workspace",
|
||||
}
|
||||
|
||||
assert.Equal(t, int64(42), exec.ID)
|
||||
assert.Equal(t, "Write", exec.ToolName)
|
||||
assert.Equal(t, `{"file_path": "/test.go"}`, exec.ToolInput)
|
||||
assert.Equal(t, "File written", exec.ToolOutput)
|
||||
assert.Equal(t, int64(1234567890000), exec.CreatedAtEpoch)
|
||||
assert.Equal(t, "/workspace", exec.CWD)
|
||||
}
|
||||
|
||||
func TestSummaryRequest_Struct(t *testing.T) {
|
||||
req := SummaryRequest{
|
||||
SessionDBID: 100,
|
||||
SDKSessionID: "sdk-abc",
|
||||
Project: "my-project",
|
||||
UserPrompt: "Fix the bug",
|
||||
LastUserMessage: "Please fix the auth bug",
|
||||
LastAssistantMessage: "I've fixed the authentication issue",
|
||||
}
|
||||
|
||||
assert.Equal(t, int64(100), req.SessionDBID)
|
||||
assert.Equal(t, "sdk-abc", req.SDKSessionID)
|
||||
assert.Equal(t, "my-project", req.Project)
|
||||
assert.Equal(t, "Fix the bug", req.UserPrompt)
|
||||
assert.Equal(t, "Please fix the auth bug", req.LastUserMessage)
|
||||
assert.Equal(t, "I've fixed the authentication issue", req.LastAssistantMessage)
|
||||
}
|
||||
@@ -710,3 +710,495 @@ func TestGetWorkerVersion_WithServer(t *testing.T) {
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// TestGetWorkerPort_EdgeCases tests GetWorkerPort with various edge cases.
|
||||
func TestGetWorkerPort_EdgeCases(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
envValue string
|
||||
expectedPort int
|
||||
shouldSetEnv bool
|
||||
}{
|
||||
{
|
||||
name: "zero port uses default",
|
||||
envValue: "0",
|
||||
expectedPort: DefaultWorkerPort,
|
||||
shouldSetEnv: true,
|
||||
},
|
||||
{
|
||||
name: "negative port uses default",
|
||||
envValue: "-1",
|
||||
expectedPort: DefaultWorkerPort,
|
||||
shouldSetEnv: true,
|
||||
},
|
||||
{
|
||||
name: "empty string uses default",
|
||||
envValue: "",
|
||||
expectedPort: DefaultWorkerPort,
|
||||
shouldSetEnv: true,
|
||||
},
|
||||
{
|
||||
name: "whitespace uses default",
|
||||
envValue: " ",
|
||||
expectedPort: DefaultWorkerPort,
|
||||
shouldSetEnv: true,
|
||||
},
|
||||
{
|
||||
name: "large valid port",
|
||||
envValue: "65535",
|
||||
expectedPort: 65535,
|
||||
shouldSetEnv: true,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
if tt.shouldSetEnv {
|
||||
t.Setenv("CLAUDE_MNEMONIC_WORKER_PORT", tt.envValue)
|
||||
}
|
||||
port := GetWorkerPort()
|
||||
assert.Equal(t, tt.expectedPort, port)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// TestVersionVariable tests the Version variable.
|
||||
func TestVersionVariable(t *testing.T) {
|
||||
// Version is set at build time, but defaults to "dev"
|
||||
assert.NotEmpty(t, Version)
|
||||
}
|
||||
|
||||
// TestProjectIDWithName_RootPath tests ProjectIDWithName with root path.
|
||||
func TestProjectIDWithName_RootPath(t *testing.T) {
|
||||
result := ProjectIDWithName("/")
|
||||
// Should handle root path gracefully
|
||||
assert.NotEmpty(t, result)
|
||||
assert.Contains(t, result, "_") // Should still have underscore separator
|
||||
}
|
||||
|
||||
// TestProjectIDWithName_SameDirname tests that same dirname with different paths get different IDs.
|
||||
func TestProjectIDWithName_SameDirname(t *testing.T) {
|
||||
id1 := ProjectIDWithName("/home/user1/project")
|
||||
id2 := ProjectIDWithName("/home/user2/project")
|
||||
|
||||
// Both have same dirname "project" but different full paths
|
||||
assert.Contains(t, id1, "project_")
|
||||
assert.Contains(t, id2, "project_")
|
||||
|
||||
// But different hashes due to different full paths
|
||||
assert.NotEqual(t, id1, id2)
|
||||
}
|
||||
|
||||
// TestBaseInput_PartialFields tests BaseInput with partial fields.
|
||||
func TestBaseInput_PartialFields(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
input string
|
||||
expected BaseInput
|
||||
}{
|
||||
{
|
||||
name: "only session_id",
|
||||
input: `{"session_id":"test-123"}`,
|
||||
expected: BaseInput{SessionID: "test-123"},
|
||||
},
|
||||
{
|
||||
name: "only cwd",
|
||||
input: `{"cwd":"/tmp/test"}`,
|
||||
expected: BaseInput{CWD: "/tmp/test"},
|
||||
},
|
||||
{
|
||||
name: "empty object",
|
||||
input: `{}`,
|
||||
expected: BaseInput{},
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
var base BaseInput
|
||||
err := json.Unmarshal([]byte(tt.input), &base)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, tt.expected.SessionID, base.SessionID)
|
||||
assert.Equal(t, tt.expected.CWD, base.CWD)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// TestHookResponse_Marshal tests HookResponse JSON marshaling.
|
||||
func TestHookResponse_Marshal(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
response HookResponse
|
||||
contains []string
|
||||
}{
|
||||
{
|
||||
name: "continue true",
|
||||
response: HookResponse{Continue: true},
|
||||
contains: []string{`"continue":true`},
|
||||
},
|
||||
{
|
||||
name: "continue false",
|
||||
response: HookResponse{Continue: false},
|
||||
contains: []string{`"continue":false`},
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
data, err := json.Marshal(tt.response)
|
||||
require.NoError(t, err)
|
||||
for _, s := range tt.contains {
|
||||
assert.Contains(t, string(data), s)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// TestHookResponse_Unmarshal tests HookResponse JSON unmarshaling.
|
||||
func TestHookResponse_Unmarshal(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
input string
|
||||
expected HookResponse
|
||||
}{
|
||||
{
|
||||
name: "continue true",
|
||||
input: `{"continue":true}`,
|
||||
expected: HookResponse{Continue: true},
|
||||
},
|
||||
{
|
||||
name: "continue false",
|
||||
input: `{"continue":false}`,
|
||||
expected: HookResponse{Continue: false},
|
||||
},
|
||||
{
|
||||
name: "missing continue defaults to false",
|
||||
input: `{}`,
|
||||
expected: HookResponse{Continue: false},
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
var resp HookResponse
|
||||
err := json.Unmarshal([]byte(tt.input), &resp)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, tt.expected.Continue, resp.Continue)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// TestHookContext_Initialization tests HookContext struct initialization.
|
||||
func TestHookContext_Initialization(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
ctx HookContext
|
||||
}{
|
||||
{
|
||||
name: "full context",
|
||||
ctx: HookContext{
|
||||
HookName: "session-start",
|
||||
Port: 37777,
|
||||
Project: "my-project_abc123",
|
||||
SessionID: "session-123",
|
||||
CWD: "/home/user/project",
|
||||
RawInput: []byte(`{"key":"value"}`),
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "minimal context",
|
||||
ctx: HookContext{
|
||||
HookName: "stop",
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "empty context",
|
||||
ctx: HookContext{},
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
// Just verify the struct can be created and accessed
|
||||
assert.Equal(t, tt.ctx.HookName, tt.ctx.HookName)
|
||||
assert.Equal(t, tt.ctx.Port, tt.ctx.Port)
|
||||
assert.Equal(t, tt.ctx.Project, tt.ctx.Project)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// TestPOST_MarshalError tests POST with unmarshalable body.
|
||||
func TestPOST_MarshalError(t *testing.T) {
|
||||
// Create a value that can't be marshaled
|
||||
badValue := make(chan int)
|
||||
|
||||
_, err := POST(99999, "/test", badValue)
|
||||
require.Error(t, err)
|
||||
}
|
||||
|
||||
// TestPOST_Timeout tests POST with timeout.
|
||||
func TestPOST_Timeout(t *testing.T) {
|
||||
// Try to connect to a port that's not listening
|
||||
_, err := POST(99998, "/test", map[string]string{"key": "value"})
|
||||
require.Error(t, err)
|
||||
}
|
||||
|
||||
// TestGET_Timeout tests GET with timeout.
|
||||
func TestGET_Timeout(t *testing.T) {
|
||||
// Try to connect to a port that's not listening
|
||||
_, err := GET(99998, "/test")
|
||||
require.Error(t, err)
|
||||
}
|
||||
|
||||
// TestIsWorkerRunning_Timeout tests IsWorkerRunning with timeout.
|
||||
func TestIsWorkerRunning_Timeout(t *testing.T) {
|
||||
// Non-existent port should quickly return false
|
||||
start := time.Now()
|
||||
result := IsWorkerRunning(99997)
|
||||
elapsed := time.Since(start)
|
||||
|
||||
assert.False(t, result)
|
||||
assert.Less(t, elapsed, 5*time.Second) // Should not hang
|
||||
}
|
||||
|
||||
// TestIsPortInUse_Timeout tests IsPortInUse with timeout.
|
||||
func TestIsPortInUse_Timeout(t *testing.T) {
|
||||
// Non-existent port should quickly return false
|
||||
start := time.Now()
|
||||
result := IsPortInUse(99996)
|
||||
elapsed := time.Since(start)
|
||||
|
||||
assert.False(t, result)
|
||||
assert.Less(t, elapsed, 2*time.Second) // Should not hang
|
||||
}
|
||||
|
||||
// TestGetWorkerVersion_Timeout tests GetWorkerVersion with timeout.
|
||||
func TestGetWorkerVersion_Timeout(t *testing.T) {
|
||||
// Non-existent port should quickly return empty
|
||||
start := time.Now()
|
||||
result := GetWorkerVersion(99995)
|
||||
elapsed := time.Since(start)
|
||||
|
||||
assert.Empty(t, result)
|
||||
assert.Less(t, elapsed, 5*time.Second) // Should not hang
|
||||
}
|
||||
|
||||
// TestVersionsCompatible_EdgeCases tests versionsCompatible edge cases.
|
||||
func TestVersionsCompatible_EdgeCases(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
v1 string
|
||||
v2 string
|
||||
expected bool
|
||||
}{
|
||||
{
|
||||
name: "empty versions",
|
||||
v1: "",
|
||||
v2: "",
|
||||
expected: true, // Same base (empty)
|
||||
},
|
||||
{
|
||||
name: "one empty one dev",
|
||||
v1: "",
|
||||
v2: "dev",
|
||||
expected: true, // dev is compatible with anything
|
||||
},
|
||||
{
|
||||
name: "prerelease versions same base",
|
||||
v1: "v1.0.0-alpha",
|
||||
v2: "v1.0.0-beta",
|
||||
expected: true, // Same base 1.0.0
|
||||
},
|
||||
{
|
||||
name: "version with rc suffix",
|
||||
v1: "v2.0.0-rc1",
|
||||
v2: "v2.0.0",
|
||||
expected: true, // Same base 2.0.0
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
result := versionsCompatible(tt.v1, tt.v2)
|
||||
assert.Equal(t, tt.expected, result)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// TestExtractBaseVersion_EdgeCases tests extractBaseVersion edge cases.
|
||||
func TestExtractBaseVersion_EdgeCases(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
version string
|
||||
expected string
|
||||
}{
|
||||
{
|
||||
name: "version starting with hyphen",
|
||||
version: "-dirty",
|
||||
expected: "-dirty", // hyphen at index 0 is not > 0, so no truncation
|
||||
},
|
||||
{
|
||||
name: "just v",
|
||||
version: "v",
|
||||
expected: "",
|
||||
},
|
||||
{
|
||||
name: "multiple hyphens",
|
||||
version: "v1.0.0-alpha-beta-gamma",
|
||||
expected: "1.0.0",
|
||||
},
|
||||
{
|
||||
name: "no hyphen at all",
|
||||
version: "v2.0.0",
|
||||
expected: "2.0.0",
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
result := extractBaseVersion(tt.version)
|
||||
assert.Equal(t, tt.expected, result)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// TestProjectIDWithName_RelativePath tests ProjectIDWithName with relative paths.
|
||||
func TestProjectIDWithName_RelativePath(t *testing.T) {
|
||||
// Relative paths should be converted to absolute
|
||||
result := ProjectIDWithName(".")
|
||||
assert.NotEmpty(t, result)
|
||||
assert.Contains(t, result, "_")
|
||||
}
|
||||
|
||||
// TestProjectIDWithName_DeepPath tests ProjectIDWithName with deep paths.
|
||||
func TestProjectIDWithName_DeepPath(t *testing.T) {
|
||||
result := ProjectIDWithName("/a/very/deep/nested/path/to/project")
|
||||
assert.Contains(t, result, "project_")
|
||||
assert.NotEmpty(t, result)
|
||||
}
|
||||
|
||||
// TestPOST_EmptyBody tests POST with empty body.
|
||||
func TestPOST_EmptyBody(t *testing.T) {
|
||||
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
assert.Equal(t, http.MethodPost, r.Method)
|
||||
w.WriteHeader(http.StatusOK)
|
||||
json.NewEncoder(w).Encode(map[string]string{"status": "ok"})
|
||||
}))
|
||||
defer server.Close()
|
||||
|
||||
var port int
|
||||
_, err := fmt.Sscanf(server.URL, "http://127.0.0.1:%d", &port)
|
||||
require.NoError(t, err)
|
||||
|
||||
result, err := POST(port, "/test", map[string]string{})
|
||||
require.NoError(t, err)
|
||||
assert.NotNil(t, result)
|
||||
}
|
||||
|
||||
// TestGET_WithQueryParams tests GET with query parameters.
|
||||
func TestGET_WithQueryParams(t *testing.T) {
|
||||
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
assert.Equal(t, http.MethodGet, r.Method)
|
||||
assert.Equal(t, "/test?foo=bar", r.URL.String())
|
||||
w.WriteHeader(http.StatusOK)
|
||||
json.NewEncoder(w).Encode(map[string]string{"status": "ok"})
|
||||
}))
|
||||
defer server.Close()
|
||||
|
||||
var port int
|
||||
_, err := fmt.Sscanf(server.URL, "http://127.0.0.1:%d", &port)
|
||||
require.NoError(t, err)
|
||||
|
||||
result, err := GET(port, "/test?foo=bar")
|
||||
require.NoError(t, err)
|
||||
assert.NotNil(t, result)
|
||||
}
|
||||
|
||||
// TestHookResponse_RoundTrip tests JSON marshal/unmarshal round-trip.
|
||||
func TestHookResponse_RoundTrip(t *testing.T) {
|
||||
original := HookResponse{Continue: true}
|
||||
|
||||
data, err := json.Marshal(original)
|
||||
require.NoError(t, err)
|
||||
|
||||
var decoded HookResponse
|
||||
err = json.Unmarshal(data, &decoded)
|
||||
require.NoError(t, err)
|
||||
|
||||
assert.Equal(t, original.Continue, decoded.Continue)
|
||||
}
|
||||
|
||||
// TestBaseInput_RoundTrip tests BaseInput JSON round-trip.
|
||||
func TestBaseInput_RoundTrip(t *testing.T) {
|
||||
original := BaseInput{
|
||||
SessionID: "test-session",
|
||||
CWD: "/home/user/project",
|
||||
PermissionMode: "standard",
|
||||
HookEventName: "session-start",
|
||||
}
|
||||
|
||||
data, err := json.Marshal(original)
|
||||
require.NoError(t, err)
|
||||
|
||||
var decoded BaseInput
|
||||
err = json.Unmarshal(data, &decoded)
|
||||
require.NoError(t, err)
|
||||
|
||||
assert.Equal(t, original.SessionID, decoded.SessionID)
|
||||
assert.Equal(t, original.CWD, decoded.CWD)
|
||||
assert.Equal(t, original.PermissionMode, decoded.PermissionMode)
|
||||
assert.Equal(t, original.HookEventName, decoded.HookEventName)
|
||||
}
|
||||
|
||||
// TestHookContext_RawInput tests HookContext with different raw input types.
|
||||
func TestHookContext_RawInput(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
rawInput []byte
|
||||
}{
|
||||
{
|
||||
name: "json object",
|
||||
rawInput: []byte(`{"key":"value"}`),
|
||||
},
|
||||
{
|
||||
name: "json array",
|
||||
rawInput: []byte(`[1,2,3]`),
|
||||
},
|
||||
{
|
||||
name: "empty object",
|
||||
rawInput: []byte(`{}`),
|
||||
},
|
||||
{
|
||||
name: "nil input",
|
||||
rawInput: nil,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
ctx := HookContext{
|
||||
HookName: "test",
|
||||
RawInput: tt.rawInput,
|
||||
}
|
||||
assert.Equal(t, tt.rawInput, ctx.RawInput)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// TestDefaultWorkerPort tests that the default port constant is valid.
|
||||
func TestDefaultWorkerPort(t *testing.T) {
|
||||
assert.Greater(t, DefaultWorkerPort, 1024, "Default port should be above privileged port range")
|
||||
assert.Less(t, DefaultWorkerPort, 65535, "Default port should be valid TCP port")
|
||||
}
|
||||
|
||||
// TestHealthCheckTimeout tests the health check timeout is reasonable.
|
||||
func TestHealthCheckTimeout(t *testing.T) {
|
||||
assert.Greater(t, HealthCheckTimeout, 100*time.Millisecond)
|
||||
assert.Less(t, HealthCheckTimeout, 10*time.Second)
|
||||
}
|
||||
|
||||
// TestStartupTimeout tests the startup timeout is reasonable.
|
||||
func TestStartupTimeout(t *testing.T) {
|
||||
assert.Greater(t, StartupTimeout, 5*time.Second)
|
||||
assert.LessOrEqual(t, StartupTimeout, time.Minute)
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user