diff --git a/.github/workflows/release.yaml b/.github/workflows/release.yaml new file mode 100644 index 0000000..3861f74 --- /dev/null +++ b/.github/workflows/release.yaml @@ -0,0 +1,20 @@ +name: Release + +on: + push: + branches: + - main + +permissions: + contents: write + packages: write + id-token: write + +jobs: + release: + uses: lukaszraczylo/shared-actions/.github/workflows/go-release.yaml@main + with: + go-version: ">=1.24" + docker-enabled: true + docker-registry: ghcr.io + secrets: inherit diff --git a/.gitignore b/.gitignore index a36c95b..eded220 100644 --- a/.gitignore +++ b/.gitignore @@ -1 +1,3 @@ compaction-mcp +compactor +dist/ diff --git a/.goreleaser.yaml b/.goreleaser.yaml new file mode 100644 index 0000000..10c231e --- /dev/null +++ b/.goreleaser.yaml @@ -0,0 +1,76 @@ +version: 2 + +builds: + - binary: compactor + env: + - CGO_ENABLED=0 + goos: + - linux + - darwin + goarch: + - amd64 + - arm64 + ldflags: + - -s -w + +archives: + - format: tar.gz + name_template: "{{ .ProjectName }}_{{ .Os }}_{{ .Arch }}" + +dockers: + - image_templates: + - "ghcr.io/lukaszraczylo/compaction-mcp:{{ .Tag }}-amd64" + - "ghcr.io/lukaszraczylo/compaction-mcp:latest-amd64" + use: buildx + build_flag_templates: + - "--platform=linux/amd64" + goarch: amd64 + dockerfile: Dockerfile + + - image_templates: + - "ghcr.io/lukaszraczylo/compaction-mcp:{{ .Tag }}-arm64" + - "ghcr.io/lukaszraczylo/compaction-mcp:latest-arm64" + use: buildx + build_flag_templates: + - "--platform=linux/arm64" + goarch: arm64 + dockerfile: Dockerfile + +docker_manifests: + - name_template: "ghcr.io/lukaszraczylo/compaction-mcp:{{ .Tag }}" + image_templates: + - "ghcr.io/lukaszraczylo/compaction-mcp:{{ .Tag }}-amd64" + - "ghcr.io/lukaszraczylo/compaction-mcp:{{ .Tag }}-arm64" + + - name_template: "ghcr.io/lukaszraczylo/compaction-mcp:latest" + image_templates: + - "ghcr.io/lukaszraczylo/compaction-mcp:latest-amd64" + - "ghcr.io/lukaszraczylo/compaction-mcp:latest-arm64" + +checksum: + name_template: "checksums.txt" + +signs: + - cmd: cosign + artifacts: checksum + output: true + args: + - "sign-blob" + - "--yes" + - "${artifact}" + - "--output-certificate" + - "${signature}.pem" + - "--bundle" + - "${signature}.sigstore.json" + +docker_signs: + - cmd: cosign + artifacts: manifests + output: true + args: + - "sign" + - "--yes" + - "${artifact}" + +changelog: + disable: true diff --git a/Dockerfile b/Dockerfile new file mode 100644 index 0000000..8fdccf8 --- /dev/null +++ b/Dockerfile @@ -0,0 +1,3 @@ +FROM gcr.io/distroless/static-debian12:nonroot +COPY compactor /usr/local/bin/compactor +ENTRYPOINT ["compactor"] diff --git a/Makefile b/Makefile new file mode 100644 index 0000000..8b6eaa6 --- /dev/null +++ b/Makefile @@ -0,0 +1,17 @@ +BINARY := compactor +BUILDFLAGS := -buildvcs=false -trimpath + +.PHONY: build clean test lint + +build: + go build $(BUILDFLAGS) -o $(BINARY) . + +clean: + rm -f $(BINARY) + +test: + go test -race -count=1 ./... + +lint: + go vet ./... + gofmt -l . diff --git a/README.md b/README.md new file mode 100644 index 0000000..9567ba8 --- /dev/null +++ b/README.md @@ -0,0 +1,159 @@ +# compactor + +MCP server that manages LLM working memory within a token budget. Stores, retrieves, and compacts context so conversations stay under limit without losing valuable information. + +Designed to complement long-term memory tools (like claude-mnemonic) by handling short-term session context. + +## Install + +```sh +go build -o compactor . +``` + +Single binary, no external dependencies. ~6 MiB. + +## Usage + +```sh +# Ephemeral (in-memory, default) +compactor + +# With persistent state +compactor --state-dir ~/.local/share/compactor + +# Explicit token budget +compactor --budget 80000 +``` + +### Claude Code + +`.claude/settings.json`: +```json +{ + "mcpServers": { + "compactor": { + "command": "/path/to/compactor", + "args": ["--state-dir", "/tmp/compactor-state"] + } + } +} +``` + +### Cursor / other MCP clients + +Same pattern. The server auto-detects the client and sets a reasonable budget: +- **Claude** clients: 80K tokens (40% of 200K context) +- **Cursor**: 60K tokens +- Override with `--budget` flag + +## Tools + +| Tool | Description | +|------|-------------| +| `recall` | **Call first every session.** Restores previous context — returns budget status + top items by relevance | +| `store` | Store content with optional summary, tags, and importance (1-10) | +| `query` | BM25-ranked search by text and/or tag filtering | +| `status` | Check budget usage, item count, auto-compact settings | +| `compact` | Trigger compaction to a target usage ratio | +| `update` | Add/update summary for an item (post-compaction workflow) | +| `pin` / `unpin` | Protect items from eviction | +| `forget` | Remove a specific item | +| `list` | Paginated item listing (newest first) | +| `bulk_store` | Store multiple items in one call (JSON array) | +| `export` | Export all items, optionally as summaries | +| `configure` | Adjust budget, auto-compact toggle and threshold | + +## How compaction works + +Three-phase pipeline, triggered automatically at 90% budget or manually via `compact`: + +1. **Summary promotion** - Replaces content with its summary (lowest-scored items first) +2. **Deduplication** - Merges items with >70% word overlap (Jaccard similarity), keeping the higher-scored item +3. **Eviction** - Removes lowest-scored items until target usage is reached + +After compaction, items without summaries are flagged. The LLM can then generate summaries via `update` for future compaction cycles. + +## Scoring + +Each item gets a retention score combining four signals: + +``` +score = 0.4 * importance + 0.3 * recency + 0.2 * access - 0.1 * size_penalty +``` + +**Content-type awareness** adjusts scoring automatically: + +| Type | Detection | Score multiplier | Decay half-life | +|------|-----------|-----------------|-----------------| +| Error | `error:`, `panic:`, stack traces | 1.5x | 30 min | +| Decision | "decided", "going with", "approach:" | 1.3x | 6 hours | +| Code | `func`, `class`, backtick fences | 1.2x | 6 hours | +| Prose | Default | 1.0x | 2 hours | +| Tool output | `$ ` prefix, table chars | 0.7x | 15 min | + +Pinned items are never evicted. + +## Search + +Full-text search uses BM25 ranking (k1=1.2, b=0.75) with: +- camelCase and snake_case token splitting +- 5x score boost for tag matches +- Combined BM25 relevance + item retention score + +## Auto-tagging + +When no tags are provided, items are automatically tagged based on content: +- Content type (error, code, decision, tool-output) +- File extensions (.go, .ts, .py, etc.) +- Infrastructure keywords (kubernetes, docker, cilium, postgres, etc.) +- URL presence (tagged as "reference") + +## Persistence + +With `--state-dir`, state is saved as atomic JSON every 30 seconds (when dirty) and on graceful shutdown. Without it, storage is ephemeral per session. + +## CLI flags + +| Flag | Default | Description | +|------|---------|-------------| +| `--budget` | `100000` | Token budget (overrides auto-detection) | +| `--state-dir` | `""` | Persistent state directory (empty = ephemeral) | + +## Making it seamless + +The compactor is a tool the LLM must actively use — it doesn't intercept context automatically. To make usage habitual, add this to your `CLAUDE.md`: + +```markdown +## Working Memory (compactor MCP) +- At session start, ALWAYS call `recall` to restore previous context +- After making decisions, reading key files, or encountering errors: call `store` with a summary +- Before re-reading a file: call `query` to check if it's already stored +- When `status` shows >80% usage: call `compact`, then `update` items it flags +- Pin architecture decisions and user preferences with `pin` +``` + +The server also sends instructions via the MCP handshake that guide the LLM, but CLAUDE.md rules are stronger because they're treated as hard requirements. + +### How the three layers work together + +1. **MCP server instructions** — injected at connection time, tell the LLM the workflow +2. **CLAUDE.md rules** — persistent across sessions, override default behavior +3. **`recall` tool** — gives the LLM a single action to restore context, reducing friction from 12 tools to 1 entry point + +With persistence (`--state-dir`), context survives across sessions. The LLM calls `recall` → gets back its stored decisions, errors, code snippets → continues where it left off. + +## Architecture + +``` +main.go - Entry point, CLI flags, MCP server setup, persistence wiring +store.go - Core store: items, scoring, compaction, BM25 integration +tools.go - MCP tool definitions and handlers +index.go - BM25 inverted index with tag boosting +content.go - Content type detection and auto-tagging +persist.go - Atomic JSON persistence with background save +tokens.go - Token count estimation (~4 chars/token) +``` + +## License + +Private. diff --git a/content.go b/content.go new file mode 100644 index 0000000..fc6afc9 --- /dev/null +++ b/content.go @@ -0,0 +1,240 @@ +package main + +import ( + "path/filepath" + "regexp" + "strings" +) + +// ContentType classifies stored content for scoring and decay tuning. +type ContentType int + +const ( + ContentProse ContentType = iota + ContentCode + ContentError + ContentToolOutput + ContentDecision +) + +var ( + errorPatterns = []string{ + "error:", "Error:", "panic:", "FAIL", "goroutine", + "Exception", "Traceback", + } + + // Stack trace pattern: file.go:123 or File.java:45 + stackTraceRe = regexp.MustCompile(`\w+\.\w+:\d+`) + + codeKeywords = []string{ + "func ", "class ", "def ", "import ", "package ", "#include", + } + + decisionKeywords = []string{ + "decided", "agreed", "will use", "chosen", + "approach:", "decision:", "going with", + } + + fileExtMap = map[string]string{ + ".go": "go", + ".ts": "typescript", + ".py": "python", + ".yaml": "yaml", + ".yml": "yaml", + ".json": "json", + ".rs": "rust", + ".jsx": "react", + ".tsx": "react", + } + + infraKeywords = []string{ + "kubernetes", "docker", "cilium", "postgres", + "nginx", "redis", "graphql", "terraform", + } + + urlRe = regexp.MustCompile(`https?://\S+`) + filePathRe = regexp.MustCompile(`(?:\s|^)/?(?:[\w.-]+/){2,}[\w.-]+`) +) + +const maxTags = 5 + +// DetectContentType returns the content classification using priority: +// Error > Code > Decision > ToolOutput > Prose. +func DetectContentType(content string) ContentType { + if isError(content) { + return ContentError + } + if isCode(content) { + return ContentCode + } + if isDecision(content) { + return ContentDecision + } + if isToolOutput(content) { + return ContentToolOutput + } + return ContentProse +} + +func isError(content string) bool { + for _, p := range errorPatterns { + if strings.Contains(content, p) { + return true + } + } + return stackTraceRe.FindString(content) != "" && strings.Contains(content, "\n") +} + +func isCode(content string) bool { + if strings.Contains(content, "```") { + return true + } + for _, kw := range codeKeywords { + if strings.Contains(content, kw) { + return true + } + } + return bracketDensity(content) > 0.05 +} + +func bracketDensity(content string) float64 { + if len(content) == 0 { + return 0 + } + count := 0 + for _, c := range content { + switch c { + case '{', '}', '(', ')', '[', ']': + count++ + } + } + return float64(count) / float64(len([]rune(content))) +} + +func isDecision(content string) bool { + lower := strings.ToLower(content) + for _, kw := range decisionKeywords { + if strings.Contains(lower, kw) { + return true + } + } + return false +} + +func isToolOutput(content string) bool { + if strings.HasPrefix(content, "$ ") || strings.HasPrefix(content, "> ") { + return true + } + for _, ch := range []string{"───", "│", "├"} { + if strings.Contains(content, ch) { + return true + } + } + matches := filePathRe.FindAllString(content, -1) + words := strings.Fields(content) + if len(words) > 0 && float64(len(matches))/float64(len(words)) > 0.3 { + return true + } + return false +} + +// AutoTags extracts up to 5 deduplicated tags from content. +func AutoTags(content string) []string { + seen := make(map[string]struct{}) + var tags []string + + add := func(tag string) { + if len(tags) >= maxTags { + return + } + lower := strings.ToLower(tag) + if _, ok := seen[lower]; ok { + return + } + seen[lower] = struct{}{} + tags = append(tags, lower) + } + + // Content type tag + ct := DetectContentType(content) + name := ContentTypeName(ct) + if name != "prose" { + add(name) + } + + // File extension tags + words := strings.Fields(content) + for _, w := range words { + ext := filepath.Ext(strings.TrimRight(w, ",:;)\"'`")) + if tag, ok := fileExtMap[ext]; ok { + add(tag) + } + } + + // Infrastructure keyword tags + lower := strings.ToLower(content) + for _, kw := range infraKeywords { + if strings.Contains(lower, kw) { + add(kw) + } + } + + // URL tag + if urlRe.MatchString(content) { + add("reference") + } + + return tags +} + +// ScoreMultiplier returns an importance multiplier based on content type. +func ScoreMultiplier(ct ContentType) float64 { + switch ct { + case ContentError: + return 1.5 + case ContentDecision: + return 1.3 + case ContentCode: + return 1.2 + case ContentToolOutput: + return 0.7 + default: + return 1.0 + } +} + +// DecayHalfLifeMinutes returns the recency half-life in minutes for a content type. +func DecayHalfLifeMinutes(ct ContentType) float64 { + switch ct { + case ContentError: + return 30 + case ContentDecision: + return 360 + case ContentCode: + return 360 + case ContentProse: + return 120 + case ContentToolOutput: + return 15 + default: + return 120 + } +} + +// ContentTypeName returns the human-readable name for a content type. +func ContentTypeName(ct ContentType) string { + switch ct { + case ContentProse: + return "prose" + case ContentCode: + return "code" + case ContentError: + return "error" + case ContentToolOutput: + return "tool-output" + case ContentDecision: + return "decision" + default: + return "prose" + } +} diff --git a/content_test.go b/content_test.go new file mode 100644 index 0000000..9aabf0d --- /dev/null +++ b/content_test.go @@ -0,0 +1,320 @@ +package main + +import ( + "strings" + "testing" +) + +func TestDetectContentType(t *testing.T) { + tests := []struct { + name string + content string + want ContentType + }{ + { + name: "error with Error: prefix", + content: "Error: connection refused to postgres on port 5432", + want: ContentError, + }, + { + name: "error with panic", + content: "panic: runtime error: index out of range [3] with length 2", + want: ContentError, + }, + { + name: "error with FAIL", + content: "FAIL compaction-mcp [build failed]", + want: ContentError, + }, + { + name: "error with Traceback", + content: "Traceback (most recent call last):\n File \"main.py\", line 1", + want: ContentError, + }, + { + name: "error with goroutine", + content: "goroutine 1 [running]:\nmain.main()\n\t/app/main.go:12", + want: ContentError, + }, + { + name: "code with backtick fence", + content: "Here is the fix:\n```go\nfunc main() {}\n```", + want: ContentCode, + }, + { + name: "code with func keyword", + content: "func NewStore(budget int) *Store {\n\treturn &Store{}\n}", + want: ContentCode, + }, + { + name: "code with import keyword", + content: "import (\n\t\"fmt\"\n\t\"os\"\n)", + want: ContentCode, + }, + { + name: "code with high bracket density", + content: "{{{}}}(())[[]]{()}{{}}", + want: ContentCode, + }, + { + name: "decision with decided", + content: "We decided to use SQLite for local storage instead of BoltDB", + want: ContentDecision, + }, + { + name: "decision with going with", + content: "After discussion, going with the monorepo approach for simplicity", + want: ContentDecision, + }, + { + name: "decision with approach:", + content: "approach: sidecar proxy with Envoy for service mesh", + want: ContentDecision, + }, + { + name: "tool output with dollar prompt", + content: "$ go test -v ./...\nPASS\nok \tcompaction-mcp\t0.003s", + want: ContentToolOutput, + }, + { + name: "tool output with angle bracket prompt", + content: "> ls -la /etc/nginx/\ntotal 48\ndrwxr-xr-x 2 root root 4096 Jan 1 00:00 conf.d", + want: ContentToolOutput, + }, + { + name: "tool output with table chars", + content: "Name │ Status │ Age\n───────────├────────├─────\nnginx-pod │ Running│ 2d", + want: ContentToolOutput, + }, + { + name: "prose default", + content: "The cilium project provides networking for Kubernetes clusters using eBPF technology.", + want: ContentProse, + }, + { + name: "prose simple sentence", + content: "Let's meet tomorrow to discuss the architecture.", + want: ContentProse, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got := DetectContentType(tt.content) + if got != tt.want { + t.Errorf("DetectContentType() = %s, want %s", + ContentTypeName(got), ContentTypeName(tt.want)) + } + }) + } +} + +func TestDetectPriority(t *testing.T) { + tests := []struct { + name string + content string + want ContentType + }{ + { + name: "error beats code", + content: "Error: compilation failed\nfunc main() {\n\tpanic(\"bad\")\n}", + want: ContentError, + }, + { + name: "error beats decision", + content: "We decided to fix the panic: runtime error in production", + want: ContentError, + }, + { + name: "code beats decision", + content: "We decided to use this:\nfunc Handle() {}", + want: ContentCode, + }, + { + name: "code beats tool output", + content: "$ cat main.go\npackage main\nimport \"fmt\"", + want: ContentToolOutput, // starts with "$ " so tool output wins first in priority? No -- Error > Code > Decision > ToolOutput + }, + } + + // The last case: "$ " prefix makes it tool output, but "import " makes it code. + // Priority is Error > Code > Decision > ToolOutput, so Code should win. + // But "$ " prefix is checked in isToolOutput, and isCode is checked first. + // Actually: DetectContentType checks error first, then code, then decision, then tool output. + // "import " is in the content so isCode returns true => ContentCode. + // Fix the expected value: + tests[3].want = ContentCode + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got := DetectContentType(tt.content) + if got != tt.want { + t.Errorf("DetectContentType() = %s, want %s", + ContentTypeName(got), ContentTypeName(tt.want)) + } + }) + } +} + +func TestAutoTags(t *testing.T) { + tests := []struct { + name string + content string + want []string // subset that must appear + }{ + { + name: "go file paths produce go tag", + content: "Modified /Users/dev/project/main.go and store.go to fix the issue", + want: []string{"go"}, + }, + { + name: "typescript file produces typescript tag", + content: "Check the component in src/App.tsx for the bug", + want: []string{"react"}, + }, + { + name: "python file produces python tag", + content: "Updated models.py with new schema", + want: []string{"python"}, + }, + { + name: "URL produces reference tag", + content: "See https://kubernetes.io/docs/concepts/ for details", + want: []string{"reference", "kubernetes"}, + }, + { + name: "kubernetes keyword", + content: "Deploy to Kubernetes cluster using helm chart", + want: []string{"kubernetes"}, + }, + { + name: "error content gets error tag", + content: "Error: connection refused to redis server", + want: []string{"error", "redis"}, + }, + { + name: "docker keyword", + content: "Build the Docker image with multi-stage builds", + want: []string{"docker"}, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got := AutoTags(tt.content) + gotSet := make(map[string]struct{}, len(got)) + for _, tag := range got { + gotSet[tag] = struct{}{} + } + for _, w := range tt.want { + if _, ok := gotSet[w]; !ok { + t.Errorf("AutoTags() missing expected tag %q, got %v", w, got) + } + } + }) + } +} + +func TestAutoTagsMax(t *testing.T) { + // Content with many possible tags to verify the cap at 5. + content := "Error: kubernetes docker cilium postgres nginx redis https://example.com main.go script.py app.tsx" + tags := AutoTags(content) + if len(tags) > maxTags { + t.Errorf("AutoTags() returned %d tags, want at most %d: %v", len(tags), maxTags, tags) + } + // Should have some tags + if len(tags) == 0 { + t.Error("AutoTags() returned no tags for content-rich input") + } +} + +func TestScoreMultiplier(t *testing.T) { + tests := []struct { + ct ContentType + want float64 + }{ + {ContentError, 1.5}, + {ContentDecision, 1.3}, + {ContentCode, 1.2}, + {ContentProse, 1.0}, + {ContentToolOutput, 0.7}, + } + + for _, tt := range tests { + t.Run(ContentTypeName(tt.ct), func(t *testing.T) { + got := ScoreMultiplier(tt.ct) + if got != tt.want { + t.Errorf("ScoreMultiplier(%s) = %f, want %f", + ContentTypeName(tt.ct), got, tt.want) + } + }) + } +} + +func TestDecayHalfLife(t *testing.T) { + tests := []struct { + ct ContentType + want float64 + }{ + {ContentError, 30}, + {ContentDecision, 360}, + {ContentCode, 360}, + {ContentProse, 120}, + {ContentToolOutput, 15}, + } + + for _, tt := range tests { + t.Run(ContentTypeName(tt.ct), func(t *testing.T) { + got := DecayHalfLifeMinutes(tt.ct) + if got != tt.want { + t.Errorf("DecayHalfLifeMinutes(%s) = %f, want %f", + ContentTypeName(tt.ct), got, tt.want) + } + }) + } +} + +func TestContentTypeName(t *testing.T) { + tests := []struct { + want string + ct ContentType + }{ + {"prose", ContentProse}, + {"code", ContentCode}, + {"error", ContentError}, + {"tool-output", ContentToolOutput}, + {"decision", ContentDecision}, + } + + for _, tt := range tests { + t.Run(tt.want, func(t *testing.T) { + got := ContentTypeName(tt.ct) + if got != tt.want { + t.Errorf("ContentTypeName() = %q, want %q", got, tt.want) + } + }) + } +} + +func TestEmptyContent(t *testing.T) { + ct := DetectContentType("") + if ct != ContentProse { + t.Errorf("DetectContentType(\"\") = %s, want prose", ContentTypeName(ct)) + } + + tags := AutoTags("") + if len(tags) != 0 { + t.Errorf("AutoTags(\"\") = %v, want empty", tags) + } + + // Also test whitespace-only + ct = DetectContentType(" ") + if ct != ContentProse { + t.Errorf("DetectContentType(whitespace) = %s, want prose", ContentTypeName(ct)) + } + + tags = AutoTags(strings.Repeat(" ", 100)) + if len(tags) != 0 { + t.Errorf("AutoTags(whitespace) = %v, want empty", tags) + } +} diff --git a/index.go b/index.go new file mode 100644 index 0000000..f9f8cdf --- /dev/null +++ b/index.go @@ -0,0 +1,190 @@ +package main + +import ( + "math" + "regexp" + "sort" + "strings" + "unicode" +) + +// Index is a BM25 inverted index for full-text search over stored documents. +type Index struct { + docs map[string]map[string]int // docID -> term -> frequency + docLen map[string]int // docID -> total terms + postings map[string]map[string]struct{} // term -> set of docIDs + docTags map[string]map[string]struct{} // docID -> tag set (boosted 5x) + n int + avgDL float64 +} + +// SearchResult holds a document ID and its BM25 relevance score. +type SearchResult struct { + ID string + Score float64 +} + +// NewIndex creates an empty BM25 index. +func NewIndex() *Index { + return &Index{ + docs: make(map[string]map[string]int), + docLen: make(map[string]int), + postings: make(map[string]map[string]struct{}), + docTags: make(map[string]map[string]struct{}), + } +} + +// Add indexes a document with the given content and tags. +// Tags are stored separately and receive a 5x score boost during search. +func (idx *Index) Add(id, content string, tags []string) { + // Remove first if already present to avoid stale data. + if _, exists := idx.docs[id]; exists { + idx.Remove(id) + } + + tokens := tokenize(content) + tf := make(map[string]int, len(tokens)) + for _, t := range tokens { + tf[t]++ + } + + idx.docs[id] = tf + idx.docLen[id] = len(tokens) + + for term := range tf { + if idx.postings[term] == nil { + idx.postings[term] = make(map[string]struct{}) + } + idx.postings[term][id] = struct{}{} + } + + tagSet := make(map[string]struct{}, len(tags)) + for _, tag := range tags { + for _, t := range tokenize(tag) { + tagSet[t] = struct{}{} + } + } + idx.docTags[id] = tagSet + + idx.n++ + idx.recalcAvgDL() +} + +// Remove deletes a document from the index. +func (idx *Index) Remove(id string) { + tf, ok := idx.docs[id] + if !ok { + return + } + + for term := range tf { + if set, exists := idx.postings[term]; exists { + delete(set, id) + if len(set) == 0 { + delete(idx.postings, term) + } + } + } + + delete(idx.docs, id) + delete(idx.docLen, id) + delete(idx.docTags, id) + + idx.n-- + idx.recalcAvgDL() +} + +// Search returns the top `limit` documents ranked by BM25 score for the query. +// Tag matches receive a 5x boost on top of the BM25 score. +func (idx *Index) Search(query string, limit int) []SearchResult { + terms := tokenize(query) + if len(terms) == 0 || idx.n == 0 { + return nil + } + + const ( + k1 = 1.2 + b = 0.75 + tagBoost = 5.0 + ) + + scores := make(map[string]float64) + + for _, term := range terms { + docSet, ok := idx.postings[term] + if !ok { + continue + } + df := float64(len(docSet)) + idf := math.Log((float64(idx.n)-df+0.5)/(df+0.5) + 1.0) + + for docID := range docSet { + tfVal := float64(idx.docs[docID][term]) + dl := float64(idx.docLen[docID]) + num := tfVal * (k1 + 1) + denom := tfVal + k1*(1-b+b*(dl/idx.avgDL)) + scores[docID] += idf * (num / denom) + } + + // Tag boost: add 5x the IDF-weighted score for docs whose tags match. + for docID, tagSet := range idx.docTags { + if _, hit := tagSet[term]; hit { + dl := float64(idx.docLen[docID]) + // Use a synthetic TF of 1 for tag matches. + num := 1.0 * (k1 + 1) + denom := 1.0 + k1*(1-b+b*(dl/idx.avgDL)) + scores[docID] += tagBoost * idf * (num / denom) + } + } + } + + results := make([]SearchResult, 0, len(scores)) + for id, score := range scores { + results = append(results, SearchResult{ID: id, Score: score}) + } + + sort.Slice(results, func(i, j int) bool { + return results[i].Score > results[j].Score + }) + + if limit > 0 && len(results) > limit { + results = results[:limit] + } + return results +} + +func (idx *Index) recalcAvgDL() { + if idx.n == 0 { + idx.avgDL = 0 + return + } + total := 0 + for _, dl := range idx.docLen { + total += dl + } + idx.avgDL = float64(total) / float64(idx.n) +} + +// camelRe matches boundaries in camelCase identifiers (e.g. "handleCompact"). +var camelRe = regexp.MustCompile(`([a-z])([A-Z])`) + +// tokenize splits text into lowercase terms, handling camelCase and snake_case. +// Tokens shorter than 2 characters are filtered out. +func tokenize(s string) []string { + // Split camelCase: insert space at lowercase-to-uppercase boundary. + s = camelRe.ReplaceAllString(s, "${1} ${2}") + + // Split on any non-letter, non-digit character (handles snake_case, punctuation, whitespace). + splitter := func(r rune) bool { + return !unicode.IsLetter(r) && !unicode.IsDigit(r) + } + parts := strings.FieldsFunc(strings.ToLower(s), splitter) + + tokens := make([]string, 0, len(parts)) + for _, p := range parts { + if len(p) >= 2 { + tokens = append(tokens, p) + } + } + return tokens +} diff --git a/index_test.go b/index_test.go new file mode 100644 index 0000000..53c7c83 --- /dev/null +++ b/index_test.go @@ -0,0 +1,182 @@ +package main + +import ( + "reflect" + "testing" +) + +func TestIndexBasicSearch(t *testing.T) { + idx := NewIndex() + idx.Add("go-intro", "Go is a statically typed compiled language designed at Google", nil) + idx.Add("rust-intro", "Rust is a systems programming language focused on safety", nil) + idx.Add("go-concurrency", "Go provides goroutines and channels for concurrent programming", nil) + + results := idx.Search("Go programming", 10) + if len(results) == 0 { + t.Fatal("expected results, got none") + } + + // "go-concurrency" mentions both "go" and "programming" so it should rank first. + if results[0].ID != "go-concurrency" { + t.Errorf("expected go-concurrency first, got %s", results[0].ID) + } + + // All three docs should appear since "programming" or "go" appears in each. + if len(results) < 2 { + t.Errorf("expected at least 2 results, got %d", len(results)) + } +} + +func TestIndexTagBoost(t *testing.T) { + idx := NewIndex() + idx.Add("content-only", "database migration tools are useful for schema changes", nil) + idx.Add("tagged", "various development tools and utilities", []string{"database"}) + + results := idx.Search("database", 10) + if len(results) < 2 { + t.Fatalf("expected 2 results, got %d", len(results)) + } + + // The tagged doc should rank higher due to the 5x tag boost. + if results[0].ID != "tagged" { + t.Errorf("expected tagged doc first due to tag boost, got %s", results[0].ID) + } +} + +func TestIndexRemove(t *testing.T) { + idx := NewIndex() + idx.Add("keep", "important context about the project architecture", nil) + idx.Add("remove-me", "temporary notes about architecture review", nil) + + idx.Remove("remove-me") + + results := idx.Search("architecture", 10) + for _, r := range results { + if r.ID == "remove-me" { + t.Error("removed doc should not appear in search results") + } + } + + if len(results) != 1 { + t.Errorf("expected 1 result, got %d", len(results)) + } + if results[0].ID != "keep" { + t.Errorf("expected 'keep', got %s", results[0].ID) + } +} + +func TestIndexIdentifierSplitting(t *testing.T) { + idx := NewIndex() + idx.Add("handler", "the handleCompact function processes compaction requests", nil) + + results := idx.Search("compact", 10) + if len(results) == 0 { + t.Fatal("expected to find doc via camelCase split, got none") + } + if results[0].ID != "handler" { + t.Errorf("expected handler doc, got %s", results[0].ID) + } +} + +func TestIndexBM25LengthNormalization(t *testing.T) { + idx := NewIndex() + + // Short doc with the target term. + idx.Add("short", "compact server design", nil) + + // Long doc with the target term appearing only once, buried in filler. + long := "the server architecture includes many components such as " + + "authentication authorization logging monitoring caching routing " + + "validation serialization deserialization middleware handlers " + + "controllers services repositories models entities interfaces " + + "adapters ports configuration deployment orchestration compact" + idx.Add("long", long, nil) + + results := idx.Search("compact", 10) + if len(results) < 2 { + t.Fatalf("expected 2 results, got %d", len(results)) + } + + // Short doc should score higher due to BM25 length normalization. + if results[0].ID != "short" { + t.Errorf("expected short doc to rank first due to length normalization, got %s", results[0].ID) + } +} + +func TestIndexEmptySearch(t *testing.T) { + idx := NewIndex() + + results := idx.Search("anything", 10) + if len(results) != 0 { + t.Errorf("expected empty results from empty index, got %d", len(results)) + } + + results = idx.Search("", 10) + if len(results) != 0 { + t.Errorf("expected empty results for empty query, got %d", len(results)) + } +} + +func TestIndexMultipleTerms(t *testing.T) { + idx := NewIndex() + idx.Add("partial", "the server handles requests efficiently", nil) + idx.Add("full-match", "the server handles context compaction efficiently", nil) + + results := idx.Search("context compaction server", 10) + if len(results) < 2 { + t.Fatalf("expected 2 results, got %d", len(results)) + } + + // full-match has all three query terms; partial has only two. + if results[0].ID != "full-match" { + t.Errorf("expected full-match first (matches more query terms), got %s", results[0].ID) + } +} + +func TestTokenize(t *testing.T) { + tests := []struct { + name string + input string + want []string + }{ + { + name: "camelCase", + input: "handleCompact", + want: []string{"handle", "compact"}, + }, + { + name: "snake_case", + input: "auto_compact", + want: []string{"auto", "compact"}, + }, + { + name: "mixed", + input: "handleCompact auto_compact", + want: []string{"handle", "compact", "auto", "compact"}, + }, + { + name: "filters short tokens", + input: "a I go do it", + want: []string{"go", "do", "it"}, + }, + { + name: "punctuation", + input: "hello, world! foo-bar", + want: []string{"hello", "world", "foo", "bar"}, + }, + { + name: "uppercase", + input: "HTTPServer", + want: []string{"httpserver"}, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got := tokenize(tt.input) + if !reflect.DeepEqual(got, tt.want) { + t.Errorf("tokenize(%q) = %v, want %v", tt.input, got, tt.want) + } + }) + } +} diff --git a/main.go b/main.go index 9170800..4dcb21c 100644 --- a/main.go +++ b/main.go @@ -5,26 +5,44 @@ import ( "flag" "fmt" "os" + "os/signal" "strings" + "syscall" + "time" "github.com/mark3labs/mcp-go/mcp" "github.com/mark3labs/mcp-go/server" ) -const serverInstructions = `Context compactor — manages working memory within a token budget. +const serverInstructions = `Context compactor — your working memory. Use it to avoid losing information when your context window compresses. -At session start, call 'configure' with token_budget set to ~40% of your context window. -Example: 200K context window → token_budget = 80000. +MANDATORY: Call 'recall' at the start of every session to restore previous context. -Workflow: -- 'store' important context (always include a summary for efficient compaction) -- 'query' to retrieve stored information instead of re-reading sources -- 'status' to check budget usage -- 'compact' when usage is high — it frees space and identifies items needing summarization -- 'update' to add summaries to items flagged by compaction` +WHEN TO STORE (call 'store' with a summary): +- After making a decision or choosing an approach +- After encountering and understanding an error +- After reading a file you'll need to reference later +- After the user explains requirements or constraints +- Before your context is likely to compress (long sessions, large outputs) + +WHEN TO QUERY (call 'query' instead of re-reading): +- Before reading a file you may have stored previously +- When you need to recall a decision, error, or requirement +- When the user references something from earlier in the session + +WHEN TO COMPACT (call 'compact'): +- When 'status' shows >80% budget usage +- After 'compact', use 'update' to summarize items it flags + +TIPS: +- Always include a summary when storing — enables efficient compaction +- Tag items for easy retrieval: error, decision, code, requirement +- Pin critical items (architecture decisions, user preferences) with 'pin' +- Higher importance (7-10) for decisions and requirements, lower (1-4) for tool output` func main() { budget := flag.Int("budget", 100000, "Token budget for context storage") + stateDir := flag.String("state-dir", "", "Directory for persistent state (empty = ephemeral)") flag.Parse() budgetExplicit := false @@ -36,6 +54,21 @@ func main() { store := NewStore(*budget) + var persister *Persister + if *stateDir != "" { + var err error + persister, err = NewPersister(*stateDir, store) + if err != nil { + fmt.Fprintf(os.Stderr, "persistence error: %v\n", err) + os.Exit(1) + } + if err := persister.Load(); err != nil { + fmt.Fprintf(os.Stderr, "load state error: %v\n", err) + os.Exit(1) + } + persister.Start(30 * time.Second) + } + hooks := &server.Hooks{} if !budgetExplicit { hooks.OnAfterInitialize = append(hooks.OnAfterInitialize, @@ -63,8 +96,24 @@ func main() { registerTools(s, store) + if persister != nil { + sigCh := make(chan os.Signal, 1) + signal.Notify(sigCh, syscall.SIGINT, syscall.SIGTERM) + go func() { + <-sigCh + persister.Stop() + os.Exit(0) + }() + } + if err := server.ServeStdio(s); err != nil { fmt.Fprintf(os.Stderr, "server error: %v\n", err) + if persister != nil { + persister.Stop() + } os.Exit(1) } + if persister != nil { + persister.Stop() + } } diff --git a/persist.go b/persist.go new file mode 100644 index 0000000..f50fc56 --- /dev/null +++ b/persist.go @@ -0,0 +1,202 @@ +package main + +import ( + "encoding/json" + "os" + "path/filepath" + "sync" + "time" +) + +const stateVersion = 1 + +type persistedState struct { + SavedAt time.Time `json:"saved_at"` + Items []persistedItem `json:"items"` + Version int `json:"version"` + Budget int `json:"budget"` +} + +type persistedItem struct { + CreatedAt time.Time `json:"created_at"` + ID string `json:"id"` + Content string `json:"content"` + Summary string `json:"summary,omitempty"` + Tags []string `json:"tags,omitempty"` + Importance int `json:"importance"` + AccessCount int `json:"access_count"` + Tokens int `json:"tokens"` + ContentType ContentType `json:"content_type,omitempty"` + Pinned bool `json:"pinned,omitempty"` +} + +// Persister handles file-backed persistence for a Store. +type Persister struct { + store *Store + stopCh chan struct{} + dir string + mu sync.Mutex + dirty bool +} + +// snapshot returns a copy of all items and the token budget from the store. +func (s *Store) snapshot() ([]persistedItem, int) { + s.mu.Lock() + defer s.mu.Unlock() + + items := make([]persistedItem, 0, len(s.items)) + for _, item := range s.items { + items = append(items, persistedItem{ + ID: item.ID, + Content: item.Content, + Summary: item.Summary, + Tags: item.Tags, + Importance: item.Importance, + ContentType: item.ContentType, + Pinned: item.Pinned, + CreatedAt: item.CreatedAt, + AccessCount: item.AccessCount, + Tokens: item.Tokens, + }) + } + return items, s.tokenBudget +} + +// restore loads persisted items back into the store. +func (s *Store) restore(items []persistedItem) { + s.mu.Lock() + defer s.mu.Unlock() + + for _, pi := range items { + item := &Item{ + ID: pi.ID, + Content: pi.Content, + Summary: pi.Summary, + Tags: pi.Tags, + Importance: pi.Importance, + ContentType: pi.ContentType, + Pinned: pi.Pinned, + CreatedAt: pi.CreatedAt, + AccessCount: pi.AccessCount, + Tokens: pi.Tokens, + AccessedAt: time.Now(), + } + s.items[item.ID] = item + s.usedTokens += item.Tokens + s.index.Add(item.ID, item.Content, item.Tags) + } +} + +// NewPersister creates a Persister that saves store state to the given directory. +func NewPersister(dir string, store *Store) (*Persister, error) { + if err := os.MkdirAll(dir, 0o750); err != nil { + return nil, err + } + return &Persister{ + dir: dir, + store: store, + }, nil +} + +func (p *Persister) stateFile() string { + return filepath.Join(p.dir, "state.json") +} + +func (p *Persister) tmpFile() string { + return filepath.Join(p.dir, "state.json.tmp") +} + +// Save snapshots the store and writes it atomically to state.json. +func (p *Persister) Save() error { + items, budget := p.store.snapshot() + + state := persistedState{ + Version: stateVersion, + Budget: budget, + Items: items, + SavedAt: time.Now(), + } + + data, err := json.MarshalIndent(state, "", " ") + if err != nil { + return err + } + + tmp := p.tmpFile() + if err := os.WriteFile(tmp, data, 0o600); err != nil { + return err + } + + return os.Rename(tmp, p.stateFile()) +} + +// Load reads state.json and restores items into the store. +// Returns nil if the file does not exist (fresh start). +// Returns an error if the file exists but cannot be parsed. +func (p *Persister) Load() error { + data, err := os.ReadFile(p.stateFile()) + if err != nil { + if os.IsNotExist(err) { + return nil + } + return err + } + + var state persistedState + if err := json.Unmarshal(data, &state); err != nil { + return err + } + + p.store.restore(state.Items) + return nil +} + +// Start launches a background goroutine that periodically saves if dirty. +func (p *Persister) Start(interval time.Duration) { + p.mu.Lock() + p.stopCh = make(chan struct{}) + p.mu.Unlock() + + go func() { + ticker := time.NewTicker(interval) + defer ticker.Stop() + + for { + select { + case <-ticker.C: + p.mu.Lock() + shouldSave := p.dirty + p.dirty = false + p.mu.Unlock() + + if shouldSave { + _ = p.Save() + } + case <-p.stopCh: + return + } + } + }() +} + +// MarkDirty flags the store state as needing a save. +func (p *Persister) MarkDirty() { + p.mu.Lock() + p.dirty = true + p.mu.Unlock() +} + +// Stop signals the background goroutine to exit and performs a final save if dirty. +func (p *Persister) Stop() { + p.mu.Lock() + if p.stopCh != nil { + close(p.stopCh) + } + shouldSave := p.dirty + p.dirty = false + p.mu.Unlock() + + if shouldSave { + _ = p.Save() + } +} diff --git a/persist_test.go b/persist_test.go new file mode 100644 index 0000000..bb568e4 --- /dev/null +++ b/persist_test.go @@ -0,0 +1,178 @@ +package main + +import ( + "os" + "path/filepath" + "testing" + "time" +) + +func TestPersistSaveLoad(t *testing.T) { + dir := t.TempDir() + store1 := NewStore(10000) + + if _, err := store1.Add("first item content", "first summary", []string{"tag1"}, 5); err != nil { + t.Fatalf("Add: %v", err) + } + if _, err := store1.Add("second item content", "", []string{"tag2", "tag3"}, 8); err != nil { + t.Fatalf("Add: %v", err) + } + + p1, err := NewPersister(dir, store1) + if err != nil { + t.Fatalf("NewPersister: %v", err) + } + + errSave := p1.Save() + if errSave != nil { + t.Fatalf("Save: %v", errSave) + } + + // Load into a fresh store and verify roundtrip. + store2 := NewStore(10000) + p2, errP2 := NewPersister(dir, store2) + if errP2 != nil { + t.Fatalf("NewPersister (store2): %v", errP2) + } + + errLoad := p2.Load() + if errLoad != nil { + t.Fatalf("Load: %v", errLoad) + } + + _, _, count1, _ := store1.Status() + _, _, count2, _ := store2.Status() + + if count2 != count1 { + t.Fatalf("item count mismatch: got %d, want %d", count2, count1) + } + + // Verify individual items survived the roundtrip. + items := store2.Query("", nil, 100) + if len(items) != 2 { + t.Fatalf("expected 2 items from query, got %d", len(items)) + } + + found := make(map[string]bool) + for _, item := range items { + found[item.Content] = true + if item.Tokens <= 0 { + t.Errorf("item %s has non-positive token count: %d", item.ID, item.Tokens) + } + } + + if !found["first item content"] { + t.Error("missing 'first item content' after load") + } + if !found["second item content"] { + t.Error("missing 'second item content' after load") + } +} + +func TestPersistAtomicWrite(t *testing.T) { + dir := t.TempDir() + store := NewStore(10000) + if _, err := store.Add("test content", "", nil, 5); err != nil { + t.Fatalf("Add: %v", err) + } + + p, err := NewPersister(dir, store) + if err != nil { + t.Fatalf("NewPersister: %v", err) + } + + errSave := p.Save() + if errSave != nil { + t.Fatalf("Save: %v", errSave) + } + + // state.json should exist. + _, errStat := os.Stat(filepath.Join(dir, "state.json")) + if errStat != nil { + t.Fatalf("state.json missing: %v", errStat) + } + + // Temp file should not linger. + _, errTmp := os.Stat(filepath.Join(dir, "state.json.tmp")) + if !os.IsNotExist(errTmp) { + t.Fatal("state.json.tmp should not exist after successful save") + } +} + +func TestPersistLoadMissing(t *testing.T) { + dir := t.TempDir() + store := NewStore(10000) + + p, err := NewPersister(dir, store) + if err != nil { + t.Fatalf("NewPersister: %v", err) + } + + // Loading from a directory with no state.json should succeed (fresh start). + errLoad := p.Load() + if errLoad != nil { + t.Fatalf("Load on missing file should return nil, got: %v", errLoad) + } + + _, _, count, _ := store.Status() + if count != 0 { + t.Fatalf("expected 0 items after loading missing file, got %d", count) + } +} + +func TestPersistLoadCorrupted(t *testing.T) { + dir := t.TempDir() + + // Write garbage to state.json. + errWrite := os.WriteFile(filepath.Join(dir, "state.json"), []byte("{{{garbage"), 0o600) + if errWrite != nil { + t.Fatalf("writing corrupt file: %v", errWrite) + } + + store := NewStore(10000) + p, err := NewPersister(dir, store) + if err != nil { + t.Fatalf("NewPersister: %v", err) + } + + if p.Load() == nil { + t.Fatal("Load should return error for corrupted file") + } +} + +func TestPersistMarkDirty(t *testing.T) { + dir := t.TempDir() + store := NewStore(10000) + + p, err := NewPersister(dir, store) + if err != nil { + t.Fatalf("NewPersister: %v", err) + } + + // Initially not dirty -- Start + Stop should not create state.json. + p.Start(time.Hour) // long interval so tick won't fire + p.Stop() + + _, errStat := os.Stat(filepath.Join(dir, "state.json")) + if !os.IsNotExist(errStat) { + t.Fatal("state.json should not exist when not dirty") + } + + // Mark dirty, then Stop should trigger a save. + if _, err := store.Add("dirty item", "", nil, 5); err != nil { + t.Fatalf("Add: %v", err) + } + p2, errP2 := NewPersister(dir, store) + if errP2 != nil { + t.Fatalf("NewPersister: %v", errP2) + } + + p2.Start(time.Hour) + p2.MarkDirty() + p2.Stop() + + _, errFinal := os.Stat(filepath.Join(dir, "state.json")) + if errFinal != nil { + t.Fatalf("state.json should exist after dirty stop: %v", errFinal) + } +} diff --git a/semver.yaml b/semver.yaml new file mode 100644 index 0000000..5a2e4a0 --- /dev/null +++ b/semver.yaml @@ -0,0 +1,24 @@ +version: 1 +force: + major: 0 + minor: 1 + patch: 0 +blacklist: + - "Merge branch" + - "Merge pull request" +wording: + patch: + - update + - fix + - bugfix + - patch + - tweak + minor: + - change + - improve + - add + - feature + - enhance + major: + - breaking + - major diff --git a/store.go b/store.go index f07f098..9895646 100644 --- a/store.go +++ b/store.go @@ -2,6 +2,7 @@ package main import ( "crypto/rand" + "errors" "fmt" "math" "sort" @@ -10,6 +11,12 @@ import ( "time" ) +const ( + maxItems = 10000 + maxContentBytes = 1 << 20 // 1 MiB + maxDedupCandidates = 500 +) + type Item struct { CreatedAt time.Time AccessedAt time.Time @@ -20,11 +27,18 @@ type Item struct { Tokens int AccessCount int Importance int + ContentType ContentType Pinned bool } +type SummaryCandidate struct { + ID string + Preview string + Tokens int +} + type CompactResult struct { - NeedsSummary []*Item + NeedsSummary []SummaryCandidate TokensFreed int TokensBefore int TokensAfter int @@ -33,8 +47,16 @@ type CompactResult struct { Deduplicated int } +type BulkItem struct { + Content string + Summary string + Tags []string + Importance int +} + type Store struct { items map[string]*Item + index *Index tokenBudget int usedTokens int autoCompactThreshold float64 @@ -45,47 +67,62 @@ type Store struct { func NewStore(tokenBudget int) *Store { return &Store{ items: make(map[string]*Item), + index: NewIndex(), tokenBudget: tokenBudget, autoCompact: true, autoCompactThreshold: 0.9, } } -func (s *Store) Add(content, summary string, tags []string, importance int) *Item { +func (s *Store) Add(content, summary string, tags []string, importance int) (*Item, error) { s.mu.Lock() defer s.mu.Unlock() + if len(s.items) >= maxItems { + return nil, errors.New("item limit reached (max 10000)") + } + if len(content) > maxContentBytes { + return nil, fmt.Errorf("content too large (%d bytes, max %d)", len(content), maxContentBytes) + } + if importance < 1 { importance = 5 } else if importance > 10 { importance = 10 } + ct := DetectContentType(content) + if len(tags) == 0 { + tags = AutoTags(content) + } + tokens := EstimateTokens(content) item := &Item{ - ID: newID(), - Content: content, - Summary: summary, - Tags: tags, - Importance: importance, - CreatedAt: time.Now(), - AccessedAt: time.Now(), - Tokens: tokens, + ID: newID(), + Content: content, + Summary: summary, + Tags: tags, + Importance: importance, + ContentType: ct, + CreatedAt: time.Now(), + AccessedAt: time.Now(), + Tokens: tokens, } s.items[item.ID] = item s.usedTokens += tokens + s.index.Add(item.ID, content, tags) if s.autoCompact && s.tokenBudget > 0 { if float64(s.usedTokens)/float64(s.tokenBudget) > s.autoCompactThreshold { - s.compactLocked(0.8) + s.compactLocked(0.8, false) } } - return item + return item, nil } -func (s *Store) Get(id string) (*Item, bool) { +func (s *Store) Get(id string) (Item, bool) { s.mu.Lock() defer s.mu.Unlock() @@ -93,8 +130,9 @@ func (s *Store) Get(id string) (*Item, bool) { if ok { item.AccessedAt = time.Now() item.AccessCount++ + return *item, true } - return item, ok + return Item{}, false } func (s *Store) Remove(id string) bool { @@ -106,6 +144,7 @@ func (s *Store) Remove(id string) bool { return false } s.usedTokens -= item.Tokens + s.index.Remove(id) delete(s.items, id) return true } @@ -121,6 +160,17 @@ func (s *Store) Pin(id string) bool { return ok } +func (s *Store) Unpin(id string) bool { + s.mu.Lock() + defer s.mu.Unlock() + + item, ok := s.items[id] + if ok { + item.Pinned = false + } + return ok +} + func (s *Store) UpdateSummary(id, summary string) bool { s.mu.Lock() defer s.mu.Unlock() @@ -132,7 +182,7 @@ func (s *Store) UpdateSummary(id, summary string) bool { return ok } -func (s *Store) Query(query string, tags []string, limit int) []*Item { +func (s *Store) Query(query string, tags []string, limit int) []Item { s.mu.Lock() defer s.mu.Unlock() @@ -140,26 +190,160 @@ func (s *Store) Query(query string, tags []string, limit int) []*Item { limit = 10 } - queryWords := wordSet(query) - var results []*Item + type scored struct { + item *Item + score float64 + } + var results []scored - for _, item := range s.items { - if len(tags) > 0 && !hasAnyTag(item, tags) { - continue + if query != "" { + // BM25 search ranks results by relevance + searchResults := s.index.Search(query, 0) + for _, sr := range searchResults { + item, ok := s.items[sr.ID] + if !ok { + continue + } + if len(tags) > 0 && !hasAnyTag(item, tags) { + continue + } + // Combine BM25 relevance with item score + results = append(results, scored{item: item, score: sr.Score + s.scoreLocked(item)}) + } + } else { + // No query text — filter by tags, sort by score + for _, item := range s.items { + if len(tags) > 0 && !hasAnyTag(item, tags) { + continue + } + results = append(results, scored{item: item, score: s.scoreLocked(item)}) } - item.AccessedAt = time.Now() - item.AccessCount++ - results = append(results, item) } sort.Slice(results, func(i, j int) bool { - return s.queryScore(results[i], queryWords) > s.queryScore(results[j], queryWords) + return results[i].score > results[j].score }) if len(results) > limit { results = results[:limit] } - return results + + // Only bump AccessedAt/AccessCount on items that made the cut + out := make([]Item, len(results)) + for i, r := range results { + r.item.AccessedAt = time.Now() + r.item.AccessCount++ + out[i] = *r.item + } + return out +} + +func (s *Store) ListItems(offset, limit int) ([]Item, int) { + s.mu.Lock() + defer s.mu.Unlock() + + total := len(s.items) + + // Collect and sort by creation time descending + all := make([]*Item, 0, total) + for _, item := range s.items { + all = append(all, item) + } + sort.Slice(all, func(i, j int) bool { + return all[i].CreatedAt.After(all[j].CreatedAt) + }) + + if offset < 0 { + offset = 0 + } + if limit <= 0 { + limit = 20 + } + + if offset >= len(all) { + return nil, total + } + end := offset + limit + if end > len(all) { + end = len(all) + } + + out := make([]Item, end-offset) + for i, item := range all[offset:end] { + out[i] = *item + } + return out, total +} + +func (s *Store) BulkAdd(items []BulkItem) ([]*Item, []error) { + results := make([]*Item, len(items)) + errs := make([]error, len(items)) + for i, bi := range items { + item, err := s.Add(bi.Content, bi.Summary, bi.Tags, bi.Importance) + results[i] = item + errs[i] = err + } + return results, errs +} + +func (s *Store) Export(summariesOnly bool) []Item { + s.mu.Lock() + defer s.mu.Unlock() + + out := make([]Item, 0, len(s.items)) + for _, item := range s.items { + cp := *item + if summariesOnly && cp.Summary != "" { + cp.Content = cp.Summary + cp.Tokens = EstimateTokens(cp.Content) + } + out = append(out, cp) + } + + sort.Slice(out, func(i, j int) bool { + return out[i].CreatedAt.Before(out[j].CreatedAt) + }) + return out +} + +// Recall returns status info and the top items by retention score. +// Designed as a single "session start" call to restore working context. +func (s *Store) Recall(limit int) (budget, used, count int, usage float64, items []Item) { + s.mu.Lock() + defer s.mu.Unlock() + + budget = s.tokenBudget + used = s.usedTokens + count = len(s.items) + if budget > 0 { + usage = float64(used) / float64(budget) + } + + if count == 0 || limit <= 0 { + return + } + + type scored struct { + item *Item + score float64 + } + all := make([]scored, 0, count) + for _, item := range s.items { + all = append(all, scored{item: item, score: s.scoreLocked(item)}) + } + sort.Slice(all, func(i, j int) bool { + return all[i].score > all[j].score + }) + + n := limit + if n > len(all) { + n = len(all) + } + items = make([]Item, n) + for i := 0; i < n; i++ { + items[i] = *all[i].item + } + return } func (s *Store) scoreLocked(item *Item) float64 { @@ -167,11 +351,13 @@ func (s *Store) scoreLocked(item *Item) float64 { return math.MaxFloat64 } + halfLife := DecayHalfLifeMinutes(item.ContentType) age := time.Since(item.AccessedAt).Minutes() - recency := math.Exp(-age / 120.0) // half-life ~2 hours + recency := math.Exp(-age / halfLife) importance := float64(item.Importance) / 10.0 - access := math.Log1p(float64(item.AccessCount)) / 5.0 + importance *= ScoreMultiplier(item.ContentType) + access := math.Min(math.Log1p(float64(item.AccessCount))/5.0, 1.0) var sizePenalty float64 if s.tokenBudget > 0 { @@ -181,25 +367,6 @@ func (s *Store) scoreLocked(item *Item) float64 { return (0.4 * importance) + (0.3 * recency) + (0.2 * access) - (0.1 * sizePenalty) } -func (s *Store) queryScore(item *Item, queryWords map[string]struct{}) float64 { - base := s.scoreLocked(item) - if len(queryWords) == 0 { - return base - } - - contentWords := wordSet(item.Content) - if item.Summary != "" { - for w := range wordSet(item.Summary) { - contentWords[w] = struct{}{} - } - } - for _, tag := range item.Tags { - contentWords[strings.ToLower(tag)] = struct{}{} - } - - return base + (0.5 * jaccardSimilarity(queryWords, contentWords)) -} - func (s *Store) Status() (budget, used, count int, usage float64) { s.mu.Lock() defer s.mu.Unlock() @@ -253,10 +420,10 @@ func (s *Store) AutoCompactThreshold() float64 { func (s *Store) Compact(targetUsage float64) CompactResult { s.mu.Lock() defer s.mu.Unlock() - return s.compactLocked(targetUsage) + return s.compactLocked(targetUsage, true) } -func (s *Store) compactLocked(targetUsage float64) CompactResult { +func (s *Store) compactLocked(targetUsage float64, fullCompaction bool) CompactResult { result := CompactResult{TokensBefore: s.usedTokens} if s.tokenBudget <= 0 { @@ -271,13 +438,27 @@ func (s *Store) compactLocked(targetUsage float64) CompactResult { } // Phase 1: Summary promotion — replace content with summary to save tokens + // Sort candidates by score ascending so we promote lowest-scoring items first + type promoCandidate struct { + item *Item + score float64 + } + var promoCandidates []promoCandidate for _, item := range s.items { - if s.usedTokens <= targetTokens { - break - } if item.Pinned || item.Summary == "" || item.Content == item.Summary { continue } + promoCandidates = append(promoCandidates, promoCandidate{item: item, score: s.scoreLocked(item)}) + } + sort.Slice(promoCandidates, func(i, j int) bool { + return promoCandidates[i].score < promoCandidates[j].score + }) + + for _, pc := range promoCandidates { + if s.usedTokens <= targetTokens { + break + } + item := pc.item oldTokens := item.Tokens item.Content = item.Summary item.Tokens = EstimateTokens(item.Content) @@ -286,58 +467,74 @@ func (s *Store) compactLocked(targetUsage float64) CompactResult { s.usedTokens -= saved result.Summarized++ result.TokensFreed += saved + s.index.Add(item.ID, item.Content, item.Tags) } } - // Phase 2: Deduplication — merge items with >70% word overlap - ids := make([]string, 0, len(s.items)) - for id := range s.items { - ids = append(ids, id) - } - merged := make(map[string]bool) - - for i := 0; i < len(ids); i++ { - if merged[ids[i]] { - continue + // Phase 2: Deduplication — merge items with >70% word overlap (full compaction only) + if fullCompaction { + ids := make([]string, 0, len(s.items)) + for id := range s.items { + ids = append(ids, id) } - a := s.items[ids[i]] - if a == nil || a.Pinned { - continue - } - aWords := wordSet(a.Content) - for j := i + 1; j < len(ids); j++ { - if merged[ids[j]] { + // Cap dedup candidates + if len(ids) > maxDedupCandidates { + // Sort by score ascending to prioritize merging low-value items + sort.Slice(ids, func(i, j int) bool { + return s.scoreLocked(s.items[ids[i]]) < s.scoreLocked(s.items[ids[j]]) + }) + ids = ids[:maxDedupCandidates] + } + + merged := make(map[string]bool) + + for i := 0; i < len(ids); i++ { + if merged[ids[i]] { continue } - b := s.items[ids[j]] - if b == nil { + a := s.items[ids[i]] + if a == nil || a.Pinned { continue } + aWords := wordSet(a.Content) - if jaccardSimilarity(aWords, wordSet(b.Content)) > 0.7 { - // Keep higher-scoring item, merge tags - if s.scoreLocked(a) >= s.scoreLocked(b) { - a.Tags = mergeTags(a.Tags, b.Tags) - if a.Summary == "" && b.Summary != "" { - a.Summary = b.Summary - } - s.usedTokens -= b.Tokens - result.TokensFreed += b.Tokens - delete(s.items, ids[j]) - merged[ids[j]] = true - } else { - b.Tags = mergeTags(b.Tags, a.Tags) - if b.Summary == "" && a.Summary != "" { - b.Summary = a.Summary - } - s.usedTokens -= a.Tokens - result.TokensFreed += a.Tokens - delete(s.items, ids[i]) - merged[ids[i]] = true - break + for j := i + 1; j < len(ids); j++ { + if merged[ids[j]] { + continue + } + b := s.items[ids[j]] + if b == nil { + continue + } + + if jaccardSimilarity(aWords, wordSet(b.Content)) > 0.7 { + // Keep higher-scoring item, merge tags + if s.scoreLocked(a) >= s.scoreLocked(b) { + a.Tags = mergeTags(a.Tags, b.Tags) + if a.Summary == "" && b.Summary != "" { + a.Summary = b.Summary + } + s.usedTokens -= b.Tokens + result.TokensFreed += b.Tokens + s.index.Remove(ids[j]) + delete(s.items, ids[j]) + merged[ids[j]] = true + } else { + b.Tags = mergeTags(b.Tags, a.Tags) + if b.Summary == "" && a.Summary != "" { + b.Summary = a.Summary + } + s.usedTokens -= a.Tokens + result.TokensFreed += a.Tokens + s.index.Remove(ids[i]) + delete(s.items, ids[i]) + merged[ids[i]] = true + result.Deduplicated++ + break + } + result.Deduplicated++ } - result.Deduplicated++ } } } @@ -360,6 +557,7 @@ func (s *Store) compactLocked(targetUsage float64) CompactResult { } s.usedTokens -= item.Tokens result.TokensFreed += item.Tokens + s.index.Remove(item.ID) delete(s.items, item.ID) result.Evicted++ } @@ -368,7 +566,15 @@ func (s *Store) compactLocked(targetUsage float64) CompactResult { // Collect items that could benefit from LLM summarization for _, item := range s.items { if item.Summary == "" && item.Tokens > 100 { - result.NeedsSummary = append(result.NeedsSummary, item) + preview := item.Content + if len(preview) > 80 { + preview = preview[:80] + } + result.NeedsSummary = append(result.NeedsSummary, SummaryCandidate{ + ID: item.ID, + Tokens: item.Tokens, + Preview: preview, + }) } } diff --git a/store_test.go b/store_test.go index 6ee5ef3..f8e1ae8 100644 --- a/store_test.go +++ b/store_test.go @@ -26,7 +26,10 @@ func TestEstimateTokens(t *testing.T) { func TestStoreAddAndQuery(t *testing.T) { s := NewStore(100000) - item := s.Add("cilium uses netkit on kernel 6.8", "cilium netkit needs 6.8", []string{"cilium", "networking"}, 8) + item, err := s.Add("cilium uses netkit on kernel 6.8", "cilium netkit needs 6.8", []string{"cilium", "networking"}, 8) + if err != nil { + t.Fatalf("Add failed: %v", err) + } if item.ID == "" { t.Fatal("expected non-empty ID") } @@ -59,7 +62,10 @@ func TestStoreAddAndQuery(t *testing.T) { func TestStoreRemoveAndPin(t *testing.T) { s := NewStore(100000) - item := s.Add("test content", "", nil, 5) + item, err := s.Add("test content", "", nil, 5) + if err != nil { + t.Fatalf("Add failed: %v", err) + } if !s.Pin(item.ID) { t.Fatal("pin failed") } @@ -79,36 +85,47 @@ func TestStoreRemoveAndPin(t *testing.T) { } func TestCompaction(t *testing.T) { - s := NewStore(200) // small budget: ~200 tokens = ~800 chars + // Budget of 20 tokens. Disable auto-compact so items survive Add(). + // Each item is ~12 tokens, so 3 items = ~36 tokens > budget. + s := NewStore(20) + s.Configure(0, boolPtr(false), 0) - // Store items that exceed budget - s.Add("alpha bravo charlie delta echo foxtrot golf hotel", "short alpha", []string{"a"}, 3) - s.Add("india juliet kilo lima mike november oscar papa", "", []string{"b"}, 5) - s.Add("quebec romeo sierra tango uniform victor whiskey", "", []string{"c"}, 8) + if _, err := s.Add("alpha bravo charlie delta echo foxtrot golf hotel", "short alpha", []string{"a"}, 3); err != nil { + t.Fatalf("Add: %v", err) + } + if _, err := s.Add("india juliet kilo lima mike november oscar papa", "", []string{"b"}, 5); err != nil { + t.Fatalf("Add: %v", err) + } + if _, err := s.Add("quebec romeo sierra tango uniform victor whiskey", "", []string{"c"}, 8); err != nil { + t.Fatalf("Add: %v", err) + } _, used, _, _ := s.Status() if used == 0 { t.Fatal("expected non-zero usage") } - result := s.Compact(0.5) // compact to 50% + result := s.Compact(0.5) // compact to 50% of 20 = 10 tokens if result.TokensAfter > result.TokensBefore { t.Error("compaction should not increase tokens") } - if result.TokensFreed == 0 && result.TokensBefore > 100 { + if result.TokensFreed == 0 { t.Error("expected some tokens freed") } // Item with summary should have been promoted if result.Summarized == 0 { - t.Log("note: no summary promotions (may depend on budget math)") + t.Error("expected at least one summary promotion") } } func TestSummaryPromotion(t *testing.T) { - s := NewStore(100) // very tight budget + // Budget of 20 tokens. Content is ~23 tokens, summary is ~7 tokens. + // Disable auto-compact. Compaction to 0.3 = target 6 tokens, so promotion must happen. + s := NewStore(20) + s.Configure(0, boolPtr(false), 0) - item := s.Add( + item, _ := s.Add( "this is a very long content string that takes many tokens to represent in the context window", "long content, many tokens", nil, 5, @@ -116,7 +133,7 @@ func TestSummaryPromotion(t *testing.T) { result := s.Compact(0.3) // aggressive compaction if result.Summarized == 0 { - t.Log("note: summary promotion did not trigger") + t.Error("expected summary promotion to trigger") } got, ok := s.Get(item.ID) @@ -129,24 +146,40 @@ func TestSummaryPromotion(t *testing.T) { } func TestDeduplication(t *testing.T) { - s := NewStore(100) + // Budget of 20 tokens, two items with 5/6 word overlap (>70%). + // Disable auto-compact so both items survive Add(). + s := NewStore(20) + s.Configure(0, boolPtr(false), 0) - s.Add("alpha bravo charlie delta echo foxtrot", "", []string{"a"}, 5) - s.Add("alpha bravo charlie delta echo golf", "", []string{"b"}, 5) + if _, err := s.Add("alpha bravo charlie delta echo foxtrot", "", []string{"a"}, 5); err != nil { + t.Fatalf("Add: %v", err) + } + if _, err := s.Add("alpha bravo charlie delta echo golf", "", []string{"b"}, 5); err != nil { + t.Fatalf("Add: %v", err) + } _, _, countBefore, _ := s.Status() - s.Compact(0.3) + if countBefore != 2 { + t.Fatalf("expected 2 items before compact, got %d", countBefore) + } + result := s.Compact(0.3) _, _, countAfter, _ := s.Status() if countAfter >= countBefore { - t.Log("note: dedup did not reduce count (similarity may be below threshold)") + t.Errorf("dedup should reduce count: before=%d after=%d", countBefore, countAfter) + } + if result.Deduplicated == 0 { + t.Error("expected at least one deduplication") } } func TestUpdateSummary(t *testing.T) { s := NewStore(100000) - item := s.Add("full content here", "", nil, 5) + item, err := s.Add("full content here", "", nil, 5) + if err != nil { + t.Fatalf("Add failed: %v", err) + } if item.Summary != "" { t.Fatal("should have no summary initially") } @@ -182,13 +215,13 @@ func TestJaccardSimilarity(t *testing.T) { } func TestAutoCompact(t *testing.T) { - s := NewStore(50) // very small: 50 tokens = ~200 chars + s := NewStore(10) // very small: 10 tokens s.Configure(0, boolPtr(true), 0.5) - // Add items that should trigger auto-compaction - s.Add("first item with some content padding", "first", nil, 3) - s.Add("second item with more content padding", "second", nil, 3) - s.Add("third item triggering compaction now", "third", nil, 3) + // Add items that should trigger auto-compaction (errors ignored; auto-compact may evict) + s.Add("first item with some content padding", "first", nil, 3) //nolint:gosec + s.Add("second item with more content padding", "second", nil, 3) //nolint:gosec + s.Add("third item triggering compaction now", "third", nil, 3) //nolint:gosec _, used, _, usage := s.Status() // Auto-compact should have kept usage reasonable @@ -197,4 +230,279 @@ func TestAutoCompact(t *testing.T) { } } +func TestUnpin(t *testing.T) { + s := NewStore(100000) + + item, err := s.Add("test content for unpin", "", nil, 5) + if err != nil { + t.Fatalf("Add failed: %v", err) + } + + // Pin then unpin + if !s.Pin(item.ID) { + t.Fatal("pin failed") + } + got, _ := s.Get(item.ID) + if !got.Pinned { + t.Fatal("should be pinned") + } + + if !s.Unpin(item.ID) { + t.Fatal("unpin failed") + } + got, _ = s.Get(item.ID) + if got.Pinned { + t.Fatal("should be unpinned") + } + + // Unpin non-existent item + if s.Unpin("nonexistent") { + t.Error("unpin of non-existent item should return false") + } +} + +func TestListItems(t *testing.T) { + s := NewStore(100000) + + // Add 5 items + for i := 0; i < 5; i++ { + _, err := s.Add("content "+string(rune('a'+i)), "", nil, 5) + if err != nil { + t.Fatalf("Add %d failed: %v", i, err) + } + } + + // List all + items, total := s.ListItems(0, 10) + if total != 5 { + t.Errorf("expected total 5, got %d", total) + } + if len(items) != 5 { + t.Errorf("expected 5 items, got %d", len(items)) + } + + // List with offset + items, total = s.ListItems(3, 10) + if total != 5 { + t.Errorf("expected total 5, got %d", total) + } + if len(items) != 2 { + t.Errorf("expected 2 items with offset 3, got %d", len(items)) + } + + // List with limit + items, total = s.ListItems(0, 2) + if total != 5 { + t.Errorf("expected total 5, got %d", total) + } + if len(items) != 2 { + t.Errorf("expected 2 items with limit 2, got %d", len(items)) + } + + // List beyond end + items, total = s.ListItems(10, 5) + if total != 5 { + t.Errorf("expected total 5, got %d", total) + } + if items != nil { + t.Errorf("expected nil items beyond end, got %d", len(items)) + } +} + +func TestBulkAdd(t *testing.T) { + s := NewStore(100000) + + bulkItems := []BulkItem{ + {Content: "first bulk item", Summary: "first", Tags: []string{"bulk"}, Importance: 5}, + {Content: "second bulk item", Summary: "second", Tags: []string{"bulk"}, Importance: 7}, + {Content: "third bulk item", Summary: "", Tags: nil, Importance: 3}, + } + + results, errs := s.BulkAdd(bulkItems) + if len(results) != 3 { + t.Fatalf("expected 3 results, got %d", len(results)) + } + for i, err := range errs { + if err != nil { + t.Errorf("bulk add item %d failed: %v", i, err) + } + } + for i, item := range results { + if item == nil { + t.Errorf("bulk add item %d is nil", i) + } else if item.ID == "" { + t.Errorf("bulk add item %d has empty ID", i) + } + } + + _, _, count, _ := s.Status() + if count != 3 { + t.Errorf("expected 3 items in store, got %d", count) + } +} + +func TestExport(t *testing.T) { + s := NewStore(100000) + + if _, err := s.Add("full content one", "summary one", []string{"a"}, 5); err != nil { + t.Fatalf("Add: %v", err) + } + if _, err := s.Add("full content two", "", []string{"b"}, 7); err != nil { + t.Fatalf("Add: %v", err) + } + + // Export all + items := s.Export(false) + if len(items) != 2 { + t.Fatalf("expected 2 exported items, got %d", len(items)) + } + for _, item := range items { + if strings.HasPrefix(item.Content, "summary") { + t.Error("full export should not replace content with summary") + } + } + + // Export summaries only + items = s.Export(true) + if len(items) != 2 { + t.Fatalf("expected 2 exported items, got %d", len(items)) + } + foundSummary := false + foundFull := false + for _, item := range items { + if item.Content == "summary one" { + foundSummary = true + } + if item.Content == "full content two" { + foundFull = true // no summary available, keep full content + } + } + if !foundSummary { + t.Error("summaries_only should replace content with summary where available") + } + if !foundFull { + t.Error("summaries_only should keep full content when no summary available") + } +} + +func TestItemCountLimit(t *testing.T) { + // We can't add 10001 items in a test (too slow), but we can test the limit + // by lowering the effective count. Instead, test with a smaller approach: + // fill the store to capacity and verify the error. + s := NewStore(1000000) + + // Override: we test by directly checking the error message + // Add one item, then manipulate to test boundary + item, err := s.Add("test", "", nil, 5) + if err != nil { + t.Fatalf("first add failed: %v", err) + } + if item == nil { + t.Fatal("expected non-nil item") + } + + // Add content that's too large + bigContent := strings.Repeat("x", maxContentBytes+1) + _, err = s.Add(bigContent, "", nil, 5) + if err == nil { + t.Error("expected error for oversized content") + } + if err != nil && !strings.Contains(err.Error(), "content too large") { + t.Errorf("unexpected error: %v", err) + } +} + +func TestContentSizeLimit(t *testing.T) { + s := NewStore(100000) + + // Exactly at limit should succeed + content := strings.Repeat("x", maxContentBytes) + _, err := s.Add(content, "", nil, 5) + if err != nil { + t.Errorf("content at exact limit should succeed: %v", err) + } + + // Over limit should fail + content = strings.Repeat("x", maxContentBytes+1) + _, err = s.Add(content, "", nil, 5) + if err == nil { + t.Error("content over limit should fail") + } +} + +func TestQueryAccessCountFix(t *testing.T) { + s := NewStore(100000) + + // Add 5 items + ids := make([]string, 5) + for i := 0; i < 5; i++ { + item, _ := s.Add("item "+string(rune('a'+i)), "", nil, 5) + ids[i] = item.ID + } + + // Query with limit 2 - only the top 2 should get access bumps + s.Query("item", nil, 2) + + // Check that we got exactly 2 items with AccessCount > 0 + bumped := 0 + for _, id := range ids { + got, ok := s.Get(id) + if !ok { + t.Fatalf("item %s not found", id) + } + // Get itself bumps access count by 1, so items that were + // bumped by Query will have AccessCount >= 2 after Get + if got.AccessCount >= 2 { + bumped++ + } + } + // Only the 2 items returned by Query should have been bumped + // (plus the Get call bumps all by 1) + if bumped > 2 { + t.Errorf("expected at most 2 items bumped by query, got %d", bumped) + } +} + +func TestGetReturnsValueCopy(t *testing.T) { + s := NewStore(100000) + + item, _ := s.Add("original content", "", nil, 5) + got, ok := s.Get(item.ID) + if !ok { + t.Fatal("item not found") + } + + // Mutating the returned copy should not affect the store. + // We use a helper to avoid the unusedwrite lint. + mutateItemContent(&got, "mutated") + got2, _ := s.Get(item.ID) + if got2.Content != "original content" { + t.Error("Get should return value copy; mutation should not affect store") + } +} + +func mutateItemContent(item *Item, content string) { item.Content = content } + +func TestQueryReturnsValueCopies(t *testing.T) { + s := NewStore(100000) + + if _, err := s.Add("original query content", "", []string{"test"}, 5); err != nil { + t.Fatalf("Add: %v", err) + } + results := s.Query("original", nil, 10) + if len(results) != 1 { + t.Fatalf("expected 1 result, got %d", len(results)) + } + + // Mutating the returned copy should not affect the store + results[0].Content = "mutated" + results2 := s.Query("original", nil, 10) + if len(results2) != 1 { + t.Fatalf("expected 1 result, got %d", len(results2)) + } + if results2[0].Content == "mutated" { + t.Error("Query should return value copies; mutation should not affect store") + } +} + func boolPtr(b bool) *bool { return &b } diff --git a/tools.go b/tools.go index c37bf56..be2242c 100644 --- a/tools.go +++ b/tools.go @@ -2,6 +2,7 @@ package main import ( "context" + "encoding/json" "fmt" "strings" @@ -10,6 +11,13 @@ import ( ) func registerTools(s *server.MCPServer, store *Store) { + s.AddTool(mcp.NewTool("recall", + mcp.WithDescription("Restore working context from previous sessions. Call this FIRST at the start of every session. Returns budget status and the most important stored items. Use this before re-reading files or asking the user to repeat information."), + mcp.WithReadOnlyHintAnnotation(true), + mcp.WithDestructiveHintAnnotation(false), + mcp.WithNumber("limit", mcp.Description("Max items to return (default 20)")), + ), handleRecall(store)) + s.AddTool(mcp.NewTool("store", mcp.WithDescription("Store a context item for later retrieval. Offload information from working context. Provide a summary for efficient compaction when budget is tight."), mcp.WithString("content", mcp.Required(), mcp.Description("The content to store")), @@ -46,6 +54,13 @@ func registerTools(s *server.MCPServer, store *Store) { mcp.WithString("id", mcp.Required(), mcp.Description("Item ID to pin")), ), handlePin(store)) + s.AddTool(mcp.NewTool("unpin", + mcp.WithDescription("Unpin a context item to allow automatic eviction during compaction."), + mcp.WithDestructiveHintAnnotation(false), + mcp.WithIdempotentHintAnnotation(true), + mcp.WithString("id", mcp.Required(), mcp.Description("Item ID to unpin")), + ), handleUnpin(store)) + s.AddTool(mcp.NewTool("forget", mcp.WithDescription("Remove a context item from storage."), mcp.WithDestructiveHintAnnotation(true), @@ -68,6 +83,75 @@ func registerTools(s *server.MCPServer, store *Store) { mcp.WithString("id", mcp.Required(), mcp.Description("Item ID to update")), mcp.WithString("summary", mcp.Required(), mcp.Description("New summary for the item")), ), handleUpdate(store)) + + s.AddTool(mcp.NewTool("list", + mcp.WithDescription("List stored context items with pagination. Returns items sorted by creation time (newest first)."), + mcp.WithReadOnlyHintAnnotation(true), + mcp.WithDestructiveHintAnnotation(false), + mcp.WithNumber("offset", mcp.Description("Number of items to skip (default 0)")), + mcp.WithNumber("limit", mcp.Description("Max items to return (default 20)")), + ), handleList(store)) + + s.AddTool(mcp.NewTool("bulk_store", + mcp.WithDescription("Store multiple context items at once. Accepts a JSON array of items."), + mcp.WithString("items", mcp.Required(), mcp.Description("JSON array of items: [{\"content\":\"...\",\"summary\":\"...\",\"tags\":[\"...\"],\"importance\":5}]")), + ), handleBulkStore(store)) + + s.AddTool(mcp.NewTool("export", + mcp.WithDescription("Export all stored context items. Optionally return summaries instead of full content where available."), + mcp.WithReadOnlyHintAnnotation(true), + mcp.WithDestructiveHintAnnotation(false), + mcp.WithBoolean("summaries_only", mcp.Description("If true, return summaries instead of full content where available")), + ), handleExport(store)) +} + +func handleRecall(store *Store) server.ToolHandlerFunc { + return func(ctx context.Context, req mcp.CallToolRequest) (*mcp.CallToolResult, error) { + limit := req.GetInt("limit", 20) + + budget, used, count, usage, items := store.Recall(limit) + + var sb strings.Builder + fmt.Fprintf(&sb, "Budget: %d/%d tokens (%.1f%%), %d items stored\n", used, budget, usage*100, count) + + if count == 0 { + sb.WriteString("\nNo stored context. Start using 'store' to offload information.") + return mcp.NewToolResultText(sb.String()), nil + } + + tight := store.BudgetTight() + fmt.Fprintf(&sb, "\nTop %d items by relevance", len(items)) + if tight { + sb.WriteString(" (budget tight, showing summaries where available)") + } + sb.WriteString(":\n\n") + + for _, item := range items { + fmt.Fprintf(&sb, "[%s] importance:%d tokens:%d", item.ID, item.Importance, item.Tokens) + if item.Pinned { + sb.WriteString(" PINNED") + } + sb.WriteString("\n") + if len(item.Tags) > 0 { + fmt.Fprintf(&sb, "Tags: %s\n", strings.Join(item.Tags, ", ")) + } + sb.WriteString("---\n") + if tight && item.Summary != "" { + sb.WriteString(item.Summary) + } else if item.Summary != "" { + sb.WriteString(item.Summary) + } else { + preview := item.Content + if len(preview) > 200 { + preview = preview[:200] + "..." + } + sb.WriteString(preview) + } + sb.WriteString("\n\n") + } + + return mcp.NewToolResultText(sb.String()), nil + } } func handleStore(store *Store) server.ToolHandlerFunc { @@ -90,7 +174,10 @@ func handleStore(store *Store) server.ToolHandlerFunc { } } - item := store.Add(content, summary, tags, importance) + item, addErr := store.Add(content, summary, tags, importance) + if addErr != nil { + return mcp.NewToolResultError(addErr.Error()), nil + } budget, used, count, usage := store.Status() return mcp.NewToolResultText(fmt.Sprintf( @@ -164,7 +251,7 @@ func handleStatus(store *Store) server.ToolHandlerFunc { func handleCompact(store *Store) server.ToolHandlerFunc { return func(ctx context.Context, req mcp.CallToolRequest) (*mcp.CallToolResult, error) { target := req.GetFloat("target_usage", 0.7) - if target <= 0 || target > 1.0 { + if target < 0 || target > 1.0 { target = 0.7 } @@ -176,16 +263,16 @@ func handleCompact(store *Store) server.ToolHandlerFunc { fmt.Fprintf(&sb, "- Summary promoted: %d items\n", result.Summarized) fmt.Fprintf(&sb, "- Deduplicated: %d pairs\n", result.Deduplicated) fmt.Fprintf(&sb, "- Tokens freed: %d\n", result.TokensFreed) - fmt.Fprintf(&sb, "- Budget: %d → %d tokens\n", result.TokensBefore, result.TokensAfter) + fmt.Fprintf(&sb, "- Budget: %d -> %d tokens\n", result.TokensBefore, result.TokensAfter) if len(result.NeedsSummary) > 0 { sb.WriteString("\nItems that would benefit from summarization:\n") - for _, item := range result.NeedsSummary { - preview := item.Content + for _, sc := range result.NeedsSummary { + preview := sc.Preview if len(preview) > 80 { - preview = preview[:80] + "..." + preview = preview[:80] } - fmt.Fprintf(&sb, "- [%s] %d tokens: %q\n", item.ID, item.Tokens, preview) + fmt.Fprintf(&sb, "- [%s] %d tokens: %q\n", sc.ID, sc.Tokens, preview) } sb.WriteString("\nUse 'update' tool to add summaries to these items.") } @@ -207,6 +294,19 @@ func handlePin(store *Store) server.ToolHandlerFunc { } } +func handleUnpin(store *Store) server.ToolHandlerFunc { + return func(ctx context.Context, req mcp.CallToolRequest) (*mcp.CallToolResult, error) { + id, err := req.RequireString("id") + if err != nil { + return mcp.NewToolResultError(err.Error()), nil + } + if store.Unpin(id) { + return mcp.NewToolResultText(fmt.Sprintf("Unpinned [%s]", id)), nil + } + return mcp.NewToolResultError(fmt.Sprintf("Item [%s] not found", id)), nil + } +} + func handleForget(store *Store) server.ToolHandlerFunc { return func(ctx context.Context, req mcp.CallToolRequest) (*mcp.CallToolResult, error) { id, err := req.RequireString("id") @@ -265,3 +365,105 @@ func handleUpdate(store *Store) server.ToolHandlerFunc { return mcp.NewToolResultError(fmt.Sprintf("Item [%s] not found", id)), nil } } + +func handleList(store *Store) server.ToolHandlerFunc { + return func(ctx context.Context, req mcp.CallToolRequest) (*mcp.CallToolResult, error) { + offset := req.GetInt("offset", 0) + limit := req.GetInt("limit", 20) + + items, total := store.ListItems(offset, limit) + if len(items) == 0 { + return mcp.NewToolResultText(fmt.Sprintf("No items (total: %d).", total)), nil + } + + var sb strings.Builder + fmt.Fprintf(&sb, "Items %d-%d of %d:\n\n", offset+1, offset+len(items), total) + for _, item := range items { + fmt.Fprintf(&sb, "[%s] importance:%d tokens:%d", item.ID, item.Importance, item.Tokens) + if item.Pinned { + sb.WriteString(" PINNED") + } + sb.WriteString("\n") + if len(item.Tags) > 0 { + fmt.Fprintf(&sb, "Tags: %s\n", strings.Join(item.Tags, ", ")) + } + preview := item.Content + if len(preview) > 80 { + preview = preview[:80] + "..." + } + fmt.Fprintf(&sb, "%s\n\n", preview) + } + return mcp.NewToolResultText(sb.String()), nil + } +} + +func handleBulkStore(store *Store) server.ToolHandlerFunc { + return func(ctx context.Context, req mcp.CallToolRequest) (*mcp.CallToolResult, error) { + itemsJSON, err := req.RequireString("items") + if err != nil { + return mcp.NewToolResultError(err.Error()), nil + } + + var bulkItems []BulkItem + if err := json.Unmarshal([]byte(itemsJSON), &bulkItems); err != nil { + return mcp.NewToolResultError(fmt.Sprintf("Invalid JSON: %v", err)), nil + } + if len(bulkItems) == 0 { + return mcp.NewToolResultError("No items provided"), nil + } + + results, errs := store.BulkAdd(bulkItems) + + var sb strings.Builder + stored := 0 + failed := 0 + for i, item := range results { + if errs[i] != nil { + failed++ + fmt.Fprintf(&sb, "FAILED item %d: %v\n", i+1, errs[i]) + } else { + stored++ + fmt.Fprintf(&sb, "Stored [%s] (%d tokens)\n", item.ID, item.Tokens) + } + } + + budget, used, count, usage := store.Status() + fmt.Fprintf(&sb, "\nStored: %d, Failed: %d\nBudget: %d/%d tokens (%.0f%%), %d items", + stored, failed, used, budget, usage*100, count) + + return mcp.NewToolResultText(sb.String()), nil + } +} + +func handleExport(store *Store) server.ToolHandlerFunc { + return func(ctx context.Context, req mcp.CallToolRequest) (*mcp.CallToolResult, error) { + summariesOnly := req.GetBool("summaries_only", false) + + items := store.Export(summariesOnly) + if len(items) == 0 { + return mcp.NewToolResultText("No items to export."), nil + } + + var sb strings.Builder + fmt.Fprintf(&sb, "Exported %d items", len(items)) + if summariesOnly { + sb.WriteString(" (summaries where available)") + } + sb.WriteString(":\n\n") + + for _, item := range items { + fmt.Fprintf(&sb, "[%s] importance:%d tokens:%d", item.ID, item.Importance, item.Tokens) + if item.Pinned { + sb.WriteString(" PINNED") + } + sb.WriteString("\n") + if len(item.Tags) > 0 { + fmt.Fprintf(&sb, "Tags: %s\n", strings.Join(item.Tags, ", ")) + } + sb.WriteString("---\n") + sb.WriteString(item.Content) + sb.WriteString("\n\n") + } + return mcp.NewToolResultText(sb.String()), nil + } +}