Increase test coverage to 45.6%

This commit is contained in:
2025-12-17 12:39:47 +00:00
parent 4add030bed
commit c259bb1d18
13 changed files with 5484 additions and 0 deletions
+254
View File
@@ -0,0 +1,254 @@
package sqlite
import (
"net/http"
"net/http/httptest"
"testing"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
func TestNullString(t *testing.T) {
tests := []struct {
name string
input string
expected string
valid bool
}{
{"empty_string", "", "", false},
{"non_empty_string", "hello", "hello", true},
{"whitespace", " ", " ", true},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
result := nullString(tt.input)
assert.Equal(t, tt.expected, result.String)
assert.Equal(t, tt.valid, result.Valid)
})
}
}
func TestNullInt(t *testing.T) {
tests := []struct {
name string
input int
expected int64
valid bool
}{
{"zero", 0, 0, false},
{"positive", 42, 42, true},
{"negative", -1, -1, false},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
result := nullInt(tt.input)
assert.Equal(t, tt.expected, result.Int64)
assert.Equal(t, tt.valid, result.Valid)
})
}
}
func TestRepeatPlaceholders(t *testing.T) {
tests := []struct {
name string
n int
expected string
}{
{"zero", 0, ""},
{"negative", -1, ""},
{"one", 1, ", ?"},
{"two", 2, ", ?, ?"},
{"three", 3, ", ?, ?, ?"},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
result := repeatPlaceholders(tt.n)
assert.Equal(t, tt.expected, result)
})
}
}
func TestInt64SliceToInterface(t *testing.T) {
tests := []struct {
name string
input []int64
expected []interface{}
}{
{"empty", []int64{}, []interface{}{}},
{"single", []int64{42}, []interface{}{int64(42)}},
{"multiple", []int64{1, 2, 3}, []interface{}{int64(1), int64(2), int64(3)}},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
result := int64SliceToInterface(tt.input)
assert.Equal(t, tt.expected, result)
})
}
}
func TestParseLimitParam(t *testing.T) {
tests := []struct {
name string
query string
defaultLimit int
expected int
}{
{"no_param_uses_default", "", 10, 10},
{"valid_limit", "limit=20", 10, 20},
{"invalid_limit_uses_default", "limit=abc", 10, 10},
{"zero_limit_uses_default", "limit=0", 10, 10},
{"negative_limit_uses_default", "limit=-5", 10, 10},
{"large_limit", "limit=1000", 10, 1000},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
req := httptest.NewRequest("GET", "/?"+tt.query, nil)
result := ParseLimitParam(req, tt.defaultLimit)
assert.Equal(t, tt.expected, result)
})
}
}
func TestScanSummary(t *testing.T) {
db, _, cleanup := testDB(t)
defer cleanup()
createBaseTables(t, db)
seedSession(t, db, "claude-123", "sdk-123", "test-project")
// Insert a test summary
_, err := db.Exec(`
INSERT INTO session_summaries (sdk_session_id, project, request, investigated, learned, completed, next_steps, notes, prompt_number, discovery_tokens, created_at, created_at_epoch)
VALUES ('sdk-123', 'test-project', 'test request', 'test investigated', 'test learned', 'test completed', 'test next steps', 'test notes', 1, 100, '2025-01-01T00:00:00Z', 1704067200000)
`)
require.NoError(t, err)
// Query and scan
row := db.QueryRow(`
SELECT id, sdk_session_id, project, request, investigated, learned, completed, next_steps, notes, prompt_number, discovery_tokens, created_at, created_at_epoch
FROM session_summaries WHERE sdk_session_id = ?
`, "sdk-123")
summary, err := scanSummary(row)
require.NoError(t, err)
assert.NotNil(t, summary)
assert.Equal(t, "sdk-123", summary.SDKSessionID)
assert.Equal(t, "test-project", summary.Project)
assert.Equal(t, "test request", summary.Request.String)
assert.Equal(t, "test investigated", summary.Investigated.String)
assert.Equal(t, "test learned", summary.Learned.String)
assert.Equal(t, "test completed", summary.Completed.String)
assert.Equal(t, "test next steps", summary.NextSteps.String)
assert.Equal(t, "test notes", summary.Notes.String)
}
func TestScanSummaryRows(t *testing.T) {
db, _, cleanup := testDB(t)
defer cleanup()
createBaseTables(t, db)
seedSession(t, db, "claude-123", "sdk-123", "test-project")
// Insert multiple summaries
_, err := db.Exec(`
INSERT INTO session_summaries (sdk_session_id, project, request, investigated, learned, completed, next_steps, notes, prompt_number, discovery_tokens, created_at, created_at_epoch)
VALUES
('sdk-123', 'test-project', 'request 1', '', '', '', '', '', 1, 0, '2025-01-01T00:00:00Z', 1704067200000),
('sdk-123', 'test-project', 'request 2', '', '', '', '', '', 2, 0, '2025-01-02T00:00:00Z', 1704153600000)
`)
require.NoError(t, err)
rows, err := db.Query(`
SELECT id, sdk_session_id, project, request, investigated, learned, completed, next_steps, notes, prompt_number, discovery_tokens, created_at, created_at_epoch
FROM session_summaries WHERE sdk_session_id = ? ORDER BY id
`, "sdk-123")
require.NoError(t, err)
defer rows.Close()
summaries, err := scanSummaryRows(rows)
require.NoError(t, err)
assert.Len(t, summaries, 2)
assert.Equal(t, "request 1", summaries[0].Request.String)
assert.Equal(t, "request 2", summaries[1].Request.String)
}
func TestScanPromptWithSession(t *testing.T) {
db, _, cleanup := testDB(t)
defer cleanup()
createBaseTables(t, db)
seedSession(t, db, "claude-123", "sdk-123", "test-project")
// Insert a test prompt
_, err := db.Exec(`
INSERT INTO user_prompts (claude_session_id, prompt_number, prompt_text, matched_observations, created_at, created_at_epoch)
VALUES ('claude-123', 1, 'test prompt', 5, '2025-01-01T00:00:00Z', 1704067200000)
`)
require.NoError(t, err)
// Query with session join
row := db.QueryRow(`
SELECT p.id, p.claude_session_id, p.prompt_number, p.prompt_text, p.matched_observations, p.created_at, p.created_at_epoch, s.project, s.sdk_session_id
FROM user_prompts p
JOIN sdk_sessions s ON p.claude_session_id = s.claude_session_id
WHERE p.claude_session_id = ?
`, "claude-123")
prompt, err := scanPromptWithSession(row)
require.NoError(t, err)
assert.NotNil(t, prompt)
assert.Equal(t, "claude-123", prompt.ClaudeSessionID)
assert.Equal(t, 1, prompt.PromptNumber)
assert.Equal(t, "test prompt", prompt.PromptText)
assert.Equal(t, 5, prompt.MatchedObservations)
assert.Equal(t, "test-project", prompt.Project)
assert.Equal(t, "sdk-123", prompt.SDKSessionID)
}
func TestScanPromptWithSessionRows(t *testing.T) {
db, _, cleanup := testDB(t)
defer cleanup()
createBaseTables(t, db)
seedSession(t, db, "claude-123", "sdk-123", "test-project")
// Insert multiple prompts
_, err := db.Exec(`
INSERT INTO user_prompts (claude_session_id, prompt_number, prompt_text, matched_observations, created_at, created_at_epoch)
VALUES
('claude-123', 1, 'prompt one', 3, '2025-01-01T00:00:00Z', 1704067200000),
('claude-123', 2, 'prompt two', 5, '2025-01-02T00:00:00Z', 1704153600000)
`)
require.NoError(t, err)
rows, err := db.Query(`
SELECT p.id, p.claude_session_id, p.prompt_number, p.prompt_text, p.matched_observations, p.created_at, p.created_at_epoch, s.project, s.sdk_session_id
FROM user_prompts p
JOIN sdk_sessions s ON p.claude_session_id = s.claude_session_id
WHERE p.claude_session_id = ? ORDER BY p.id
`, "claude-123")
require.NoError(t, err)
defer rows.Close()
prompts, err := scanPromptWithSessionRows(rows)
require.NoError(t, err)
assert.Len(t, prompts, 2)
assert.Equal(t, "prompt one", prompts[0].PromptText)
assert.Equal(t, "prompt two", prompts[1].PromptText)
}
func TestParseLimitParam_HTTPRequest(t *testing.T) {
// Test with an actual HTTP request
handler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
limit := ParseLimitParam(r, 25)
if limit != 50 {
t.Errorf("Expected limit 50, got %d", limit)
}
})
req := httptest.NewRequest("GET", "http://example.com/api?limit=50", nil)
w := httptest.NewRecorder()
handler.ServeHTTP(w, req)
}
+196
View File
@@ -0,0 +1,196 @@
package sqlite
import (
"testing"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
func TestNewMigrationManager(t *testing.T) {
db, _, cleanup := testDB(t)
defer cleanup()
manager := NewMigrationManager(db)
require.NotNil(t, manager)
assert.Equal(t, db, manager.db)
}
func TestMigrationManager_EnsureSchemaVersionsTable(t *testing.T) {
db, _, cleanup := testDB(t)
defer cleanup()
manager := NewMigrationManager(db)
// Should create table without error
err := manager.EnsureSchemaVersionsTable()
require.NoError(t, err)
// Table should exist
var count int
err = db.QueryRow("SELECT COUNT(*) FROM schema_versions").Scan(&count)
require.NoError(t, err)
assert.Equal(t, 0, count) // Empty table
// Calling again should not error (IF NOT EXISTS)
err = manager.EnsureSchemaVersionsTable()
require.NoError(t, err)
}
func TestMigrationManager_GetAppliedVersions_Empty(t *testing.T) {
db, _, cleanup := testDB(t)
defer cleanup()
manager := NewMigrationManager(db)
err := manager.EnsureSchemaVersionsTable()
require.NoError(t, err)
versions, err := manager.GetAppliedVersions()
require.NoError(t, err)
assert.Empty(t, versions)
}
func TestMigrationManager_GetAppliedVersions_WithVersions(t *testing.T) {
db, _, cleanup := testDB(t)
defer cleanup()
manager := NewMigrationManager(db)
err := manager.EnsureSchemaVersionsTable()
require.NoError(t, err)
// Insert some versions
_, err = db.Exec("INSERT INTO schema_versions (version, applied_at) VALUES (1, '2025-01-01T00:00:00Z')")
require.NoError(t, err)
_, err = db.Exec("INSERT INTO schema_versions (version, applied_at) VALUES (2, '2025-01-02T00:00:00Z')")
require.NoError(t, err)
versions, err := manager.GetAppliedVersions()
require.NoError(t, err)
assert.Len(t, versions, 2)
assert.True(t, versions[1])
assert.True(t, versions[2])
assert.False(t, versions[3]) // Not applied
}
func TestMigrationManager_ApplyMigration(t *testing.T) {
db, _, cleanup := testDB(t)
defer cleanup()
manager := NewMigrationManager(db)
err := manager.EnsureSchemaVersionsTable()
require.NoError(t, err)
// Apply a simple migration
migration := Migration{
Version: 100,
Name: "test_migration",
SQL: "CREATE TABLE test_table (id INTEGER PRIMARY KEY, name TEXT)",
}
err = manager.ApplyMigration(migration)
require.NoError(t, err)
// Verify table was created
var count int
err = db.QueryRow("SELECT COUNT(*) FROM sqlite_master WHERE type='table' AND name='test_table'").Scan(&count)
require.NoError(t, err)
assert.Equal(t, 1, count)
// Verify migration was recorded
var version int
err = db.QueryRow("SELECT version FROM schema_versions WHERE version = 100").Scan(&version)
require.NoError(t, err)
assert.Equal(t, 100, version)
}
func TestMigrationManager_ApplyMigration_InvalidSQL(t *testing.T) {
db, _, cleanup := testDB(t)
defer cleanup()
manager := NewMigrationManager(db)
err := manager.EnsureSchemaVersionsTable()
require.NoError(t, err)
// Try to apply invalid migration
migration := Migration{
Version: 100,
Name: "invalid_migration",
SQL: "INVALID SQL SYNTAX",
}
err = manager.ApplyMigration(migration)
assert.Error(t, err)
assert.Contains(t, err.Error(), "execute migration 100")
}
func TestMigrationManager_RunMigrations_SingleMigration(t *testing.T) {
db, _, cleanup := testDB(t)
defer cleanup()
// Create a test migration manager with a subset of migrations
manager := NewMigrationManager(db)
// First ensure schema versions table exists
err := manager.EnsureSchemaVersionsTable()
require.NoError(t, err)
// Apply first migration manually
err = manager.ApplyMigration(Migrations[0])
require.NoError(t, err)
// Verify the first migration version was recorded
versions, err := manager.GetAppliedVersions()
require.NoError(t, err)
assert.True(t, versions[Migrations[0].Version])
}
func TestMigrationManager_RunMigrations_SkipsApplied(t *testing.T) {
db, _, cleanup := testDB(t)
defer cleanup()
manager := NewMigrationManager(db)
err := manager.EnsureSchemaVersionsTable()
require.NoError(t, err)
// Mark some migrations as already applied
_, err = db.Exec("INSERT INTO schema_versions (version, applied_at) VALUES (4, '2025-01-01T00:00:00Z')")
require.NoError(t, err)
// Get applied versions
versions, err := manager.GetAppliedVersions()
require.NoError(t, err)
assert.True(t, versions[4])
}
func TestMigration_Struct(t *testing.T) {
m := Migration{
Version: 1,
Name: "test",
SQL: "SELECT 1",
}
assert.Equal(t, 1, m.Version)
assert.Equal(t, "test", m.Name)
assert.Equal(t, "SELECT 1", m.SQL)
}
func TestMigrations_List(t *testing.T) {
// Verify migrations are ordered correctly
assert.NotEmpty(t, Migrations)
// Verify all migrations have required fields
for i, m := range Migrations {
assert.Greater(t, m.Version, 0, "Migration %d has invalid version", i)
assert.NotEmpty(t, m.Name, "Migration %d has empty name", i)
assert.NotEmpty(t, m.SQL, "Migration %d has empty SQL", i)
}
// Verify key migrations exist
versionSet := make(map[int]bool)
for _, m := range Migrations {
versionSet[m.Version] = true
}
assert.True(t, versionSet[4], "Should have sdk_agent_architecture migration")
assert.True(t, versionSet[17], "Should have sqlite_vec_vectors migration")
}
+145
View File
@@ -800,3 +800,148 @@ func TestExtractKeywords(t *testing.T) {
})
}
}
func TestObservationStore_GetObservationsByProjectStrict(t *testing.T) {
obsStore, _, cleanup := testObservationStore(t)
defer cleanup()
ctx := context.Background()
// Create project-scoped observation for project-a
projectObs := &models.ParsedObservation{
Type: models.ObsTypeDiscovery,
Title: "Project A specific",
Narrative: "Only for project-a",
Concepts: []string{"local-concept"},
}
_, _, err := obsStore.StoreObservation(ctx, "session-1", "project-a", projectObs, 1, 100)
require.NoError(t, err)
// Create global observation from project-a
globalObs := &models.ParsedObservation{
Type: models.ObsTypeDiscovery,
Title: "Global security practice",
Narrative: "Best practice for all",
Concepts: []string{"security", "best-practice"},
}
_, _, err = obsStore.StoreObservation(ctx, "session-1", "project-a", globalObs, 2, 100)
require.NoError(t, err)
// Create observation for project-b
projectBObs := &models.ParsedObservation{
Type: models.ObsTypeDiscovery,
Title: "Project B specific",
Narrative: "Only for project-b",
}
_, _, err = obsStore.StoreObservation(ctx, "session-1", "project-b", projectBObs, 1, 100)
require.NoError(t, err)
// GetObservationsByProjectStrict for project-a should only return project-a observations
// This is different from GetRecentObservations which includes globals from other projects
results, err := obsStore.GetObservationsByProjectStrict(ctx, "project-a", 10)
require.NoError(t, err)
assert.Len(t, results, 2) // Only observations created in project-a
// Verify both are from project-a
for _, obs := range results {
assert.Equal(t, "project-a", obs.Project)
}
// GetObservationsByProjectStrict for project-b should only return project-b observations
results, err = obsStore.GetObservationsByProjectStrict(ctx, "project-b", 10)
require.NoError(t, err)
assert.Len(t, results, 1)
assert.Equal(t, "Project B specific", results[0].Title.String)
}
func TestObservationStore_SearchObservationsFTS_EmptyQuery(t *testing.T) {
obsStore, _, cleanup := testObservationStore(t)
defer cleanup()
ctx := context.Background()
// Create an observation
obs := &models.ParsedObservation{
Type: models.ObsTypeDiscovery,
Title: "Test observation",
Narrative: "Some content here",
}
_, _, err := obsStore.StoreObservation(ctx, "session-1", "project-a", obs, 1, 100)
require.NoError(t, err)
// Search with only stop words (should return nil)
results, err := obsStore.SearchObservationsFTS(ctx, "the a an is are", "project-a", 10)
require.NoError(t, err)
assert.Nil(t, results)
// Search with empty query
results, err = obsStore.SearchObservationsFTS(ctx, "", "project-a", 10)
require.NoError(t, err)
assert.Nil(t, results)
}
func TestObservationStore_SearchObservationsFTS_DefaultLimit(t *testing.T) {
obsStore, _, cleanup := testObservationStore(t)
defer cleanup()
ctx := context.Background()
// Create observations
for i := 0; i < 15; i++ {
obs := &models.ParsedObservation{
Type: models.ObsTypeDiscovery,
Title: "Authentication test " + string(rune('A'+i)),
Narrative: "Auth related content",
}
_, _, err := obsStore.StoreObservation(ctx, "session-1", "project-a", obs, i+1, 100)
require.NoError(t, err)
time.Sleep(time.Millisecond)
}
// Search with limit 0 (should default to 10)
results, err := obsStore.SearchObservationsFTS(ctx, "authentication", "project-a", 0)
require.NoError(t, err)
assert.LessOrEqual(t, len(results), 10)
// Search with negative limit (should default to 10)
results, err = obsStore.SearchObservationsFTS(ctx, "authentication", "project-a", -5)
require.NoError(t, err)
assert.LessOrEqual(t, len(results), 10)
}
func TestObservationStore_GetAllRecentObservations(t *testing.T) {
obsStore, _, cleanup := testObservationStore(t)
defer cleanup()
ctx := context.Background()
// Create observations across different projects
projects := []string{"project-a", "project-b", "project-c"}
for _, proj := range projects {
for i := 0; i < 3; i++ {
obs := &models.ParsedObservation{
Type: models.ObsTypeDiscovery,
Title: proj + " observation " + string(rune('A'+i)),
Narrative: "Content for " + proj,
}
_, _, err := obsStore.StoreObservation(ctx, "session-1", proj, obs, i+1, 100)
require.NoError(t, err)
time.Sleep(time.Millisecond)
}
}
// Get all recent observations
results, err := obsStore.GetAllRecentObservations(ctx, 100)
require.NoError(t, err)
assert.Len(t, results, 9) // 3 projects * 3 observations
// Verify they are in descending order by epoch
for i := 1; i < len(results); i++ {
assert.GreaterOrEqual(t, results[i-1].CreatedAtEpoch, results[i].CreatedAtEpoch)
}
// Test with limit
results, err = obsStore.GetAllRecentObservations(ctx, 5)
require.NoError(t, err)
assert.Len(t, results, 5)
}