diff --git a/.claude-plugin/marketplace.json b/.claude-plugin/marketplace.json new file mode 100644 index 0000000..8d0ef5c --- /dev/null +++ b/.claude-plugin/marketplace.json @@ -0,0 +1,22 @@ +{ + "$schema": "https://anthropic.com/claude-code/marketplace.schema.json", + "name": "claude-mnemonic", + "version": "0.6.38", + "description": "Persistent memory system for Claude Code - stores observations, session summaries, and user prompts with semantic search", + "owner": { + "name": "lukaszraczylo", + "email": "lukaszraczylo@users.noreply.github.com" + }, + "plugins": [ + { + "name": "claude-mnemonic", + "description": "Persistent memory system for Claude Code - Go implementation with SQLite and ChromaDB", + "version": "0.6.38", + "author": { + "name": "lukaszraczylo" + }, + "source": "./", + "category": "productivity" + } + ] +} diff --git a/.claude-plugin/plugin.json b/.claude-plugin/plugin.json new file mode 100644 index 0000000..4149d6f --- /dev/null +++ b/.claude-plugin/plugin.json @@ -0,0 +1,22 @@ +{ + "name": "claude-mnemonic", + "version": "v0.11.57-dirty", + "description": "Persistent memory system for Claude Code with SQLite, FTS5, and vector search", + "author": { + "name": "lukaszraczylo", + "email": "lukaszraczylo@users.noreply.github.com", + "url": "https://github.com/lukaszraczylo" + }, + "homepage": "https://github.com/lukaszraczylo/claude-mnemonic", + "repository": "https://github.com/lukaszraczylo/claude-mnemonic", + "license": "MIT", + "hooks": "./hooks/hooks.json", + "mcpServers": { + "claude-mnemonic": { + "command": "${CLAUDE_PLUGIN_ROOT}/mcp-server", + "args": ["--project", "${CLAUDE_PROJECT}"], + "env": {} + } + }, + "commands": ["./commands/restart.md"] +} diff --git a/.gitignore b/.gitignore index e1f5fe4..697563e 100644 --- a/.gitignore +++ b/.gitignore @@ -82,11 +82,7 @@ logs/ # goreleaser dist/ docs/dist -.claude-plugin -# Auto-generated plugin configs (generated by scripts/generate-plugin-config.sh) -.claude-plugin/ - -# Non-template plugin configs (keep only .tpl files) +# Non-template plugin configs (keep only .tpl files in plugin/ dir) plugin/.claude-plugin/plugin.json plugin/.claude-plugin/marketplace.json diff --git a/Makefile b/Makefile index d710d4c..5e981ee 100644 --- a/Makefile +++ b/Makefile @@ -5,6 +5,7 @@ VERSION := $(shell git describe --tags --always --dirty 2>/dev/null || echo "dev LDFLAGS := -ldflags "-X main.Version=$(VERSION) -X github.com/lukaszraczylo/claude-mnemonic/pkg/hooks.Version=$(VERSION) -s -w" -buildvcs=false BUILD_DIR := bin PLUGIN_DIR := plugin +STABLE_BIN := $(HOME)/.claude-mnemonic/bin # Go settings GOOS ?= $(shell go env GOOS) @@ -22,13 +23,12 @@ all: build setup-libs: @./scripts/download-onnx-libs.sh all -# Update root plugin metadata with current version +# Update version in committed plugin metadata update-version: - @mkdir -p .claude-plugin - @sed 's/{{ .Version }}/$(VERSION)/g; s/{{.Version}}/$(VERSION)/g' $(PLUGIN_DIR)/.claude-plugin/plugin.json.tpl > .claude-plugin/plugin.json - @echo "Updated .claude-plugin/plugin.json to version $(VERSION)" - @# marketplace.json contains release-specific data (URLs, SHA256 hashes) that requires manual update per release. - @# Only the top-level version field is updated here. + @if [ -f .claude-plugin/plugin.json ]; then \ + sed 's/"version": "[^"]*"/"version": "$(VERSION)"/' .claude-plugin/plugin.json > .claude-plugin/plugin.json.tmp && mv .claude-plugin/plugin.json.tmp .claude-plugin/plugin.json; \ + echo "Updated .claude-plugin/plugin.json to version $(VERSION)"; \ + fi @if [ -f marketplace.json ]; then \ sed 's/"version": "[^"]*"/"version": "$(VERSION)"/' marketplace.json > marketplace.json.tmp && mv marketplace.json.tmp marketplace.json; \ echo "Updated marketplace.json version fields to $(VERSION)"; \ @@ -121,9 +121,11 @@ build-windows: stop-worker: @echo "Stopping worker..." @-pkill -TERM -f 'claude-mnemonic.*worker' 2>/dev/null || true + @-pkill -TERM -f '\.claude-mnemonic/bin/worker' 2>/dev/null || true @-pkill -TERM -f '\.claude/plugins/.*/worker' 2>/dev/null || true @sleep 1 @-pkill -9 -f 'claude-mnemonic.*worker' 2>/dev/null || true + @-pkill -9 -f '\.claude-mnemonic/bin/worker' 2>/dev/null || true @-pkill -9 -f '\.claude/plugins/.*/worker' 2>/dev/null || true @-lsof -ti :37777 | xargs kill -9 2>/dev/null || true @sleep 1 @@ -131,12 +133,7 @@ stop-worker: # Start worker in background start-worker: @echo "Starting worker..." - @# Prefer cache directory (where Claude Code looks), fall back to marketplaces - @if [ -f "$(HOME)/.claude/plugins/cache/claude-mnemonic/claude-mnemonic/$(VERSION)/worker" ]; then \ - nohup $(HOME)/.claude/plugins/cache/claude-mnemonic/claude-mnemonic/$(VERSION)/worker > /tmp/claude-mnemonic-worker.log 2>&1 & \ - else \ - nohup $(HOME)/.claude/plugins/marketplaces/claude-mnemonic/worker > /tmp/claude-mnemonic-worker.log 2>&1 & \ - fi + @nohup $(STABLE_BIN)/worker > /tmp/claude-mnemonic-worker.log 2>&1 & @sleep 1 @if curl -s http://localhost:37777/health > /dev/null 2>&1; then \ echo "Worker started successfully (http://localhost:37777)"; \ @@ -147,26 +144,31 @@ start-worker: # Restart worker restart-worker: stop-worker start-worker -# Install to Claude plugins directory +# Install to stable binary location and register with Claude Code install: build stop-worker - @echo "Installing to Claude plugins directory..." + @echo "Installing claude-mnemonic..." @# Verify build output binaries exist @test -f $(BUILD_DIR)/worker || { echo "ERROR: $(BUILD_DIR)/worker not found. Build may have failed."; exit 1; } @test -f $(BUILD_DIR)/mcp-server || { echo "ERROR: $(BUILD_DIR)/mcp-server not found. Build may have failed."; exit 1; } @test -d $(BUILD_DIR)/hooks || { echo "ERROR: $(BUILD_DIR)/hooks not found. Build may have failed."; exit 1; } - @# Install to marketplaces directory (for direct installs) - @mkdir -p $(HOME)/.claude/plugins/marketplaces/claude-mnemonic/hooks + @# Install binaries to stable location (survives Claude Code updates) + @mkdir -p $(STABLE_BIN)/hooks + cp $(BUILD_DIR)/worker $(STABLE_BIN)/ + cp $(BUILD_DIR)/mcp-server $(STABLE_BIN)/ + cp $(BUILD_DIR)/hooks/* $(STABLE_BIN)/hooks/ + @echo "Binaries installed to $(STABLE_BIN)" + @# Set up marketplace directory with wrapper scripts and metadata @mkdir -p $(HOME)/.claude/plugins/marketplaces/claude-mnemonic/.claude-plugin + @mkdir -p $(HOME)/.claude/plugins/marketplaces/claude-mnemonic/hooks @mkdir -p $(HOME)/.claude/plugins/marketplaces/claude-mnemonic/commands - cp $(BUILD_DIR)/worker $(HOME)/.claude/plugins/marketplaces/claude-mnemonic/ - cp $(BUILD_DIR)/mcp-server $(HOME)/.claude/plugins/marketplaces/claude-mnemonic/ - cp $(BUILD_DIR)/hooks/* $(HOME)/.claude/plugins/marketplaces/claude-mnemonic/hooks/ - cp $(PLUGIN_DIR)/hooks/hooks.json $(HOME)/.claude/plugins/marketplaces/claude-mnemonic/hooks/ - @# Copy slash commands if they exist - @if [ -d "$(PLUGIN_DIR)/commands" ]; then cp -r $(PLUGIN_DIR)/commands/* $(HOME)/.claude/plugins/marketplaces/claude-mnemonic/commands/ 2>/dev/null || true; fi - @# Update plugin.json and marketplace.json with current version to prevent stale version directories - @sed 's/{{ .Version }}/$(VERSION)/g; s/{{.Version}}/$(VERSION)/g' $(PLUGIN_DIR)/.claude-plugin/plugin.json.tpl > $(HOME)/.claude/plugins/marketplaces/claude-mnemonic/.claude-plugin/plugin.json - @sed 's/{{ .Version }}/$(VERSION)/g; s/{{.Version}}/$(VERSION)/g' $(PLUGIN_DIR)/.claude-plugin/marketplace.json.tpl > $(HOME)/.claude/plugins/marketplaces/claude-mnemonic/.claude-plugin/marketplace.json + @cp .claude-plugin/plugin.json $(HOME)/.claude/plugins/marketplaces/claude-mnemonic/.claude-plugin/ + @cp marketplace.json $(HOME)/.claude/plugins/marketplaces/claude-mnemonic/.claude-plugin/marketplace.json 2>/dev/null || true + @cp hooks/hooks.json $(HOME)/.claude/plugins/marketplaces/claude-mnemonic/hooks/ + @cp hooks/session-start hooks/user-prompt hooks/post-tool-use hooks/stop hooks/subagent-stop hooks/statusline $(HOME)/.claude/plugins/marketplaces/claude-mnemonic/hooks/ + @chmod +x $(HOME)/.claude/plugins/marketplaces/claude-mnemonic/hooks/* + @cp mcp-server $(HOME)/.claude/plugins/marketplaces/claude-mnemonic/ + @chmod +x $(HOME)/.claude/plugins/marketplaces/claude-mnemonic/mcp-server + @cp commands/restart.md $(HOME)/.claude/plugins/marketplaces/claude-mnemonic/commands/ @echo "Registering plugin with Claude Code..." @./scripts/register-plugin.sh "$(VERSION)" @$(MAKE) start-worker diff --git a/commands/restart.md b/commands/restart.md new file mode 100644 index 0000000..5de0806 --- /dev/null +++ b/commands/restart.md @@ -0,0 +1,21 @@ +# Restart Claude Mnemonic Worker + +Restart the claude-mnemonic worker process. Use this command when experiencing issues with the memory system. + +## Instructions + +1. Call the restart API endpoint using curl: + ```bash + curl -X POST http://127.0.0.1:37777/api/restart + ``` + +2. Wait a moment for the worker to restart (typically 1-2 seconds) + +3. Verify the worker is running by checking the version: + ```bash + curl -s http://127.0.0.1:37777/api/version + ``` + +4. Report the result to the user, including the version number from the response. + +If the restart fails, suggest the user check `/tmp/claude-mnemonic-worker.log` for errors. diff --git a/go.mod b/go.mod index e6e51d9..5ac0755 100644 --- a/go.mod +++ b/go.mod @@ -6,16 +6,16 @@ replace github.com/sugarme/tokenizer => github.com/clems4ever/tokenizer v0.0.0-2 require ( github.com/asg017/sqlite-vec-go-bindings v0.1.6 - github.com/fsnotify/fsnotify v1.9.0 - github.com/go-chi/chi/v5 v5.2.5 + github.com/fsnotify/fsnotify v1.10.1 + github.com/go-chi/chi/v5 v5.3.0 github.com/go-gormigrate/gormigrate/v2 v2.1.5 github.com/goccy/go-json v0.10.6 - github.com/mattn/go-sqlite3 v1.14.42 + github.com/mattn/go-sqlite3 v1.14.44 github.com/rs/zerolog v1.35.1 github.com/smacker/go-tree-sitter v0.0.0-20240827094217-dd81d9e9be82 github.com/stretchr/testify v1.11.1 github.com/sugarme/tokenizer v0.3.0 - github.com/yalue/onnxruntime_go v1.27.0 + github.com/yalue/onnxruntime_go v1.30.1 golang.org/x/sync v0.20.0 gorm.io/driver/sqlite v1.6.0 gorm.io/gorm v1.31.1 @@ -27,14 +27,14 @@ require ( github.com/jinzhu/inflection v1.0.0 // indirect github.com/jinzhu/now v1.1.5 // indirect github.com/mattn/go-colorable v0.1.14 // indirect - github.com/mattn/go-isatty v0.0.21 // indirect + github.com/mattn/go-isatty v0.0.22 // indirect github.com/mitchellh/colorstring v0.0.0-20190213212951-d06e56a500db // indirect github.com/patrickmn/go-cache v2.1.0+incompatible // indirect github.com/pmezard/go-difflib v1.0.0 // indirect github.com/rivo/uniseg v0.4.7 // indirect github.com/schollz/progressbar/v2 v2.15.0 // indirect github.com/sugarme/regexpset v0.0.0-20200920021344-4d4ec8eaf93c // indirect - golang.org/x/sys v0.43.0 // indirect - golang.org/x/text v0.36.0 // indirect + golang.org/x/sys v0.45.0 // indirect + golang.org/x/text v0.37.0 // indirect gopkg.in/yaml.v3 v3.0.1 // indirect ) diff --git a/go.sum b/go.sum index 66f1943..cf1409c 100644 --- a/go.sum +++ b/go.sum @@ -7,10 +7,10 @@ github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= github.com/emirpasic/gods v1.18.1 h1:FXtiHYKDGKCW2KzwZKx0iC0PQmdlorYgdFG9jPXJ1Bc= github.com/emirpasic/gods v1.18.1/go.mod h1:8tpGGwCnJ5H4r6BWwaV6OrWmMoPhUl5jm/FMNAnJvWQ= -github.com/fsnotify/fsnotify v1.9.0 h1:2Ml+OJNzbYCTzsxtv8vKSFD9PbJjmhYF14k/jKC7S9k= -github.com/fsnotify/fsnotify v1.9.0/go.mod h1:8jBTzvmWwFyi3Pb8djgCCO5IBqzKJ/Jwo8TRcHyHii0= -github.com/go-chi/chi/v5 v5.2.5 h1:Eg4myHZBjyvJmAFjFvWgrqDTXFyOzjj7YIm3L3mu6Ug= -github.com/go-chi/chi/v5 v5.2.5/go.mod h1:X7Gx4mteadT3eDOMTsXzmI4/rwUpOwBHLpAfupzFJP0= +github.com/fsnotify/fsnotify v1.10.1 h1:b0/UzAf9yR5rhf3RPm9gf3ehBPpf0oZKIjtpKrx59Ho= +github.com/fsnotify/fsnotify v1.10.1/go.mod h1:TLheqan6HD6GBK6PrDWyDPBaEV8LspOxvPSjC+bVfgo= +github.com/go-chi/chi/v5 v5.3.0 h1:halUjDxhshgXHMrao5bB8eNBXo/rnzwr8m5m36glehM= +github.com/go-chi/chi/v5 v5.3.0/go.mod h1:R+tYY2hNuVUUjxoPtqUdgBqevM9s9njzkTLutVsOCto= github.com/go-gormigrate/gormigrate/v2 v2.1.5 h1:1OyorA5LtdQw12cyJDEHuTrEV3GiXiIhS4/QTTa/SM8= github.com/go-gormigrate/gormigrate/v2 v2.1.5/go.mod h1:mj9ekk/7CPF3VjopaFvWKN2v7fN3D9d3eEOAXRhi/+M= github.com/goccy/go-json v0.10.6 h1:p8HrPJzOakx/mn/bQtjgNjdTcN+/S6FcG2CTtQOrHVU= @@ -21,10 +21,10 @@ github.com/jinzhu/now v1.1.5 h1:/o9tlHleP7gOFmsnYNz3RGnqzefHA47wQpKrrdTIwXQ= github.com/jinzhu/now v1.1.5/go.mod h1:d3SSVoowX0Lcu0IBviAWJpolVfI5UJVZZ7cO71lE/z8= github.com/mattn/go-colorable v0.1.14 h1:9A9LHSqF/7dyVVX6g0U9cwm9pG3kP9gSzcuIPHPsaIE= github.com/mattn/go-colorable v0.1.14/go.mod h1:6LmQG8QLFO4G5z1gPvYEzlUgJ2wF+stgPZH1UqBm1s8= -github.com/mattn/go-isatty v0.0.21 h1:xYae+lCNBP7QuW4PUnNG61ffM4hVIfm+zUzDuSzYLGs= -github.com/mattn/go-isatty v0.0.21/go.mod h1:ZXfXG4SQHsB/w3ZeOYbR0PrPwLy+n6xiMrJlRFqopa4= -github.com/mattn/go-sqlite3 v1.14.42 h1:MigqEP4ZmHw3aIdIT7T+9TLa90Z6smwcthx+Azv4Cgo= -github.com/mattn/go-sqlite3 v1.14.42/go.mod h1:pjEuOr8IwzLJP2MfGeTb0A35jauH+C2kbHKBr7yXKVQ= +github.com/mattn/go-isatty v0.0.22 h1:j8l17JJ9i6VGPUFUYoTUKPSgKe/83EYU2zBC7YNKMw4= +github.com/mattn/go-isatty v0.0.22/go.mod h1:ZXfXG4SQHsB/w3ZeOYbR0PrPwLy+n6xiMrJlRFqopa4= +github.com/mattn/go-sqlite3 v1.14.44 h1:3VSe+xafpbzsLbdr2AWlAZk9yRHiBhTBakioXaCKTF8= +github.com/mattn/go-sqlite3 v1.14.44/go.mod h1:pjEuOr8IwzLJP2MfGeTb0A35jauH+C2kbHKBr7yXKVQ= github.com/mitchellh/colorstring v0.0.0-20190213212951-d06e56a500db h1:62I3jR2EmQ4l5rM/4FEfDWcRD+abF5XlKShorW5LRoQ= github.com/mitchellh/colorstring v0.0.0-20190213212951-d06e56a500db/go.mod h1:l0dey0ia/Uv7NcFFVbCLtqEBQbrT4OCwCSKTEv6enCw= github.com/patrickmn/go-cache v2.1.0+incompatible h1:HRMgzkcYKYpi3C8ajMPV8OFXaaRUnok+kx1WdO15EQc= @@ -45,14 +45,14 @@ github.com/stretchr/testify v1.11.1 h1:7s2iGBzp5EwR7/aIZr8ao5+dra3wiQyKjjFuvgVKu github.com/stretchr/testify v1.11.1/go.mod h1:wZwfW3scLgRK+23gO65QZefKpKQRnfz6sD981Nm4B6U= github.com/sugarme/regexpset v0.0.0-20200920021344-4d4ec8eaf93c h1:pwb4kNSHb4K89ymCaN+5lPH/MwnfSVg4rzGDh4d+iy4= github.com/sugarme/regexpset v0.0.0-20200920021344-4d4ec8eaf93c/go.mod h1:2gwkXLWbDGUQWeL3RtpCmcY4mzCtU13kb9UsAg9xMaw= -github.com/yalue/onnxruntime_go v1.27.0 h1:c1YSgDNtpf0WGtxj3YeRIb8VC5LmM1J+Ve3uHdteC1U= -github.com/yalue/onnxruntime_go v1.27.0/go.mod h1:b4X26A8pekNb1ACJ58wAXgNKeUCGEAQ9dmACut9Sm/4= +github.com/yalue/onnxruntime_go v1.30.1 h1:NaEng5lWbsHZ/8X1dtaw1mIj7eV1ozyjbFo//g0ktl4= +github.com/yalue/onnxruntime_go v1.30.1/go.mod h1:b4X26A8pekNb1ACJ58wAXgNKeUCGEAQ9dmACut9Sm/4= golang.org/x/sync v0.20.0 h1:e0PTpb7pjO8GAtTs2dQ6jYa5BWYlMuX047Dco/pItO4= golang.org/x/sync v0.20.0/go.mod h1:9xrNwdLfx4jkKbNva9FpL6vEN7evnE43NNNJQ2LF3+0= -golang.org/x/sys v0.43.0 h1:Rlag2XtaFTxp19wS8MXlJwTvoh8ArU6ezoyFsMyCTNI= -golang.org/x/sys v0.43.0/go.mod h1:4GL1E5IUh+htKOUEOaiffhrAeqysfVGipDYzABqnCmw= -golang.org/x/text v0.36.0 h1:JfKh3XmcRPqZPKevfXVpI1wXPTqbkE5f7JA92a55Yxg= -golang.org/x/text v0.36.0/go.mod h1:NIdBknypM8iqVmPiuco0Dh6P5Jcdk8lJL0CUebqK164= +golang.org/x/sys v0.45.0 h1:dO4czNzziLiiXplLQgBCEpCvXQ3dnkn0SdaZSYdQ+FY= +golang.org/x/sys v0.45.0/go.mod h1:4GL1E5IUh+htKOUEOaiffhrAeqysfVGipDYzABqnCmw= +golang.org/x/text v0.37.0 h1:Cqjiwd9eSg8e0QAkyCaQTNHFIIzWtidPahFWR83rTrc= +golang.org/x/text v0.37.0/go.mod h1:a5sjxXGs9hsn/AJVwuElvCAo9v8QYLzvavO5z2PiM38= gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405 h1:yhCVgyC4o1eVCa2tZl7eS0r+SDo693bJlVdllGtEeKM= gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA= diff --git a/hooks/hooks.json b/hooks/hooks.json new file mode 100644 index 0000000..0addee0 --- /dev/null +++ b/hooks/hooks.json @@ -0,0 +1,61 @@ +{ + "description": "Claude Mnemonic - Persistent memory hooks for observations, prompts, and session summaries", + "hooks": { + "SessionStart": [ + { + "hooks": [ + { + "type": "command", + "command": "${CLAUDE_PLUGIN_ROOT}/hooks/session-start", + "timeout": 30 + } + ] + } + ], + "UserPromptSubmit": [ + { + "hooks": [ + { + "type": "command", + "command": "${CLAUDE_PLUGIN_ROOT}/hooks/user-prompt", + "timeout": 10 + } + ] + } + ], + "PostToolUse": [ + { + "matcher": "*", + "hooks": [ + { + "type": "command", + "command": "${CLAUDE_PLUGIN_ROOT}/hooks/post-tool-use", + "timeout": 10 + } + ] + } + ], + "SubagentStop": [ + { + "hooks": [ + { + "type": "command", + "command": "${CLAUDE_PLUGIN_ROOT}/hooks/subagent-stop", + "timeout": 10 + } + ] + } + ], + "Stop": [ + { + "hooks": [ + { + "type": "command", + "command": "${CLAUDE_PLUGIN_ROOT}/hooks/stop", + "timeout": 30 + } + ] + } + ] + } +} diff --git a/hooks/post-tool-use b/hooks/post-tool-use new file mode 100755 index 0000000..86f1d07 --- /dev/null +++ b/hooks/post-tool-use @@ -0,0 +1,4 @@ +#!/bin/sh +BIN="$HOME/.claude-mnemonic/bin/hooks/post-tool-use" +[ -x "$BIN" ] && exec "$BIN" "$@" +exit 0 diff --git a/hooks/session-start b/hooks/session-start new file mode 100755 index 0000000..add62e5 --- /dev/null +++ b/hooks/session-start @@ -0,0 +1,4 @@ +#!/bin/sh +BIN="$HOME/.claude-mnemonic/bin/hooks/session-start" +[ -x "$BIN" ] && exec "$BIN" "$@" +exit 0 diff --git a/hooks/statusline b/hooks/statusline new file mode 100755 index 0000000..34c76d6 --- /dev/null +++ b/hooks/statusline @@ -0,0 +1,4 @@ +#!/bin/sh +BIN="$HOME/.claude-mnemonic/bin/hooks/statusline" +[ -x "$BIN" ] && exec "$BIN" "$@" +exit 0 diff --git a/hooks/stop b/hooks/stop new file mode 100755 index 0000000..00df910 --- /dev/null +++ b/hooks/stop @@ -0,0 +1,4 @@ +#!/bin/sh +BIN="$HOME/.claude-mnemonic/bin/hooks/stop" +[ -x "$BIN" ] && exec "$BIN" "$@" +exit 0 diff --git a/hooks/subagent-stop b/hooks/subagent-stop new file mode 100755 index 0000000..2108c19 --- /dev/null +++ b/hooks/subagent-stop @@ -0,0 +1,4 @@ +#!/bin/sh +BIN="$HOME/.claude-mnemonic/bin/hooks/subagent-stop" +[ -x "$BIN" ] && exec "$BIN" "$@" +exit 0 diff --git a/hooks/user-prompt b/hooks/user-prompt new file mode 100755 index 0000000..c648171 --- /dev/null +++ b/hooks/user-prompt @@ -0,0 +1,4 @@ +#!/bin/sh +BIN="$HOME/.claude-mnemonic/bin/hooks/user-prompt" +[ -x "$BIN" ] && exec "$BIN" "$@" +exit 0 diff --git a/internal/db/gorm/store.go b/internal/db/gorm/store.go index 17c8a69..480ecf7 100644 --- a/internal/db/gorm/store.go +++ b/internal/db/gorm/store.go @@ -513,13 +513,6 @@ func (s *Store) ExecWithTimeout(ctx context.Context, timeout time.Duration, quer return nil } -// QueryRowWithTimeout executes a row query with timeout. -func (s *Store) QueryRowWithTimeout(ctx context.Context, timeout time.Duration, query string, args ...any) *sql.Row { - timeoutCtx, cancel := s.WithTimeout(ctx, timeout, "query_row") - // Note: cancel will be called when row.Scan() completes or errors - _ = cancel // Caller must ensure proper cleanup - return s.sqlDB.QueryRowContext(timeoutCtx, query, args...) -} // TransactionWithTimeout wraps a transaction function with timeout handling. // The transaction is automatically rolled back if the context times out. diff --git a/internal/embedding/assets/lib/darwin-arm64/.version b/internal/embedding/assets/lib/darwin-arm64/.version index ae96cc7..5ff8c4f 100644 --- a/internal/embedding/assets/lib/darwin-arm64/.version +++ b/internal/embedding/assets/lib/darwin-arm64/.version @@ -1 +1 @@ -1.24.3 +1.26.0 diff --git a/internal/embedding/assets/lib/linux-amd64/.version b/internal/embedding/assets/lib/linux-amd64/.version index ae96cc7..5ff8c4f 100644 --- a/internal/embedding/assets/lib/linux-amd64/.version +++ b/internal/embedding/assets/lib/linux-amd64/.version @@ -1 +1 @@ -1.24.3 +1.26.0 diff --git a/internal/embedding/assets/lib/linux-arm64/.version b/internal/embedding/assets/lib/linux-arm64/.version index ae96cc7..5ff8c4f 100644 --- a/internal/embedding/assets/lib/linux-arm64/.version +++ b/internal/embedding/assets/lib/linux-arm64/.version @@ -1 +1 @@ -1.24.3 +1.26.0 diff --git a/internal/embedding/assets/lib/windows-amd64/.version b/internal/embedding/assets/lib/windows-amd64/.version index ae96cc7..5ff8c4f 100644 --- a/internal/embedding/assets/lib/windows-amd64/.version +++ b/internal/embedding/assets/lib/windows-amd64/.version @@ -1 +1 @@ -1.24.3 +1.26.0 diff --git a/internal/graph/graph_test.go b/internal/graph/graph_test.go new file mode 100644 index 0000000..4f5fe58 --- /dev/null +++ b/internal/graph/graph_test.go @@ -0,0 +1,674 @@ +//go:build fts5 + +package graph + +import ( + "context" + "database/sql" + "testing" + "time" + + "github.com/lukaszraczylo/claude-mnemonic/pkg/models" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +// ---- helpers ---------------------------------------------------------------- + +func makeObs(id int64, sessionID string, concepts, filesRead, filesModified []string) *models.Observation { + return &models.Observation{ + ID: id, + SDKSessionID: sessionID, + Title: sql.NullString{String: "title", Valid: true}, + Project: "test-project", + Type: models.ObsTypeDecision, + Concepts: concepts, + FilesRead: filesRead, + FilesModified: filesModified, + CreatedAtEpoch: time.Now().UnixMilli(), + } +} + +// ---- ObservationGraph ------------------------------------------------------- + +func TestNewObservationGraph_Empty(t *testing.T) { + g := NewObservationGraph() + require.NotNil(t, g) + + stats := g.Stats() + assert.Equal(t, 0, stats.NodeCount) + assert.Equal(t, 0, stats.EdgeCount) +} + +func TestAddNode_StoresAndRetrieves(t *testing.T) { + g := NewObservationGraph() + node := &Node{ + ID: 42, + Degree: 0, + Metadata: NodeMetadata{ + Project: "proj", + Type: "decision", + Title: "test node", + }, + } + g.AddNode(node) + + got, err := g.GetNode(42) + require.NoError(t, err) + assert.Equal(t, int64(42), got.ID) + assert.Equal(t, "test node", got.Metadata.Title) +} + +func TestAddNode_OverwritesExisting(t *testing.T) { + g := NewObservationGraph() + g.AddNode(&Node{ID: 1, Metadata: NodeMetadata{Title: "old"}}) + g.AddNode(&Node{ID: 1, Metadata: NodeMetadata{Title: "new"}}) + + got, err := g.GetNode(1) + require.NoError(t, err) + assert.Equal(t, "new", got.Metadata.Title) +} + +func TestGetNode_NotFound(t *testing.T) { + g := NewObservationGraph() + _, err := g.GetNode(999) + assert.Error(t, err) +} + +func TestAddEdge_UpdatesDegree(t *testing.T) { + g := NewObservationGraph() + g.AddNode(&Node{ID: 1}) + g.AddNode(&Node{ID: 2}) + + g.AddEdge(Edge{FromID: 1, ToID: 2, Relation: RelationTemporal, Weight: 0.8}) + + n1, _ := g.GetNode(1) + n2, _ := g.GetNode(2) + assert.Equal(t, 1, n1.Degree) + assert.Equal(t, 1, n2.Degree) +} + +func TestAddEdge_MissingNodesDontPanic(t *testing.T) { + g := NewObservationGraph() + // Adding edge referencing non-existent nodes must not panic + assert.NotPanics(t, func() { + g.AddEdge(Edge{FromID: 100, ToID: 200, Relation: RelationConcept, Weight: 0.5}) + }) +} + +// ---- BuildCSR / GetNeighbors ------------------------------------------------ + +func TestBuildCSR_NoNodes_ReturnsError(t *testing.T) { + g := NewObservationGraph() + err := g.BuildCSR() + assert.Error(t, err) +} + +func TestGetNeighbors_AfterBuildCSR(t *testing.T) { + g := NewObservationGraph() + for _, id := range []int64{1, 2, 3} { + g.AddNode(&Node{ID: id}) + } + g.AddEdge(Edge{FromID: 1, ToID: 2, Relation: RelationTemporal, Weight: 0.8}) + g.AddEdge(Edge{FromID: 1, ToID: 3, Relation: RelationConcept, Weight: 0.6}) + + require.NoError(t, g.BuildCSR()) + + neighbors, weights, err := g.GetNeighbors(1) + require.NoError(t, err) + assert.Len(t, neighbors, 2) + assert.Len(t, weights, 2) +} + +func TestGetNeighbors_NodeWithNoOutgoingEdges(t *testing.T) { + g := NewObservationGraph() + g.AddNode(&Node{ID: 1}) + g.AddNode(&Node{ID: 2}) + // Edge only from 2 → 1; node 2 is a leaf from 1's perspective + g.AddEdge(Edge{FromID: 2, ToID: 1, Relation: RelationTemporal, Weight: 0.8}) + require.NoError(t, g.BuildCSR()) + + neighbors, weights, err := g.GetNeighbors(1) + require.NoError(t, err) + assert.Empty(t, neighbors) + assert.Empty(t, weights) +} + +func TestGetNeighbors_NodeNotInGraph(t *testing.T) { + g := NewObservationGraph() + g.AddNode(&Node{ID: 1}) + require.NoError(t, g.BuildCSR()) + + _, _, err := g.GetNeighbors(999) + assert.Error(t, err) +} + +// ---- FindHubs --------------------------------------------------------------- + +func TestFindHubs_EmptyGraph(t *testing.T) { + g := NewObservationGraph() + hubs := g.FindHubs(0.1) + assert.Nil(t, hubs) +} + +func TestFindHubs_IdentifiesHighDegreeNodes(t *testing.T) { + g := NewObservationGraph() + // Node 1 connected to everyone else → hub + for id := int64(1); id <= 5; id++ { + g.AddNode(&Node{ID: id}) + } + for id := int64(2); id <= 5; id++ { + g.AddEdge(Edge{FromID: 1, ToID: id, Relation: RelationConcept, Weight: 0.5}) + } + + hubs := g.FindHubs(0.2) // top 20% + assert.Contains(t, hubs, int64(1)) +} + +func TestFindHubs_Percentile100_ReturnsEmpty(t *testing.T) { + // percentile=1.0 → cutoff = ceil(N * (1 - 1.0)) = ceil(0) = 0 → no hubs + g := NewObservationGraph() + for id := int64(1); id <= 4; id++ { + g.AddNode(&Node{ID: id}) + } + hubs := g.FindHubs(1.0) + assert.Empty(t, hubs) +} + +func TestFindHubs_Percentile0_ReturnsAllNodes(t *testing.T) { + // percentile=0.0 → cutoff = ceil(N * 1.0) = N → all nodes returned + g := NewObservationGraph() + for id := int64(1); id <= 4; id++ { + g.AddNode(&Node{ID: id}) + } + hubs := g.FindHubs(0.0) + assert.Len(t, hubs, 4) +} + +// ---- Stats ------------------------------------------------------------------ + +func TestStats_EdgeTypesCounted(t *testing.T) { + g := NewObservationGraph() + for _, id := range []int64{1, 2, 3} { + g.AddNode(&Node{ID: id}) + } + g.AddEdge(Edge{FromID: 1, ToID: 2, Relation: RelationTemporal, Weight: 0.8}) + g.AddEdge(Edge{FromID: 1, ToID: 3, Relation: RelationConcept, Weight: 0.6}) + g.AddEdge(Edge{FromID: 2, ToID: 3, Relation: RelationConcept, Weight: 0.6}) + + stats := g.Stats() + assert.Equal(t, 3, stats.NodeCount) + assert.Equal(t, 3, stats.EdgeCount) + assert.Equal(t, 1, stats.EdgeTypes[RelationTemporal]) + assert.Equal(t, 2, stats.EdgeTypes[RelationConcept]) +} + +func TestStats_DegreeMetrics(t *testing.T) { + g := NewObservationGraph() + // Node 1: degree 2, nodes 2,3: degree 1 each + for _, id := range []int64{1, 2, 3} { + g.AddNode(&Node{ID: id}) + } + g.AddEdge(Edge{FromID: 1, ToID: 2, Relation: RelationTemporal, Weight: 0.8}) + g.AddEdge(Edge{FromID: 1, ToID: 3, Relation: RelationTemporal, Weight: 0.8}) + + stats := g.Stats() + assert.Equal(t, 2, stats.MaxDegree) + assert.Equal(t, 1, stats.MinDegree) + assert.InDelta(t, 4.0/3.0, stats.AvgDegree, 0.001) +} + +// ---- BuildFromObservations -------------------------------------------------- + +func TestBuildFromObservations_SingleObservation_ReturnsError(t *testing.T) { + obs := []*models.Observation{makeObs(1, "s1", nil, nil, nil)} + // Single observation: DetectEdges returns nil, BuildCSR errors (no nodes never happens + // since node was added — but CSR build will succeed with 1 node and 0 edges). + g, err := BuildFromObservations(context.Background(), obs) + // With 1 node, BuildCSR succeeds (nodes exist); no edges → valid graph. + require.NoError(t, err) + require.NotNil(t, g) + stats := g.Stats() + assert.Equal(t, 1, stats.NodeCount) + assert.Equal(t, 0, stats.EdgeCount) +} + +func TestBuildFromObservations_SetsNodeMetadata(t *testing.T) { + obs := []*models.Observation{ + { + ID: 7, + SDKSessionID: "sess", + Project: "myproject", + Type: models.ObsTypeFeature, + Title: sql.NullString{String: "feature title", Valid: true}, + IsSuperseded: true, + CreatedAtEpoch: time.Now().UnixMilli(), + }, + } + g, err := BuildFromObservations(context.Background(), obs) + require.NoError(t, err) + + node, err := g.GetNode(7) + require.NoError(t, err) + assert.Equal(t, "myproject", node.Metadata.Project) + assert.Equal(t, "feature title", node.Metadata.Title) + assert.Equal(t, string(models.ObsTypeFeature), node.Metadata.Type) + assert.True(t, node.Metadata.IsSuperseded) +} + +func TestBuildFromObservations_TitleMissing_EmptyString(t *testing.T) { + obs := []*models.Observation{ + { + ID: 3, + SDKSessionID: "s", + Title: sql.NullString{Valid: false}, + CreatedAtEpoch: time.Now().UnixMilli(), + }, + } + g, err := BuildFromObservations(context.Background(), obs) + require.NoError(t, err) + + node, err := g.GetNode(3) + require.NoError(t, err) + assert.Equal(t, "", node.Metadata.Title) +} + +func TestBuildFromObservations_WithTemporalEdges(t *testing.T) { + obs := []*models.Observation{ + makeObs(1, "sess-a", nil, nil, nil), + makeObs(2, "sess-a", nil, nil, nil), + makeObs(3, "sess-b", nil, nil, nil), + } + g, err := BuildFromObservations(context.Background(), obs) + require.NoError(t, err) + + stats := g.Stats() + assert.Equal(t, 3, stats.NodeCount) + // obs 1 and 2 share session → 1 temporal edge + assert.Equal(t, 1, stats.EdgeTypes[RelationTemporal]) +} + +func TestBuildFromObservations_WithConceptEdges(t *testing.T) { + obs := []*models.Observation{ + makeObs(1, "s1", []string{"security", "auth"}, nil, nil), + makeObs(2, "s2", []string{"security"}, nil, nil), + makeObs(3, "s3", []string{"unrelated"}, nil, nil), + } + g, err := BuildFromObservations(context.Background(), obs) + require.NoError(t, err) + + stats := g.Stats() + // obs 1 and 2 share "security" + assert.GreaterOrEqual(t, stats.EdgeTypes[RelationConcept], 1) +} + +func TestBuildFromObservations_WithFileOverlapEdges(t *testing.T) { + obs := []*models.Observation{ + makeObs(1, "s1", nil, []string{"pkg/foo.go", "pkg/bar.go"}, nil), + makeObs(2, "s2", nil, []string{"pkg/foo.go", "pkg/baz.go"}, nil), + } + g, err := BuildFromObservations(context.Background(), obs) + require.NoError(t, err) + + stats := g.Stats() + // Jaccard({foo,bar},{foo,baz}) = 1/3 ≈ 0.333 > MinFileOverlapForEdge(0.3) + assert.GreaterOrEqual(t, stats.EdgeTypes[RelationFileOverlap], 1) +} + +// ---- RelationType.String() -------------------------------------------------- + +func TestRelationType_String(t *testing.T) { + cases := []struct { + rt RelationType + want string + }{ + {RelationFileOverlap, "file_overlap"}, + {RelationSemantic, "semantic"}, + {RelationTemporal, "temporal"}, + {RelationConcept, "concept"}, + {RelationType(99), "unknown"}, + } + for _, tc := range cases { + t.Run(tc.want, func(t *testing.T) { + assert.Equal(t, tc.want, tc.rt.String()) + }) + } +} + +// ---- DetectEdges (edge_detector.go) ---------------------------------------- + +func TestDetectEdges_LessThanTwo_ReturnsNil(t *testing.T) { + edges, err := DetectEdges(context.Background(), nil) + assert.NoError(t, err) + assert.Nil(t, edges) + + edges, err = DetectEdges(context.Background(), []*models.Observation{makeObs(1, "s", nil, nil, nil)}) + assert.NoError(t, err) + assert.Nil(t, edges) +} + +func TestDetectEdges_SameSession_CreatesTemporalEdges(t *testing.T) { + obs := []*models.Observation{ + makeObs(1, "session-x", nil, nil, nil), + makeObs(2, "session-x", nil, nil, nil), + makeObs(3, "session-x", nil, nil, nil), + } + edges, err := DetectEdges(context.Background(), obs) + require.NoError(t, err) + + var temporal []Edge + for _, e := range edges { + if e.Relation == RelationTemporal { + temporal = append(temporal, e) + } + } + // Consecutive pairs: (1,2) and (2,3) + assert.Len(t, temporal, 2) + assert.InDelta(t, 0.8, temporal[0].Weight, 0.001) +} + +func TestDetectEdges_DifferentSessions_NoTemporalEdges(t *testing.T) { + obs := []*models.Observation{ + makeObs(1, "sess-a", nil, nil, nil), + makeObs(2, "sess-b", nil, nil, nil), + } + edges, err := DetectEdges(context.Background(), obs) + require.NoError(t, err) + + for _, e := range edges { + assert.NotEqual(t, RelationTemporal, e.Relation) + } +} + +func TestDetectEdges_EmptySessionID_NoTemporalEdge(t *testing.T) { + obs := []*models.Observation{ + makeObs(1, "", nil, nil, nil), + makeObs(2, "", nil, nil, nil), + } + edges, err := DetectEdges(context.Background(), obs) + require.NoError(t, err) + + for _, e := range edges { + assert.NotEqual(t, RelationTemporal, e.Relation) + } +} + +func TestDetectEdges_SharedConcepts_CreatesConceptEdges(t *testing.T) { + obs := []*models.Observation{ + makeObs(1, "s1", []string{"performance", "caching"}, nil, nil), + makeObs(2, "s2", []string{"performance"}, nil, nil), + makeObs(3, "s3", []string{"caching"}, nil, nil), + } + edges, err := DetectEdges(context.Background(), obs) + require.NoError(t, err) + + conceptEdges := filterByRelation(edges, RelationConcept) + // obs1↔obs2 (performance), obs1↔obs3 (caching) → 2 concept edges + assert.Len(t, conceptEdges, 2) +} + +func TestDetectEdges_NoConcepts_NoConceptEdges(t *testing.T) { + obs := []*models.Observation{ + makeObs(1, "s1", nil, nil, nil), + makeObs(2, "s2", nil, nil, nil), + } + edges, err := DetectEdges(context.Background(), obs) + require.NoError(t, err) + + assert.Empty(t, filterByRelation(edges, RelationConcept)) +} + +func TestDetectEdges_FileOverlap_AboveThreshold_CreatesEdge(t *testing.T) { + // Jaccard 2/3 ≈ 0.667 > 0.3 threshold + obs := []*models.Observation{ + makeObs(1, "s1", nil, []string{"a.go", "b.go", "c.go"}, nil), + makeObs(2, "s2", nil, []string{"a.go", "b.go", "d.go"}, nil), + } + edges, err := DetectEdges(context.Background(), obs) + require.NoError(t, err) + + fileEdges := filterByRelation(edges, RelationFileOverlap) + require.Len(t, fileEdges, 1) + assert.InDelta(t, 2.0/4.0, float64(fileEdges[0].Weight), 0.01) +} + +func TestDetectEdges_FileOverlap_BelowThreshold_NoEdge(t *testing.T) { + // Jaccard 1/9 ≈ 0.11 < 0.3 threshold + obs := []*models.Observation{ + makeObs(1, "s1", nil, []string{"a.go", "b.go", "c.go", "d.go", "e.go"}, nil), + makeObs(2, "s2", nil, []string{"a.go", "f.go", "g.go", "h.go", "i.go"}, nil), + } + edges, err := DetectEdges(context.Background(), obs) + require.NoError(t, err) + + assert.Empty(t, filterByRelation(edges, RelationFileOverlap)) +} + +func TestDetectEdges_FilesModified_CountsForOverlap(t *testing.T) { + obs := []*models.Observation{ + makeObs(1, "s1", nil, nil, []string{"pkg/core.go", "pkg/util.go"}), + makeObs(2, "s2", nil, []string{"pkg/core.go"}, []string{"pkg/util.go"}), + } + edges, err := DetectEdges(context.Background(), obs) + require.NoError(t, err) + + fileEdges := filterByRelation(edges, RelationFileOverlap) + assert.NotEmpty(t, fileEdges) +} + +func TestDetectEdges_NoEdgeDuplicates(t *testing.T) { + // Same pair via two concepts → only one concept edge + obs := []*models.Observation{ + makeObs(1, "s1", []string{"security", "auth"}, nil, nil), + makeObs(2, "s2", []string{"security", "auth"}, nil, nil), + } + edges, err := DetectEdges(context.Background(), obs) + require.NoError(t, err) + + conceptEdges := filterByRelation(edges, RelationConcept) + // Both share security and auth, but deduplication should keep only 1 edge per pair per call + // The seen map deduplicates: only first concept that creates the pair wins + assert.Len(t, conceptEdges, 1) +} + +// ---- calculateFileOverlap --------------------------------------------------- + +func TestCalculateFileOverlap_DisjointSets_Zero(t *testing.T) { + result := calculateFileOverlap([]string{"a.go", "b.go"}, []string{"c.go", "d.go"}) + assert.Equal(t, float32(0.0), result) +} + +func TestCalculateFileOverlap_IdenticalSets_One(t *testing.T) { + files := []string{"a.go", "b.go", "c.go"} + result := calculateFileOverlap(files, files) + assert.InDelta(t, 1.0, float64(result), 0.001) +} + +func TestCalculateFileOverlap_EmptySlices_Zero(t *testing.T) { + assert.Equal(t, float32(0.0), calculateFileOverlap(nil, []string{"a.go"})) + assert.Equal(t, float32(0.0), calculateFileOverlap([]string{"a.go"}, nil)) + assert.Equal(t, float32(0.0), calculateFileOverlap(nil, nil)) +} + +func TestCalculateFileOverlap_Jaccard_Correct(t *testing.T) { + // {a,b,c} ∩ {b,c,d} = {b,c} → 2/4 = 0.5 + result := calculateFileOverlap([]string{"a", "b", "c"}, []string{"b", "c", "d"}) + assert.InDelta(t, 0.5, float64(result), 0.001) +} + +func TestCalculateFileOverlap_Duplicates_TreatedAsSet(t *testing.T) { + // Duplicates collapse: {a,a,b} → {a,b}; {a,b,b} → {a,b}; Jaccard = 1.0 + result := calculateFileOverlap([]string{"a", "a", "b"}, []string{"a", "b", "b"}) + assert.InDelta(t, 1.0, float64(result), 0.001) +} + +// ---- DetectSemanticEdges ---------------------------------------------------- + +func TestDetectSemanticEdges_AboveThreshold_CreatesEdge(t *testing.T) { + // Identical vectors → similarity = 1.0 > 0.85 + emb := []float32{1.0, 0.0, 0.0} + obs := []*models.Observation{makeObs(1, "s", nil, nil, nil), makeObs(2, "s", nil, nil, nil)} + embeddings := map[int64][]float32{1: emb, 2: emb} + + edges := DetectSemanticEdges(context.Background(), obs, embeddings) + require.Len(t, edges, 1) + assert.Equal(t, RelationSemantic, edges[0].Relation) + assert.InDelta(t, 1.0, float64(edges[0].Weight), 0.001) +} + +func TestDetectSemanticEdges_BelowThreshold_NoEdge(t *testing.T) { + // Orthogonal vectors → similarity = 0.0 + obs := []*models.Observation{makeObs(1, "s", nil, nil, nil), makeObs(2, "s", nil, nil, nil)} + embeddings := map[int64][]float32{ + 1: {1.0, 0.0, 0.0}, + 2: {0.0, 1.0, 0.0}, + } + + edges := DetectSemanticEdges(context.Background(), obs, embeddings) + assert.Empty(t, edges) +} + +func TestDetectSemanticEdges_MissingEmbedding_Skipped(t *testing.T) { + obs := []*models.Observation{makeObs(1, "s", nil, nil, nil), makeObs(2, "s", nil, nil, nil)} + // Only obs 1 has embedding + embeddings := map[int64][]float32{1: {1.0, 0.0, 0.0}} + + edges := DetectSemanticEdges(context.Background(), obs, embeddings) + assert.Empty(t, edges) +} + +func TestDetectSemanticEdges_NoDuplicates(t *testing.T) { + emb := []float32{0.9, 0.1, 0.0} + obs := []*models.Observation{ + makeObs(1, "s", nil, nil, nil), + makeObs(2, "s", nil, nil, nil), + makeObs(3, "s", nil, nil, nil), + } + embeddings := map[int64][]float32{1: emb, 2: emb, 3: emb} + + edges := DetectSemanticEdges(context.Background(), obs, embeddings) + // 3 pairs: (1,2),(1,3),(2,3) + assert.Len(t, edges, 3) +} + +// ---- cosineSimilarity ------------------------------------------------------- + +func TestCosineSimilarity_IdenticalVectors(t *testing.T) { + v := []float32{1.0, 2.0, 3.0} + result := cosineSimilarity(v, v) + assert.InDelta(t, 1.0, float64(result), 0.0001) +} + +func TestCosineSimilarity_OppositeVectors(t *testing.T) { + a := []float32{1.0, 0.0} + b := []float32{-1.0, 0.0} + result := cosineSimilarity(a, b) + assert.InDelta(t, -1.0, float64(result), 0.0001) +} + +func TestCosineSimilarity_OrthogonalVectors(t *testing.T) { + a := []float32{1.0, 0.0} + b := []float32{0.0, 1.0} + result := cosineSimilarity(a, b) + assert.InDelta(t, 0.0, float64(result), 0.0001) +} + +func TestCosineSimilarity_ZeroVector_ReturnsZero(t *testing.T) { + a := []float32{0.0, 0.0} + b := []float32{1.0, 0.0} + assert.Equal(t, float32(0.0), cosineSimilarity(a, b)) + assert.Equal(t, float32(0.0), cosineSimilarity(b, a)) +} + +func TestCosineSimilarity_MismatchedLength_ReturnsZero(t *testing.T) { + a := []float32{1.0, 2.0} + b := []float32{1.0, 2.0, 3.0} + assert.Equal(t, float32(0.0), cosineSimilarity(a, b)) +} + +// ---- edgeKey ---------------------------------------------------------------- + +func TestEdgeKey_Symmetric(t *testing.T) { + // Must produce the same key regardless of order + assert.Equal(t, edgeKey(1, 2), edgeKey(2, 1)) + assert.Equal(t, edgeKey(100, 5), edgeKey(5, 100)) +} + +func TestEdgeKey_DifferentPairs_DifferentKeys(t *testing.T) { + assert.NotEqual(t, edgeKey(1, 2), edgeKey(1, 3)) + assert.NotEqual(t, edgeKey(1, 2), edgeKey(2, 3)) +} + +// ---- pruneEdges ------------------------------------------------------------- + +func TestPruneEdges_BelowLimit_NoChange(t *testing.T) { + edges := []Edge{ + {FromID: 1, ToID: 2, Weight: 0.9}, + {FromID: 1, ToID: 3, Weight: 0.7}, + } + pruned := pruneEdges(edges, 5) + assert.Len(t, pruned, 2) +} + +func TestPruneEdges_ZeroLimit_ReturnsAll(t *testing.T) { + edges := []Edge{{FromID: 1, ToID: 2, Weight: 0.5}} + pruned := pruneEdges(edges, 0) + assert.Len(t, pruned, 1) +} + +func TestPruneEdges_KeepsHighWeightEdges(t *testing.T) { + // Node 1 gets 4 edges, limit is 2 → only the 2 heaviest should survive + edges := []Edge{ + {FromID: 1, ToID: 2, Weight: 0.9}, + {FromID: 1, ToID: 3, Weight: 0.8}, + {FromID: 1, ToID: 4, Weight: 0.3}, + {FromID: 1, ToID: 5, Weight: 0.1}, + } + pruned := pruneEdges(edges, 2) + + weights := make([]float32, len(pruned)) + for i, e := range pruned { + weights[i] = e.Weight + } + assert.Contains(t, weights, float32(0.9)) + assert.Contains(t, weights, float32(0.8)) +} + +// ---- sortEdgesByWeight ------------------------------------------------------ + +func TestSortEdgesByWeight_DescendingOrder(t *testing.T) { + edges := []Edge{ + {Weight: 0.3}, + {Weight: 0.9}, + {Weight: 0.1}, + {Weight: 0.7}, + } + sortEdgesByWeight(edges) + for i := 1; i < len(edges); i++ { + assert.GreaterOrEqual(t, edges[i-1].Weight, edges[i].Weight) + } +} + +func TestSortEdgesByWeight_EmptySlice_NoPanic(t *testing.T) { + assert.NotPanics(t, func() { + sortEdgesByWeight([]Edge{}) + }) +} + +func TestSortEdgesByWeight_SingleElement_Unchanged(t *testing.T) { + edges := []Edge{{Weight: 0.5}} + sortEdgesByWeight(edges) + assert.Equal(t, float32(0.5), edges[0].Weight) +} + +// ---- helpers ---------------------------------------------------------------- + +func filterByRelation(edges []Edge, rel RelationType) []Edge { + var out []Edge + for _, e := range edges { + if e.Relation == rel { + out = append(out, e) + } + } + return out +} diff --git a/internal/maintenance/service_test.go b/internal/maintenance/service_test.go new file mode 100644 index 0000000..fc0fda2 --- /dev/null +++ b/internal/maintenance/service_test.go @@ -0,0 +1,727 @@ +//go:build fts5 + +// Package maintenance provides scheduled maintenance tasks for claude-mnemonic. +package maintenance + +import ( + "context" + "os" + "path/filepath" + "sync" + "testing" + "time" + + "github.com/rs/zerolog" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + "gorm.io/gorm/logger" + + "github.com/lukaszraczylo/claude-mnemonic/internal/config" + gormdb "github.com/lukaszraczylo/claude-mnemonic/internal/db/gorm" + "github.com/lukaszraczylo/claude-mnemonic/pkg/models" +) + +// testSetup creates a full maintenance service with a real temporary database. +func testSetup(t *testing.T, cfg *config.Config) (*Service, *gormdb.Store, *gormdb.ObservationStore, *gormdb.PromptStore, func()) { + t.Helper() + + tmpDir, err := os.MkdirTemp("", "maintenance_test_*") + require.NoError(t, err, "create temp dir") + + dbPath := filepath.Join(tmpDir, "test.db") + storeCfg := gormdb.Config{ + Path: dbPath, + MaxConns: 4, + LogLevel: logger.Silent, + } + + store, err := gormdb.NewStore(storeCfg) + if err != nil { + os.RemoveAll(tmpDir) + t.Fatalf("NewStore failed: %v", err) + } + + observationStore := gormdb.NewObservationStore(store, nil, nil, nil) + summaryStore := gormdb.NewSummaryStore(store) + promptStore := gormdb.NewPromptStore(store, nil) + + svc := NewService(store, observationStore, summaryStore, promptStore, nil, cfg, zerolog.Nop()) + + cleanup := func() { + store.Close() + os.RemoveAll(tmpDir) + } + + return svc, store, observationStore, promptStore, cleanup +} + +// defaultCfg returns a maintenance-enabled config for tests. +func defaultCfg() *config.Config { + cfg := config.Default() + cfg.MaintenanceEnabled = true + cfg.MaintenanceIntervalHours = 1 + cfg.ObservationRetentionDays = 0 + cfg.CleanupStaleObservations = false + return cfg +} + +// insertObservation is a helper that inserts an observation and returns its ID. +func insertObservation(t *testing.T, obsStore *gormdb.ObservationStore, session, project string, seq int) int64 { + t.Helper() + obs := &models.ParsedObservation{ + Type: models.ObsTypeDiscovery, + Title: "test observation", + } + id, _, err := obsStore.StoreObservation(context.Background(), session, project, obs, seq, 10) + require.NoError(t, err) + return id +} + +// ---- NewService ---- + +func TestNewService_ReturnsNonNilService(t *testing.T) { + svc, _, _, _, cleanup := testSetup(t, defaultCfg()) + defer cleanup() + + assert.NotNil(t, svc) +} + +func TestNewService_InitializesChannels(t *testing.T) { + svc, _, _, _, cleanup := testSetup(t, defaultCfg()) + defer cleanup() + + // stopCh and doneCh must be non-nil so Stop/Wait don't panic. + assert.NotNil(t, svc.stopCh) + assert.NotNil(t, svc.doneCh) +} + +// ---- Stats ---- + +func TestStats_DefaultValues(t *testing.T) { + svc, _, _, _, cleanup := testSetup(t, defaultCfg()) + defer cleanup() + + stats := svc.Stats() + + assert.Equal(t, true, stats["enabled"]) + assert.Equal(t, 1, stats["interval_hours"]) + assert.Equal(t, 0, stats["retention_days"]) + assert.Equal(t, false, stats["cleanup_stale"]) + assert.Equal(t, int64(0), stats["total_cleaned_obs"]) + assert.Equal(t, int64(0), stats["total_optimizes"]) + assert.Equal(t, false, stats["running"]) +} + +func TestStats_ReflectsConfigFields(t *testing.T) { + tests := []struct { + name string + cfg *config.Config + wantEnabled bool + wantHours int + wantDays int + wantStale bool + }{ + { + name: "maintenance disabled", + cfg: func() *config.Config { + c := defaultCfg() + c.MaintenanceEnabled = false + return c + }(), + wantEnabled: false, + wantHours: 1, + wantDays: 0, + wantStale: false, + }, + { + name: "retention and stale cleanup enabled", + cfg: func() *config.Config { + c := defaultCfg() + c.ObservationRetentionDays = 30 + c.CleanupStaleObservations = true + c.MaintenanceIntervalHours = 12 + return c + }(), + wantEnabled: true, + wantHours: 12, + wantDays: 30, + wantStale: true, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + svc, _, _, _, cleanup := testSetup(t, tt.cfg) + defer cleanup() + + stats := svc.Stats() + assert.Equal(t, tt.wantEnabled, stats["enabled"]) + assert.Equal(t, tt.wantHours, stats["interval_hours"]) + assert.Equal(t, tt.wantDays, stats["retention_days"]) + assert.Equal(t, tt.wantStale, stats["cleanup_stale"]) + }) + } +} + +// ---- Stop (idempotency) ---- + +func TestStop_WhenNotRunning_DoesNotPanic(t *testing.T) { + svc, _, _, _, cleanup := testSetup(t, defaultCfg()) + defer cleanup() + + // Service was never started — Stop must be a no-op. + assert.NotPanics(t, func() { svc.Stop() }) +} + +func TestStop_CalledTwice_DoesNotPanic(t *testing.T) { + // Start with maintenance disabled so Start() returns immediately. + cfg := defaultCfg() + cfg.MaintenanceEnabled = false + + svc, _, _, _, cleanup := testSetup(t, cfg) + defer cleanup() + + ctx := context.Background() + go svc.Start(ctx) + svc.Wait() // drains doneCh after early return + + // Stop after Wait — must not panic or double-close. + assert.NotPanics(t, func() { svc.Stop() }) +} + +// ---- Start / running flag ---- + +func TestStart_MaintenanceDisabled_ExitsImmediately(t *testing.T) { + cfg := defaultCfg() + cfg.MaintenanceEnabled = false + + svc, _, _, _, cleanup := testSetup(t, cfg) + defer cleanup() + + ctx := context.Background() + go svc.Start(ctx) + + done := make(chan struct{}) + go func() { + svc.Wait() + close(done) + }() + + select { + case <-done: + // Good — returned without blocking. + case <-time.After(2 * time.Second): + t.Fatal("Start() did not return promptly when maintenance is disabled") + } + + stats := svc.Stats() + assert.Equal(t, false, stats["running"]) +} + +func TestStart_StopSignal_ExitsCleanly(t *testing.T) { + // Start() with maintenance disabled exits immediately — verified in + // TestStart_MaintenanceDisabled_ExitsImmediately. + // + // The ticker/stop path is hard to test because Start() always sleeps + // 5 minutes before entering the loop. We verify instead that Stop() + // on an already-stopped service is safe and that the doneCh is closed + // after exit (i.e., Wait() returns). + cfg := defaultCfg() + cfg.MaintenanceEnabled = false + + svc, _, _, _, cleanup := testSetup(t, cfg) + defer cleanup() + + go svc.Start(context.Background()) + + done := make(chan struct{}) + go func() { + svc.Wait() + close(done) + }() + + select { + case <-done: + // doneCh was closed — Start exited and Wait returned. + case <-time.After(2 * time.Second): + t.Fatal("Wait() did not return after Start exited") + } + + // Stop after Wait must be a no-op and must not panic. + assert.NotPanics(t, func() { svc.Stop() }) +} + +func TestStart_DoubleStart_SecondCallIsNoOp(t *testing.T) { + cfg := defaultCfg() + cfg.MaintenanceEnabled = false // exits immediately + + svc, _, _, _, cleanup := testSetup(t, cfg) + defer cleanup() + + ctx := context.Background() + + // First call. + go svc.Start(ctx) + svc.Wait() + + // Second call on the same (exhausted) svc should be a no-op and not panic. + assert.NotPanics(t, func() { + // svc.running is now false again — but doneCh is already closed. + // A second Start would attempt to close doneCh again which would panic + // if the running guard is missing. Verify the guard works. + svc.mu.Lock() + running := svc.running + svc.mu.Unlock() + assert.False(t, running) + }) +} + +// ---- RunNow ---- + +func TestRunNow_UpdatesLastRunTime(t *testing.T) { + cfg := defaultCfg() + cfg.ObservationRetentionDays = 0 + cfg.CleanupStaleObservations = false + + svc, _, _, _, cleanup := testSetup(t, cfg) + defer cleanup() + + before := time.Now() + svc.RunNow(context.Background()) + + // Allow async goroutine to finish. + time.Sleep(200 * time.Millisecond) + + svc.mu.Lock() + lastRun := svc.lastRunTime + svc.mu.Unlock() + + assert.True(t, lastRun.After(before) || lastRun.Equal(before), + "lastRunTime should be updated after RunNow") +} + +func TestRunNow_IncrementsOptimizeCounter(t *testing.T) { + svc, _, _, _, cleanup := testSetup(t, defaultCfg()) + defer cleanup() + + svc.RunNow(context.Background()) + time.Sleep(300 * time.Millisecond) + + svc.mu.Lock() + optimizes := svc.totalOptimizeRun + svc.mu.Unlock() + + assert.Equal(t, int64(1), optimizes) +} + +func TestRunNow_StatsTotalOptimizesReflected(t *testing.T) { + svc, _, _, _, cleanup := testSetup(t, defaultCfg()) + defer cleanup() + + svc.RunNow(context.Background()) + time.Sleep(300 * time.Millisecond) + + stats := svc.Stats() + assert.Equal(t, int64(1), stats["total_optimizes"]) +} + +// ---- cleanupOldObservations (via RunNow) ---- + +func TestRunNow_RetentionDaysZero_NothingDeleted(t *testing.T) { + cfg := defaultCfg() + cfg.ObservationRetentionDays = 0 + + svc, _, obsStore, _, cleanup := testSetup(t, cfg) + defer cleanup() + + // Insert observations. + for i := 0; i < 5; i++ { + insertObservation(t, obsStore, "session-1", "proj", i) + } + + svc.RunNow(context.Background()) + time.Sleep(300 * time.Millisecond) + + remaining, err := obsStore.GetRecentObservations(context.Background(), "proj", 20) + require.NoError(t, err) + assert.Equal(t, 5, len(remaining), "nothing should be deleted when retention_days = 0") + + svc.mu.Lock() + cleaned := svc.totalCleanedObs + svc.mu.Unlock() + assert.Equal(t, int64(0), cleaned) +} + +func TestRunNow_RetentionDays_DeletesExpiredObservations(t *testing.T) { + cfg := defaultCfg() + cfg.ObservationRetentionDays = 1 // keep only last 1 day + + svc, store, _, _, cleanup := testSetup(t, cfg) + defer cleanup() + + ctx := context.Background() + + // Insert an observation and back-date it to 2 days ago. + obs := &gormdb.Observation{ + SDKSessionID: "old-session", + Project: "proj", + Type: models.ObsTypeDiscovery, + CreatedAt: "2000-01-01T00:00:00Z", + CreatedAtEpoch: time.Now().AddDate(0, 0, -2).Unix(), + Scope: models.ScopeProject, + ImportanceScore: 1.0, + } + require.NoError(t, store.GetDB().WithContext(ctx).Create(obs).Error) + + // Insert a recent observation (should survive). + recentObs := &gormdb.Observation{ + SDKSessionID: "new-session", + Project: "proj", + Type: models.ObsTypeDiscovery, + CreatedAt: time.Now().Format(time.RFC3339), + CreatedAtEpoch: time.Now().Unix(), + Scope: models.ScopeProject, + ImportanceScore: 1.0, + } + require.NoError(t, store.GetDB().WithContext(ctx).Create(recentObs).Error) + + svc.RunNow(ctx) + time.Sleep(300 * time.Millisecond) + + // Only the recent observation should remain. + var count int64 + store.GetDB().WithContext(ctx).Model(&gormdb.Observation{}).Count(&count) + assert.Equal(t, int64(1), count, "expired observation should have been deleted") + + svc.mu.Lock() + cleaned := svc.totalCleanedObs + svc.mu.Unlock() + assert.Equal(t, int64(1), cleaned) +} + +func TestRunNow_RetentionDays_VectorCleanupCalled(t *testing.T) { + cfg := defaultCfg() + cfg.ObservationRetentionDays = 1 + + tmpDir, err := os.MkdirTemp("", "maintenance_vec_test_*") + require.NoError(t, err) + defer os.RemoveAll(tmpDir) + + store, err := gormdb.NewStore(gormdb.Config{ + Path: filepath.Join(tmpDir, "test.db"), + MaxConns: 4, + LogLevel: logger.Silent, + }) + require.NoError(t, err) + defer store.Close() + + observationStore := gormdb.NewObservationStore(store, nil, nil, nil) + summaryStore := gormdb.NewSummaryStore(store) + promptStore := gormdb.NewPromptStore(store, nil) + + var mu sync.Mutex + var capturedIDs []int64 + + vectorCleanupFn := func(_ context.Context, ids []int64) { + mu.Lock() + defer mu.Unlock() + capturedIDs = append(capturedIDs, ids...) + } + + svc := NewService(store, observationStore, summaryStore, promptStore, vectorCleanupFn, cfg, zerolog.Nop()) + + ctx := context.Background() + + // Insert an expired observation directly. + obs := &gormdb.Observation{ + SDKSessionID: "session-x", + Project: "proj", + Type: models.ObsTypeDiscovery, + CreatedAt: "2000-01-01T00:00:00Z", + CreatedAtEpoch: time.Now().AddDate(0, 0, -2).Unix(), + Scope: models.ScopeProject, + ImportanceScore: 1.0, + } + require.NoError(t, store.GetDB().WithContext(ctx).Create(obs).Error) + + svc.RunNow(ctx) + time.Sleep(300 * time.Millisecond) + + mu.Lock() + ids := capturedIDs + mu.Unlock() + + assert.NotEmpty(t, ids, "vector cleanup callback must be called with deleted IDs") + assert.Contains(t, ids, obs.ID) +} + +// ---- cleanupStaleObservations (via RunNow) ---- + +func TestRunNow_CleanupStale_DeletesSupersededObservations(t *testing.T) { + cfg := defaultCfg() + cfg.CleanupStaleObservations = true + cfg.ObservationRetentionDays = 0 + + svc, store, obsStore, _, cleanup := testSetup(t, cfg) + defer cleanup() + + ctx := context.Background() + + // Insert an active observation. + activeID := insertObservation(t, obsStore, "session-1", "proj", 1) + + // Insert and mark a stale observation. + staleID := insertObservation(t, obsStore, "session-1", "proj", 2) + require.NoError(t, store.GetDB().WithContext(ctx). + Model(&gormdb.Observation{}). + Where("id = ?", staleID). + Update("is_superseded", 1).Error) + + svc.RunNow(ctx) + time.Sleep(300 * time.Millisecond) + + // Active observation must survive. + var activeCount int64 + store.GetDB().WithContext(ctx).Model(&gormdb.Observation{}).Where("id = ?", activeID).Count(&activeCount) + assert.Equal(t, int64(1), activeCount, "active observation must not be deleted") + + // Stale observation must be gone. + var staleCount int64 + store.GetDB().WithContext(ctx).Model(&gormdb.Observation{}).Where("id = ?", staleID).Count(&staleCount) + assert.Equal(t, int64(0), staleCount, "stale observation must be deleted") +} + +func TestRunNow_CleanupStale_DisabledLeavesStaleObservations(t *testing.T) { + cfg := defaultCfg() + cfg.CleanupStaleObservations = false + + svc, store, obsStore, _, cleanup := testSetup(t, cfg) + defer cleanup() + + ctx := context.Background() + + staleID := insertObservation(t, obsStore, "session-1", "proj", 1) + require.NoError(t, store.GetDB().WithContext(ctx). + Model(&gormdb.Observation{}). + Where("id = ?", staleID). + Update("is_superseded", 1).Error) + + svc.RunNow(ctx) + time.Sleep(300 * time.Millisecond) + + var count int64 + store.GetDB().WithContext(ctx).Model(&gormdb.Observation{}).Where("id = ?", staleID).Count(&count) + assert.Equal(t, int64(1), count, "stale observation must survive when cleanup_stale is false") +} + +func TestRunNow_CleanupStale_NoStaleRows_NothingChanged(t *testing.T) { + cfg := defaultCfg() + cfg.CleanupStaleObservations = true + + svc, _, obsStore, _, cleanup := testSetup(t, cfg) + defer cleanup() + + ctx := context.Background() + + // Only active observations. + for i := 0; i < 3; i++ { + insertObservation(t, obsStore, "session-1", "proj", i) + } + + svc.RunNow(ctx) + time.Sleep(300 * time.Millisecond) + + remaining, err := obsStore.GetRecentObservations(ctx, "proj", 20) + require.NoError(t, err) + assert.Equal(t, 3, len(remaining)) +} + +// ---- cleanupOldPrompts (via RunNow) ---- + +func TestRunNow_CleanupOldPrompts_DeletesExpiredPrompts(t *testing.T) { + svc, store, _, _, cleanup := testSetup(t, defaultCfg()) + defer cleanup() + + ctx := context.Background() + + // Insert a prompt with an old epoch (31 days ago). + oldPrompt := &gormdb.UserPrompt{ + ClaudeSessionID: "session-old", + PromptText: "old prompt", + PromptNumber: 1, + CreatedAt: "2000-01-01T00:00:00Z", + CreatedAtEpoch: time.Now().AddDate(0, 0, -31).Unix(), + } + require.NoError(t, store.GetDB().WithContext(ctx).Create(oldPrompt).Error) + + // Insert a recent prompt (should survive). + recentPrompt := &gormdb.UserPrompt{ + ClaudeSessionID: "session-new", + PromptText: "recent prompt", + PromptNumber: 1, + CreatedAt: time.Now().Format(time.RFC3339), + CreatedAtEpoch: time.Now().Unix(), + } + require.NoError(t, store.GetDB().WithContext(ctx).Create(recentPrompt).Error) + + svc.RunNow(ctx) + time.Sleep(300 * time.Millisecond) + + var count int64 + store.GetDB().WithContext(ctx).Model(&gormdb.UserPrompt{}).Count(&count) + assert.Equal(t, int64(1), count, "only the recent prompt should survive") +} + +func TestRunNow_CleanupOldPrompts_NothingExpired_AllSurvive(t *testing.T) { + svc, store, _, promptStore, cleanup := testSetup(t, defaultCfg()) + defer cleanup() + + ctx := context.Background() + + for i := 1; i <= 5; i++ { + _, err := promptStore.SaveUserPromptWithMatches(ctx, "session-1", i, "prompt", 1) + require.NoError(t, err) + } + + svc.RunNow(ctx) + time.Sleep(300 * time.Millisecond) + + var count int64 + store.GetDB().WithContext(ctx).Model(&gormdb.UserPrompt{}).Count(&count) + assert.Equal(t, int64(5), count, "no prompts should be deleted when none are expired") +} + +// ---- Stats race safety ---- + +func TestStats_ConcurrentAccess_NoRace(t *testing.T) { + svc, _, _, _, cleanup := testSetup(t, defaultCfg()) + defer cleanup() + + var wg sync.WaitGroup + for i := 0; i < 20; i++ { + wg.Add(1) + go func() { + defer wg.Done() + _ = svc.Stats() + }() + } + wg.Wait() +} + +// ---- RunNow concurrent safety ---- + +func TestRunNow_ConcurrentCalls_NoRace(t *testing.T) { + svc, _, _, _, cleanup := testSetup(t, defaultCfg()) + defer cleanup() + + ctx := context.Background() + var wg sync.WaitGroup + for i := 0; i < 5; i++ { + wg.Add(1) + go func() { + defer wg.Done() + svc.RunNow(ctx) + }() + } + wg.Wait() + time.Sleep(500 * time.Millisecond) +} + +// ---- lastRunDuration is populated ---- + +func TestRunNow_LastRunDuration_IsPopulated(t *testing.T) { + svc, _, _, _, cleanup := testSetup(t, defaultCfg()) + defer cleanup() + + svc.RunNow(context.Background()) + time.Sleep(300 * time.Millisecond) + + svc.mu.Lock() + dur := svc.lastRunDuration + svc.mu.Unlock() + + assert.Greater(t, int64(dur), int64(0), "lastRunDuration should be set after a maintenance run") +} + +func TestStats_LastDurationMs_IsPopulated(t *testing.T) { + svc, _, _, _, cleanup := testSetup(t, defaultCfg()) + defer cleanup() + + svc.RunNow(context.Background()) + time.Sleep(300 * time.Millisecond) + + stats := svc.Stats() + // The value is int64 milliseconds; it might be 0 for very fast runs — just verify the key exists. + _, ok := stats["last_duration_ms"] + assert.True(t, ok, "stats must contain last_duration_ms key") +} + +// ---- Batch deletion boundary ---- + +func TestRunNow_RetentionDays_BatchDeletion_MoreThan100Rows(t *testing.T) { + cfg := defaultCfg() + cfg.ObservationRetentionDays = 1 + + svc, store, _, _, cleanup := testSetup(t, cfg) + defer cleanup() + + ctx := context.Background() + + // Insert 150 expired observations (forces 2 batches of 100). + for i := 0; i < 150; i++ { + obs := &gormdb.Observation{ + SDKSessionID: "session-old", + Project: "proj", + Type: models.ObsTypeDiscovery, + CreatedAt: "2000-01-01T00:00:00Z", + CreatedAtEpoch: time.Now().AddDate(0, 0, -2).Unix(), + Scope: models.ScopeProject, + ImportanceScore: 1.0, + } + require.NoError(t, store.GetDB().WithContext(ctx).Create(obs).Error) + } + + svc.RunNow(ctx) + time.Sleep(500 * time.Millisecond) + + var remaining int64 + store.GetDB().WithContext(ctx).Model(&gormdb.Observation{}).Count(&remaining) + assert.Equal(t, int64(0), remaining, "all 150 expired observations should be deleted in batches") + + svc.mu.Lock() + cleaned := svc.totalCleanedObs + svc.mu.Unlock() + assert.Equal(t, int64(150), cleaned) +} + +func TestRunNow_CleanupStale_BatchDeletion_MoreThan100Rows(t *testing.T) { + cfg := defaultCfg() + cfg.CleanupStaleObservations = true + + svc, store, _, _, cleanup := testSetup(t, cfg) + defer cleanup() + + ctx := context.Background() + + // Insert 120 superseded observations. + for i := 0; i < 120; i++ { + obs := &gormdb.Observation{ + SDKSessionID: "session-stale", + Project: "proj", + Type: models.ObsTypeDiscovery, + CreatedAt: time.Now().Format(time.RFC3339), + CreatedAtEpoch: time.Now().Unix(), + Scope: models.ScopeProject, + ImportanceScore: 1.0, + IsSuperseded: 1, + } + require.NoError(t, store.GetDB().WithContext(ctx).Create(obs).Error) + } + + svc.RunNow(ctx) + time.Sleep(500 * time.Millisecond) + + var remaining int64 + store.GetDB().WithContext(ctx).Model(&gormdb.Observation{}).Where("is_superseded = ?", 1).Count(&remaining) + assert.Equal(t, int64(0), remaining, "all 120 stale observations should be deleted in batches") +} diff --git a/internal/update/update.go b/internal/update/update.go index 096d82a..342cc13 100644 --- a/internal/update/update.go +++ b/internal/update/update.go @@ -532,13 +532,20 @@ func (u *Updater) replaceBinaries(extractDir string) error { func (u *Updater) getInstallDirectories() []string { dirs := []string{u.installDir} - // Also check cache directories where Claude Code looks for plugins home, err := os.UserHomeDir() if err != nil { return dirs } - // Look for cache directories under ~/.claude/plugins/cache/claude-mnemonic/claude-mnemonic/ + // Primary stable binary location (survives Claude Code updates) + stableBin := filepath.Join(home, ".claude-mnemonic", "bin") + if stableBin != u.installDir { + if _, err := os.Stat(stableBin); err == nil { + dirs = append(dirs, stableBin) + } + } + + // Also check cache directories where Claude Code looks for plugins cacheBase := filepath.Join(home, ".claude/plugins/cache/claude-mnemonic/claude-mnemonic") entries, err := os.ReadDir(cacheBase) if err != nil { @@ -548,7 +555,6 @@ func (u *Updater) getInstallDirectories() []string { for _, entry := range entries { if entry.IsDir() { cacheDir := filepath.Join(cacheBase, entry.Name()) - // Only add if it's different from installDir and contains a worker binary if cacheDir != u.installDir { workerPath := filepath.Join(cacheDir, "worker") if _, err := os.Stat(workerPath); err == nil { diff --git a/internal/update/update_test.go b/internal/update/update_test.go new file mode 100644 index 0000000..ab97a11 --- /dev/null +++ b/internal/update/update_test.go @@ -0,0 +1,1078 @@ +//go:build fts5 + +// Package update provides self-update functionality for claude-mnemonic. +package update + +import ( + "archive/tar" + "bytes" + "compress/gzip" + "context" + "crypto/sha256" + "encoding/hex" + "encoding/json" + "fmt" + "io" + "net/http" + "net/http/httptest" + "os" + "path/filepath" + "runtime" + "strings" + "testing" + "time" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +// --------------------------------------------------------------------------- +// isNewerVersion +// --------------------------------------------------------------------------- + +func TestIsNewerVersion(t *testing.T) { + tests := []struct { + name string + latest string + current string + want bool + }{ + { + name: "newer_major", + latest: "2.0.0", + current: "1.0.0", + want: true, + }, + { + name: "newer_minor", + latest: "1.1.0", + current: "1.0.0", + want: true, + }, + { + name: "newer_patch", + latest: "1.0.1", + current: "1.0.0", + want: true, + }, + { + name: "same_version", + latest: "1.0.0", + current: "1.0.0", + want: false, + }, + { + name: "older_major", + latest: "0.9.9", + current: "1.0.0", + want: false, + }, + { + name: "older_minor", + latest: "1.0.0", + current: "1.1.0", + want: false, + }, + { + name: "older_patch", + latest: "1.0.0", + current: "1.0.1", + want: false, + }, + { + name: "v_prefix_latest", + latest: "v1.2.0", + current: "1.1.0", + want: true, + }, + { + name: "v_prefix_current", + latest: "1.2.0", + current: "v1.1.0", + want: true, + }, + { + name: "v_prefix_both", + latest: "v1.2.0", + current: "v1.2.0", + want: false, + }, + { + name: "dev_build_current_same_base", + latest: "0.3.5", + current: "0.3.5-2-gca711a8-dirty", + want: false, + }, + { + name: "dev_build_current_older_base", + latest: "0.3.6", + current: "0.3.5-2-gca711a8-dirty", + want: true, + }, + { + name: "dev_build_current_newer_base", + latest: "0.3.4", + current: "0.3.5-2-gca711a8-dirty", + want: false, + }, + { + name: "longer_latest_semver", + latest: "1.0.0.1", + current: "1.0.0", + want: true, + }, + { + name: "longer_current_semver", + latest: "1.0.0", + current: "1.0.0.1", + want: false, + }, + { + name: "zero_versions", + latest: "0.0.0", + current: "0.0.0", + want: false, + }, + { + name: "major_rollback", + latest: "1.0.0", + current: "2.0.0", + want: false, + }, + } + + for _, tc := range tests { + t.Run(tc.name, func(t *testing.T) { + got := isNewerVersion(tc.latest, tc.current) + assert.Equal(t, tc.want, got, + "isNewerVersion(%q, %q)", tc.latest, tc.current) + }) + } +} + +// --------------------------------------------------------------------------- +// getPlatform +// --------------------------------------------------------------------------- + +func TestGetPlatform(t *testing.T) { + got := getPlatform() + expected := fmt.Sprintf("%s_%s", runtime.GOOS, runtime.GOARCH) + assert.Equal(t, expected, got) + assert.Contains(t, got, "_", "platform string must contain underscore separator") + assert.NotEmpty(t, got) +} + +func TestGetPlatform_ContainsOSAndArch(t *testing.T) { + got := getPlatform() + parts := strings.SplitN(got, "_", 2) + require.Len(t, parts, 2, "platform must have exactly two parts separated by underscore") + assert.Equal(t, runtime.GOOS, parts[0]) + assert.Equal(t, runtime.GOARCH, parts[1]) +} + +// --------------------------------------------------------------------------- +// GetManualUpdateCommand +// --------------------------------------------------------------------------- + +func TestGetManualUpdateCommand(t *testing.T) { + tests := []struct { + name string + version string + wantContains []string + wantNotContains []string + }{ + { + name: "empty_version_returns_latest", + version: "", + wantContains: []string{"curl -sSL", InstallScriptURL, "| bash"}, + wantNotContains: []string{"bash -s --"}, + }, + { + name: "specific_version_appended", + version: "v1.2.3", + wantContains: []string{"curl -sSL", InstallScriptURL, "| bash -s --", "v1.2.3"}, + }, + { + name: "version_without_v_prefix", + version: "1.2.3", + wantContains: []string{"1.2.3"}, + }, + } + + for _, tc := range tests { + t.Run(tc.name, func(t *testing.T) { + got := GetManualUpdateCommand(tc.version) + for _, want := range tc.wantContains { + assert.Contains(t, got, want) + } + for _, notWant := range tc.wantNotContains { + assert.NotContains(t, got, notWant) + } + }) + } +} + +// --------------------------------------------------------------------------- +// getInstallDirectories +// --------------------------------------------------------------------------- + +func TestGetInstallDirectories_AlwaysContainsInstallDir(t *testing.T) { + dir := t.TempDir() + u := New("1.0.0", dir) + dirs := u.getInstallDirectories() + assert.Contains(t, dirs, dir) +} + +func TestGetInstallDirectories_NoDuplicateWhenInstallDirIsStableBin(t *testing.T) { + // Create a fake home with stableBin == installDir + home := t.TempDir() + t.Setenv("HOME", home) + + stableBin := filepath.Join(home, ".claude-mnemonic", "bin") + require.NoError(t, os.MkdirAll(stableBin, 0750)) + + u := New("1.0.0", stableBin) + dirs := u.getInstallDirectories() + + count := 0 + for _, d := range dirs { + if d == stableBin { + count++ + } + } + assert.Equal(t, 1, count, "stableBin should appear exactly once when it equals installDir") +} + +func TestGetInstallDirectories_IncludesStableBinWhenExists(t *testing.T) { + home := t.TempDir() + t.Setenv("HOME", home) + + installDir := filepath.Join(home, "some-other-dir") + require.NoError(t, os.MkdirAll(installDir, 0750)) + + stableBin := filepath.Join(home, ".claude-mnemonic", "bin") + require.NoError(t, os.MkdirAll(stableBin, 0750)) + + u := New("1.0.0", installDir) + dirs := u.getInstallDirectories() + + assert.Contains(t, dirs, installDir) + assert.Contains(t, dirs, stableBin) +} + +func TestGetInstallDirectories_SkipsStableBinWhenAbsent(t *testing.T) { + home := t.TempDir() + t.Setenv("HOME", home) + + installDir := filepath.Join(home, "install-dir") + require.NoError(t, os.MkdirAll(installDir, 0750)) + // stableBin NOT created + + u := New("1.0.0", installDir) + dirs := u.getInstallDirectories() + + stableBin := filepath.Join(home, ".claude-mnemonic", "bin") + assert.NotContains(t, dirs, stableBin) +} + +func TestGetInstallDirectories_IncludesCacheDirsWithWorkerBinary(t *testing.T) { + home := t.TempDir() + t.Setenv("HOME", home) + + installDir := filepath.Join(home, "install-dir") + require.NoError(t, os.MkdirAll(installDir, 0750)) + + // Create a fake cache dir with a worker binary + cacheBase := filepath.Join(home, ".claude/plugins/cache/claude-mnemonic/claude-mnemonic/v1.2.3") + require.NoError(t, os.MkdirAll(cacheBase, 0750)) + require.NoError(t, os.WriteFile(filepath.Join(cacheBase, "worker"), []byte("fake"), 0755)) + + // Cache dir without worker — should NOT be included + cacheMissing := filepath.Join(home, ".claude/plugins/cache/claude-mnemonic/claude-mnemonic/v1.1.0") + require.NoError(t, os.MkdirAll(cacheMissing, 0750)) + + u := New("1.0.0", installDir) + dirs := u.getInstallDirectories() + + assert.Contains(t, dirs, cacheBase) + assert.NotContains(t, dirs, cacheMissing) +} + +// --------------------------------------------------------------------------- +// verifyChecksum +// --------------------------------------------------------------------------- + +func makeTarGzFile(t *testing.T, dir, content string) (path string, checksum string) { + t.Helper() + archivePath := filepath.Join(dir, "release.tar.gz") + + var buf bytes.Buffer + gw := gzip.NewWriter(&buf) + tw := tar.NewWriter(gw) + + body := []byte(content) + hdr := &tar.Header{ + Name: "worker", + Mode: 0755, + Size: int64(len(body)), + Typeflag: tar.TypeReg, + } + require.NoError(t, tw.WriteHeader(hdr)) + _, err := tw.Write(body) + require.NoError(t, err) + require.NoError(t, tw.Close()) + require.NoError(t, gw.Close()) + + require.NoError(t, os.WriteFile(archivePath, buf.Bytes(), 0644)) + + h := sha256.New() + h.Write(buf.Bytes()) + return archivePath, hex.EncodeToString(h.Sum(nil)) +} + +func TestVerifyChecksum_ValidChecksum(t *testing.T) { + dir := t.TempDir() + u := New("1.0.0", dir) + + version := "1.2.3" + platform := getPlatform() + archiveName := fmt.Sprintf("claude-mnemonic_%s_%s.tar.gz", version, platform) + + archivePath, checksum := makeTarGzFile(t, dir, "binary content") + + checksumsContent := fmt.Sprintf("%s %s\n", checksum, archiveName) + checksumsPath := filepath.Join(dir, "checksums.txt") + require.NoError(t, os.WriteFile(checksumsPath, []byte(checksumsContent), 0644)) + + err := u.verifyChecksum(archivePath, checksumsPath, version) + assert.NoError(t, err) +} + +func TestVerifyChecksum_WrongChecksum(t *testing.T) { + dir := t.TempDir() + u := New("1.0.0", dir) + + version := "1.2.3" + platform := getPlatform() + archiveName := fmt.Sprintf("claude-mnemonic_%s_%s.tar.gz", version, platform) + + archivePath, _ := makeTarGzFile(t, dir, "binary content") + + // Use a bogus checksum + checksumsContent := fmt.Sprintf("%s %s\n", strings.Repeat("a", 64), archiveName) + checksumsPath := filepath.Join(dir, "checksums.txt") + require.NoError(t, os.WriteFile(checksumsPath, []byte(checksumsContent), 0644)) + + err := u.verifyChecksum(archivePath, checksumsPath, version) + require.Error(t, err) + assert.Contains(t, err.Error(), "checksum mismatch") +} + +func TestVerifyChecksum_MissingEntry(t *testing.T) { + dir := t.TempDir() + u := New("1.0.0", dir) + + archivePath, _ := makeTarGzFile(t, dir, "binary content") + + // Checksums file has entry for a different platform only + checksumsContent := "aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa claude-mnemonic_1.2.3_other_platform.tar.gz\n" + checksumsPath := filepath.Join(dir, "checksums.txt") + require.NoError(t, os.WriteFile(checksumsPath, []byte(checksumsContent), 0644)) + + err := u.verifyChecksum(archivePath, checksumsPath, "1.2.3") + require.Error(t, err) + assert.Contains(t, err.Error(), "no checksum found") +} + +func TestVerifyChecksum_MissingArchiveFile(t *testing.T) { + dir := t.TempDir() + u := New("1.0.0", dir) + + checksumsPath := filepath.Join(dir, "checksums.txt") + require.NoError(t, os.WriteFile(checksumsPath, []byte("dummy content"), 0644)) + + err := u.verifyChecksum(filepath.Join(dir, "nonexistent.tar.gz"), checksumsPath, "1.2.3") + require.Error(t, err) +} + +func TestVerifyChecksum_MissingChecksumsFile(t *testing.T) { + dir := t.TempDir() + u := New("1.0.0", dir) + + archivePath, _ := makeTarGzFile(t, dir, "binary content") + + err := u.verifyChecksum(archivePath, filepath.Join(dir, "nonexistent.txt"), "1.2.3") + require.Error(t, err) +} + +func TestVerifyChecksum_EmptyChecksumsFile(t *testing.T) { + dir := t.TempDir() + u := New("1.0.0", dir) + + archivePath, _ := makeTarGzFile(t, dir, "binary content") + checksumsPath := filepath.Join(dir, "checksums.txt") + require.NoError(t, os.WriteFile(checksumsPath, []byte(""), 0644)) + + err := u.verifyChecksum(archivePath, checksumsPath, "1.2.3") + require.Error(t, err) + assert.Contains(t, err.Error(), "no checksum found") +} + +func TestVerifyChecksum_MultipleEntriesPicksCorrect(t *testing.T) { + dir := t.TempDir() + u := New("1.0.0", dir) + + version := "1.2.3" + platform := getPlatform() + archiveName := fmt.Sprintf("claude-mnemonic_%s_%s.tar.gz", version, platform) + + archivePath, correctChecksum := makeTarGzFile(t, dir, "binary content") + + // File with multiple entries including the correct one + checksumsContent := strings.Join([]string{ + fmt.Sprintf("%s claude-mnemonic_%s_linux_arm64.tar.gz", strings.Repeat("b", 64), version), + fmt.Sprintf("%s claude-mnemonic_%s_windows_amd64.tar.gz", strings.Repeat("c", 64), version), + fmt.Sprintf("%s %s", correctChecksum, archiveName), + }, "\n") + "\n" + + checksumsPath := filepath.Join(dir, "checksums.txt") + require.NoError(t, os.WriteFile(checksumsPath, []byte(checksumsContent), 0644)) + + err := u.verifyChecksum(archivePath, checksumsPath, version) + assert.NoError(t, err) +} + +// --------------------------------------------------------------------------- +// extractTarGz +// --------------------------------------------------------------------------- + +func makeTarGzArchive(t *testing.T, files map[string]string) []byte { + t.Helper() + var buf bytes.Buffer + gw := gzip.NewWriter(&buf) + tw := tar.NewWriter(gw) + + for name, content := range files { + body := []byte(content) + hdr := &tar.Header{ + Name: name, + Mode: 0755, + Size: int64(len(body)), + Typeflag: tar.TypeReg, + } + require.NoError(t, tw.WriteHeader(hdr)) + _, err := tw.Write(body) + require.NoError(t, err) + } + + require.NoError(t, tw.Close()) + require.NoError(t, gw.Close()) + return buf.Bytes() +} + +func makeTarGzWithDir(t *testing.T, dirName string, files map[string]string) []byte { + t.Helper() + var buf bytes.Buffer + gw := gzip.NewWriter(&buf) + tw := tar.NewWriter(gw) + + // Add directory entry + hdr := &tar.Header{ + Name: dirName + "/", + Typeflag: tar.TypeDir, + Mode: 0750, + } + require.NoError(t, tw.WriteHeader(hdr)) + + for name, content := range files { + body := []byte(content) + filePath := dirName + "/" + name + fhdr := &tar.Header{ + Name: filePath, + Mode: 0755, + Size: int64(len(body)), + Typeflag: tar.TypeReg, + } + require.NoError(t, tw.WriteHeader(fhdr)) + _, err := tw.Write(body) + require.NoError(t, err) + } + + require.NoError(t, tw.Close()) + require.NoError(t, gw.Close()) + return buf.Bytes() +} + +func TestExtractTarGz_ExtractsFiles(t *testing.T) { + dir := t.TempDir() + u := New("1.0.0", dir) + + files := map[string]string{ + "worker": "worker binary content", + "mcp-server": "mcp binary content", + } + archiveBytes := makeTarGzArchive(t, files) + + archivePath := filepath.Join(dir, "archive.tar.gz") + require.NoError(t, os.WriteFile(archivePath, archiveBytes, 0644)) + + destDir := filepath.Join(dir, "extracted") + err := u.extractTarGz(archivePath, destDir) + require.NoError(t, err) + + for name, expectedContent := range files { + data, err := os.ReadFile(filepath.Join(destDir, name)) + require.NoError(t, err) + assert.Equal(t, expectedContent, string(data)) + } +} + +func TestExtractTarGz_ExtractsDirectories(t *testing.T) { + dir := t.TempDir() + u := New("1.0.0", dir) + + archiveBytes := makeTarGzWithDir(t, "hooks", map[string]string{ + "session-start": "session start hook", + "stop": "stop hook", + }) + + archivePath := filepath.Join(dir, "archive.tar.gz") + require.NoError(t, os.WriteFile(archivePath, archiveBytes, 0644)) + + destDir := filepath.Join(dir, "extracted") + err := u.extractTarGz(archivePath, destDir) + require.NoError(t, err) + + data, err := os.ReadFile(filepath.Join(destDir, "hooks", "session-start")) + require.NoError(t, err) + assert.Equal(t, "session start hook", string(data)) +} + +func TestExtractTarGz_PreventPathTraversal(t *testing.T) { + dir := t.TempDir() + u := New("1.0.0", dir) + + // Create archive with path traversal attempt + var buf bytes.Buffer + gw := gzip.NewWriter(&buf) + tw := tar.NewWriter(gw) + + body := []byte("malicious content") + hdr := &tar.Header{ + Name: "../../../etc/evil", + Mode: 0755, + Size: int64(len(body)), + Typeflag: tar.TypeReg, + } + require.NoError(t, tw.WriteHeader(hdr)) + _, err := tw.Write(body) + require.NoError(t, err) + require.NoError(t, tw.Close()) + require.NoError(t, gw.Close()) + + archivePath := filepath.Join(dir, "malicious.tar.gz") + require.NoError(t, os.WriteFile(archivePath, buf.Bytes(), 0644)) + + destDir := filepath.Join(dir, "extracted") + err = u.extractTarGz(archivePath, destDir) + require.Error(t, err) + assert.Contains(t, err.Error(), "invalid tar path") +} + +func TestExtractTarGz_RejectsDecompressionBomb(t *testing.T) { + // The implementation rejects a file whose io.Copy drains exactly MaxExtractedSize + // bytes from the LimitReader (written == MaxExtractedSize triggers the check). + // We create a valid tar.gz where the single file is exactly MaxExtractedSize bytes. + // Writing 250 MB would be too slow; instead use a small custom MaxExtractedSize + // by writing a helper that creates an archive matching a small cap, then run + // extractTarGz with that archive against the real constant. + // + // Since the constant is 250 MB we cannot write that much in a unit test. + // We test the guard path indirectly: create an archive that is VALID but whose + // declared size exceeds a tiny limit — we do this by making a tiny in-process + // copy of extractTarGz with a lower cap, or we call the real function with a + // file of exactly MaxExtractedSize bytes using a sparse write approach. + // + // Practical approach: use a pipe-backed fake that produces exactly MaxExtractedSize + // bytes of zeroes through the gzip+tar chain without buffering 250MB in RAM. + // The tar writer is closed properly so the archive is valid; the content is a + // stream of zero bytes piped directly into the compressor. + + dir := t.TempDir() + u := New("1.0.0", dir) + + archivePath := filepath.Join(dir, "bomb.tar.gz") + archiveFile, err := os.Create(archivePath) + require.NoError(t, err) + + // Build archive via pipe so we never buffer 250MB in memory. + pr, pw := io.Pipe() + + var writeErr error + go func() { + gw := gzip.NewWriter(pw) + tw := tar.NewWriter(gw) + + const size = MaxExtractedSize // exactly the limit + hdr := &tar.Header{ + Name: "bomb", + Mode: 0755, + Size: size, + Typeflag: tar.TypeReg, + } + if err := tw.WriteHeader(hdr); err != nil { + _ = pw.CloseWithError(err) + return + } + // Stream zeros without allocating 250 MB + zeros := make([]byte, 32*1024) + remaining := int64(size) + for remaining > 0 { + n := int64(len(zeros)) + if n > remaining { + n = remaining + } + if _, err := tw.Write(zeros[:n]); err != nil { + writeErr = err + break + } + remaining -= n + } + _ = tw.Close() + _ = gw.Close() + _ = pw.Close() + }() + + _, copyErr := io.Copy(archiveFile, pr) + _ = archiveFile.Close() + require.NoError(t, copyErr) + require.NoError(t, writeErr) + + destDir := filepath.Join(dir, "extracted") + err = u.extractTarGz(archivePath, destDir) + require.Error(t, err) + assert.Contains(t, err.Error(), "exceeds maximum allowed size") +} + +func TestExtractTarGz_NonExistentArchive(t *testing.T) { + dir := t.TempDir() + u := New("1.0.0", dir) + + err := u.extractTarGz(filepath.Join(dir, "nonexistent.tar.gz"), filepath.Join(dir, "out")) + require.Error(t, err) +} + +func TestExtractTarGz_InvalidGzip(t *testing.T) { + dir := t.TempDir() + u := New("1.0.0", dir) + + archivePath := filepath.Join(dir, "bad.tar.gz") + require.NoError(t, os.WriteFile(archivePath, []byte("this is not gzip"), 0644)) + + err := u.extractTarGz(archivePath, filepath.Join(dir, "out")) + require.Error(t, err) +} + +func TestExtractTarGz_FilePermissionsPreserved(t *testing.T) { + dir := t.TempDir() + u := New("1.0.0", dir) + + var buf bytes.Buffer + gw := gzip.NewWriter(&buf) + tw := tar.NewWriter(gw) + + body := []byte("executable") + hdr := &tar.Header{ + Name: "worker", + Mode: 0755, + Size: int64(len(body)), + Typeflag: tar.TypeReg, + } + require.NoError(t, tw.WriteHeader(hdr)) + _, err := tw.Write(body) + require.NoError(t, err) + require.NoError(t, tw.Close()) + require.NoError(t, gw.Close()) + + archivePath := filepath.Join(dir, "archive.tar.gz") + require.NoError(t, os.WriteFile(archivePath, buf.Bytes(), 0644)) + + destDir := filepath.Join(dir, "extracted") + require.NoError(t, u.extractTarGz(archivePath, destDir)) + + info, err := os.Stat(filepath.Join(destDir, "worker")) + require.NoError(t, err) + // Mode is masked with 0755 in the implementation + assert.Equal(t, os.FileMode(0755), info.Mode().Perm()) +} + +// --------------------------------------------------------------------------- +// New / GetStatus / setStatus / setError +// --------------------------------------------------------------------------- + +func TestNew_DefaultState(t *testing.T) { + u := New("1.0.0", "/some/dir") + assert.Equal(t, "1.0.0", u.currentVersion) + assert.Equal(t, "/some/dir", u.installDir) + assert.NotNil(t, u.httpClient) + + status := u.GetStatus() + assert.Equal(t, "idle", status.State) + assert.Equal(t, float64(0), status.Progress) +} + +func TestGetStatus_ReflectsSetStatus(t *testing.T) { + u := New("1.0.0", t.TempDir()) + u.setStatus("downloading", 0.5, "halfway there") + + s := u.GetStatus() + assert.Equal(t, "downloading", s.State) + assert.Equal(t, 0.5, s.Progress) + assert.Equal(t, "halfway there", s.Message) +} + +func TestSetError_SetsErrorState(t *testing.T) { + u := New("1.0.0", t.TempDir()) + u.setError(fmt.Errorf("something went wrong")) + + s := u.GetStatus() + assert.Equal(t, "error", s.State) + assert.Equal(t, "something went wrong", s.Error) + assert.Equal(t, "Update failed", s.Message) + assert.NotEmpty(t, s.ManualUpdateCommand) + assert.Contains(t, s.ManualUpdateCommand, "curl") +} + +// --------------------------------------------------------------------------- +// CheckForUpdate via httptest.NewServer +// --------------------------------------------------------------------------- + +func buildFakeRelease(tagName string, assets []Asset) Release { + return Release{ + TagName: tagName, + Name: "Release " + tagName, + Body: "release notes", + PublishedAt: time.Now(), + Assets: assets, + } +} + +func TestCheckForUpdate_UpdateAvailable(t *testing.T) { + platform := getPlatform() + newVersion := "9.9.9" + archiveName := fmt.Sprintf("claude-mnemonic_%s_%s.tar.gz", newVersion, platform) + + release := buildFakeRelease("v"+newVersion, []Asset{ + {Name: archiveName, BrowserDownloadURL: "http://example.com/archive.tar.gz"}, + {Name: "checksums.txt", BrowserDownloadURL: "http://example.com/checksums.txt"}, + {Name: "checksums.txt.sigstore.json", BrowserDownloadURL: "http://example.com/bundle.json"}, + }) + + srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.Header().Set("Content-Type", "application/json") + require.NoError(t, json.NewEncoder(w).Encode(release)) + })) + defer srv.Close() + + u := New("1.0.0", t.TempDir()) + u.httpClient = srv.Client() + + // Override the API URL by swapping the updater's client transport to point at test server. + // Since ReleasesAPI is a package-level const we need to make the request go to srv. + // Use a custom RoundTripper that rewrites the URL host. + origTransport := srv.Client().Transport + u.httpClient.Transport = &rewriteHostTransport{ + target: srv.URL, + wrapped: origTransport, + } + + // We can't easily redirect the const URL — instead call CheckForUpdate against + // a test server by temporarily overriding the request URL via a custom transport. + // The transport below rewrites any request to our test server. + info, err := u.checkForUpdateURL(context.Background(), srv.URL) + require.NoError(t, err) + require.NotNil(t, info) + assert.True(t, info.Available) + assert.Equal(t, newVersion, info.LatestVersion) + assert.Equal(t, "1.0.0", info.CurrentVersion) + assert.Equal(t, "http://example.com/archive.tar.gz", info.DownloadURL) + assert.Equal(t, "http://example.com/checksums.txt", info.ChecksumsURL) + assert.Equal(t, "http://example.com/bundle.json", info.BundleURL) + assert.NotEmpty(t, info.ManualUpdateCommand) +} + +func TestCheckForUpdate_NoUpdateWhenCurrentIsLatest(t *testing.T) { + release := buildFakeRelease("v1.0.0", nil) + + srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + require.NoError(t, json.NewEncoder(w).Encode(release)) + })) + defer srv.Close() + + u := New("1.0.0", t.TempDir()) + + info, err := u.checkForUpdateURL(context.Background(), srv.URL) + require.NoError(t, err) + assert.False(t, info.Available) +} + +func TestCheckForUpdate_ServerError(t *testing.T) { + srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(http.StatusInternalServerError) + })) + defer srv.Close() + + u := New("1.0.0", t.TempDir()) + + _, err := u.checkForUpdateURL(context.Background(), srv.URL) + require.Error(t, err) + assert.Contains(t, err.Error(), "500") +} + +func TestCheckForUpdate_InvalidJSON(t *testing.T) { + srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + _, _ = io.WriteString(w, "not json at all{{{") + })) + defer srv.Close() + + u := New("1.0.0", t.TempDir()) + + _, err := u.checkForUpdateURL(context.Background(), srv.URL) + require.Error(t, err) +} + +func TestCheckForUpdate_UsesCache(t *testing.T) { + callCount := 0 + release := buildFakeRelease("v2.0.0", nil) + + srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + callCount++ + require.NoError(t, json.NewEncoder(w).Encode(release)) + })) + defer srv.Close() + + u := New("1.0.0", t.TempDir()) + + // First call — hits server + info1, err := u.checkForUpdateURL(context.Background(), srv.URL) + require.NoError(t, err) + assert.Equal(t, 1, callCount) + + // Second call — should use cache (lastCheck set within last hour) + info2, err := u.checkForUpdateURL(context.Background(), srv.URL) + require.NoError(t, err) + assert.Equal(t, 1, callCount, "second call must use cache") + assert.Equal(t, info1, info2) +} + +func TestCheckForUpdate_CacheExpires(t *testing.T) { + callCount := 0 + release := buildFakeRelease("v2.0.0", nil) + + srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + callCount++ + require.NoError(t, json.NewEncoder(w).Encode(release)) + })) + defer srv.Close() + + u := New("1.0.0", t.TempDir()) + + // First call + _, err := u.checkForUpdateURL(context.Background(), srv.URL) + require.NoError(t, err) + assert.Equal(t, 1, callCount) + + // Force cache expiry by backdating lastCheck + u.mu.Lock() + u.lastCheck = time.Now().Add(-2 * time.Hour) + u.mu.Unlock() + + // Second call — cache is stale, must hit server + _, err = u.checkForUpdateURL(context.Background(), srv.URL) + require.NoError(t, err) + assert.Equal(t, 2, callCount) +} + +func TestCheckForUpdate_ContextCancellation(t *testing.T) { + srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + // Simulate slow server + time.Sleep(200 * time.Millisecond) + require.NoError(t, json.NewEncoder(w).Encode(buildFakeRelease("v2.0.0", nil))) + })) + defer srv.Close() + + ctx, cancel := context.WithTimeout(context.Background(), 50*time.Millisecond) + defer cancel() + + u := New("1.0.0", t.TempDir()) + + _, err := u.checkForUpdateURL(ctx, srv.URL) + require.Error(t, err) +} + +// --------------------------------------------------------------------------- +// downloadFile via httptest.NewServer +// --------------------------------------------------------------------------- + +func TestDownloadFile_Success(t *testing.T) { + content := "file content here" + srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + _, _ = io.WriteString(w, content) + })) + defer srv.Close() + + dir := t.TempDir() + u := New("1.0.0", dir) + + destPath := filepath.Join(dir, "downloaded.txt") + err := u.downloadFile(context.Background(), srv.URL+"/file", destPath) + require.NoError(t, err) + + data, err := os.ReadFile(destPath) + require.NoError(t, err) + assert.Equal(t, content, string(data)) +} + +func TestDownloadFile_ServerReturns404(t *testing.T) { + srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(http.StatusNotFound) + })) + defer srv.Close() + + u := New("1.0.0", t.TempDir()) + err := u.downloadFile(context.Background(), srv.URL+"/missing", filepath.Join(t.TempDir(), "out")) + require.Error(t, err) + assert.Contains(t, err.Error(), "404") +} + +func TestDownloadFile_InvalidURL(t *testing.T) { + u := New("1.0.0", t.TempDir()) + err := u.downloadFile(context.Background(), "://bad-url", filepath.Join(t.TempDir(), "out")) + require.Error(t, err) +} + +func TestDownloadFile_ContextCancellation(t *testing.T) { + srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + time.Sleep(200 * time.Millisecond) + _, _ = io.WriteString(w, "late response") + })) + defer srv.Close() + + ctx, cancel := context.WithTimeout(context.Background(), 30*time.Millisecond) + defer cancel() + + u := New("1.0.0", t.TempDir()) + err := u.downloadFile(ctx, srv.URL+"/slow", filepath.Join(t.TempDir(), "out")) + require.Error(t, err) +} + +// --------------------------------------------------------------------------- +// ApplyUpdate — no-op / guard cases +// --------------------------------------------------------------------------- + +func TestApplyUpdate_NoUpdateAvailable(t *testing.T) { + u := New("1.0.0", t.TempDir()) + info := &UpdateInfo{Available: false, DownloadURL: ""} + err := u.ApplyUpdate(context.Background(), info) + require.Error(t, err) + assert.Contains(t, err.Error(), "no update available") +} + +func TestApplyUpdate_MissingDownloadURL(t *testing.T) { + u := New("1.0.0", t.TempDir()) + info := &UpdateInfo{Available: true, DownloadURL: ""} + err := u.ApplyUpdate(context.Background(), info) + require.Error(t, err) + assert.Contains(t, err.Error(), "no update available or download URL missing") +} + +// --------------------------------------------------------------------------- +// helpers +// --------------------------------------------------------------------------- + +// rewriteHostTransport rewrites the host of every outgoing request to target. +type rewriteHostTransport struct { + wrapped http.RoundTripper + target string +} + +func (r *rewriteHostTransport) RoundTrip(req *http.Request) (*http.Response, error) { + cloned := req.Clone(req.Context()) + cloned.URL.Host = strings.TrimPrefix(r.target, "http://") + cloned.URL.Scheme = "http" + return r.wrapped.RoundTrip(cloned) +} + +// checkForUpdateURL is a testable variant of CheckForUpdate that accepts a custom API URL. +// It mirrors CheckForUpdate but uses the provided URL instead of ReleasesAPI. +func (u *Updater) checkForUpdateURL(ctx context.Context, apiURL string) (*UpdateInfo, error) { + u.setStatus("checking", 0, "Checking for updates...") + + u.mu.RLock() + if time.Since(u.lastCheck) < time.Hour && u.cachedUpdate != nil { + cached := u.cachedUpdate + u.mu.RUnlock() + u.setStatus("idle", 0, "") + return cached, nil + } + u.mu.RUnlock() + + req, err := http.NewRequestWithContext(ctx, "GET", apiURL, nil) + if err != nil { + u.setError(err) + return nil, err + } + req.Header.Set("Accept", "application/vnd.github.v3+json") + req.Header.Set("User-Agent", "claude-mnemonic/"+u.currentVersion) + + resp, err := u.httpClient.Do(req) + if err != nil { + u.setError(err) + return nil, fmt.Errorf("failed to check for updates: %w", err) + } + defer resp.Body.Close() + + if resp.StatusCode != http.StatusOK { + err := fmt.Errorf("GitHub API returned status %d", resp.StatusCode) + u.setError(err) + return nil, err + } + + var release Release + if err := json.NewDecoder(resp.Body).Decode(&release); err != nil { + u.setError(err) + return nil, fmt.Errorf("failed to parse release info: %w", err) + } + + info := &UpdateInfo{ + CurrentVersion: u.currentVersion, + LatestVersion: strings.TrimPrefix(release.TagName, "v"), + ReleaseNotes: release.Body, + PublishedAt: release.PublishedAt, + } + info.Available = isNewerVersion(info.LatestVersion, u.currentVersion) + info.ManualUpdateCommand = GetManualUpdateCommand("v" + info.LatestVersion) + + if info.Available { + platform := getPlatform() + archiveName := fmt.Sprintf("claude-mnemonic_%s_%s.tar.gz", info.LatestVersion, platform) + for _, asset := range release.Assets { + switch { + case asset.Name == archiveName: + info.DownloadURL = asset.BrowserDownloadURL + case asset.Name == "checksums.txt": + info.ChecksumsURL = asset.BrowserDownloadURL + case asset.Name == "checksums.txt.sigstore.json": + info.BundleURL = asset.BrowserDownloadURL + } + } + } + + u.mu.Lock() + u.lastCheck = time.Now() + u.cachedUpdate = info + u.mu.Unlock() + + u.setStatus("idle", 0, "") + return info, nil +} diff --git a/internal/watcher/watcher_test.go b/internal/watcher/watcher_test.go new file mode 100644 index 0000000..3c58181 --- /dev/null +++ b/internal/watcher/watcher_test.go @@ -0,0 +1,416 @@ +//go:build fts5 + +package watcher + +import ( + "context" + "os" + "path/filepath" + "sync" + "sync/atomic" + "testing" + "time" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +// waitForCondition polls fn every 10ms until it returns true or timeout expires. +func waitForCondition(t *testing.T, timeout time.Duration, fn func() bool) bool { + t.Helper() + deadline := time.Now().Add(timeout) + for time.Now().Before(deadline) { + if fn() { + return true + } + time.Sleep(10 * time.Millisecond) + } + return false +} + +// TestNew_CreatesWatcherWithCorrectFields verifies New initialises all fields correctly. +func TestNew_CreatesWatcherWithCorrectFields(t *testing.T) { + dir := t.TempDir() + target := filepath.Join(dir, "db.sqlite") + + called := false + cb := func() { called = true } + + w, err := New(target, cb) + require.NoError(t, err) + require.NotNil(t, w) + defer w.Stop() //nolint:errcheck + + assert.Equal(t, target, w.targetPath) + assert.Equal(t, dir, w.parentPath) + assert.Equal(t, 100*time.Millisecond, w.debounce) + assert.NotNil(t, w.watcher) + assert.NotNil(t, w.ctx) + assert.NotNil(t, w.cancel) + assert.False(t, w.running) + assert.False(t, called, "callback must not be invoked on creation") +} + +// TestNew_NilCallback is valid — handleDeletion guards for nil onDelete. +func TestNew_NilCallback(t *testing.T) { + dir := t.TempDir() + target := filepath.Join(dir, "db.sqlite") + + w, err := New(target, nil) + require.NoError(t, err) + require.NotNil(t, w) + defer w.Stop() //nolint:errcheck + + assert.Nil(t, w.onDelete) +} + +// TestStart_SetsRunningTrue verifies Start transitions running to true. +func TestStart_SetsRunningTrue(t *testing.T) { + dir := t.TempDir() + target := filepath.Join(dir, "db.sqlite") + + w, err := New(target, func() {}) + require.NoError(t, err) + defer w.Stop() //nolint:errcheck + + err = w.Start() + require.NoError(t, err) + + w.mu.Lock() + running := w.running + w.mu.Unlock() + assert.True(t, running) +} + +// TestStart_Idempotent verifies calling Start twice does not panic or return error. +func TestStart_Idempotent(t *testing.T) { + dir := t.TempDir() + target := filepath.Join(dir, "db.sqlite") + + w, err := New(target, func() {}) + require.NoError(t, err) + defer w.Stop() //nolint:errcheck + + require.NoError(t, w.Start()) + require.NoError(t, w.Start(), "second Start must be a no-op without error") + + // Still only one goroutine running — running flag is still true. + w.mu.Lock() + running := w.running + w.mu.Unlock() + assert.True(t, running) +} + +// TestStop_SetsRunningFalse verifies Stop transitions running to false. +func TestStop_SetsRunningFalse(t *testing.T) { + dir := t.TempDir() + target := filepath.Join(dir, "db.sqlite") + + w, err := New(target, func() {}) + require.NoError(t, err) + + require.NoError(t, w.Start()) + require.NoError(t, w.Stop()) + + w.mu.Lock() + running := w.running + w.mu.Unlock() + assert.False(t, running) +} + +// TestStop_Idempotent verifies calling Stop when not running returns nil. +func TestStop_Idempotent(t *testing.T) { + dir := t.TempDir() + target := filepath.Join(dir, "db.sqlite") + + w, err := New(target, func() {}) + require.NoError(t, err) + + // Never started — Stop must be a no-op. + assert.NoError(t, w.Stop()) + // Second stop after the first no-op must also succeed. + assert.NoError(t, w.Stop()) +} + +// TestStop_WithoutStart verifies Stop on an unstarted watcher is safe. +func TestStop_WithoutStart(t *testing.T) { + dir := t.TempDir() + target := filepath.Join(dir, "db.sqlite") + + w, err := New(target, func() {}) + require.NoError(t, err) + + err = w.Stop() + assert.NoError(t, err) +} + +// TestTargetDeletion_CallbackFired verifies that deleting the target file triggers onDelete. +func TestTargetDeletion_CallbackFired(t *testing.T) { + dir := t.TempDir() + target := filepath.Join(dir, "db.sqlite") + + // Create the target file so the parent watch is real. + require.NoError(t, os.WriteFile(target, []byte("data"), 0o644)) + + var callCount int32 + w, err := New(target, func() { + atomic.AddInt32(&callCount, 1) + }) + require.NoError(t, err) + defer w.Stop() //nolint:errcheck + + require.NoError(t, w.Start()) + + // Delete the target file. + require.NoError(t, os.Remove(target)) + + // Wait up to 1 second for the debounced callback (debounce=100ms). + fired := waitForCondition(t, 1*time.Second, func() bool { + return atomic.LoadInt32(&callCount) > 0 + }) + assert.True(t, fired, "onDelete callback not called after target deletion") +} + +// TestTargetDeletion_CallbackCalledOnce verifies debounce suppresses duplicate events. +func TestTargetDeletion_CallbackCalledOnce(t *testing.T) { + dir := t.TempDir() + target := filepath.Join(dir, "db.sqlite") + require.NoError(t, os.WriteFile(target, []byte("data"), 0o644)) + + var callCount int32 + w, err := New(target, func() { + atomic.AddInt32(&callCount, 1) + }) + require.NoError(t, err) + defer w.Stop() //nolint:errcheck + + require.NoError(t, w.Start()) + + require.NoError(t, os.Remove(target)) + + // Wait for callback to fire. + waitForCondition(t, 1*time.Second, func() bool { + return atomic.LoadInt32(&callCount) > 0 + }) + + // Wait an extra debounce window to confirm no second call arrives. + time.Sleep(300 * time.Millisecond) + assert.Equal(t, int32(1), atomic.LoadInt32(&callCount), "callback fired more than once for a single deletion") +} + +// TestTargetRecreation_CancelsCallback verifies that recreating the target before the +// debounce fires suppresses the onDelete callback. +func TestTargetRecreation_CancelsCallback(t *testing.T) { + dir := t.TempDir() + target := filepath.Join(dir, "db.sqlite") + require.NoError(t, os.WriteFile(target, []byte("data"), 0o644)) + + var callCount int32 + // Use a longer debounce so we can recreate before it fires. + w, err := New(target, func() { + atomic.AddInt32(&callCount, 1) + }) + require.NoError(t, err) + // Override debounce to give us a larger window. + w.debounce = 300 * time.Millisecond + defer w.Stop() //nolint:errcheck + + require.NoError(t, w.Start()) + + // Delete then immediately recreate within the debounce window. + require.NoError(t, os.Remove(target)) + time.Sleep(20 * time.Millisecond) // ensure delete event is processed + require.NoError(t, os.WriteFile(target, []byte("data"), 0o644)) + + // Wait past full debounce period to confirm callback was cancelled. + time.Sleep(500 * time.Millisecond) + assert.Equal(t, int32(0), atomic.LoadInt32(&callCount), "callback should have been cancelled by recreation") +} + +// TestParentDirectoryDeletion_CallbackFired verifies that deleting the parent directory +// triggers the onDelete callback. +func TestParentDirectoryDeletion_CallbackFired(t *testing.T) { + // Create a nested structure: base/sub/db.sqlite so we can remove sub + // without losing t.TempDir (which is base). + base := t.TempDir() + sub := filepath.Join(base, "sub") + require.NoError(t, os.Mkdir(sub, 0o755)) + target := filepath.Join(sub, "db.sqlite") + require.NoError(t, os.WriteFile(target, []byte("data"), 0o644)) + + var callCount int32 + w, err := New(target, func() { + atomic.AddInt32(&callCount, 1) + }) + require.NoError(t, err) + defer w.Stop() //nolint:errcheck + + require.NoError(t, w.Start()) + + // Remove parent directory entirely. + require.NoError(t, os.RemoveAll(sub)) + + fired := waitForCondition(t, 1500*time.Millisecond, func() bool { + return atomic.LoadInt32(&callCount) > 0 + }) + assert.True(t, fired, "onDelete callback not called after parent directory deletion") +} + +// TestAddWatch_NonExistentParent verifies addWatch returns an error when parent is absent. +func TestAddWatch_NonExistentParent(t *testing.T) { + // Point watcher at a path whose parent definitely does not exist. + nonExistent := filepath.Join(t.TempDir(), "missing", "db.sqlite") + + w, err := New(nonExistent, func() {}) + require.NoError(t, err) + defer w.Stop() //nolint:errcheck + + err = w.addWatch() + assert.Error(t, err, "addWatch must fail when parent directory does not exist") +} + +// TestContextCancellation_StopsWatchLoop verifies the watchLoop exits when Stop is called. +func TestContextCancellation_StopsWatchLoop(t *testing.T) { + dir := t.TempDir() + target := filepath.Join(dir, "db.sqlite") + + w, err := New(target, func() {}) + require.NoError(t, err) + + require.NoError(t, w.Start()) + + // Stop cancels the context; the goroutine should exit cleanly. + require.NoError(t, w.Stop()) + + // Give the goroutine a moment to exit — then verify running is false. + time.Sleep(50 * time.Millisecond) + w.mu.Lock() + running := w.running + w.mu.Unlock() + assert.False(t, running) +} + +// TestParentDirRecreation_ReEstablishesWatch verifies that recreating the parent after +// deletion allows subsequent target-deletion events to fire the callback. +func TestParentDirRecreation_ReEstablishesWatch(t *testing.T) { + base := t.TempDir() + sub := filepath.Join(base, "sub") + require.NoError(t, os.Mkdir(sub, 0o755)) + target := filepath.Join(sub, "db.sqlite") + require.NoError(t, os.WriteFile(target, []byte("data"), 0o644)) + + var callCount int32 + w, err := New(target, func() { + atomic.AddInt32(&callCount, 1) + }) + require.NoError(t, err) + defer w.Stop() //nolint:errcheck + + require.NoError(t, w.Start()) + + // Remove the parent. + require.NoError(t, os.RemoveAll(sub)) + + // Wait for first callback. + fired := waitForCondition(t, 1500*time.Millisecond, func() bool { + return atomic.LoadInt32(&callCount) > 0 + }) + require.True(t, fired, "first deletion callback must fire") + + firstCount := atomic.LoadInt32(&callCount) + + // Recreate parent and target — re-established watch should allow a second callback. + require.NoError(t, os.Mkdir(sub, 0o755)) + require.NoError(t, os.WriteFile(target, []byte("data"), 0o644)) + + // Wait for handleDeletion's goroutine to attempt re-adding the watch (500ms sleep inside). + time.Sleep(700 * time.Millisecond) + + // Now delete the target again. + require.NoError(t, os.Remove(target)) + + // We only assert the first callback fired; the re-watch is best-effort and + // OS-timing-dependent, so we don't hard-assert a second callback. + assert.GreaterOrEqual(t, atomic.LoadInt32(&callCount), firstCount, "call count must not decrease") +} + +// TestConcurrentStartStop verifies that concurrent Start/Stop calls do not race or panic. +func TestConcurrentStartStop(t *testing.T) { + dir := t.TempDir() + target := filepath.Join(dir, "db.sqlite") + + w, err := New(target, func() {}) + require.NoError(t, err) + + var wg sync.WaitGroup + ctx, cancel := context.WithTimeout(context.Background(), 500*time.Millisecond) + defer cancel() + + // Launch goroutines that repeatedly start/stop the watcher. + for i := 0; i < 5; i++ { + wg.Add(1) + go func() { + defer wg.Done() + for { + select { + case <-ctx.Done(): + return + default: + _ = w.Start() + time.Sleep(5 * time.Millisecond) + _ = w.Stop() + time.Sleep(5 * time.Millisecond) + } + } + }() + } + wg.Wait() + // No panic = pass. Final state: running should be consistent (we don't assert + // a specific value since Stop may have won last). +} + +// TestDebounceField_DefaultValue asserts the default debounce is 100ms. +func TestDebounceField_DefaultValue(t *testing.T) { + dir := t.TempDir() + w, err := New(filepath.Join(dir, "x"), func() {}) + require.NoError(t, err) + defer w.Stop() //nolint:errcheck + assert.Equal(t, 100*time.Millisecond, w.debounce) +} + +// TestCallbackNotCalledWhenStopped verifies that if we Stop before the debounce fires, +// the callback is not invoked after Stop (context cancel exits the watchLoop). +func TestCallbackNotCalledWhenStopped(t *testing.T) { + dir := t.TempDir() + target := filepath.Join(dir, "db.sqlite") + require.NoError(t, os.WriteFile(target, []byte("data"), 0o644)) + + var callCount int32 + w, err := New(target, func() { + atomic.AddInt32(&callCount, 1) + }) + require.NoError(t, err) + w.debounce = 500 * time.Millisecond // wide window + + require.NoError(t, w.Start()) + + // Delete file — debounce timer is now running (500ms). + require.NoError(t, os.Remove(target)) + time.Sleep(20 * time.Millisecond) // let event propagate + + // Stop before timer fires — context is cancelled, watchLoop exits. + require.NoError(t, w.Stop()) + + // Wait past the debounce window; the AfterFunc may still fire (it's not + // tied to the context), but the watcher is stopped. We assert the loop + // itself exited cleanly. + time.Sleep(700 * time.Millisecond) + + // The AfterFunc timer fires outside the watchLoop — callback may or may not + // have fired depending on OS scheduling. We assert no panic occurred. + // The important invariant: running is false. + w.mu.Lock() + running := w.running + w.mu.Unlock() + assert.False(t, running) +} diff --git a/internal/worker/service.go b/internal/worker/service.go index 37098f6..98e28ea 100644 --- a/internal/worker/service.go +++ b/internal/worker/service.go @@ -8,6 +8,7 @@ import ( "os" "sync" "sync/atomic" + "syscall" "time" "github.com/go-chi/chi/v5" @@ -366,9 +367,9 @@ func NewService(version string) (*Service, error) { router := chi.NewRouter() sseBroadcaster := sse.NewBroadcaster() - // Determine install directory (plugin location) + // Determine install directory (stable binary location, survives Claude Code updates) homeDir, _ := os.UserHomeDir() - installDir := fmt.Sprintf("%s/.claude/plugins/marketplaces/claude-mnemonic", homeDir) + installDir := fmt.Sprintf("%s/.claude-mnemonic/bin", homeDir) // Create rate limiter with generous limits (100 req/sec, burst of 200) // These limits are per-client and allow for intensive CLI usage @@ -736,10 +737,14 @@ func (s *Service) reinitializeDatabase() { log.Info().Msg("Query expansion reconnected after reinit") } - // Close old reranker if exists + // Close old vector client and reranker before swapping s.initMu.RLock() + oldVectorClient := s.vectorClient oldReranker := s.reranker s.initMu.RUnlock() + if oldVectorClient != nil { + _ = oldVectorClient.Close() + } if oldReranker != nil { _ = oldReranker.Close() } @@ -815,8 +820,11 @@ func (s *Service) reloadConfig() { // Give SSE clients a moment to receive the message time.Sleep(100 * time.Millisecond) - // Exit cleanly - hooks will restart us with new config - os.Exit(0) + // Send SIGTERM to self for graceful shutdown (hooks will restart us) + p, err := os.FindProcess(os.Getpid()) + if err == nil { + _ = p.Signal(syscall.SIGTERM) + } } // setInitError records an initialization error. @@ -1592,15 +1600,17 @@ func (s *Service) processQueue() { ticker := time.NewTicker(QueueProcessInterval) defer ticker.Stop() + s.initMu.RLock() + notify := s.sessionManager.ProcessNotify + s.initMu.RUnlock() + for { select { case <-s.ctx.Done(): return - case <-s.sessionManager.ProcessNotify: - // Immediate processing when observation is queued + case <-notify: s.processAllSessions() case <-ticker.C: - // Fallback periodic processing s.processAllSessions() } } @@ -1610,31 +1620,36 @@ func (s *Service) processQueue() { // Messages are processed in parallel using goroutines, with concurrency // limited by a channel-based semaphore. func (s *Service) processAllSessions() { - // Get all sessions with pending messages - sessions := s.sessionManager.GetAllSessions() + s.initMu.RLock() + mgr := s.sessionManager + proc := s.processor + s.initMu.RUnlock() + if mgr == nil || proc == nil { + return + } + + sessions := mgr.GetAllSessions() var wg sync.WaitGroup sem := make(chan struct{}, MaxConcurrentProcessing) for _, sess := range sessions { - // Get pending messages - messages := s.sessionManager.DrainMessages(sess.SessionDBID) + messages := mgr.DrainMessages(sess.SessionDBID) if len(messages) == 0 { continue } - // Process each message in a goroutine with semaphore for _, msg := range messages { wg.Add(1) - sem <- struct{}{} // Acquire semaphore slot + sem <- struct{}{} go func(sess *session.ActiveSession, msg session.PendingMessage) { defer wg.Done() - defer func() { <-sem }() // Release semaphore slot + defer func() { <-sem }() switch msg.Type { case session.MessageTypeObservation: if msg.Observation != nil { - err := s.processor.ProcessObservation( + err := proc.ProcessObservation( s.ctx, sess.SDKSessionID, sess.Project, @@ -1653,7 +1668,7 @@ func (s *Service) processAllSessions() { case session.MessageTypeSummarize: if msg.Summarize != nil { - err := s.processor.ProcessSummary( + err := proc.ProcessSummary( s.ctx, sess.SessionDBID, sess.SDKSessionID, @@ -1667,18 +1682,15 @@ func (s *Service) processAllSessions() { Int64("sessionId", sess.SessionDBID). Msg("Failed to process summary") } - // Delete session after summary - s.sessionManager.DeleteSession(sess.SessionDBID) + mgr.DeleteSession(sess.SessionDBID) } } }(sess, msg) } } - // Wait for all goroutines to complete wg.Wait() - // Broadcast status after processing s.broadcastProcessingStatus() } @@ -1787,8 +1799,15 @@ func (s *Service) Shutdown(ctx context.Context) error { // broadcastProcessingStatus broadcasts the current processing status. func (s *Service) broadcastProcessingStatus() { - isProcessing := s.sessionManager.IsAnySessionProcessing() - queueDepth := s.sessionManager.GetTotalQueueDepth() + s.initMu.RLock() + mgr := s.sessionManager + s.initMu.RUnlock() + if mgr == nil { + return + } + + isProcessing := mgr.IsAnySessionProcessing() + queueDepth := mgr.GetTotalQueueDepth() s.sseBroadcaster.Broadcast(map[string]any{ "type": "processing_status", diff --git a/marketplace.json b/marketplace.json index a49620c..5ff3847 100644 --- a/marketplace.json +++ b/marketplace.json @@ -1,7 +1,7 @@ { "$schema": "https://anthropic.com/claude-code/marketplace.schema.json", "name": "claude-mnemonic", - "version": "1.0.0", + "version": "v0.11.57-dirty", "description": "Persistent memory system for Claude Code - stores observations, session summaries, and user prompts with semantic search", "owner": { "name": "lukaszraczylo", @@ -12,7 +12,7 @@ "plugins": [ { "name": "claude-mnemonic", - "description": "Persistent memory system for Claude Code - Go implementation with SQLite and ChromaDB vector search", + "description": "Persistent memory system for Claude Code - Go implementation with SQLite and sqlite-vec vector search", "version": "0.11.105", "author": { "name": "lukaszraczylo", diff --git a/mcp-server b/mcp-server new file mode 100755 index 0000000..9099216 --- /dev/null +++ b/mcp-server @@ -0,0 +1,6 @@ +#!/bin/sh +BIN="$HOME/.claude-mnemonic/bin/mcp-server" +[ -x "$BIN" ] && exec "$BIN" "$@" +echo "claude-mnemonic: mcp-server not found at $BIN" >&2 +echo "Install: cd $(dirname "$0") && make install" >&2 +exit 1 diff --git a/pkg/hooks/worker.go b/pkg/hooks/worker.go index 02fc09e..44e8cdc 100644 --- a/pkg/hooks/worker.go +++ b/pkg/hooks/worker.go @@ -397,7 +397,15 @@ func KillProcessOnPort(port int) error { // findWorkerBinary finds the worker binary path. func findWorkerBinary() string { - // Check CLAUDE_PLUGIN_ROOT first (set by Claude Code when running hooks) + home := os.Getenv("HOME") + + // Stable binary location (primary, survives Claude Code updates) + stablePath := filepath.Join(home, ".claude-mnemonic", "bin", "worker") + if _, err := os.Stat(stablePath); err == nil { + return stablePath + } + + // Check CLAUDE_PLUGIN_ROOT (set by Claude Code when running hooks) if pluginRoot := os.Getenv("CLAUDE_PLUGIN_ROOT"); pluginRoot != "" { workerPath := filepath.Join(pluginRoot, "worker") if _, err := os.Stat(workerPath); err == nil { @@ -406,7 +414,6 @@ func findWorkerBinary() string { } // Check common locations - home := os.Getenv("HOME") locations := []string{ "./worker", "./bin/worker", @@ -418,10 +425,9 @@ func findWorkerBinary() string { } } - // Try cache directory with any version (glob returns lexically sorted matches) + // Try cache directory with any version matches, _ := filepath.Glob(filepath.Join(home, ".claude/plugins/cache/claude-mnemonic/claude-mnemonic/*/worker")) if len(matches) > 0 { - // Use the last match (latest version due to lexical sorting) return matches[len(matches)-1] } diff --git a/scripts/download-onnx-libs.sh b/scripts/download-onnx-libs.sh index c18b88e..83bfcc5 100755 --- a/scripts/download-onnx-libs.sh +++ b/scripts/download-onnx-libs.sh @@ -6,7 +6,7 @@ set -e -ONNX_VERSION="1.24.3" +ONNX_VERSION="1.26.0" ASSETS_DIR="internal/embedding/assets/lib" PLATFORM="${1:-all}" FORCE_DOWNLOAD=false diff --git a/scripts/register-plugin.sh b/scripts/register-plugin.sh index de75af3..3b51433 100755 --- a/scripts/register-plugin.sh +++ b/scripts/register-plugin.sh @@ -1,5 +1,6 @@ #!/bin/bash # Register claude-mnemonic plugin with Claude Code +# Ensures plugin survives CLI updates via extraKnownMarketplaces in settings.json set -e @@ -9,6 +10,7 @@ MARKETPLACES_FILE="$HOME/.claude/plugins/known_marketplaces.json" PLUGIN_KEY="claude-mnemonic@claude-mnemonic" MARKETPLACE_NAME="claude-mnemonic" MARKETPLACE_PATH="$HOME/.claude/plugins/marketplaces/claude-mnemonic" +STABLE_BIN="$HOME/.claude-mnemonic/bin" # Get version from git tags (same as Makefile), or use argument if provided VERSION="${1:-$(git describe --tags --always --dirty 2>/dev/null || echo "dev")}" @@ -21,7 +23,8 @@ TIMESTAMP=$(date -u +"%Y-%m-%dT%H:%M:%S.000Z") # The last argument is treated as the input file, output goes to input_file.tmp safe_jq_write() { local args=("$@") - local input_file="${args[-1]}" + local last_idx=$((${#args[@]} - 1)) + local input_file="${args[$last_idx]}" local tmp_file="${input_file}.tmp" if jq "${args[@]}" > "$tmp_file"; then @@ -69,29 +72,21 @@ if [ ! -f "$MARKETPLACES_FILE" ]; then echo '{}' > "$MARKETPLACES_FILE" fi -# Validate marketplace path exists and contains expected files -if [ ! -d "$MARKETPLACE_PATH" ]; then - echo "Warning: Marketplace directory not found at $MARKETPLACE_PATH" - echo "Plugin files may not be copied to cache correctly." -fi - -# Ensure cache directory exists and copy plugin files +# Ensure cache directory exists and copy plugin files (wrapper scripts + metadata) mkdir -p "$CACHE_PATH/.claude-plugin" mkdir -p "$CACHE_PATH/hooks" mkdir -p "$CACHE_PATH/commands" -# Copy files from marketplace to cache -if ! cp -r "$MARKETPLACE_PATH/"* "$CACHE_PATH/" 2>/dev/null; then - echo "ERROR: Failed to copy plugin files to cache directory" - exit 1 +# Copy wrapper scripts and metadata from marketplace to cache +if [ -d "$MARKETPLACE_PATH" ]; then + cp -r "$MARKETPLACE_PATH/"* "$CACHE_PATH/" 2>/dev/null || true fi -# Verify critical files exist in cache -for f in worker mcp-server hooks/hooks.json .claude-plugin/plugin.json; do - if [ ! -f "$CACHE_PATH/$f" ]; then - echo "WARNING: Expected file $f not found in cache after copy" - fi -done +# Also ensure actual binaries are available via the wrapper scripts +# The wrappers delegate to $STABLE_BIN which has the real binaries +if [ ! -x "$STABLE_BIN/mcp-server" ]; then + echo "WARNING: Binaries not found at $STABLE_BIN. Run 'make install' from the source directory." +fi # --- JSON registration --- # Uses jq if available, falls back to python3 for systems without jq. @@ -132,13 +127,12 @@ plugins["plugins"][plugin_key] = [{ "installPath": cache_path, "version": version, "installedAt": timestamp, - "lastUpdated": timestamp, - "isLocal": True + "lastUpdated": timestamp }] save_json(plugins_file, plugins) print("Plugin registered in installed_plugins.json") -# 2. settings.json +# 2. settings.json — enable plugin AND add to extraKnownMarketplaces for persistence settings = load_json(settings_file) settings.setdefault("enabledPlugins", {}) settings["enabledPlugins"][plugin_key] = True @@ -147,14 +141,26 @@ settings["statusLine"] = { "command": statusline_cmd, "padding": 0 } +# extraKnownMarketplaces ensures plugin survives Claude Code CLI updates +settings.setdefault("extraKnownMarketplaces", {}) +settings["extraKnownMarketplaces"][marketplace_name] = { + "source": { + "repo": "lukaszraczylo/claude-mnemonic", + "source": "github" + } +} save_json(settings_file, settings) print("Plugin enabled in settings.json") +print("Added to extraKnownMarketplaces for persistence") print("Statusline configured in settings.json") # 3. known_marketplaces.json marketplaces = load_json(marketplaces_file) marketplaces[marketplace_name] = { - "source": {"source": "directory", "path": marketplace_path}, + "source": { + "source": "github", + "repo": "lukaszraczylo/claude-mnemonic" + }, "installLocation": marketplace_path, "lastUpdated": timestamp } @@ -190,8 +196,7 @@ if command -v jq &> /dev/null; then "installPath": "$CACHE_PATH", "version": "$VERSION", "installedAt": "$TIMESTAMP", - "lastUpdated": "$TIMESTAMP", - "isLocal": true + "lastUpdated": "$TIMESTAMP" }] EOF ) @@ -202,7 +207,7 @@ EOF echo "Plugin registered in installed_plugins.json" - # Enable the plugin in settings.json and configure statusline + # Enable the plugin in settings.json, configure statusline, and add to extraKnownMarketplaces STATUSLINE_ENTRY=$(cat <