fix: plugin no longer vanishes after Claude Code updates

Root cause: plugin registered as directory source in known_marketplaces.json,
which gets wiped on CLI updates. Now registers in extraKnownMarketplaces
(settings.json) as a GitHub source — same mechanism caveman/context-mode use.

Binaries install to ~/.claude-mnemonic/bin/ instead of the Claude-managed
plugins directory. Thin wrapper scripts in the repo let the marketplace
clone find them. Nothing gets cleaned up when Claude refreshes its cache.

Also fixed along the way:
- ONNX Runtime 1.24.3 → 1.26.0 (API v25 mismatch broke all embedding tests)
- Vector client leaked on DB reinit, processQueue had a race on sessionManager
- reloadConfig called os.Exit(0) bypassing graceful shutdown
- Removed dead QueryRowWithTimeout that leaked contexts
- Added tests for graph/watcher/maintenance/update (all were at 0%)
This commit is contained in:
2026-05-24 01:15:23 +01:00
parent cfc95c9ce4
commit f07875ee82
32 changed files with 3217 additions and 127 deletions
-7
View File
@@ -513,13 +513,6 @@ func (s *Store) ExecWithTimeout(ctx context.Context, timeout time.Duration, quer
return nil
}
// QueryRowWithTimeout executes a row query with timeout.
func (s *Store) QueryRowWithTimeout(ctx context.Context, timeout time.Duration, query string, args ...any) *sql.Row {
timeoutCtx, cancel := s.WithTimeout(ctx, timeout, "query_row")
// Note: cancel will be called when row.Scan() completes or errors
_ = cancel // Caller must ensure proper cleanup
return s.sqlDB.QueryRowContext(timeoutCtx, query, args...)
}
// TransactionWithTimeout wraps a transaction function with timeout handling.
// The transaction is automatically rolled back if the context times out.
@@ -1 +1 @@
1.24.3
1.26.0
@@ -1 +1 @@
1.24.3
1.26.0
@@ -1 +1 @@
1.24.3
1.26.0
@@ -1 +1 @@
1.24.3
1.26.0
+674
View File
@@ -0,0 +1,674 @@
//go:build fts5
package graph
import (
"context"
"database/sql"
"testing"
"time"
"github.com/lukaszraczylo/claude-mnemonic/pkg/models"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
// ---- helpers ----------------------------------------------------------------
func makeObs(id int64, sessionID string, concepts, filesRead, filesModified []string) *models.Observation {
return &models.Observation{
ID: id,
SDKSessionID: sessionID,
Title: sql.NullString{String: "title", Valid: true},
Project: "test-project",
Type: models.ObsTypeDecision,
Concepts: concepts,
FilesRead: filesRead,
FilesModified: filesModified,
CreatedAtEpoch: time.Now().UnixMilli(),
}
}
// ---- ObservationGraph -------------------------------------------------------
func TestNewObservationGraph_Empty(t *testing.T) {
g := NewObservationGraph()
require.NotNil(t, g)
stats := g.Stats()
assert.Equal(t, 0, stats.NodeCount)
assert.Equal(t, 0, stats.EdgeCount)
}
func TestAddNode_StoresAndRetrieves(t *testing.T) {
g := NewObservationGraph()
node := &Node{
ID: 42,
Degree: 0,
Metadata: NodeMetadata{
Project: "proj",
Type: "decision",
Title: "test node",
},
}
g.AddNode(node)
got, err := g.GetNode(42)
require.NoError(t, err)
assert.Equal(t, int64(42), got.ID)
assert.Equal(t, "test node", got.Metadata.Title)
}
func TestAddNode_OverwritesExisting(t *testing.T) {
g := NewObservationGraph()
g.AddNode(&Node{ID: 1, Metadata: NodeMetadata{Title: "old"}})
g.AddNode(&Node{ID: 1, Metadata: NodeMetadata{Title: "new"}})
got, err := g.GetNode(1)
require.NoError(t, err)
assert.Equal(t, "new", got.Metadata.Title)
}
func TestGetNode_NotFound(t *testing.T) {
g := NewObservationGraph()
_, err := g.GetNode(999)
assert.Error(t, err)
}
func TestAddEdge_UpdatesDegree(t *testing.T) {
g := NewObservationGraph()
g.AddNode(&Node{ID: 1})
g.AddNode(&Node{ID: 2})
g.AddEdge(Edge{FromID: 1, ToID: 2, Relation: RelationTemporal, Weight: 0.8})
n1, _ := g.GetNode(1)
n2, _ := g.GetNode(2)
assert.Equal(t, 1, n1.Degree)
assert.Equal(t, 1, n2.Degree)
}
func TestAddEdge_MissingNodesDontPanic(t *testing.T) {
g := NewObservationGraph()
// Adding edge referencing non-existent nodes must not panic
assert.NotPanics(t, func() {
g.AddEdge(Edge{FromID: 100, ToID: 200, Relation: RelationConcept, Weight: 0.5})
})
}
// ---- BuildCSR / GetNeighbors ------------------------------------------------
func TestBuildCSR_NoNodes_ReturnsError(t *testing.T) {
g := NewObservationGraph()
err := g.BuildCSR()
assert.Error(t, err)
}
func TestGetNeighbors_AfterBuildCSR(t *testing.T) {
g := NewObservationGraph()
for _, id := range []int64{1, 2, 3} {
g.AddNode(&Node{ID: id})
}
g.AddEdge(Edge{FromID: 1, ToID: 2, Relation: RelationTemporal, Weight: 0.8})
g.AddEdge(Edge{FromID: 1, ToID: 3, Relation: RelationConcept, Weight: 0.6})
require.NoError(t, g.BuildCSR())
neighbors, weights, err := g.GetNeighbors(1)
require.NoError(t, err)
assert.Len(t, neighbors, 2)
assert.Len(t, weights, 2)
}
func TestGetNeighbors_NodeWithNoOutgoingEdges(t *testing.T) {
g := NewObservationGraph()
g.AddNode(&Node{ID: 1})
g.AddNode(&Node{ID: 2})
// Edge only from 2 → 1; node 2 is a leaf from 1's perspective
g.AddEdge(Edge{FromID: 2, ToID: 1, Relation: RelationTemporal, Weight: 0.8})
require.NoError(t, g.BuildCSR())
neighbors, weights, err := g.GetNeighbors(1)
require.NoError(t, err)
assert.Empty(t, neighbors)
assert.Empty(t, weights)
}
func TestGetNeighbors_NodeNotInGraph(t *testing.T) {
g := NewObservationGraph()
g.AddNode(&Node{ID: 1})
require.NoError(t, g.BuildCSR())
_, _, err := g.GetNeighbors(999)
assert.Error(t, err)
}
// ---- FindHubs ---------------------------------------------------------------
func TestFindHubs_EmptyGraph(t *testing.T) {
g := NewObservationGraph()
hubs := g.FindHubs(0.1)
assert.Nil(t, hubs)
}
func TestFindHubs_IdentifiesHighDegreeNodes(t *testing.T) {
g := NewObservationGraph()
// Node 1 connected to everyone else → hub
for id := int64(1); id <= 5; id++ {
g.AddNode(&Node{ID: id})
}
for id := int64(2); id <= 5; id++ {
g.AddEdge(Edge{FromID: 1, ToID: id, Relation: RelationConcept, Weight: 0.5})
}
hubs := g.FindHubs(0.2) // top 20%
assert.Contains(t, hubs, int64(1))
}
func TestFindHubs_Percentile100_ReturnsEmpty(t *testing.T) {
// percentile=1.0 → cutoff = ceil(N * (1 - 1.0)) = ceil(0) = 0 → no hubs
g := NewObservationGraph()
for id := int64(1); id <= 4; id++ {
g.AddNode(&Node{ID: id})
}
hubs := g.FindHubs(1.0)
assert.Empty(t, hubs)
}
func TestFindHubs_Percentile0_ReturnsAllNodes(t *testing.T) {
// percentile=0.0 → cutoff = ceil(N * 1.0) = N → all nodes returned
g := NewObservationGraph()
for id := int64(1); id <= 4; id++ {
g.AddNode(&Node{ID: id})
}
hubs := g.FindHubs(0.0)
assert.Len(t, hubs, 4)
}
// ---- Stats ------------------------------------------------------------------
func TestStats_EdgeTypesCounted(t *testing.T) {
g := NewObservationGraph()
for _, id := range []int64{1, 2, 3} {
g.AddNode(&Node{ID: id})
}
g.AddEdge(Edge{FromID: 1, ToID: 2, Relation: RelationTemporal, Weight: 0.8})
g.AddEdge(Edge{FromID: 1, ToID: 3, Relation: RelationConcept, Weight: 0.6})
g.AddEdge(Edge{FromID: 2, ToID: 3, Relation: RelationConcept, Weight: 0.6})
stats := g.Stats()
assert.Equal(t, 3, stats.NodeCount)
assert.Equal(t, 3, stats.EdgeCount)
assert.Equal(t, 1, stats.EdgeTypes[RelationTemporal])
assert.Equal(t, 2, stats.EdgeTypes[RelationConcept])
}
func TestStats_DegreeMetrics(t *testing.T) {
g := NewObservationGraph()
// Node 1: degree 2, nodes 2,3: degree 1 each
for _, id := range []int64{1, 2, 3} {
g.AddNode(&Node{ID: id})
}
g.AddEdge(Edge{FromID: 1, ToID: 2, Relation: RelationTemporal, Weight: 0.8})
g.AddEdge(Edge{FromID: 1, ToID: 3, Relation: RelationTemporal, Weight: 0.8})
stats := g.Stats()
assert.Equal(t, 2, stats.MaxDegree)
assert.Equal(t, 1, stats.MinDegree)
assert.InDelta(t, 4.0/3.0, stats.AvgDegree, 0.001)
}
// ---- BuildFromObservations --------------------------------------------------
func TestBuildFromObservations_SingleObservation_ReturnsError(t *testing.T) {
obs := []*models.Observation{makeObs(1, "s1", nil, nil, nil)}
// Single observation: DetectEdges returns nil, BuildCSR errors (no nodes never happens
// since node was added — but CSR build will succeed with 1 node and 0 edges).
g, err := BuildFromObservations(context.Background(), obs)
// With 1 node, BuildCSR succeeds (nodes exist); no edges → valid graph.
require.NoError(t, err)
require.NotNil(t, g)
stats := g.Stats()
assert.Equal(t, 1, stats.NodeCount)
assert.Equal(t, 0, stats.EdgeCount)
}
func TestBuildFromObservations_SetsNodeMetadata(t *testing.T) {
obs := []*models.Observation{
{
ID: 7,
SDKSessionID: "sess",
Project: "myproject",
Type: models.ObsTypeFeature,
Title: sql.NullString{String: "feature title", Valid: true},
IsSuperseded: true,
CreatedAtEpoch: time.Now().UnixMilli(),
},
}
g, err := BuildFromObservations(context.Background(), obs)
require.NoError(t, err)
node, err := g.GetNode(7)
require.NoError(t, err)
assert.Equal(t, "myproject", node.Metadata.Project)
assert.Equal(t, "feature title", node.Metadata.Title)
assert.Equal(t, string(models.ObsTypeFeature), node.Metadata.Type)
assert.True(t, node.Metadata.IsSuperseded)
}
func TestBuildFromObservations_TitleMissing_EmptyString(t *testing.T) {
obs := []*models.Observation{
{
ID: 3,
SDKSessionID: "s",
Title: sql.NullString{Valid: false},
CreatedAtEpoch: time.Now().UnixMilli(),
},
}
g, err := BuildFromObservations(context.Background(), obs)
require.NoError(t, err)
node, err := g.GetNode(3)
require.NoError(t, err)
assert.Equal(t, "", node.Metadata.Title)
}
func TestBuildFromObservations_WithTemporalEdges(t *testing.T) {
obs := []*models.Observation{
makeObs(1, "sess-a", nil, nil, nil),
makeObs(2, "sess-a", nil, nil, nil),
makeObs(3, "sess-b", nil, nil, nil),
}
g, err := BuildFromObservations(context.Background(), obs)
require.NoError(t, err)
stats := g.Stats()
assert.Equal(t, 3, stats.NodeCount)
// obs 1 and 2 share session → 1 temporal edge
assert.Equal(t, 1, stats.EdgeTypes[RelationTemporal])
}
func TestBuildFromObservations_WithConceptEdges(t *testing.T) {
obs := []*models.Observation{
makeObs(1, "s1", []string{"security", "auth"}, nil, nil),
makeObs(2, "s2", []string{"security"}, nil, nil),
makeObs(3, "s3", []string{"unrelated"}, nil, nil),
}
g, err := BuildFromObservations(context.Background(), obs)
require.NoError(t, err)
stats := g.Stats()
// obs 1 and 2 share "security"
assert.GreaterOrEqual(t, stats.EdgeTypes[RelationConcept], 1)
}
func TestBuildFromObservations_WithFileOverlapEdges(t *testing.T) {
obs := []*models.Observation{
makeObs(1, "s1", nil, []string{"pkg/foo.go", "pkg/bar.go"}, nil),
makeObs(2, "s2", nil, []string{"pkg/foo.go", "pkg/baz.go"}, nil),
}
g, err := BuildFromObservations(context.Background(), obs)
require.NoError(t, err)
stats := g.Stats()
// Jaccard({foo,bar},{foo,baz}) = 1/3 ≈ 0.333 > MinFileOverlapForEdge(0.3)
assert.GreaterOrEqual(t, stats.EdgeTypes[RelationFileOverlap], 1)
}
// ---- RelationType.String() --------------------------------------------------
func TestRelationType_String(t *testing.T) {
cases := []struct {
rt RelationType
want string
}{
{RelationFileOverlap, "file_overlap"},
{RelationSemantic, "semantic"},
{RelationTemporal, "temporal"},
{RelationConcept, "concept"},
{RelationType(99), "unknown"},
}
for _, tc := range cases {
t.Run(tc.want, func(t *testing.T) {
assert.Equal(t, tc.want, tc.rt.String())
})
}
}
// ---- DetectEdges (edge_detector.go) ----------------------------------------
func TestDetectEdges_LessThanTwo_ReturnsNil(t *testing.T) {
edges, err := DetectEdges(context.Background(), nil)
assert.NoError(t, err)
assert.Nil(t, edges)
edges, err = DetectEdges(context.Background(), []*models.Observation{makeObs(1, "s", nil, nil, nil)})
assert.NoError(t, err)
assert.Nil(t, edges)
}
func TestDetectEdges_SameSession_CreatesTemporalEdges(t *testing.T) {
obs := []*models.Observation{
makeObs(1, "session-x", nil, nil, nil),
makeObs(2, "session-x", nil, nil, nil),
makeObs(3, "session-x", nil, nil, nil),
}
edges, err := DetectEdges(context.Background(), obs)
require.NoError(t, err)
var temporal []Edge
for _, e := range edges {
if e.Relation == RelationTemporal {
temporal = append(temporal, e)
}
}
// Consecutive pairs: (1,2) and (2,3)
assert.Len(t, temporal, 2)
assert.InDelta(t, 0.8, temporal[0].Weight, 0.001)
}
func TestDetectEdges_DifferentSessions_NoTemporalEdges(t *testing.T) {
obs := []*models.Observation{
makeObs(1, "sess-a", nil, nil, nil),
makeObs(2, "sess-b", nil, nil, nil),
}
edges, err := DetectEdges(context.Background(), obs)
require.NoError(t, err)
for _, e := range edges {
assert.NotEqual(t, RelationTemporal, e.Relation)
}
}
func TestDetectEdges_EmptySessionID_NoTemporalEdge(t *testing.T) {
obs := []*models.Observation{
makeObs(1, "", nil, nil, nil),
makeObs(2, "", nil, nil, nil),
}
edges, err := DetectEdges(context.Background(), obs)
require.NoError(t, err)
for _, e := range edges {
assert.NotEqual(t, RelationTemporal, e.Relation)
}
}
func TestDetectEdges_SharedConcepts_CreatesConceptEdges(t *testing.T) {
obs := []*models.Observation{
makeObs(1, "s1", []string{"performance", "caching"}, nil, nil),
makeObs(2, "s2", []string{"performance"}, nil, nil),
makeObs(3, "s3", []string{"caching"}, nil, nil),
}
edges, err := DetectEdges(context.Background(), obs)
require.NoError(t, err)
conceptEdges := filterByRelation(edges, RelationConcept)
// obs1↔obs2 (performance), obs1↔obs3 (caching) → 2 concept edges
assert.Len(t, conceptEdges, 2)
}
func TestDetectEdges_NoConcepts_NoConceptEdges(t *testing.T) {
obs := []*models.Observation{
makeObs(1, "s1", nil, nil, nil),
makeObs(2, "s2", nil, nil, nil),
}
edges, err := DetectEdges(context.Background(), obs)
require.NoError(t, err)
assert.Empty(t, filterByRelation(edges, RelationConcept))
}
func TestDetectEdges_FileOverlap_AboveThreshold_CreatesEdge(t *testing.T) {
// Jaccard 2/3 ≈ 0.667 > 0.3 threshold
obs := []*models.Observation{
makeObs(1, "s1", nil, []string{"a.go", "b.go", "c.go"}, nil),
makeObs(2, "s2", nil, []string{"a.go", "b.go", "d.go"}, nil),
}
edges, err := DetectEdges(context.Background(), obs)
require.NoError(t, err)
fileEdges := filterByRelation(edges, RelationFileOverlap)
require.Len(t, fileEdges, 1)
assert.InDelta(t, 2.0/4.0, float64(fileEdges[0].Weight), 0.01)
}
func TestDetectEdges_FileOverlap_BelowThreshold_NoEdge(t *testing.T) {
// Jaccard 1/9 ≈ 0.11 < 0.3 threshold
obs := []*models.Observation{
makeObs(1, "s1", nil, []string{"a.go", "b.go", "c.go", "d.go", "e.go"}, nil),
makeObs(2, "s2", nil, []string{"a.go", "f.go", "g.go", "h.go", "i.go"}, nil),
}
edges, err := DetectEdges(context.Background(), obs)
require.NoError(t, err)
assert.Empty(t, filterByRelation(edges, RelationFileOverlap))
}
func TestDetectEdges_FilesModified_CountsForOverlap(t *testing.T) {
obs := []*models.Observation{
makeObs(1, "s1", nil, nil, []string{"pkg/core.go", "pkg/util.go"}),
makeObs(2, "s2", nil, []string{"pkg/core.go"}, []string{"pkg/util.go"}),
}
edges, err := DetectEdges(context.Background(), obs)
require.NoError(t, err)
fileEdges := filterByRelation(edges, RelationFileOverlap)
assert.NotEmpty(t, fileEdges)
}
func TestDetectEdges_NoEdgeDuplicates(t *testing.T) {
// Same pair via two concepts → only one concept edge
obs := []*models.Observation{
makeObs(1, "s1", []string{"security", "auth"}, nil, nil),
makeObs(2, "s2", []string{"security", "auth"}, nil, nil),
}
edges, err := DetectEdges(context.Background(), obs)
require.NoError(t, err)
conceptEdges := filterByRelation(edges, RelationConcept)
// Both share security and auth, but deduplication should keep only 1 edge per pair per call
// The seen map deduplicates: only first concept that creates the pair wins
assert.Len(t, conceptEdges, 1)
}
// ---- calculateFileOverlap ---------------------------------------------------
func TestCalculateFileOverlap_DisjointSets_Zero(t *testing.T) {
result := calculateFileOverlap([]string{"a.go", "b.go"}, []string{"c.go", "d.go"})
assert.Equal(t, float32(0.0), result)
}
func TestCalculateFileOverlap_IdenticalSets_One(t *testing.T) {
files := []string{"a.go", "b.go", "c.go"}
result := calculateFileOverlap(files, files)
assert.InDelta(t, 1.0, float64(result), 0.001)
}
func TestCalculateFileOverlap_EmptySlices_Zero(t *testing.T) {
assert.Equal(t, float32(0.0), calculateFileOverlap(nil, []string{"a.go"}))
assert.Equal(t, float32(0.0), calculateFileOverlap([]string{"a.go"}, nil))
assert.Equal(t, float32(0.0), calculateFileOverlap(nil, nil))
}
func TestCalculateFileOverlap_Jaccard_Correct(t *testing.T) {
// {a,b,c} ∩ {b,c,d} = {b,c} → 2/4 = 0.5
result := calculateFileOverlap([]string{"a", "b", "c"}, []string{"b", "c", "d"})
assert.InDelta(t, 0.5, float64(result), 0.001)
}
func TestCalculateFileOverlap_Duplicates_TreatedAsSet(t *testing.T) {
// Duplicates collapse: {a,a,b} → {a,b}; {a,b,b} → {a,b}; Jaccard = 1.0
result := calculateFileOverlap([]string{"a", "a", "b"}, []string{"a", "b", "b"})
assert.InDelta(t, 1.0, float64(result), 0.001)
}
// ---- DetectSemanticEdges ----------------------------------------------------
func TestDetectSemanticEdges_AboveThreshold_CreatesEdge(t *testing.T) {
// Identical vectors → similarity = 1.0 > 0.85
emb := []float32{1.0, 0.0, 0.0}
obs := []*models.Observation{makeObs(1, "s", nil, nil, nil), makeObs(2, "s", nil, nil, nil)}
embeddings := map[int64][]float32{1: emb, 2: emb}
edges := DetectSemanticEdges(context.Background(), obs, embeddings)
require.Len(t, edges, 1)
assert.Equal(t, RelationSemantic, edges[0].Relation)
assert.InDelta(t, 1.0, float64(edges[0].Weight), 0.001)
}
func TestDetectSemanticEdges_BelowThreshold_NoEdge(t *testing.T) {
// Orthogonal vectors → similarity = 0.0
obs := []*models.Observation{makeObs(1, "s", nil, nil, nil), makeObs(2, "s", nil, nil, nil)}
embeddings := map[int64][]float32{
1: {1.0, 0.0, 0.0},
2: {0.0, 1.0, 0.0},
}
edges := DetectSemanticEdges(context.Background(), obs, embeddings)
assert.Empty(t, edges)
}
func TestDetectSemanticEdges_MissingEmbedding_Skipped(t *testing.T) {
obs := []*models.Observation{makeObs(1, "s", nil, nil, nil), makeObs(2, "s", nil, nil, nil)}
// Only obs 1 has embedding
embeddings := map[int64][]float32{1: {1.0, 0.0, 0.0}}
edges := DetectSemanticEdges(context.Background(), obs, embeddings)
assert.Empty(t, edges)
}
func TestDetectSemanticEdges_NoDuplicates(t *testing.T) {
emb := []float32{0.9, 0.1, 0.0}
obs := []*models.Observation{
makeObs(1, "s", nil, nil, nil),
makeObs(2, "s", nil, nil, nil),
makeObs(3, "s", nil, nil, nil),
}
embeddings := map[int64][]float32{1: emb, 2: emb, 3: emb}
edges := DetectSemanticEdges(context.Background(), obs, embeddings)
// 3 pairs: (1,2),(1,3),(2,3)
assert.Len(t, edges, 3)
}
// ---- cosineSimilarity -------------------------------------------------------
func TestCosineSimilarity_IdenticalVectors(t *testing.T) {
v := []float32{1.0, 2.0, 3.0}
result := cosineSimilarity(v, v)
assert.InDelta(t, 1.0, float64(result), 0.0001)
}
func TestCosineSimilarity_OppositeVectors(t *testing.T) {
a := []float32{1.0, 0.0}
b := []float32{-1.0, 0.0}
result := cosineSimilarity(a, b)
assert.InDelta(t, -1.0, float64(result), 0.0001)
}
func TestCosineSimilarity_OrthogonalVectors(t *testing.T) {
a := []float32{1.0, 0.0}
b := []float32{0.0, 1.0}
result := cosineSimilarity(a, b)
assert.InDelta(t, 0.0, float64(result), 0.0001)
}
func TestCosineSimilarity_ZeroVector_ReturnsZero(t *testing.T) {
a := []float32{0.0, 0.0}
b := []float32{1.0, 0.0}
assert.Equal(t, float32(0.0), cosineSimilarity(a, b))
assert.Equal(t, float32(0.0), cosineSimilarity(b, a))
}
func TestCosineSimilarity_MismatchedLength_ReturnsZero(t *testing.T) {
a := []float32{1.0, 2.0}
b := []float32{1.0, 2.0, 3.0}
assert.Equal(t, float32(0.0), cosineSimilarity(a, b))
}
// ---- edgeKey ----------------------------------------------------------------
func TestEdgeKey_Symmetric(t *testing.T) {
// Must produce the same key regardless of order
assert.Equal(t, edgeKey(1, 2), edgeKey(2, 1))
assert.Equal(t, edgeKey(100, 5), edgeKey(5, 100))
}
func TestEdgeKey_DifferentPairs_DifferentKeys(t *testing.T) {
assert.NotEqual(t, edgeKey(1, 2), edgeKey(1, 3))
assert.NotEqual(t, edgeKey(1, 2), edgeKey(2, 3))
}
// ---- pruneEdges -------------------------------------------------------------
func TestPruneEdges_BelowLimit_NoChange(t *testing.T) {
edges := []Edge{
{FromID: 1, ToID: 2, Weight: 0.9},
{FromID: 1, ToID: 3, Weight: 0.7},
}
pruned := pruneEdges(edges, 5)
assert.Len(t, pruned, 2)
}
func TestPruneEdges_ZeroLimit_ReturnsAll(t *testing.T) {
edges := []Edge{{FromID: 1, ToID: 2, Weight: 0.5}}
pruned := pruneEdges(edges, 0)
assert.Len(t, pruned, 1)
}
func TestPruneEdges_KeepsHighWeightEdges(t *testing.T) {
// Node 1 gets 4 edges, limit is 2 → only the 2 heaviest should survive
edges := []Edge{
{FromID: 1, ToID: 2, Weight: 0.9},
{FromID: 1, ToID: 3, Weight: 0.8},
{FromID: 1, ToID: 4, Weight: 0.3},
{FromID: 1, ToID: 5, Weight: 0.1},
}
pruned := pruneEdges(edges, 2)
weights := make([]float32, len(pruned))
for i, e := range pruned {
weights[i] = e.Weight
}
assert.Contains(t, weights, float32(0.9))
assert.Contains(t, weights, float32(0.8))
}
// ---- sortEdgesByWeight ------------------------------------------------------
func TestSortEdgesByWeight_DescendingOrder(t *testing.T) {
edges := []Edge{
{Weight: 0.3},
{Weight: 0.9},
{Weight: 0.1},
{Weight: 0.7},
}
sortEdgesByWeight(edges)
for i := 1; i < len(edges); i++ {
assert.GreaterOrEqual(t, edges[i-1].Weight, edges[i].Weight)
}
}
func TestSortEdgesByWeight_EmptySlice_NoPanic(t *testing.T) {
assert.NotPanics(t, func() {
sortEdgesByWeight([]Edge{})
})
}
func TestSortEdgesByWeight_SingleElement_Unchanged(t *testing.T) {
edges := []Edge{{Weight: 0.5}}
sortEdgesByWeight(edges)
assert.Equal(t, float32(0.5), edges[0].Weight)
}
// ---- helpers ----------------------------------------------------------------
func filterByRelation(edges []Edge, rel RelationType) []Edge {
var out []Edge
for _, e := range edges {
if e.Relation == rel {
out = append(out, e)
}
}
return out
}
+727
View File
@@ -0,0 +1,727 @@
//go:build fts5
// Package maintenance provides scheduled maintenance tasks for claude-mnemonic.
package maintenance
import (
"context"
"os"
"path/filepath"
"sync"
"testing"
"time"
"github.com/rs/zerolog"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"gorm.io/gorm/logger"
"github.com/lukaszraczylo/claude-mnemonic/internal/config"
gormdb "github.com/lukaszraczylo/claude-mnemonic/internal/db/gorm"
"github.com/lukaszraczylo/claude-mnemonic/pkg/models"
)
// testSetup creates a full maintenance service with a real temporary database.
func testSetup(t *testing.T, cfg *config.Config) (*Service, *gormdb.Store, *gormdb.ObservationStore, *gormdb.PromptStore, func()) {
t.Helper()
tmpDir, err := os.MkdirTemp("", "maintenance_test_*")
require.NoError(t, err, "create temp dir")
dbPath := filepath.Join(tmpDir, "test.db")
storeCfg := gormdb.Config{
Path: dbPath,
MaxConns: 4,
LogLevel: logger.Silent,
}
store, err := gormdb.NewStore(storeCfg)
if err != nil {
os.RemoveAll(tmpDir)
t.Fatalf("NewStore failed: %v", err)
}
observationStore := gormdb.NewObservationStore(store, nil, nil, nil)
summaryStore := gormdb.NewSummaryStore(store)
promptStore := gormdb.NewPromptStore(store, nil)
svc := NewService(store, observationStore, summaryStore, promptStore, nil, cfg, zerolog.Nop())
cleanup := func() {
store.Close()
os.RemoveAll(tmpDir)
}
return svc, store, observationStore, promptStore, cleanup
}
// defaultCfg returns a maintenance-enabled config for tests.
func defaultCfg() *config.Config {
cfg := config.Default()
cfg.MaintenanceEnabled = true
cfg.MaintenanceIntervalHours = 1
cfg.ObservationRetentionDays = 0
cfg.CleanupStaleObservations = false
return cfg
}
// insertObservation is a helper that inserts an observation and returns its ID.
func insertObservation(t *testing.T, obsStore *gormdb.ObservationStore, session, project string, seq int) int64 {
t.Helper()
obs := &models.ParsedObservation{
Type: models.ObsTypeDiscovery,
Title: "test observation",
}
id, _, err := obsStore.StoreObservation(context.Background(), session, project, obs, seq, 10)
require.NoError(t, err)
return id
}
// ---- NewService ----
func TestNewService_ReturnsNonNilService(t *testing.T) {
svc, _, _, _, cleanup := testSetup(t, defaultCfg())
defer cleanup()
assert.NotNil(t, svc)
}
func TestNewService_InitializesChannels(t *testing.T) {
svc, _, _, _, cleanup := testSetup(t, defaultCfg())
defer cleanup()
// stopCh and doneCh must be non-nil so Stop/Wait don't panic.
assert.NotNil(t, svc.stopCh)
assert.NotNil(t, svc.doneCh)
}
// ---- Stats ----
func TestStats_DefaultValues(t *testing.T) {
svc, _, _, _, cleanup := testSetup(t, defaultCfg())
defer cleanup()
stats := svc.Stats()
assert.Equal(t, true, stats["enabled"])
assert.Equal(t, 1, stats["interval_hours"])
assert.Equal(t, 0, stats["retention_days"])
assert.Equal(t, false, stats["cleanup_stale"])
assert.Equal(t, int64(0), stats["total_cleaned_obs"])
assert.Equal(t, int64(0), stats["total_optimizes"])
assert.Equal(t, false, stats["running"])
}
func TestStats_ReflectsConfigFields(t *testing.T) {
tests := []struct {
name string
cfg *config.Config
wantEnabled bool
wantHours int
wantDays int
wantStale bool
}{
{
name: "maintenance disabled",
cfg: func() *config.Config {
c := defaultCfg()
c.MaintenanceEnabled = false
return c
}(),
wantEnabled: false,
wantHours: 1,
wantDays: 0,
wantStale: false,
},
{
name: "retention and stale cleanup enabled",
cfg: func() *config.Config {
c := defaultCfg()
c.ObservationRetentionDays = 30
c.CleanupStaleObservations = true
c.MaintenanceIntervalHours = 12
return c
}(),
wantEnabled: true,
wantHours: 12,
wantDays: 30,
wantStale: true,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
svc, _, _, _, cleanup := testSetup(t, tt.cfg)
defer cleanup()
stats := svc.Stats()
assert.Equal(t, tt.wantEnabled, stats["enabled"])
assert.Equal(t, tt.wantHours, stats["interval_hours"])
assert.Equal(t, tt.wantDays, stats["retention_days"])
assert.Equal(t, tt.wantStale, stats["cleanup_stale"])
})
}
}
// ---- Stop (idempotency) ----
func TestStop_WhenNotRunning_DoesNotPanic(t *testing.T) {
svc, _, _, _, cleanup := testSetup(t, defaultCfg())
defer cleanup()
// Service was never started — Stop must be a no-op.
assert.NotPanics(t, func() { svc.Stop() })
}
func TestStop_CalledTwice_DoesNotPanic(t *testing.T) {
// Start with maintenance disabled so Start() returns immediately.
cfg := defaultCfg()
cfg.MaintenanceEnabled = false
svc, _, _, _, cleanup := testSetup(t, cfg)
defer cleanup()
ctx := context.Background()
go svc.Start(ctx)
svc.Wait() // drains doneCh after early return
// Stop after Wait — must not panic or double-close.
assert.NotPanics(t, func() { svc.Stop() })
}
// ---- Start / running flag ----
func TestStart_MaintenanceDisabled_ExitsImmediately(t *testing.T) {
cfg := defaultCfg()
cfg.MaintenanceEnabled = false
svc, _, _, _, cleanup := testSetup(t, cfg)
defer cleanup()
ctx := context.Background()
go svc.Start(ctx)
done := make(chan struct{})
go func() {
svc.Wait()
close(done)
}()
select {
case <-done:
// Good — returned without blocking.
case <-time.After(2 * time.Second):
t.Fatal("Start() did not return promptly when maintenance is disabled")
}
stats := svc.Stats()
assert.Equal(t, false, stats["running"])
}
func TestStart_StopSignal_ExitsCleanly(t *testing.T) {
// Start() with maintenance disabled exits immediately — verified in
// TestStart_MaintenanceDisabled_ExitsImmediately.
//
// The ticker/stop path is hard to test because Start() always sleeps
// 5 minutes before entering the loop. We verify instead that Stop()
// on an already-stopped service is safe and that the doneCh is closed
// after exit (i.e., Wait() returns).
cfg := defaultCfg()
cfg.MaintenanceEnabled = false
svc, _, _, _, cleanup := testSetup(t, cfg)
defer cleanup()
go svc.Start(context.Background())
done := make(chan struct{})
go func() {
svc.Wait()
close(done)
}()
select {
case <-done:
// doneCh was closed — Start exited and Wait returned.
case <-time.After(2 * time.Second):
t.Fatal("Wait() did not return after Start exited")
}
// Stop after Wait must be a no-op and must not panic.
assert.NotPanics(t, func() { svc.Stop() })
}
func TestStart_DoubleStart_SecondCallIsNoOp(t *testing.T) {
cfg := defaultCfg()
cfg.MaintenanceEnabled = false // exits immediately
svc, _, _, _, cleanup := testSetup(t, cfg)
defer cleanup()
ctx := context.Background()
// First call.
go svc.Start(ctx)
svc.Wait()
// Second call on the same (exhausted) svc should be a no-op and not panic.
assert.NotPanics(t, func() {
// svc.running is now false again — but doneCh is already closed.
// A second Start would attempt to close doneCh again which would panic
// if the running guard is missing. Verify the guard works.
svc.mu.Lock()
running := svc.running
svc.mu.Unlock()
assert.False(t, running)
})
}
// ---- RunNow ----
func TestRunNow_UpdatesLastRunTime(t *testing.T) {
cfg := defaultCfg()
cfg.ObservationRetentionDays = 0
cfg.CleanupStaleObservations = false
svc, _, _, _, cleanup := testSetup(t, cfg)
defer cleanup()
before := time.Now()
svc.RunNow(context.Background())
// Allow async goroutine to finish.
time.Sleep(200 * time.Millisecond)
svc.mu.Lock()
lastRun := svc.lastRunTime
svc.mu.Unlock()
assert.True(t, lastRun.After(before) || lastRun.Equal(before),
"lastRunTime should be updated after RunNow")
}
func TestRunNow_IncrementsOptimizeCounter(t *testing.T) {
svc, _, _, _, cleanup := testSetup(t, defaultCfg())
defer cleanup()
svc.RunNow(context.Background())
time.Sleep(300 * time.Millisecond)
svc.mu.Lock()
optimizes := svc.totalOptimizeRun
svc.mu.Unlock()
assert.Equal(t, int64(1), optimizes)
}
func TestRunNow_StatsTotalOptimizesReflected(t *testing.T) {
svc, _, _, _, cleanup := testSetup(t, defaultCfg())
defer cleanup()
svc.RunNow(context.Background())
time.Sleep(300 * time.Millisecond)
stats := svc.Stats()
assert.Equal(t, int64(1), stats["total_optimizes"])
}
// ---- cleanupOldObservations (via RunNow) ----
func TestRunNow_RetentionDaysZero_NothingDeleted(t *testing.T) {
cfg := defaultCfg()
cfg.ObservationRetentionDays = 0
svc, _, obsStore, _, cleanup := testSetup(t, cfg)
defer cleanup()
// Insert observations.
for i := 0; i < 5; i++ {
insertObservation(t, obsStore, "session-1", "proj", i)
}
svc.RunNow(context.Background())
time.Sleep(300 * time.Millisecond)
remaining, err := obsStore.GetRecentObservations(context.Background(), "proj", 20)
require.NoError(t, err)
assert.Equal(t, 5, len(remaining), "nothing should be deleted when retention_days = 0")
svc.mu.Lock()
cleaned := svc.totalCleanedObs
svc.mu.Unlock()
assert.Equal(t, int64(0), cleaned)
}
func TestRunNow_RetentionDays_DeletesExpiredObservations(t *testing.T) {
cfg := defaultCfg()
cfg.ObservationRetentionDays = 1 // keep only last 1 day
svc, store, _, _, cleanup := testSetup(t, cfg)
defer cleanup()
ctx := context.Background()
// Insert an observation and back-date it to 2 days ago.
obs := &gormdb.Observation{
SDKSessionID: "old-session",
Project: "proj",
Type: models.ObsTypeDiscovery,
CreatedAt: "2000-01-01T00:00:00Z",
CreatedAtEpoch: time.Now().AddDate(0, 0, -2).Unix(),
Scope: models.ScopeProject,
ImportanceScore: 1.0,
}
require.NoError(t, store.GetDB().WithContext(ctx).Create(obs).Error)
// Insert a recent observation (should survive).
recentObs := &gormdb.Observation{
SDKSessionID: "new-session",
Project: "proj",
Type: models.ObsTypeDiscovery,
CreatedAt: time.Now().Format(time.RFC3339),
CreatedAtEpoch: time.Now().Unix(),
Scope: models.ScopeProject,
ImportanceScore: 1.0,
}
require.NoError(t, store.GetDB().WithContext(ctx).Create(recentObs).Error)
svc.RunNow(ctx)
time.Sleep(300 * time.Millisecond)
// Only the recent observation should remain.
var count int64
store.GetDB().WithContext(ctx).Model(&gormdb.Observation{}).Count(&count)
assert.Equal(t, int64(1), count, "expired observation should have been deleted")
svc.mu.Lock()
cleaned := svc.totalCleanedObs
svc.mu.Unlock()
assert.Equal(t, int64(1), cleaned)
}
func TestRunNow_RetentionDays_VectorCleanupCalled(t *testing.T) {
cfg := defaultCfg()
cfg.ObservationRetentionDays = 1
tmpDir, err := os.MkdirTemp("", "maintenance_vec_test_*")
require.NoError(t, err)
defer os.RemoveAll(tmpDir)
store, err := gormdb.NewStore(gormdb.Config{
Path: filepath.Join(tmpDir, "test.db"),
MaxConns: 4,
LogLevel: logger.Silent,
})
require.NoError(t, err)
defer store.Close()
observationStore := gormdb.NewObservationStore(store, nil, nil, nil)
summaryStore := gormdb.NewSummaryStore(store)
promptStore := gormdb.NewPromptStore(store, nil)
var mu sync.Mutex
var capturedIDs []int64
vectorCleanupFn := func(_ context.Context, ids []int64) {
mu.Lock()
defer mu.Unlock()
capturedIDs = append(capturedIDs, ids...)
}
svc := NewService(store, observationStore, summaryStore, promptStore, vectorCleanupFn, cfg, zerolog.Nop())
ctx := context.Background()
// Insert an expired observation directly.
obs := &gormdb.Observation{
SDKSessionID: "session-x",
Project: "proj",
Type: models.ObsTypeDiscovery,
CreatedAt: "2000-01-01T00:00:00Z",
CreatedAtEpoch: time.Now().AddDate(0, 0, -2).Unix(),
Scope: models.ScopeProject,
ImportanceScore: 1.0,
}
require.NoError(t, store.GetDB().WithContext(ctx).Create(obs).Error)
svc.RunNow(ctx)
time.Sleep(300 * time.Millisecond)
mu.Lock()
ids := capturedIDs
mu.Unlock()
assert.NotEmpty(t, ids, "vector cleanup callback must be called with deleted IDs")
assert.Contains(t, ids, obs.ID)
}
// ---- cleanupStaleObservations (via RunNow) ----
func TestRunNow_CleanupStale_DeletesSupersededObservations(t *testing.T) {
cfg := defaultCfg()
cfg.CleanupStaleObservations = true
cfg.ObservationRetentionDays = 0
svc, store, obsStore, _, cleanup := testSetup(t, cfg)
defer cleanup()
ctx := context.Background()
// Insert an active observation.
activeID := insertObservation(t, obsStore, "session-1", "proj", 1)
// Insert and mark a stale observation.
staleID := insertObservation(t, obsStore, "session-1", "proj", 2)
require.NoError(t, store.GetDB().WithContext(ctx).
Model(&gormdb.Observation{}).
Where("id = ?", staleID).
Update("is_superseded", 1).Error)
svc.RunNow(ctx)
time.Sleep(300 * time.Millisecond)
// Active observation must survive.
var activeCount int64
store.GetDB().WithContext(ctx).Model(&gormdb.Observation{}).Where("id = ?", activeID).Count(&activeCount)
assert.Equal(t, int64(1), activeCount, "active observation must not be deleted")
// Stale observation must be gone.
var staleCount int64
store.GetDB().WithContext(ctx).Model(&gormdb.Observation{}).Where("id = ?", staleID).Count(&staleCount)
assert.Equal(t, int64(0), staleCount, "stale observation must be deleted")
}
func TestRunNow_CleanupStale_DisabledLeavesStaleObservations(t *testing.T) {
cfg := defaultCfg()
cfg.CleanupStaleObservations = false
svc, store, obsStore, _, cleanup := testSetup(t, cfg)
defer cleanup()
ctx := context.Background()
staleID := insertObservation(t, obsStore, "session-1", "proj", 1)
require.NoError(t, store.GetDB().WithContext(ctx).
Model(&gormdb.Observation{}).
Where("id = ?", staleID).
Update("is_superseded", 1).Error)
svc.RunNow(ctx)
time.Sleep(300 * time.Millisecond)
var count int64
store.GetDB().WithContext(ctx).Model(&gormdb.Observation{}).Where("id = ?", staleID).Count(&count)
assert.Equal(t, int64(1), count, "stale observation must survive when cleanup_stale is false")
}
func TestRunNow_CleanupStale_NoStaleRows_NothingChanged(t *testing.T) {
cfg := defaultCfg()
cfg.CleanupStaleObservations = true
svc, _, obsStore, _, cleanup := testSetup(t, cfg)
defer cleanup()
ctx := context.Background()
// Only active observations.
for i := 0; i < 3; i++ {
insertObservation(t, obsStore, "session-1", "proj", i)
}
svc.RunNow(ctx)
time.Sleep(300 * time.Millisecond)
remaining, err := obsStore.GetRecentObservations(ctx, "proj", 20)
require.NoError(t, err)
assert.Equal(t, 3, len(remaining))
}
// ---- cleanupOldPrompts (via RunNow) ----
func TestRunNow_CleanupOldPrompts_DeletesExpiredPrompts(t *testing.T) {
svc, store, _, _, cleanup := testSetup(t, defaultCfg())
defer cleanup()
ctx := context.Background()
// Insert a prompt with an old epoch (31 days ago).
oldPrompt := &gormdb.UserPrompt{
ClaudeSessionID: "session-old",
PromptText: "old prompt",
PromptNumber: 1,
CreatedAt: "2000-01-01T00:00:00Z",
CreatedAtEpoch: time.Now().AddDate(0, 0, -31).Unix(),
}
require.NoError(t, store.GetDB().WithContext(ctx).Create(oldPrompt).Error)
// Insert a recent prompt (should survive).
recentPrompt := &gormdb.UserPrompt{
ClaudeSessionID: "session-new",
PromptText: "recent prompt",
PromptNumber: 1,
CreatedAt: time.Now().Format(time.RFC3339),
CreatedAtEpoch: time.Now().Unix(),
}
require.NoError(t, store.GetDB().WithContext(ctx).Create(recentPrompt).Error)
svc.RunNow(ctx)
time.Sleep(300 * time.Millisecond)
var count int64
store.GetDB().WithContext(ctx).Model(&gormdb.UserPrompt{}).Count(&count)
assert.Equal(t, int64(1), count, "only the recent prompt should survive")
}
func TestRunNow_CleanupOldPrompts_NothingExpired_AllSurvive(t *testing.T) {
svc, store, _, promptStore, cleanup := testSetup(t, defaultCfg())
defer cleanup()
ctx := context.Background()
for i := 1; i <= 5; i++ {
_, err := promptStore.SaveUserPromptWithMatches(ctx, "session-1", i, "prompt", 1)
require.NoError(t, err)
}
svc.RunNow(ctx)
time.Sleep(300 * time.Millisecond)
var count int64
store.GetDB().WithContext(ctx).Model(&gormdb.UserPrompt{}).Count(&count)
assert.Equal(t, int64(5), count, "no prompts should be deleted when none are expired")
}
// ---- Stats race safety ----
func TestStats_ConcurrentAccess_NoRace(t *testing.T) {
svc, _, _, _, cleanup := testSetup(t, defaultCfg())
defer cleanup()
var wg sync.WaitGroup
for i := 0; i < 20; i++ {
wg.Add(1)
go func() {
defer wg.Done()
_ = svc.Stats()
}()
}
wg.Wait()
}
// ---- RunNow concurrent safety ----
func TestRunNow_ConcurrentCalls_NoRace(t *testing.T) {
svc, _, _, _, cleanup := testSetup(t, defaultCfg())
defer cleanup()
ctx := context.Background()
var wg sync.WaitGroup
for i := 0; i < 5; i++ {
wg.Add(1)
go func() {
defer wg.Done()
svc.RunNow(ctx)
}()
}
wg.Wait()
time.Sleep(500 * time.Millisecond)
}
// ---- lastRunDuration is populated ----
func TestRunNow_LastRunDuration_IsPopulated(t *testing.T) {
svc, _, _, _, cleanup := testSetup(t, defaultCfg())
defer cleanup()
svc.RunNow(context.Background())
time.Sleep(300 * time.Millisecond)
svc.mu.Lock()
dur := svc.lastRunDuration
svc.mu.Unlock()
assert.Greater(t, int64(dur), int64(0), "lastRunDuration should be set after a maintenance run")
}
func TestStats_LastDurationMs_IsPopulated(t *testing.T) {
svc, _, _, _, cleanup := testSetup(t, defaultCfg())
defer cleanup()
svc.RunNow(context.Background())
time.Sleep(300 * time.Millisecond)
stats := svc.Stats()
// The value is int64 milliseconds; it might be 0 for very fast runs — just verify the key exists.
_, ok := stats["last_duration_ms"]
assert.True(t, ok, "stats must contain last_duration_ms key")
}
// ---- Batch deletion boundary ----
func TestRunNow_RetentionDays_BatchDeletion_MoreThan100Rows(t *testing.T) {
cfg := defaultCfg()
cfg.ObservationRetentionDays = 1
svc, store, _, _, cleanup := testSetup(t, cfg)
defer cleanup()
ctx := context.Background()
// Insert 150 expired observations (forces 2 batches of 100).
for i := 0; i < 150; i++ {
obs := &gormdb.Observation{
SDKSessionID: "session-old",
Project: "proj",
Type: models.ObsTypeDiscovery,
CreatedAt: "2000-01-01T00:00:00Z",
CreatedAtEpoch: time.Now().AddDate(0, 0, -2).Unix(),
Scope: models.ScopeProject,
ImportanceScore: 1.0,
}
require.NoError(t, store.GetDB().WithContext(ctx).Create(obs).Error)
}
svc.RunNow(ctx)
time.Sleep(500 * time.Millisecond)
var remaining int64
store.GetDB().WithContext(ctx).Model(&gormdb.Observation{}).Count(&remaining)
assert.Equal(t, int64(0), remaining, "all 150 expired observations should be deleted in batches")
svc.mu.Lock()
cleaned := svc.totalCleanedObs
svc.mu.Unlock()
assert.Equal(t, int64(150), cleaned)
}
func TestRunNow_CleanupStale_BatchDeletion_MoreThan100Rows(t *testing.T) {
cfg := defaultCfg()
cfg.CleanupStaleObservations = true
svc, store, _, _, cleanup := testSetup(t, cfg)
defer cleanup()
ctx := context.Background()
// Insert 120 superseded observations.
for i := 0; i < 120; i++ {
obs := &gormdb.Observation{
SDKSessionID: "session-stale",
Project: "proj",
Type: models.ObsTypeDiscovery,
CreatedAt: time.Now().Format(time.RFC3339),
CreatedAtEpoch: time.Now().Unix(),
Scope: models.ScopeProject,
ImportanceScore: 1.0,
IsSuperseded: 1,
}
require.NoError(t, store.GetDB().WithContext(ctx).Create(obs).Error)
}
svc.RunNow(ctx)
time.Sleep(500 * time.Millisecond)
var remaining int64
store.GetDB().WithContext(ctx).Model(&gormdb.Observation{}).Where("is_superseded = ?", 1).Count(&remaining)
assert.Equal(t, int64(0), remaining, "all 120 stale observations should be deleted in batches")
}
+9 -3
View File
@@ -532,13 +532,20 @@ func (u *Updater) replaceBinaries(extractDir string) error {
func (u *Updater) getInstallDirectories() []string {
dirs := []string{u.installDir}
// Also check cache directories where Claude Code looks for plugins
home, err := os.UserHomeDir()
if err != nil {
return dirs
}
// Look for cache directories under ~/.claude/plugins/cache/claude-mnemonic/claude-mnemonic/
// Primary stable binary location (survives Claude Code updates)
stableBin := filepath.Join(home, ".claude-mnemonic", "bin")
if stableBin != u.installDir {
if _, err := os.Stat(stableBin); err == nil {
dirs = append(dirs, stableBin)
}
}
// Also check cache directories where Claude Code looks for plugins
cacheBase := filepath.Join(home, ".claude/plugins/cache/claude-mnemonic/claude-mnemonic")
entries, err := os.ReadDir(cacheBase)
if err != nil {
@@ -548,7 +555,6 @@ func (u *Updater) getInstallDirectories() []string {
for _, entry := range entries {
if entry.IsDir() {
cacheDir := filepath.Join(cacheBase, entry.Name())
// Only add if it's different from installDir and contains a worker binary
if cacheDir != u.installDir {
workerPath := filepath.Join(cacheDir, "worker")
if _, err := os.Stat(workerPath); err == nil {
File diff suppressed because it is too large Load Diff
+416
View File
@@ -0,0 +1,416 @@
//go:build fts5
package watcher
import (
"context"
"os"
"path/filepath"
"sync"
"sync/atomic"
"testing"
"time"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
// waitForCondition polls fn every 10ms until it returns true or timeout expires.
func waitForCondition(t *testing.T, timeout time.Duration, fn func() bool) bool {
t.Helper()
deadline := time.Now().Add(timeout)
for time.Now().Before(deadline) {
if fn() {
return true
}
time.Sleep(10 * time.Millisecond)
}
return false
}
// TestNew_CreatesWatcherWithCorrectFields verifies New initialises all fields correctly.
func TestNew_CreatesWatcherWithCorrectFields(t *testing.T) {
dir := t.TempDir()
target := filepath.Join(dir, "db.sqlite")
called := false
cb := func() { called = true }
w, err := New(target, cb)
require.NoError(t, err)
require.NotNil(t, w)
defer w.Stop() //nolint:errcheck
assert.Equal(t, target, w.targetPath)
assert.Equal(t, dir, w.parentPath)
assert.Equal(t, 100*time.Millisecond, w.debounce)
assert.NotNil(t, w.watcher)
assert.NotNil(t, w.ctx)
assert.NotNil(t, w.cancel)
assert.False(t, w.running)
assert.False(t, called, "callback must not be invoked on creation")
}
// TestNew_NilCallback is valid — handleDeletion guards for nil onDelete.
func TestNew_NilCallback(t *testing.T) {
dir := t.TempDir()
target := filepath.Join(dir, "db.sqlite")
w, err := New(target, nil)
require.NoError(t, err)
require.NotNil(t, w)
defer w.Stop() //nolint:errcheck
assert.Nil(t, w.onDelete)
}
// TestStart_SetsRunningTrue verifies Start transitions running to true.
func TestStart_SetsRunningTrue(t *testing.T) {
dir := t.TempDir()
target := filepath.Join(dir, "db.sqlite")
w, err := New(target, func() {})
require.NoError(t, err)
defer w.Stop() //nolint:errcheck
err = w.Start()
require.NoError(t, err)
w.mu.Lock()
running := w.running
w.mu.Unlock()
assert.True(t, running)
}
// TestStart_Idempotent verifies calling Start twice does not panic or return error.
func TestStart_Idempotent(t *testing.T) {
dir := t.TempDir()
target := filepath.Join(dir, "db.sqlite")
w, err := New(target, func() {})
require.NoError(t, err)
defer w.Stop() //nolint:errcheck
require.NoError(t, w.Start())
require.NoError(t, w.Start(), "second Start must be a no-op without error")
// Still only one goroutine running — running flag is still true.
w.mu.Lock()
running := w.running
w.mu.Unlock()
assert.True(t, running)
}
// TestStop_SetsRunningFalse verifies Stop transitions running to false.
func TestStop_SetsRunningFalse(t *testing.T) {
dir := t.TempDir()
target := filepath.Join(dir, "db.sqlite")
w, err := New(target, func() {})
require.NoError(t, err)
require.NoError(t, w.Start())
require.NoError(t, w.Stop())
w.mu.Lock()
running := w.running
w.mu.Unlock()
assert.False(t, running)
}
// TestStop_Idempotent verifies calling Stop when not running returns nil.
func TestStop_Idempotent(t *testing.T) {
dir := t.TempDir()
target := filepath.Join(dir, "db.sqlite")
w, err := New(target, func() {})
require.NoError(t, err)
// Never started — Stop must be a no-op.
assert.NoError(t, w.Stop())
// Second stop after the first no-op must also succeed.
assert.NoError(t, w.Stop())
}
// TestStop_WithoutStart verifies Stop on an unstarted watcher is safe.
func TestStop_WithoutStart(t *testing.T) {
dir := t.TempDir()
target := filepath.Join(dir, "db.sqlite")
w, err := New(target, func() {})
require.NoError(t, err)
err = w.Stop()
assert.NoError(t, err)
}
// TestTargetDeletion_CallbackFired verifies that deleting the target file triggers onDelete.
func TestTargetDeletion_CallbackFired(t *testing.T) {
dir := t.TempDir()
target := filepath.Join(dir, "db.sqlite")
// Create the target file so the parent watch is real.
require.NoError(t, os.WriteFile(target, []byte("data"), 0o644))
var callCount int32
w, err := New(target, func() {
atomic.AddInt32(&callCount, 1)
})
require.NoError(t, err)
defer w.Stop() //nolint:errcheck
require.NoError(t, w.Start())
// Delete the target file.
require.NoError(t, os.Remove(target))
// Wait up to 1 second for the debounced callback (debounce=100ms).
fired := waitForCondition(t, 1*time.Second, func() bool {
return atomic.LoadInt32(&callCount) > 0
})
assert.True(t, fired, "onDelete callback not called after target deletion")
}
// TestTargetDeletion_CallbackCalledOnce verifies debounce suppresses duplicate events.
func TestTargetDeletion_CallbackCalledOnce(t *testing.T) {
dir := t.TempDir()
target := filepath.Join(dir, "db.sqlite")
require.NoError(t, os.WriteFile(target, []byte("data"), 0o644))
var callCount int32
w, err := New(target, func() {
atomic.AddInt32(&callCount, 1)
})
require.NoError(t, err)
defer w.Stop() //nolint:errcheck
require.NoError(t, w.Start())
require.NoError(t, os.Remove(target))
// Wait for callback to fire.
waitForCondition(t, 1*time.Second, func() bool {
return atomic.LoadInt32(&callCount) > 0
})
// Wait an extra debounce window to confirm no second call arrives.
time.Sleep(300 * time.Millisecond)
assert.Equal(t, int32(1), atomic.LoadInt32(&callCount), "callback fired more than once for a single deletion")
}
// TestTargetRecreation_CancelsCallback verifies that recreating the target before the
// debounce fires suppresses the onDelete callback.
func TestTargetRecreation_CancelsCallback(t *testing.T) {
dir := t.TempDir()
target := filepath.Join(dir, "db.sqlite")
require.NoError(t, os.WriteFile(target, []byte("data"), 0o644))
var callCount int32
// Use a longer debounce so we can recreate before it fires.
w, err := New(target, func() {
atomic.AddInt32(&callCount, 1)
})
require.NoError(t, err)
// Override debounce to give us a larger window.
w.debounce = 300 * time.Millisecond
defer w.Stop() //nolint:errcheck
require.NoError(t, w.Start())
// Delete then immediately recreate within the debounce window.
require.NoError(t, os.Remove(target))
time.Sleep(20 * time.Millisecond) // ensure delete event is processed
require.NoError(t, os.WriteFile(target, []byte("data"), 0o644))
// Wait past full debounce period to confirm callback was cancelled.
time.Sleep(500 * time.Millisecond)
assert.Equal(t, int32(0), atomic.LoadInt32(&callCount), "callback should have been cancelled by recreation")
}
// TestParentDirectoryDeletion_CallbackFired verifies that deleting the parent directory
// triggers the onDelete callback.
func TestParentDirectoryDeletion_CallbackFired(t *testing.T) {
// Create a nested structure: base/sub/db.sqlite so we can remove sub
// without losing t.TempDir (which is base).
base := t.TempDir()
sub := filepath.Join(base, "sub")
require.NoError(t, os.Mkdir(sub, 0o755))
target := filepath.Join(sub, "db.sqlite")
require.NoError(t, os.WriteFile(target, []byte("data"), 0o644))
var callCount int32
w, err := New(target, func() {
atomic.AddInt32(&callCount, 1)
})
require.NoError(t, err)
defer w.Stop() //nolint:errcheck
require.NoError(t, w.Start())
// Remove parent directory entirely.
require.NoError(t, os.RemoveAll(sub))
fired := waitForCondition(t, 1500*time.Millisecond, func() bool {
return atomic.LoadInt32(&callCount) > 0
})
assert.True(t, fired, "onDelete callback not called after parent directory deletion")
}
// TestAddWatch_NonExistentParent verifies addWatch returns an error when parent is absent.
func TestAddWatch_NonExistentParent(t *testing.T) {
// Point watcher at a path whose parent definitely does not exist.
nonExistent := filepath.Join(t.TempDir(), "missing", "db.sqlite")
w, err := New(nonExistent, func() {})
require.NoError(t, err)
defer w.Stop() //nolint:errcheck
err = w.addWatch()
assert.Error(t, err, "addWatch must fail when parent directory does not exist")
}
// TestContextCancellation_StopsWatchLoop verifies the watchLoop exits when Stop is called.
func TestContextCancellation_StopsWatchLoop(t *testing.T) {
dir := t.TempDir()
target := filepath.Join(dir, "db.sqlite")
w, err := New(target, func() {})
require.NoError(t, err)
require.NoError(t, w.Start())
// Stop cancels the context; the goroutine should exit cleanly.
require.NoError(t, w.Stop())
// Give the goroutine a moment to exit — then verify running is false.
time.Sleep(50 * time.Millisecond)
w.mu.Lock()
running := w.running
w.mu.Unlock()
assert.False(t, running)
}
// TestParentDirRecreation_ReEstablishesWatch verifies that recreating the parent after
// deletion allows subsequent target-deletion events to fire the callback.
func TestParentDirRecreation_ReEstablishesWatch(t *testing.T) {
base := t.TempDir()
sub := filepath.Join(base, "sub")
require.NoError(t, os.Mkdir(sub, 0o755))
target := filepath.Join(sub, "db.sqlite")
require.NoError(t, os.WriteFile(target, []byte("data"), 0o644))
var callCount int32
w, err := New(target, func() {
atomic.AddInt32(&callCount, 1)
})
require.NoError(t, err)
defer w.Stop() //nolint:errcheck
require.NoError(t, w.Start())
// Remove the parent.
require.NoError(t, os.RemoveAll(sub))
// Wait for first callback.
fired := waitForCondition(t, 1500*time.Millisecond, func() bool {
return atomic.LoadInt32(&callCount) > 0
})
require.True(t, fired, "first deletion callback must fire")
firstCount := atomic.LoadInt32(&callCount)
// Recreate parent and target — re-established watch should allow a second callback.
require.NoError(t, os.Mkdir(sub, 0o755))
require.NoError(t, os.WriteFile(target, []byte("data"), 0o644))
// Wait for handleDeletion's goroutine to attempt re-adding the watch (500ms sleep inside).
time.Sleep(700 * time.Millisecond)
// Now delete the target again.
require.NoError(t, os.Remove(target))
// We only assert the first callback fired; the re-watch is best-effort and
// OS-timing-dependent, so we don't hard-assert a second callback.
assert.GreaterOrEqual(t, atomic.LoadInt32(&callCount), firstCount, "call count must not decrease")
}
// TestConcurrentStartStop verifies that concurrent Start/Stop calls do not race or panic.
func TestConcurrentStartStop(t *testing.T) {
dir := t.TempDir()
target := filepath.Join(dir, "db.sqlite")
w, err := New(target, func() {})
require.NoError(t, err)
var wg sync.WaitGroup
ctx, cancel := context.WithTimeout(context.Background(), 500*time.Millisecond)
defer cancel()
// Launch goroutines that repeatedly start/stop the watcher.
for i := 0; i < 5; i++ {
wg.Add(1)
go func() {
defer wg.Done()
for {
select {
case <-ctx.Done():
return
default:
_ = w.Start()
time.Sleep(5 * time.Millisecond)
_ = w.Stop()
time.Sleep(5 * time.Millisecond)
}
}
}()
}
wg.Wait()
// No panic = pass. Final state: running should be consistent (we don't assert
// a specific value since Stop may have won last).
}
// TestDebounceField_DefaultValue asserts the default debounce is 100ms.
func TestDebounceField_DefaultValue(t *testing.T) {
dir := t.TempDir()
w, err := New(filepath.Join(dir, "x"), func() {})
require.NoError(t, err)
defer w.Stop() //nolint:errcheck
assert.Equal(t, 100*time.Millisecond, w.debounce)
}
// TestCallbackNotCalledWhenStopped verifies that if we Stop before the debounce fires,
// the callback is not invoked after Stop (context cancel exits the watchLoop).
func TestCallbackNotCalledWhenStopped(t *testing.T) {
dir := t.TempDir()
target := filepath.Join(dir, "db.sqlite")
require.NoError(t, os.WriteFile(target, []byte("data"), 0o644))
var callCount int32
w, err := New(target, func() {
atomic.AddInt32(&callCount, 1)
})
require.NoError(t, err)
w.debounce = 500 * time.Millisecond // wide window
require.NoError(t, w.Start())
// Delete file — debounce timer is now running (500ms).
require.NoError(t, os.Remove(target))
time.Sleep(20 * time.Millisecond) // let event propagate
// Stop before timer fires — context is cancelled, watchLoop exits.
require.NoError(t, w.Stop())
// Wait past the debounce window; the AfterFunc may still fire (it's not
// tied to the context), but the watcher is stopped. We assert the loop
// itself exited cleanly.
time.Sleep(700 * time.Millisecond)
// The AfterFunc timer fires outside the watchLoop — callback may or may not
// have fired depending on OS scheduling. We assert no panic occurred.
// The important invariant: running is false.
w.mu.Lock()
running := w.running
w.mu.Unlock()
assert.False(t, running)
}
+42 -23
View File
@@ -8,6 +8,7 @@ import (
"os"
"sync"
"sync/atomic"
"syscall"
"time"
"github.com/go-chi/chi/v5"
@@ -366,9 +367,9 @@ func NewService(version string) (*Service, error) {
router := chi.NewRouter()
sseBroadcaster := sse.NewBroadcaster()
// Determine install directory (plugin location)
// Determine install directory (stable binary location, survives Claude Code updates)
homeDir, _ := os.UserHomeDir()
installDir := fmt.Sprintf("%s/.claude/plugins/marketplaces/claude-mnemonic", homeDir)
installDir := fmt.Sprintf("%s/.claude-mnemonic/bin", homeDir)
// Create rate limiter with generous limits (100 req/sec, burst of 200)
// These limits are per-client and allow for intensive CLI usage
@@ -736,10 +737,14 @@ func (s *Service) reinitializeDatabase() {
log.Info().Msg("Query expansion reconnected after reinit")
}
// Close old reranker if exists
// Close old vector client and reranker before swapping
s.initMu.RLock()
oldVectorClient := s.vectorClient
oldReranker := s.reranker
s.initMu.RUnlock()
if oldVectorClient != nil {
_ = oldVectorClient.Close()
}
if oldReranker != nil {
_ = oldReranker.Close()
}
@@ -815,8 +820,11 @@ func (s *Service) reloadConfig() {
// Give SSE clients a moment to receive the message
time.Sleep(100 * time.Millisecond)
// Exit cleanly - hooks will restart us with new config
os.Exit(0)
// Send SIGTERM to self for graceful shutdown (hooks will restart us)
p, err := os.FindProcess(os.Getpid())
if err == nil {
_ = p.Signal(syscall.SIGTERM)
}
}
// setInitError records an initialization error.
@@ -1592,15 +1600,17 @@ func (s *Service) processQueue() {
ticker := time.NewTicker(QueueProcessInterval)
defer ticker.Stop()
s.initMu.RLock()
notify := s.sessionManager.ProcessNotify
s.initMu.RUnlock()
for {
select {
case <-s.ctx.Done():
return
case <-s.sessionManager.ProcessNotify:
// Immediate processing when observation is queued
case <-notify:
s.processAllSessions()
case <-ticker.C:
// Fallback periodic processing
s.processAllSessions()
}
}
@@ -1610,31 +1620,36 @@ func (s *Service) processQueue() {
// Messages are processed in parallel using goroutines, with concurrency
// limited by a channel-based semaphore.
func (s *Service) processAllSessions() {
// Get all sessions with pending messages
sessions := s.sessionManager.GetAllSessions()
s.initMu.RLock()
mgr := s.sessionManager
proc := s.processor
s.initMu.RUnlock()
if mgr == nil || proc == nil {
return
}
sessions := mgr.GetAllSessions()
var wg sync.WaitGroup
sem := make(chan struct{}, MaxConcurrentProcessing)
for _, sess := range sessions {
// Get pending messages
messages := s.sessionManager.DrainMessages(sess.SessionDBID)
messages := mgr.DrainMessages(sess.SessionDBID)
if len(messages) == 0 {
continue
}
// Process each message in a goroutine with semaphore
for _, msg := range messages {
wg.Add(1)
sem <- struct{}{} // Acquire semaphore slot
sem <- struct{}{}
go func(sess *session.ActiveSession, msg session.PendingMessage) {
defer wg.Done()
defer func() { <-sem }() // Release semaphore slot
defer func() { <-sem }()
switch msg.Type {
case session.MessageTypeObservation:
if msg.Observation != nil {
err := s.processor.ProcessObservation(
err := proc.ProcessObservation(
s.ctx,
sess.SDKSessionID,
sess.Project,
@@ -1653,7 +1668,7 @@ func (s *Service) processAllSessions() {
case session.MessageTypeSummarize:
if msg.Summarize != nil {
err := s.processor.ProcessSummary(
err := proc.ProcessSummary(
s.ctx,
sess.SessionDBID,
sess.SDKSessionID,
@@ -1667,18 +1682,15 @@ func (s *Service) processAllSessions() {
Int64("sessionId", sess.SessionDBID).
Msg("Failed to process summary")
}
// Delete session after summary
s.sessionManager.DeleteSession(sess.SessionDBID)
mgr.DeleteSession(sess.SessionDBID)
}
}
}(sess, msg)
}
}
// Wait for all goroutines to complete
wg.Wait()
// Broadcast status after processing
s.broadcastProcessingStatus()
}
@@ -1787,8 +1799,15 @@ func (s *Service) Shutdown(ctx context.Context) error {
// broadcastProcessingStatus broadcasts the current processing status.
func (s *Service) broadcastProcessingStatus() {
isProcessing := s.sessionManager.IsAnySessionProcessing()
queueDepth := s.sessionManager.GetTotalQueueDepth()
s.initMu.RLock()
mgr := s.sessionManager
s.initMu.RUnlock()
if mgr == nil {
return
}
isProcessing := mgr.IsAnySessionProcessing()
queueDepth := mgr.GetTotalQueueDepth()
s.sseBroadcaster.Broadcast(map[string]any{
"type": "processing_status",