Increase tests coverage.

This commit is contained in:
2025-12-17 11:40:08 +00:00
parent 3b042263ca
commit 95a1dff901
15 changed files with 6421 additions and 6 deletions
+520
View File
@@ -3,9 +3,11 @@ package hooks
import (
"encoding/json"
"fmt"
"net/http"
"net/http/httptest"
"testing"
"time"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
@@ -190,3 +192,521 @@ func TestFindWorkerBinary(t *testing.T) {
// Result depends on whether worker is installed, so we just check it doesn't panic
t.Logf("findWorkerBinary returned: %s", result)
}
// TestVersionsCompatible tests the versionsCompatible function.
func TestVersionsCompatible(t *testing.T) {
tests := []struct {
name string
v1 string
v2 string
expected bool
}{
{
name: "identical versions",
v1: "v1.0.0",
v2: "v1.0.0",
expected: true,
},
{
name: "same base different suffix",
v1: "v1.0.0",
v2: "v1.0.0-dirty",
expected: true,
},
{
name: "same base with commit hash",
v1: "v1.0.0-2-gca711a8",
v2: "v1.0.0-5-gabcdef1-dirty",
expected: true,
},
{
name: "different base versions",
v1: "v1.0.0",
v2: "v2.0.0",
expected: false,
},
{
name: "dev version compatible with anything",
v1: "dev",
v2: "v1.0.0",
expected: true,
},
{
name: "anything compatible with dev",
v1: "v2.0.0-dirty",
v2: "dev",
expected: true,
},
{
name: "both dev versions",
v1: "dev",
v2: "dev",
expected: true,
},
{
name: "minor version difference",
v1: "v1.2.0",
v2: "v1.3.0",
expected: false,
},
{
name: "patch version difference",
v1: "v1.0.1",
v2: "v1.0.2",
expected: false,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
result := versionsCompatible(tt.v1, tt.v2)
assert.Equal(t, tt.expected, result)
})
}
}
// TestExtractBaseVersion tests the extractBaseVersion function.
func TestExtractBaseVersion(t *testing.T) {
tests := []struct {
name string
version string
expected string
}{
{
name: "simple version with v prefix",
version: "v1.0.0",
expected: "1.0.0",
},
{
name: "version without v prefix",
version: "1.0.0",
expected: "1.0.0",
},
{
name: "version with commit suffix",
version: "v0.3.5-2-gca711a8",
expected: "0.3.5",
},
{
name: "version with dirty suffix",
version: "v0.3.5-dirty",
expected: "0.3.5",
},
{
name: "version with full suffix",
version: "v0.3.5-2-gca711a8-dirty",
expected: "0.3.5",
},
{
name: "dev version",
version: "dev",
expected: "dev",
},
{
name: "empty version",
version: "",
expected: "",
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
result := extractBaseVersion(tt.version)
assert.Equal(t, tt.expected, result)
})
}
}
// TestPOST tests the POST function with a mock server.
func TestPOST(t *testing.T) {
tests := []struct {
name string
serverHandler func(w http.ResponseWriter, r *http.Request)
body interface{}
expectError bool
expectedResult map[string]interface{}
}{
{
name: "successful POST with JSON response",
serverHandler: func(w http.ResponseWriter, r *http.Request) {
assert.Equal(t, http.MethodPost, r.Method)
assert.Equal(t, "application/json", r.Header.Get("Content-Type"))
w.WriteHeader(http.StatusOK)
json.NewEncoder(w).Encode(map[string]interface{}{"status": "ok"})
},
body: map[string]string{"key": "value"},
expectError: false,
expectedResult: map[string]interface{}{"status": "ok"},
},
{
name: "POST with 400 error",
serverHandler: func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(http.StatusBadRequest)
},
body: map[string]string{"key": "value"},
expectError: true,
},
{
name: "POST with 500 error",
serverHandler: func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(http.StatusInternalServerError)
},
body: map[string]string{"key": "value"},
expectError: true,
},
{
name: "POST with non-JSON response",
serverHandler: func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(http.StatusOK)
w.Write([]byte("not json"))
},
body: map[string]string{"key": "value"},
expectError: false,
expectedResult: nil, // Non-JSON returns nil
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
server := httptest.NewServer(http.HandlerFunc(tt.serverHandler))
defer server.Close()
// Extract port from test server
var port int
_, err := fmt.Sscanf(server.URL, "http://127.0.0.1:%d", &port)
require.NoError(t, err)
result, err := POST(port, "/test", tt.body)
if tt.expectError {
assert.Error(t, err)
} else {
assert.NoError(t, err)
if tt.expectedResult != nil {
assert.Equal(t, tt.expectedResult["status"], result["status"])
}
}
})
}
}
// TestGET tests the GET function with a mock server.
func TestGET(t *testing.T) {
tests := []struct {
name string
serverHandler func(w http.ResponseWriter, r *http.Request)
expectError bool
expectedResult map[string]interface{}
}{
{
name: "successful GET with JSON response",
serverHandler: func(w http.ResponseWriter, r *http.Request) {
assert.Equal(t, http.MethodGet, r.Method)
w.WriteHeader(http.StatusOK)
json.NewEncoder(w).Encode(map[string]interface{}{"data": "test"})
},
expectError: false,
expectedResult: map[string]interface{}{"data": "test"},
},
{
name: "GET with 404 error",
serverHandler: func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(http.StatusNotFound)
},
expectError: true,
},
{
name: "GET with invalid JSON",
serverHandler: func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(http.StatusOK)
w.Write([]byte("not valid json"))
},
expectError: true,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
server := httptest.NewServer(http.HandlerFunc(tt.serverHandler))
defer server.Close()
// Extract port from test server
var port int
_, err := fmt.Sscanf(server.URL, "http://127.0.0.1:%d", &port)
require.NoError(t, err)
result, err := GET(port, "/test")
if tt.expectError {
assert.Error(t, err)
} else {
assert.NoError(t, err)
if tt.expectedResult != nil {
assert.Equal(t, tt.expectedResult["data"], result["data"])
}
}
})
}
}
// TestProjectIDWithName_Comprehensive tests ProjectIDWithName more thoroughly.
func TestProjectIDWithName_Comprehensive(t *testing.T) {
tests := []struct {
name string
cwd string
expectedPrefix string
expectedLen int // Expected minimum length (prefix + _ + 6 char hash)
}{
{
name: "standard project path",
cwd: "/Users/test/projects/my-project",
expectedPrefix: "my-project_",
expectedLen: 17, // "my-project_" + 6 char hash
},
{
name: "short directory name",
cwd: "/tmp",
expectedPrefix: "tmp_",
expectedLen: 10, // "tmp_" + 6 char hash
},
{
name: "nested path",
cwd: "/home/user/code/org/repo",
expectedPrefix: "repo_",
expectedLen: 11, // "repo_" + 6 char hash
},
{
name: "path with special characters",
cwd: "/Users/test/my-special.project",
expectedPrefix: "my-special.project_",
expectedLen: 25,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
result := ProjectIDWithName(tt.cwd)
assert.True(t, len(result) >= tt.expectedLen, "result %s should be at least %d chars", result, tt.expectedLen)
assert.Contains(t, result, tt.expectedPrefix[:len(tt.expectedPrefix)-1]) // Check without trailing underscore
assert.Contains(t, result, "_")
// Verify hash uniqueness - same path should give same result
result2 := ProjectIDWithName(tt.cwd)
assert.Equal(t, result, result2)
})
}
}
// TestProjectIDWithName_Uniqueness tests that different paths produce different IDs.
func TestProjectIDWithName_Uniqueness(t *testing.T) {
paths := []string{
"/Users/test/project-a",
"/Users/test/project-b",
"/Users/other/project-a",
"/tmp/project-a",
}
ids := make(map[string]bool)
for _, path := range paths {
id := ProjectIDWithName(path)
assert.False(t, ids[id], "duplicate ID generated for path %s", path)
ids[id] = true
}
}
// TestHookConstants tests hook-related constants.
func TestHookConstants(t *testing.T) {
assert.Equal(t, 37777, DefaultWorkerPort)
assert.Equal(t, 1*time.Second, HealthCheckTimeout)
assert.Equal(t, 30*time.Second, StartupTimeout)
}
// TestExitCodes tests exit code constants.
func TestExitCodes(t *testing.T) {
assert.Equal(t, 0, ExitSuccess)
assert.Equal(t, 1, ExitFailure)
assert.Equal(t, 3, ExitUserMessageOnly)
}
// TestHookResponse tests HookResponse struct.
func TestHookResponse(t *testing.T) {
tests := []struct {
name string
response HookResponse
expected string
}{
{
name: "continue true",
response: HookResponse{Continue: true},
expected: `{"continue":true}`,
},
{
name: "continue false",
response: HookResponse{Continue: false},
expected: `{"continue":false}`,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
data, err := json.Marshal(tt.response)
require.NoError(t, err)
assert.JSONEq(t, tt.expected, string(data))
})
}
}
// TestBaseInput tests BaseInput struct parsing.
func TestBaseInput(t *testing.T) {
input := `{
"session_id": "test-session-123",
"cwd": "/Users/test/project",
"permission_mode": "standard",
"hook_event_name": "session-start"
}`
var base BaseInput
err := json.Unmarshal([]byte(input), &base)
require.NoError(t, err)
assert.Equal(t, "test-session-123", base.SessionID)
assert.Equal(t, "/Users/test/project", base.CWD)
assert.Equal(t, "standard", base.PermissionMode)
assert.Equal(t, "session-start", base.HookEventName)
}
// TestHookContext tests HookContext struct.
func TestHookContext(t *testing.T) {
ctx := &HookContext{
HookName: "session-start",
Port: 37777,
Project: "my-project_abc123",
SessionID: "test-session",
CWD: "/Users/test/project",
RawInput: []byte(`{"key":"value"}`),
}
assert.Equal(t, "session-start", ctx.HookName)
assert.Equal(t, 37777, ctx.Port)
assert.Equal(t, "my-project_abc123", ctx.Project)
assert.Equal(t, "test-session", ctx.SessionID)
assert.Equal(t, "/Users/test/project", ctx.CWD)
assert.Equal(t, []byte(`{"key":"value"}`), ctx.RawInput)
}
// TestIsWorkerRunning_WithServer tests IsWorkerRunning with actual server.
func TestIsWorkerRunning_WithServer(t *testing.T) {
tests := []struct {
name string
serverHandler func(w http.ResponseWriter, r *http.Request)
expectedResult bool
}{
{
name: "healthy worker returns true",
serverHandler: func(w http.ResponseWriter, r *http.Request) {
if r.URL.Path == "/api/health" {
w.WriteHeader(http.StatusOK)
}
},
expectedResult: true,
},
{
name: "unhealthy worker returns false",
serverHandler: func(w http.ResponseWriter, r *http.Request) {
if r.URL.Path == "/api/health" {
w.WriteHeader(http.StatusServiceUnavailable)
}
},
expectedResult: false,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
server := httptest.NewServer(http.HandlerFunc(tt.serverHandler))
defer server.Close()
// Extract port - note: test server binds to 127.0.0.1
var port int
_, err := fmt.Sscanf(server.URL, "http://127.0.0.1:%d", &port)
require.NoError(t, err)
// The function uses hardcoded 127.0.0.1, which matches httptest
result := IsWorkerRunning(port)
assert.Equal(t, tt.expectedResult, result)
})
}
}
// TestIsPortInUse_WithServer tests IsPortInUse with actual server.
func TestIsPortInUse_WithServer(t *testing.T) {
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(http.StatusOK)
}))
defer server.Close()
// Extract port
var port int
_, err := fmt.Sscanf(server.URL, "http://127.0.0.1:%d", &port)
require.NoError(t, err)
// Port should be in use
assert.True(t, IsPortInUse(port))
}
// TestGetWorkerVersion_WithServer tests GetWorkerVersion with actual server.
func TestGetWorkerVersion_WithServer(t *testing.T) {
tests := []struct {
name string
serverHandler func(w http.ResponseWriter, r *http.Request)
expectedResult string
}{
{
name: "returns version from server",
serverHandler: func(w http.ResponseWriter, r *http.Request) {
if r.URL.Path == "/api/version" {
w.WriteHeader(http.StatusOK)
json.NewEncoder(w).Encode(map[string]string{"version": "v1.2.3"})
}
},
expectedResult: "v1.2.3",
},
{
name: "returns empty on non-200",
serverHandler: func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(http.StatusNotFound)
},
expectedResult: "",
},
{
name: "returns empty on invalid JSON",
serverHandler: func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(http.StatusOK)
w.Write([]byte("not json"))
},
expectedResult: "",
},
{
name: "returns empty on missing version field",
serverHandler: func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(http.StatusOK)
json.NewEncoder(w).Encode(map[string]string{"other": "field"})
},
expectedResult: "",
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
server := httptest.NewServer(http.HandlerFunc(tt.serverHandler))
defer server.Close()
var port int
_, err := fmt.Sscanf(server.URL, "http://127.0.0.1:%d", &port)
require.NoError(t, err)
result := GetWorkerVersion(port)
assert.Equal(t, tt.expectedResult, result)
})
}
}
+424
View File
@@ -0,0 +1,424 @@
// Package models contains domain models for claude-mnemonic.
package models
import (
"database/sql"
"encoding/json"
"testing"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"github.com/stretchr/testify/suite"
)
// ObservationSuite is a test suite for Observation operations.
type ObservationSuite struct {
suite.Suite
}
func TestObservationSuite(t *testing.T) {
suite.Run(t, new(ObservationSuite))
}
// TestObservationTypeConstants tests observation type constants.
func (s *ObservationSuite) TestObservationTypeConstants() {
s.Equal(ObservationType("discovery"), ObsTypeDiscovery)
s.Equal(ObservationType("decision"), ObsTypeDecision)
s.Equal(ObservationType("bugfix"), ObsTypeBugfix)
s.Equal(ObservationType("feature"), ObsTypeFeature)
s.Equal(ObservationType("refactor"), ObsTypeRefactor)
s.Equal(ObservationType("change"), ObsTypeChange)
}
// TestScopeConstants tests scope constants.
func (s *ObservationSuite) TestScopeConstants() {
s.Equal(ObservationScope("project"), ScopeProject)
s.Equal(ObservationScope("global"), ScopeGlobal)
}
// TestGlobalizableConcepts tests that globalizable concepts are defined.
func (s *ObservationSuite) TestGlobalizableConcepts() {
expected := []string{
"best-practice", "pattern", "anti-pattern", "architecture",
"security", "performance", "testing",
"debugging", "workflow", "tooling",
}
s.Equal(expected, GlobalizableConcepts)
}
// TestDetermineScope_TableDriven tests scope determination with various concepts.
func (s *ObservationSuite) TestDetermineScope_TableDriven() {
tests := []struct {
name string
concepts []string
expected ObservationScope
}{
{
name: "empty concepts - project scope",
concepts: []string{},
expected: ScopeProject,
},
{
name: "no globalizable concepts - project scope",
concepts: []string{"how-it-works", "custom-tag"},
expected: ScopeProject,
},
{
name: "security concept - global scope",
concepts: []string{"security"},
expected: ScopeGlobal,
},
{
name: "best-practice concept - global scope",
concepts: []string{"best-practice"},
expected: ScopeGlobal,
},
{
name: "mixed concepts with globalizable - global scope",
concepts: []string{"how-it-works", "security"},
expected: ScopeGlobal,
},
{
name: "performance concept - global scope",
concepts: []string{"performance"},
expected: ScopeGlobal,
},
{
name: "testing concept - global scope",
concepts: []string{"testing"},
expected: ScopeGlobal,
},
{
name: "pattern concept - global scope",
concepts: []string{"pattern"},
expected: ScopeGlobal,
},
}
for _, tt := range tests {
s.Run(tt.name, func() {
result := DetermineScope(tt.concepts)
s.Equal(tt.expected, result)
})
}
}
// TestParsedObservation_FileMtimesJSON tests FileMtimes JSON serialization.
func (s *ObservationSuite) TestParsedObservation_FileMtimesJSON() {
obs := &ParsedObservation{
Type: ObsTypeDiscovery,
Title: "Test",
FileMtimes: map[string]int64{"file1.go": 1234567890, "file2.go": 1234567891},
}
// Verify mtimes can be marshaled
data, err := json.Marshal(obs.FileMtimes)
s.NoError(err)
s.Contains(string(data), "file1.go")
s.Contains(string(data), "1234567890")
}
// TestObservation_CheckStaleness_TableDriven tests staleness checking.
func (s *ObservationSuite) TestObservation_CheckStaleness_TableDriven() {
tests := []struct {
name string
storedMtimes map[string]int64
currentMtimes map[string]int64
expectedStale bool
}{
{
name: "empty stored mtimes - not stale",
storedMtimes: map[string]int64{},
currentMtimes: map[string]int64{"file.go": 1000},
expectedStale: false,
},
{
name: "matching mtimes - not stale",
storedMtimes: map[string]int64{"file.go": 1000},
currentMtimes: map[string]int64{"file.go": 1000},
expectedStale: false,
},
{
name: "file modified - stale",
storedMtimes: map[string]int64{"file.go": 1000},
currentMtimes: map[string]int64{"file.go": 2000},
expectedStale: true,
},
{
name: "file missing from current - not stale (files might not be checked)",
storedMtimes: map[string]int64{"file.go": 1000},
currentMtimes: map[string]int64{},
expectedStale: false, // Missing files don't mark as stale per the implementation
},
{
name: "multiple files, one modified - stale",
storedMtimes: map[string]int64{"file1.go": 1000, "file2.go": 2000},
currentMtimes: map[string]int64{"file1.go": 1000, "file2.go": 3000},
expectedStale: true,
},
{
name: "nil current mtimes - not stale",
storedMtimes: map[string]int64{"file.go": 1000},
currentMtimes: nil,
expectedStale: false,
},
}
for _, tt := range tests {
s.Run(tt.name, func() {
obs := &Observation{
FileMtimes: tt.storedMtimes,
}
result := obs.CheckStaleness(tt.currentMtimes)
s.Equal(tt.expectedStale, result)
})
}
}
// TestObservation_MarshalJSON tests JSON marshaling of Observation.
func (s *ObservationSuite) TestObservation_MarshalJSON() {
obs := &Observation{
ID: 1,
Project: "test-project",
Type: ObsTypeDiscovery,
Title: sql.NullString{String: "Test Title", Valid: true},
Scope: ScopeProject,
}
data, err := json.Marshal(obs)
s.NoError(err)
s.Contains(string(data), `"id":1`)
s.Contains(string(data), `"project":"test-project"`)
s.Contains(string(data), `"type":"discovery"`)
}
// TestParsedObservation_Fields tests ParsedObservation field access.
func (s *ObservationSuite) TestParsedObservation_Fields() {
obs := &ParsedObservation{
Type: ObsTypeFeature,
Title: "Add authentication",
Subtitle: "JWT-based auth",
Narrative: "Implemented JWT authentication for API endpoints",
Facts: []string{"Uses RS256 algorithm", "Tokens expire in 24h"},
Concepts: []string{"security", "auth"},
FilesRead: []string{"config.go"},
FilesModified: []string{"handler.go", "middleware.go"},
FileMtimes: map[string]int64{"handler.go": 1234567890},
}
s.Equal(ObsTypeFeature, obs.Type)
s.Equal("Add authentication", obs.Title)
s.Equal("JWT-based auth", obs.Subtitle)
s.Contains(obs.Narrative, "JWT")
s.Len(obs.Facts, 2)
s.Len(obs.Concepts, 2)
s.Len(obs.FilesRead, 1)
s.Len(obs.FilesModified, 2)
s.Len(obs.FileMtimes, 1)
}
// TestObservation_NullFields tests handling of nullable fields.
func (s *ObservationSuite) TestObservation_NullFields() {
// Test with null fields
obs := &Observation{
ID: 1,
Project: "test",
Type: ObsTypeDiscovery,
Title: sql.NullString{Valid: false},
Subtitle: sql.NullString{Valid: false},
Narrative: sql.NullString{Valid: false},
}
s.False(obs.Title.Valid)
s.False(obs.Subtitle.Valid)
s.False(obs.Narrative.Valid)
// Test with valid fields
obs2 := &Observation{
ID: 2,
Project: "test",
Type: ObsTypeBugfix,
Title: sql.NullString{String: "Fix bug", Valid: true},
Subtitle: sql.NullString{String: "Memory leak", Valid: true},
Narrative: sql.NullString{String: "Fixed memory leak in handler", Valid: true},
}
s.True(obs2.Title.Valid)
s.Equal("Fix bug", obs2.Title.String)
s.True(obs2.Subtitle.Valid)
s.Equal("Memory leak", obs2.Subtitle.String)
}
// TestNewObservation tests observation creation from parsed data.
func TestNewObservation(t *testing.T) {
parsed := &ParsedObservation{
Type: ObsTypeFeature,
Title: "Add authentication",
Subtitle: "JWT-based",
Narrative: "Implemented JWT auth",
Facts: []string{"Uses RS256"},
Concepts: []string{"security"},
FilesRead: []string{"config.go"},
FilesModified: []string{"handler.go"},
FileMtimes: map[string]int64{"handler.go": 1234567890},
}
obs := NewObservation("sdk-123", "test-project", parsed, 5, 1000)
assert.Equal(t, "sdk-123", obs.SDKSessionID)
assert.Equal(t, "test-project", obs.Project)
assert.Equal(t, ScopeGlobal, obs.Scope) // security triggers global
assert.Equal(t, ObsTypeFeature, obs.Type)
assert.Equal(t, "Add authentication", obs.Title.String)
assert.True(t, obs.Title.Valid)
assert.Equal(t, int64(5), obs.PromptNumber.Int64)
assert.Equal(t, int64(1000), obs.DiscoveryTokens)
assert.NotEmpty(t, obs.CreatedAt)
assert.Greater(t, obs.CreatedAtEpoch, int64(0))
}
// TestParsedObservation_ToStoredObservation tests conversion.
func TestParsedObservation_ToStoredObservation(t *testing.T) {
parsed := &ParsedObservation{
Type: ObsTypeDiscovery,
Title: "Test Title",
Subtitle: "Test Subtitle",
Narrative: "Test narrative",
Facts: []string{"Fact 1"},
Concepts: []string{"testing"},
}
obs := parsed.ToStoredObservation()
assert.Equal(t, ObsTypeDiscovery, obs.Type)
assert.Equal(t, "Test Title", obs.Title.String)
assert.True(t, obs.Title.Valid)
assert.Equal(t, "Test Subtitle", obs.Subtitle.String)
assert.True(t, obs.Subtitle.Valid)
}
// TestJSONStringArray tests JSONStringArray scanning.
func TestJSONStringArray(t *testing.T) {
tests := []struct {
name string
input interface{}
wantErr bool
expected JSONStringArray
}{
{
name: "nil input",
input: nil,
wantErr: false,
expected: nil,
},
{
name: "empty string",
input: "",
wantErr: false,
expected: nil,
},
{
name: "json array string",
input: `["item1", "item2"]`,
wantErr: false,
expected: JSONStringArray{"item1", "item2"},
},
{
name: "json array bytes",
input: []byte(`["a", "b", "c"]`),
wantErr: false,
expected: JSONStringArray{"a", "b", "c"},
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
var arr JSONStringArray
err := arr.Scan(tt.input)
if tt.wantErr {
assert.Error(t, err)
} else {
assert.NoError(t, err)
assert.Equal(t, tt.expected, arr)
}
})
}
}
// TestJSONInt64Map tests JSONInt64Map scanning.
func TestJSONInt64Map(t *testing.T) {
tests := []struct {
name string
input interface{}
wantErr bool
expected JSONInt64Map
}{
{
name: "nil input",
input: nil,
wantErr: false,
expected: nil,
},
{
name: "empty string",
input: "",
wantErr: false,
expected: nil,
},
{
name: "json map string",
input: `{"file.go": 1234567890}`,
wantErr: false,
expected: JSONInt64Map{"file.go": 1234567890},
},
{
name: "json map bytes",
input: []byte(`{"a.go": 100, "b.go": 200}`),
wantErr: false,
expected: JSONInt64Map{"a.go": 100, "b.go": 200},
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
var m JSONInt64Map
err := m.Scan(tt.input)
if tt.wantErr {
assert.Error(t, err)
} else {
assert.NoError(t, err)
assert.Equal(t, tt.expected, m)
}
})
}
}
// TestObservation_JSONRoundTrip tests that observations can be marshaled and unmarshaled.
func TestObservation_JSONRoundTrip(t *testing.T) {
original := &Observation{
ID: 1,
SDKSessionID: "session-123",
Project: "test-project",
Type: ObsTypeDiscovery,
Title: sql.NullString{String: "Test Title", Valid: true},
Subtitle: sql.NullString{String: "Test Subtitle", Valid: true},
Narrative: sql.NullString{String: "Test narrative content", Valid: true},
Scope: ScopeProject,
CreatedAt: "2024-01-01T00:00:00Z",
CreatedAtEpoch: 1704067200000,
}
// Marshal
data, err := json.Marshal(original)
require.NoError(t, err)
// Unmarshal into map to check fields
var result map[string]interface{}
err = json.Unmarshal(data, &result)
require.NoError(t, err)
assert.Equal(t, float64(1), result["id"])
assert.Equal(t, "test-project", result["project"])
assert.Equal(t, "discovery", result["type"])
assert.Equal(t, "Test Title", result["title"])
}
+267
View File
@@ -0,0 +1,267 @@
// Package models contains domain models for claude-mnemonic.
package models
import (
"database/sql"
"encoding/json"
"testing"
"time"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"github.com/stretchr/testify/suite"
)
// SummarySuite is a test suite for SessionSummary operations.
type SummarySuite struct {
suite.Suite
}
func TestSummarySuite(t *testing.T) {
suite.Run(t, new(SummarySuite))
}
// TestNewSessionSummary tests summary creation.
func (s *SummarySuite) TestNewSessionSummary() {
parsed := &ParsedSummary{
Request: "Fix the bug in handler.go",
Investigated: "Looked at error logs",
Learned: "The issue was a race condition",
Completed: "Fixed the race condition",
NextSteps: "Add more tests",
Notes: "Consider adding mutex",
}
summary := NewSessionSummary("sdk-123", "test-project", parsed, 5, 1000)
s.NotNil(summary)
s.Equal("sdk-123", summary.SDKSessionID)
s.Equal("test-project", summary.Project)
s.True(summary.Request.Valid)
s.Equal("Fix the bug in handler.go", summary.Request.String)
s.True(summary.Investigated.Valid)
s.True(summary.Learned.Valid)
s.True(summary.Completed.Valid)
s.True(summary.NextSteps.Valid)
s.True(summary.Notes.Valid)
s.True(summary.PromptNumber.Valid)
s.Equal(int64(5), summary.PromptNumber.Int64)
s.Equal(int64(1000), summary.DiscoveryTokens)
s.NotEmpty(summary.CreatedAt)
s.Greater(summary.CreatedAtEpoch, int64(0))
}
// TestNewSessionSummary_EmptyFields tests summary creation with empty fields.
func (s *SummarySuite) TestNewSessionSummary_EmptyFields() {
parsed := &ParsedSummary{
Request: "Test request",
// All other fields empty
}
summary := NewSessionSummary("sdk-123", "project", parsed, 0, 0)
s.True(summary.Request.Valid)
s.False(summary.Investigated.Valid)
s.False(summary.Learned.Valid)
s.False(summary.Completed.Valid)
s.False(summary.NextSteps.Valid)
s.False(summary.Notes.Valid)
s.False(summary.PromptNumber.Valid) // 0 is not valid
s.Equal(int64(0), summary.DiscoveryTokens)
}
// TestSessionSummary_MarshalJSON tests JSON marshaling.
func (s *SummarySuite) TestSessionSummary_MarshalJSON() {
summary := &SessionSummary{
ID: 1,
SDKSessionID: "sdk-123",
Project: "test-project",
Request: sql.NullString{String: "Test request", Valid: true},
Investigated: sql.NullString{String: "Test investigation", Valid: true},
Learned: sql.NullString{Valid: false}, // Invalid - should be omitted
Completed: sql.NullString{String: "Test completion", Valid: true},
NextSteps: sql.NullString{Valid: false},
Notes: sql.NullString{String: "Test notes", Valid: true},
PromptNumber: sql.NullInt64{Int64: 3, Valid: true},
DiscoveryTokens: 500,
CreatedAt: "2024-01-01T00:00:00Z",
CreatedAtEpoch: 1704067200000,
}
data, err := json.Marshal(summary)
s.NoError(err)
// Parse the JSON
var result map[string]interface{}
err = json.Unmarshal(data, &result)
s.NoError(err)
// Check fields
s.Equal(float64(1), result["id"])
s.Equal("sdk-123", result["sdk_session_id"])
s.Equal("test-project", result["project"])
s.Equal("Test request", result["request"])
s.Equal("Test investigation", result["investigated"])
s.Equal("Test completion", result["completed"])
s.Equal("Test notes", result["notes"])
s.Equal(float64(3), result["prompt_number"])
s.Equal(float64(500), result["discovery_tokens"])
// Empty fields should be omitted
_, hasLearned := result["learned"]
s.False(hasLearned, "Empty learned should be omitted")
_, hasNextSteps := result["next_steps"]
s.False(hasNextSteps, "Empty next_steps should be omitted")
}
// TestSessionSummary_MarshalJSON_AllEmpty tests JSON marshaling with all empty optional fields.
func (s *SummarySuite) TestSessionSummary_MarshalJSON_AllEmpty() {
summary := &SessionSummary{
ID: 1,
SDKSessionID: "sdk-123",
Project: "test-project",
Request: sql.NullString{Valid: false},
Investigated: sql.NullString{Valid: false},
Learned: sql.NullString{Valid: false},
Completed: sql.NullString{Valid: false},
NextSteps: sql.NullString{Valid: false},
Notes: sql.NullString{Valid: false},
PromptNumber: sql.NullInt64{Valid: false},
DiscoveryTokens: 0,
CreatedAt: "2024-01-01T00:00:00Z",
CreatedAtEpoch: 1704067200000,
}
data, err := json.Marshal(summary)
s.NoError(err)
var result map[string]interface{}
err = json.Unmarshal(data, &result)
s.NoError(err)
// Required fields should be present
s.Equal(float64(1), result["id"])
s.Equal("sdk-123", result["sdk_session_id"])
s.Equal("test-project", result["project"])
// Optional fields should be empty strings or omitted
request, hasRequest := result["request"]
if hasRequest {
s.Equal("", request)
}
}
// TestParsedSummary tests ParsedSummary structure.
func (s *SummarySuite) TestParsedSummary() {
parsed := &ParsedSummary{
Request: "Request text",
Investigated: "Investigation text",
Learned: "Learned text",
Completed: "Completed text",
NextSteps: "Next steps text",
Notes: "Notes text",
}
s.Equal("Request text", parsed.Request)
s.Equal("Investigation text", parsed.Investigated)
s.Equal("Learned text", parsed.Learned)
s.Equal("Completed text", parsed.Completed)
s.Equal("Next steps text", parsed.NextSteps)
s.Equal("Notes text", parsed.Notes)
}
// TestSessionSummaryJSON tests the JSON-friendly type.
func (s *SummarySuite) TestSessionSummaryJSON() {
j := SessionSummaryJSON{
ID: 1,
SDKSessionID: "sdk-123",
Project: "test-project",
Request: "Request",
Investigated: "Investigation",
Learned: "Learned",
Completed: "Completed",
NextSteps: "Next steps",
Notes: "Notes",
PromptNumber: 5,
DiscoveryTokens: 1000,
CreatedAt: "2024-01-01T00:00:00Z",
CreatedAtEpoch: 1704067200000,
}
s.Equal(int64(1), j.ID)
s.Equal("sdk-123", j.SDKSessionID)
s.Equal("test-project", j.Project)
s.Equal("Request", j.Request)
s.Equal("Investigation", j.Investigated)
s.Equal("Learned", j.Learned)
s.Equal("Completed", j.Completed)
s.Equal("Next steps", j.NextSteps)
s.Equal("Notes", j.Notes)
s.Equal(int64(5), j.PromptNumber)
s.Equal(int64(1000), j.DiscoveryTokens)
}
// TestSessionSummary_TimestampValidity tests that timestamps are set correctly.
func TestSessionSummary_TimestampValidity(t *testing.T) {
before := time.Now().Add(-time.Second) // Give 1 second buffer
parsed := &ParsedSummary{Request: "Test"}
summary := NewSessionSummary("sdk-123", "project", parsed, 1, 100)
after := time.Now().Add(time.Second) // Give 1 second buffer
// Parse the timestamp
createdAt, err := time.Parse(time.RFC3339, summary.CreatedAt)
require.NoError(t, err)
// Timestamp should be between before and after (with buffer)
assert.True(t, createdAt.After(before) || createdAt.Equal(before), "created_at should be >= before")
assert.True(t, createdAt.Before(after) || createdAt.Equal(after), "created_at should be <= after")
// Epoch should also be in range (with buffer)
beforeEpoch := before.UnixMilli()
afterEpoch := after.UnixMilli()
assert.GreaterOrEqual(t, summary.CreatedAtEpoch, beforeEpoch, "epoch should be >= before epoch")
assert.LessOrEqual(t, summary.CreatedAtEpoch, afterEpoch, "epoch should be <= after epoch")
}
// TestSessionSummary_JSONRoundTrip tests that summaries can be marshaled and unmarshaled.
func TestSessionSummary_JSONRoundTrip(t *testing.T) {
original := &SessionSummary{
ID: 1,
SDKSessionID: "sdk-123",
Project: "test-project",
Request: sql.NullString{String: "Test request", Valid: true},
Investigated: sql.NullString{String: "Test investigation", Valid: true},
Learned: sql.NullString{String: "Test learned", Valid: true},
Completed: sql.NullString{String: "Test completed", Valid: true},
NextSteps: sql.NullString{String: "Test next steps", Valid: true},
Notes: sql.NullString{String: "Test notes", Valid: true},
PromptNumber: sql.NullInt64{Int64: 5, Valid: true},
DiscoveryTokens: 1000,
CreatedAt: "2024-01-01T00:00:00Z",
CreatedAtEpoch: 1704067200000,
}
// Marshal
data, err := json.Marshal(original)
require.NoError(t, err)
// Unmarshal into JSON type
var result SessionSummaryJSON
err = json.Unmarshal(data, &result)
require.NoError(t, err)
// Verify
assert.Equal(t, original.ID, result.ID)
assert.Equal(t, original.SDKSessionID, result.SDKSessionID)
assert.Equal(t, original.Project, result.Project)
assert.Equal(t, original.Request.String, result.Request)
assert.Equal(t, original.Investigated.String, result.Investigated)
assert.Equal(t, original.Learned.String, result.Learned)
assert.Equal(t, original.Completed.String, result.Completed)
assert.Equal(t, original.NextSteps.String, result.NextSteps)
assert.Equal(t, original.Notes.String, result.Notes)
assert.Equal(t, original.PromptNumber.Int64, result.PromptNumber)
assert.Equal(t, original.DiscoveryTokens, result.DiscoveryTokens)
}