Increase tests coverage.

This commit is contained in:
2025-12-17 11:40:08 +00:00
parent 587cdab9a5
commit 4add030bed
15 changed files with 6421 additions and 6 deletions
+428
View File
@@ -9,8 +9,22 @@ import (
"github.com/lukaszraczylo/claude-mnemonic/pkg/models"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"github.com/stretchr/testify/suite"
)
// testObservationStoreBasic creates an ObservationStore with base tables (no FTS5).
func testObservationStoreBasic(t *testing.T) (*ObservationStore, *Store, func()) {
t.Helper()
db, _, cleanup := testDB(t)
createBaseTables(t, db)
store := newStoreFromDB(db)
obsStore := NewObservationStore(store)
return obsStore, store, cleanup
}
// testObservationStore creates an ObservationStore with a test database including FTS5.
func testObservationStore(t *testing.T) (*ObservationStore, *Store, func()) {
t.Helper()
@@ -24,6 +38,420 @@ func testObservationStore(t *testing.T) (*ObservationStore, *Store, func()) {
return obsStore, store, cleanup
}
// ObservationStoreSuite is a test suite for ObservationStore operations.
type ObservationStoreSuite struct {
suite.Suite
obsStore *ObservationStore
store *Store
cleanup func()
}
func (s *ObservationStoreSuite) SetupTest() {
s.obsStore, s.store, s.cleanup = testObservationStoreBasic(s.T())
}
func (s *ObservationStoreSuite) TearDownTest() {
if s.cleanup != nil {
s.cleanup()
}
}
func TestObservationStoreSuite(t *testing.T) {
suite.Run(t, new(ObservationStoreSuite))
}
// TestStoreObservation_TableDriven tests observation storage with various scenarios.
func (s *ObservationStoreSuite) TestStoreObservation_TableDriven() {
ctx := context.Background()
tests := []struct {
name string
sdkSessionID string
project string
obs *models.ParsedObservation
promptNum int
tokens int64
wantErr bool
}{
{
name: "basic discovery observation",
sdkSessionID: "session-basic",
project: "project-a",
obs: &models.ParsedObservation{
Type: models.ObsTypeDiscovery,
Title: "Test Title",
Subtitle: "Test Subtitle",
Narrative: "Test narrative content",
Facts: []string{"Fact 1", "Fact 2"},
Concepts: []string{"testing", "golang"},
},
promptNum: 1,
tokens: 100,
wantErr: false,
},
{
name: "bugfix observation",
sdkSessionID: "session-bugfix",
project: "project-b",
obs: &models.ParsedObservation{
Type: models.ObsTypeBugfix,
Title: "Fixed null pointer",
Narrative: "Fixed null pointer exception in handler",
FilesModified: []string{"handler.go"},
},
promptNum: 2,
tokens: 50,
wantErr: false,
},
{
name: "global scope observation",
sdkSessionID: "session-global",
project: "project-c",
obs: &models.ParsedObservation{
Type: models.ObsTypeDiscovery,
Title: "Security best practice",
Narrative: "Always validate user input",
Concepts: []string{"security", "best-practice"},
},
promptNum: 1,
tokens: 75,
wantErr: false,
},
{
name: "observation with files",
sdkSessionID: "session-files",
project: "project-d",
obs: &models.ParsedObservation{
Type: models.ObsTypeFeature,
Title: "Added authentication",
Narrative: "Implemented JWT authentication",
FilesRead: []string{"config.go", "auth.go"},
FilesModified: []string{"handler.go", "middleware.go"},
FileMtimes: map[string]int64{"handler.go": 1234567890, "middleware.go": 1234567891},
},
promptNum: 3,
tokens: 200,
wantErr: false,
},
{
name: "minimal observation",
sdkSessionID: "session-minimal",
project: "project-e",
obs: &models.ParsedObservation{
Type: models.ObsTypeChange,
},
promptNum: 0,
tokens: 0,
wantErr: false,
},
}
for _, tt := range tests {
s.Run(tt.name, func() {
id, epoch, err := s.obsStore.StoreObservation(ctx, tt.sdkSessionID, tt.project, tt.obs, tt.promptNum, tt.tokens)
if tt.wantErr {
s.Error(err)
return
}
s.NoError(err)
s.Greater(id, int64(0))
s.Greater(epoch, int64(0))
// Retrieve and verify
retrieved, err := s.obsStore.GetObservationByID(ctx, id)
s.NoError(err)
s.NotNil(retrieved)
s.Equal(id, retrieved.ID)
s.Equal(tt.project, retrieved.Project)
s.Equal(tt.obs.Type, retrieved.Type)
})
}
}
// TestGetObservationByID_NotFound tests retrieval of non-existent observation.
func (s *ObservationStoreSuite) TestGetObservationByID_NotFound() {
ctx := context.Background()
obs, err := s.obsStore.GetObservationByID(ctx, 99999)
s.NoError(err)
s.Nil(obs)
}
// TestGetRecentObservations_TableDriven tests recent observations retrieval.
func (s *ObservationStoreSuite) TestGetRecentObservations_TableDriven() {
ctx := context.Background()
// Create 15 observations
for i := 0; i < 15; i++ {
obs := &models.ParsedObservation{
Type: models.ObsTypeDiscovery,
Title: "Observation " + string(rune('A'+i)),
}
_, _, err := s.obsStore.StoreObservation(ctx, "session-"+string(rune('0'+i)), "project-a", obs, i, 10)
s.NoError(err)
time.Sleep(time.Millisecond) // Ensure different timestamps
}
tests := []struct {
name string
project string
limit int
wantCount int
}{
{
name: "limit 5",
project: "project-a",
limit: 5,
wantCount: 5,
},
{
name: "limit 10",
project: "project-a",
limit: 10,
wantCount: 10,
},
{
name: "limit higher than count",
project: "project-a",
limit: 50,
wantCount: 15,
},
{
name: "different project (no results)",
project: "project-b",
limit: 10,
wantCount: 0,
},
}
for _, tt := range tests {
s.Run(tt.name, func() {
observations, err := s.obsStore.GetRecentObservations(ctx, tt.project, tt.limit)
s.NoError(err)
s.Len(observations, tt.wantCount)
})
}
}
// TestDeleteObservations_TableDriven tests observation deletion.
func (s *ObservationStoreSuite) TestDeleteObservations_TableDriven() {
ctx := context.Background()
// Create observations
var ids []int64
for i := 0; i < 5; i++ {
obs := &models.ParsedObservation{
Type: models.ObsTypeDiscovery,
Title: "To delete " + string(rune('A'+i)),
}
id, _, err := s.obsStore.StoreObservation(ctx, "session-del", "project-del", obs, i, 10)
s.NoError(err)
ids = append(ids, id)
}
tests := []struct {
name string
toDelete []int64
wantDeleted int64
wantRemain int
}{
{
name: "delete none",
toDelete: []int64{},
wantDeleted: 0,
wantRemain: 5,
},
{
name: "delete one",
toDelete: ids[0:1],
wantDeleted: 1,
wantRemain: 4,
},
{
name: "delete remaining",
toDelete: ids[1:],
wantDeleted: 4,
wantRemain: 0,
},
}
for _, tt := range tests {
s.Run(tt.name, func() {
deleted, err := s.obsStore.DeleteObservations(ctx, tt.toDelete)
s.NoError(err)
s.Equal(tt.wantDeleted, deleted)
remaining, err := s.obsStore.GetAllRecentObservations(ctx, 100)
s.NoError(err)
s.Len(remaining, tt.wantRemain)
})
}
}
// TestGetObservationsByIDs tests retrieval by multiple IDs.
func (s *ObservationStoreSuite) TestGetObservationsByIDs() {
ctx := context.Background()
// Create observations
var ids []int64
for i := 0; i < 5; i++ {
obs := &models.ParsedObservation{
Type: models.ObsTypeDiscovery,
Title: "By ID " + string(rune('A'+i)),
}
id, _, err := s.obsStore.StoreObservation(ctx, "session-byid", "project-byid", obs, i, 10)
s.NoError(err)
ids = append(ids, id)
time.Sleep(time.Millisecond)
}
tests := []struct {
name string
queryIDs []int64
orderBy string
limit int
wantCount int
}{
{
name: "empty IDs",
queryIDs: []int64{},
orderBy: "date_desc",
limit: 10,
wantCount: 0,
},
{
name: "single ID",
queryIDs: ids[0:1],
orderBy: "date_desc",
limit: 10,
wantCount: 1,
},
{
name: "all IDs",
queryIDs: ids,
orderBy: "date_desc",
limit: 10,
wantCount: 5,
},
{
name: "with limit less than IDs",
queryIDs: ids,
orderBy: "date_desc",
limit: 3,
wantCount: 3,
},
{
name: "ascending order",
queryIDs: ids,
orderBy: "date_asc",
limit: 10,
wantCount: 5,
},
}
for _, tt := range tests {
s.Run(tt.name, func() {
observations, err := s.obsStore.GetObservationsByIDs(ctx, tt.queryIDs, tt.orderBy, tt.limit)
if tt.wantCount == 0 {
s.NoError(err)
s.Nil(observations)
} else {
s.NoError(err)
s.Len(observations, tt.wantCount)
}
})
}
}
// TestGlobalScope tests global vs project scope.
func (s *ObservationStoreSuite) TestGlobalScope() {
ctx := context.Background()
// Create project-scoped observation
projectObs := &models.ParsedObservation{
Type: models.ObsTypeDiscovery,
Title: "Project specific",
Concepts: []string{"project-specific"},
}
_, _, err := s.obsStore.StoreObservation(ctx, "session-scope", "project-a", projectObs, 1, 10)
s.NoError(err)
// Create global-scoped observation (security concept triggers global)
globalObs := &models.ParsedObservation{
Type: models.ObsTypeDiscovery,
Title: "Global security",
Concepts: []string{"security"},
}
_, _, err = s.obsStore.StoreObservation(ctx, "session-scope", "project-a", globalObs, 2, 10)
s.NoError(err)
// Project-a should see both
resultsA, err := s.obsStore.GetRecentObservations(ctx, "project-a", 10)
s.NoError(err)
s.Len(resultsA, 2)
// Project-b should only see global
resultsB, err := s.obsStore.GetRecentObservations(ctx, "project-b", 10)
s.NoError(err)
s.Len(resultsB, 1)
s.Equal("Global security", resultsB[0].Title.String)
s.Equal(models.ScopeGlobal, resultsB[0].Scope)
}
// TestSetCleanupFunc tests the cleanup function callback.
func (s *ObservationStoreSuite) TestSetCleanupFunc() {
ctx := context.Background()
var calledWith []int64
s.obsStore.SetCleanupFunc(func(ctx context.Context, deletedIDs []int64) {
calledWith = deletedIDs
})
// Store an observation
obs := &models.ParsedObservation{
Type: models.ObsTypeDiscovery,
Title: "Test cleanup",
}
_, _, err := s.obsStore.StoreObservation(ctx, "session-cleanup", "project-cleanup", obs, 1, 10)
s.NoError(err)
// Cleanup should not have been called since nothing was deleted
s.Empty(calledWith)
}
// TestGetObservationCount tests observation counting.
func (s *ObservationStoreSuite) TestGetObservationCount() {
ctx := context.Background()
// Create observations for project-a
for i := 0; i < 5; i++ {
obs := &models.ParsedObservation{
Type: models.ObsTypeDiscovery,
}
_, _, err := s.obsStore.StoreObservation(ctx, "session-count", "project-a", obs, i, 10)
s.NoError(err)
}
// Create global observation
globalObs := &models.ParsedObservation{
Type: models.ObsTypeDiscovery,
Concepts: []string{"security"},
}
_, _, err := s.obsStore.StoreObservation(ctx, "session-count", "project-a", globalObs, 6, 10)
s.NoError(err)
// Project-a should count 6 (5 project + 1 global)
count, err := s.obsStore.GetObservationCount(ctx, "project-a")
s.NoError(err)
s.Equal(6, count)
// Project-b should count 1 (only global)
count, err = s.obsStore.GetObservationCount(ctx, "project-b")
s.NoError(err)
s.Equal(1, count)
}
func TestObservationStore_StoreAndRetrieve(t *testing.T) {
obsStore, _, cleanup := testObservationStore(t)
defer cleanup()
+93
View File
@@ -194,3 +194,96 @@ func TestPromptStore_GetPromptsByIDs_EmptyInput(t *testing.T) {
require.NoError(t, err)
assert.Nil(t, prompts)
}
func TestPromptStore_FindRecentPromptByText(t *testing.T) {
promptStore, store, cleanup := testPromptStore(t)
defer cleanup()
ctx := context.Background()
// Create a session
seedSession(t, storeDB(store), "claude-1", "sdk-1", "test-project")
// Save a prompt
_, err := promptStore.SaveUserPromptWithMatches(ctx, "claude-1", 1, "Help me fix this bug in the code", 0)
require.NoError(t, err)
// Find the prompt by text - returns (id, promptNumber, found)
id, promptNum, found := promptStore.FindRecentPromptByText(ctx, "claude-1", "Help me fix this bug in the code", 60)
assert.True(t, found, "should find the exact prompt text")
assert.Greater(t, id, int64(0))
assert.Equal(t, 1, promptNum)
// Try to find non-existent prompt
_, _, found = promptStore.FindRecentPromptByText(ctx, "claude-1", "This prompt does not exist", 60)
assert.False(t, found, "should not find non-existent prompt")
// Try with different session
_, _, found = promptStore.FindRecentPromptByText(ctx, "claude-2", "Help me fix this bug in the code", 60)
assert.False(t, found, "should not find prompt for different session")
}
func TestPromptStore_FindRecentPromptByText_WindowSeconds(t *testing.T) {
promptStore, store, cleanup := testPromptStore(t)
defer cleanup()
ctx := context.Background()
// Create a session
seedSession(t, storeDB(store), "claude-1", "sdk-1", "test-project")
// Save a prompt with an old timestamp
oldEpoch := time.Now().Add(-2 * time.Hour).UnixMilli()
_, err := storeDB(store).Exec(`
INSERT INTO user_prompts (claude_session_id, prompt_number, prompt_text, created_at, created_at_epoch)
VALUES (?, ?, ?, datetime('now'), ?)
`, "claude-1", 1, "Old prompt text", oldEpoch)
require.NoError(t, err)
// Search within last hour - should not find old prompt
_, _, found := promptStore.FindRecentPromptByText(ctx, "claude-1", "Old prompt text", 3600)
assert.False(t, found, "should not find prompt outside window")
// Search within last 3 hours - should find old prompt
_, _, found = promptStore.FindRecentPromptByText(ctx, "claude-1", "Old prompt text", 3*3600)
assert.True(t, found, "should find prompt within extended window")
}
func TestPromptStore_SaveMultiplePrompts(t *testing.T) {
promptStore, store, cleanup := testPromptStore(t)
defer cleanup()
ctx := context.Background()
// Create sessions
seedSession(t, storeDB(store), "claude-1", "sdk-1", "project-x")
seedSession(t, storeDB(store), "claude-2", "sdk-2", "project-y")
tests := []struct {
claudeSessionID string
promptNum int
text string
matches int
}{
{"claude-1", 1, "First prompt", 5},
{"claude-1", 2, "Second prompt", 3},
{"claude-2", 1, "Third prompt", 0},
{"claude-1", 3, "Fourth prompt", 10},
}
for _, tt := range tests {
id, err := promptStore.SaveUserPromptWithMatches(ctx, tt.claudeSessionID, tt.promptNum, tt.text, tt.matches)
require.NoError(t, err)
assert.Greater(t, id, int64(0))
}
// Verify counts
var count int
err := storeDB(store).QueryRow("SELECT COUNT(*) FROM user_prompts WHERE claude_session_id = 'claude-1'").Scan(&count)
require.NoError(t, err)
assert.Equal(t, 3, count)
err = storeDB(store).QueryRow("SELECT COUNT(*) FROM user_prompts WHERE claude_session_id = 'claude-2'").Scan(&count)
require.NoError(t, err)
assert.Equal(t, 1, count)
}
+232 -1
View File
@@ -8,13 +8,14 @@ import (
"github.com/lukaszraczylo/claude-mnemonic/pkg/models"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"github.com/stretchr/testify/suite"
)
func testSessionStore(t *testing.T) (*SessionStore, *Store, func()) {
t.Helper()
db, _, cleanup := testDB(t)
createAllTables(t, db)
createBaseTables(t, db) // Use base tables without FTS5 for session tests
store := newStoreFromDB(db)
sessionStore := NewSessionStore(store)
@@ -22,6 +23,236 @@ func testSessionStore(t *testing.T) (*SessionStore, *Store, func()) {
return sessionStore, store, cleanup
}
// SessionStoreSuite is a test suite for SessionStore operations.
type SessionStoreSuite struct {
suite.Suite
sessionStore *SessionStore
store *Store
cleanup func()
}
func (s *SessionStoreSuite) SetupTest() {
s.sessionStore, s.store, s.cleanup = testSessionStore(s.T())
}
func (s *SessionStoreSuite) TearDownTest() {
if s.cleanup != nil {
s.cleanup()
}
}
func TestSessionStoreSuite(t *testing.T) {
suite.Run(t, new(SessionStoreSuite))
}
// TestCreateSDKSession_TableDriven tests session creation with various scenarios.
func (s *SessionStoreSuite) TestCreateSDKSession_TableDriven() {
ctx := context.Background()
tests := []struct {
name string
claudeSessionID string
project string
userPrompt string
wantErr bool
wantID bool
}{
{
name: "basic session creation",
claudeSessionID: "claude-basic",
project: "project-a",
userPrompt: "hello world",
wantErr: false,
wantID: true,
},
{
name: "empty user prompt",
claudeSessionID: "claude-noprompt",
project: "project-b",
userPrompt: "",
wantErr: false,
wantID: true,
},
{
name: "long project name",
claudeSessionID: "claude-longproj",
project: "/Users/test/Documents/very/long/path/to/some/project/directory",
userPrompt: "test",
wantErr: false,
wantID: true,
},
{
name: "unicode project name",
claudeSessionID: "claude-unicode",
project: "项目名称-プロジェクト",
userPrompt: "测试 テスト",
wantErr: false,
wantID: true,
},
{
name: "special characters in prompt",
claudeSessionID: "claude-special",
project: "project-special",
userPrompt: "Fix the bug in file.go:123 with \"quotes\" and 'apostrophes'",
wantErr: false,
wantID: true,
},
}
for _, tt := range tests {
s.Run(tt.name, func() {
id, err := s.sessionStore.CreateSDKSession(ctx, tt.claudeSessionID, tt.project, tt.userPrompt)
if tt.wantErr {
s.Error(err)
} else {
s.NoError(err)
if tt.wantID {
s.Greater(id, int64(0))
}
// Verify created session
sess, err := s.sessionStore.GetSessionByID(ctx, id)
s.NoError(err)
s.NotNil(sess)
s.Equal(tt.claudeSessionID, sess.ClaudeSessionID)
s.Equal(tt.project, sess.Project)
s.Equal(models.SessionStatusActive, sess.Status)
}
})
}
}
// TestIdempotentSession tests that session creation is idempotent.
func (s *SessionStoreSuite) TestIdempotentSession() {
ctx := context.Background()
// Create initial session
id1, err := s.sessionStore.CreateSDKSession(ctx, "claude-idem", "project-1", "prompt-1")
s.NoError(err)
s.Greater(id1, int64(0))
// Create with same claude_session_id - should return same ID
id2, err := s.sessionStore.CreateSDKSession(ctx, "claude-idem", "project-2", "prompt-2")
s.NoError(err)
s.Equal(id1, id2)
// Verify project was updated
sess, err := s.sessionStore.GetSessionByID(ctx, id1)
s.NoError(err)
s.Equal("project-2", sess.Project)
}
// TestPromptCounterOperations tests prompt counter increment and retrieval.
func (s *SessionStoreSuite) TestPromptCounterOperations() {
ctx := context.Background()
tests := []struct {
name string
increments int
expectedCount int
}{
{
name: "no increments",
increments: 0,
expectedCount: 0,
},
{
name: "single increment",
increments: 1,
expectedCount: 1,
},
{
name: "multiple increments",
increments: 5,
expectedCount: 5,
},
{
name: "many increments",
increments: 100,
expectedCount: 100,
},
}
for _, tt := range tests {
s.Run(tt.name, func() {
// Create fresh session for each test
id, err := s.sessionStore.CreateSDKSession(ctx, "claude-counter-"+tt.name, "project", "")
s.NoError(err)
// Increment specified number of times
var lastCount int
for i := 0; i < tt.increments; i++ {
lastCount, err = s.sessionStore.IncrementPromptCounter(ctx, id)
s.NoError(err)
}
// Get final count
finalCount, err := s.sessionStore.GetPromptCounter(ctx, id)
s.NoError(err)
s.Equal(tt.expectedCount, finalCount)
if tt.increments > 0 {
s.Equal(tt.expectedCount, lastCount)
}
})
}
}
// TestFindAnySDKSession tests session lookup scenarios.
func (s *SessionStoreSuite) TestFindAnySDKSession_Scenarios() {
ctx := context.Background()
// Create test sessions
_, err := s.sessionStore.CreateSDKSession(ctx, "session-find-1", "project-a", "")
s.NoError(err)
_, err = s.sessionStore.CreateSDKSession(ctx, "session-find-2", "project-b", "")
s.NoError(err)
tests := []struct {
name string
claudeSessionID string
wantFound bool
wantProject string
}{
{
name: "find existing session 1",
claudeSessionID: "session-find-1",
wantFound: true,
wantProject: "project-a",
},
{
name: "find existing session 2",
claudeSessionID: "session-find-2",
wantFound: true,
wantProject: "project-b",
},
{
name: "find non-existent session",
claudeSessionID: "session-nonexistent",
wantFound: false,
},
{
name: "find with empty string",
claudeSessionID: "",
wantFound: false,
},
}
for _, tt := range tests {
s.Run(tt.name, func() {
sess, err := s.sessionStore.FindAnySDKSession(ctx, tt.claudeSessionID)
s.NoError(err) // FindAnySDKSession returns nil,nil for not found
if tt.wantFound {
s.NotNil(sess)
s.Equal(tt.wantProject, sess.Project)
} else {
s.Nil(sess)
}
})
}
}
func TestSessionStore_CreateSDKSession(t *testing.T) {
sessionStore, _, cleanup := testSessionStore(t)
defer cleanup()
+529
View File
@@ -0,0 +1,529 @@
// Package sqlite provides SQLite database operations for claude-mnemonic.
package sqlite
import (
"context"
"database/sql"
"testing"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"github.com/stretchr/testify/suite"
)
// StoreSuite is a test suite for Store operations.
type StoreSuite struct {
suite.Suite
db *sql.DB
store *Store
cleanup func()
}
// SetupTest creates a fresh database before each test.
func (s *StoreSuite) SetupTest() {
s.db, _, s.cleanup = testDB(s.T())
createBaseTables(s.T(), s.db)
s.store = newStoreFromDB(s.db)
}
// TearDownTest cleans up after each test.
func (s *StoreSuite) TearDownTest() {
if s.cleanup != nil {
s.cleanup()
}
}
func TestStoreSuite(t *testing.T) {
suite.Run(t, new(StoreSuite))
}
// TestGetStmt tests prepared statement caching.
func (s *StoreSuite) TestGetStmt() {
tests := []struct {
name string
query string
wantErr bool
}{
{
name: "valid simple query",
query: "SELECT 1",
wantErr: false,
},
{
name: "valid query with parameter",
query: "SELECT * FROM sdk_sessions WHERE id = ?",
wantErr: false,
},
{
name: "invalid query syntax",
query: "SELECT * FROM nonexistent_table WHERE",
wantErr: true,
},
}
for _, tt := range tests {
s.Run(tt.name, func() {
stmt, err := s.store.GetStmt(tt.query)
if tt.wantErr {
s.Error(err)
s.Nil(stmt)
} else {
s.NoError(err)
s.NotNil(stmt)
// Second call should return cached statement
stmt2, err := s.store.GetStmt(tt.query)
s.NoError(err)
s.Same(stmt, stmt2)
}
})
}
}
// TestExecContext tests query execution.
func (s *StoreSuite) TestExecContext() {
ctx := context.Background()
tests := []struct {
name string
query string
args []interface{}
wantErr bool
wantAffected int64
}{
{
name: "insert session",
query: `INSERT INTO sdk_sessions (claude_session_id, sdk_session_id, project, started_at, started_at_epoch, status)
VALUES (?, ?, ?, datetime('now'), strftime('%s', 'now') * 1000, 'active')`,
args: []interface{}{"claude-1", "sdk-1", "test-project"},
wantErr: false,
wantAffected: 1,
},
{
name: "invalid query",
query: "INSERT INTO nonexistent_table VALUES (?)",
args: []interface{}{"test"},
wantErr: true,
wantAffected: 0,
},
}
for _, tt := range tests {
s.Run(tt.name, func() {
result, err := s.store.ExecContext(ctx, tt.query, tt.args...)
if tt.wantErr {
s.Error(err)
} else {
s.NoError(err)
affected, _ := result.RowsAffected()
s.Equal(tt.wantAffected, affected)
}
})
}
}
// TestQueryContext tests query execution that returns rows.
func (s *StoreSuite) TestQueryContext() {
ctx := context.Background()
// Seed data
seedSession(s.T(), s.db, "claude-1", "sdk-1", "project-a")
tests := []struct {
name string
query string
args []interface{}
wantErr bool
wantRows int
setupFunc func()
assertFunc func(rows *sql.Rows)
}{
{
name: "query existing session",
query: "SELECT id, project FROM sdk_sessions WHERE claude_session_id = ?",
args: []interface{}{"claude-1"},
wantErr: false,
wantRows: 1,
},
{
name: "query non-existent session",
query: "SELECT id, project FROM sdk_sessions WHERE claude_session_id = ?",
args: []interface{}{"nonexistent"},
wantErr: false,
wantRows: 0,
},
{
name: "query all sessions",
query: "SELECT id, project FROM sdk_sessions",
args: nil,
wantErr: false,
wantRows: 1,
},
}
for _, tt := range tests {
s.Run(tt.name, func() {
rows, err := s.store.QueryContext(ctx, tt.query, tt.args...)
if tt.wantErr {
s.Error(err)
return
}
s.NoError(err)
defer rows.Close()
count := 0
for rows.Next() {
count++
}
s.Equal(tt.wantRows, count)
})
}
}
// TestQueryRowContext tests single row query execution.
func (s *StoreSuite) TestQueryRowContext() {
ctx := context.Background()
// Seed data
seedSession(s.T(), s.db, "claude-1", "sdk-1", "project-a")
tests := []struct {
name string
query string
args []interface{}
wantErr bool
}{
{
name: "query existing session",
query: "SELECT id FROM sdk_sessions WHERE claude_session_id = ?",
args: []interface{}{"claude-1"},
wantErr: false,
},
{
name: "query non-existent session",
query: "SELECT id FROM sdk_sessions WHERE claude_session_id = ?",
args: []interface{}{"nonexistent"},
wantErr: true, // sql.ErrNoRows
},
}
for _, tt := range tests {
s.Run(tt.name, func() {
row := s.store.QueryRowContext(ctx, tt.query, tt.args...)
var id int64
err := row.Scan(&id)
if tt.wantErr {
s.Error(err)
} else {
s.NoError(err)
s.Greater(id, int64(0))
}
})
}
}
// TestPing tests database connection health check.
func (s *StoreSuite) TestPing() {
err := s.store.Ping()
s.NoError(err)
}
// TestDB tests getting the underlying database connection.
func (s *StoreSuite) TestDB() {
db := s.store.DB()
s.NotNil(db)
s.Same(s.db, db)
}
// TestClose tests closing the store.
func (s *StoreSuite) TestClose() {
// Create a separate store for close test
db, _, cleanup := testDB(s.T())
defer cleanup()
store := newStoreFromDB(db)
// Cache a statement first
_, err := store.GetStmt("SELECT 1")
s.NoError(err)
// Close should not error
err = store.Close()
s.NoError(err)
// Operations after close should fail
err = store.Ping()
s.Error(err)
}
// TestConcurrentStmtCache tests concurrent access to statement cache.
func (s *StoreSuite) TestConcurrentStmtCache() {
ctx := context.Background()
queries := []string{
"SELECT 1",
"SELECT 2",
"SELECT id FROM sdk_sessions",
"SELECT project FROM sdk_sessions",
}
done := make(chan struct{})
for i := 0; i < 10; i++ {
go func(i int) {
query := queries[i%len(queries)]
_, _ = s.store.GetStmt(query)
_, _ = s.store.ExecContext(ctx, "SELECT 1")
done <- struct{}{}
}(i)
}
for i := 0; i < 10; i++ {
<-done
}
}
// HelpersSuite tests helper functions.
type HelpersSuite struct {
suite.Suite
}
func TestHelpersSuite(t *testing.T) {
suite.Run(t, new(HelpersSuite))
}
func (s *HelpersSuite) TestNullString() {
tests := []struct {
name string
input string
wantStr string
wantBool bool
}{
{
name: "empty string",
input: "",
wantStr: "",
wantBool: false,
},
{
name: "non-empty string",
input: "test",
wantStr: "test",
wantBool: true,
},
{
name: "whitespace string",
input: " ",
wantStr: " ",
wantBool: true,
},
}
for _, tt := range tests {
s.Run(tt.name, func() {
result := nullString(tt.input)
s.Equal(tt.wantStr, result.String)
s.Equal(tt.wantBool, result.Valid)
})
}
}
func (s *HelpersSuite) TestNullInt() {
tests := []struct {
name string
input int
wantInt int64
wantBool bool
}{
{
name: "zero",
input: 0,
wantInt: 0,
wantBool: false,
},
{
name: "negative",
input: -1,
wantInt: -1,
wantBool: false,
},
{
name: "positive",
input: 42,
wantInt: 42,
wantBool: true,
},
}
for _, tt := range tests {
s.Run(tt.name, func() {
result := nullInt(tt.input)
s.Equal(tt.wantInt, result.Int64)
s.Equal(tt.wantBool, result.Valid)
})
}
}
func (s *HelpersSuite) TestRepeatPlaceholders() {
tests := []struct {
name string
input int
expected string
}{
{
name: "zero",
input: 0,
expected: "",
},
{
name: "negative",
input: -1,
expected: "",
},
{
name: "one",
input: 1,
expected: ", ?",
},
{
name: "three",
input: 3,
expected: ", ?, ?, ?",
},
}
for _, tt := range tests {
s.Run(tt.name, func() {
result := repeatPlaceholders(tt.input)
s.Equal(tt.expected, result)
})
}
}
func (s *HelpersSuite) TestInt64SliceToInterface() {
tests := []struct {
name string
input []int64
expected int
}{
{
name: "empty slice",
input: []int64{},
expected: 0,
},
{
name: "single element",
input: []int64{42},
expected: 1,
},
{
name: "multiple elements",
input: []int64{1, 2, 3, 4, 5},
expected: 5,
},
}
for _, tt := range tests {
s.Run(tt.name, func() {
result := int64SliceToInterface(tt.input)
s.Len(result, tt.expected)
for i, v := range result {
s.Equal(tt.input[i], v)
}
})
}
}
// TestBuildGetByIDsQuery tests the shared query builder.
func TestBuildGetByIDsQuery(t *testing.T) {
tests := []struct {
name string
baseQuery string
ids []int64
orderBy string
limit int
wantQuery string
wantArgs int
}{
{
name: "single id, no limit, desc order",
baseQuery: "SELECT * FROM test",
ids: []int64{1},
orderBy: "date_desc",
limit: 0,
wantQuery: "SELECT * FROM test WHERE id IN (?)\n\t\tORDER BY created_at_epoch DESC",
wantArgs: 1,
},
{
name: "multiple ids with limit and asc order",
baseQuery: "SELECT * FROM test",
ids: []int64{1, 2, 3},
orderBy: "date_asc",
limit: 10,
wantQuery: "SELECT * FROM test WHERE id IN (?, ?, ?)\n\t\tORDER BY created_at_epoch ASC LIMIT ?",
wantArgs: 4,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
query, args := BuildGetByIDsQuery(tt.baseQuery, tt.ids, tt.orderBy, tt.limit)
assert.Contains(t, query, "WHERE id IN")
assert.Len(t, args, tt.wantArgs)
})
}
}
// TestEnsureSessionExists tests session auto-creation.
func TestEnsureSessionExists(t *testing.T) {
db, _, cleanup := testDB(t)
defer cleanup()
createBaseTables(t, db)
store := newStoreFromDB(db)
ctx := context.Background()
tests := []struct {
name string
sdkSessionID string
project string
setup func()
wantErr bool
}{
{
name: "create new session",
sdkSessionID: "sdk-new",
project: "project-a",
wantErr: false,
},
{
name: "session already exists",
sdkSessionID: "sdk-existing",
project: "project-b",
setup: func() {
seedSession(t, db, "sdk-existing", "sdk-existing", "project-b")
},
wantErr: false,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
if tt.setup != nil {
tt.setup()
}
err := EnsureSessionExists(ctx, store, tt.sdkSessionID, tt.project)
if tt.wantErr {
require.Error(t, err)
} else {
require.NoError(t, err)
// Verify session exists
var id int64
err := db.QueryRow("SELECT id FROM sdk_sessions WHERE sdk_session_id = ?", tt.sdkSessionID).Scan(&id)
require.NoError(t, err)
assert.Greater(t, id, int64(0))
}
})
}
}