general improvements (#17)

* refactor(hooks): simplify hook execution with shared context

- [x] Extract BaseInput struct to eliminate duplicate fields across hooks
- [x] Create RunHook handler pattern for session-start and user-prompt
- [x] Create RunStatuslineHook for fast statusline rendering without worker startup
- [x] Add HookContext struct to pass port, project, CWD, SessionID to handlers
- [x] Add db/interface.go with ObservationReader/Writer interfaces
- [x] Add comprehensive conflict management tests in sqlite/conflict_test.go
- [x] Add vector client tests for Count, ModelVersion, NeedsRebuild, GetStaleVectors
- [x] Add FilterByThreshold helper tests for query result filtering
- [x] Make handlers_test more robust for network-dependent update checks
- [x] Update package versions in UI

* Move to GORM + general cleanup

* feat(mcp): add observation relations discovery and scoring integration

- [x] Add find_related_observations MCP tool for discovering related observations by confidence
- [x] Integrate scoring calculator and recalculator into MCP server initialization
- [x] Add pattern, relation, and session stores to MCP server dependencies
- [x] Register MCP server in Claude Code settings during plugin installation
- [x] Update install scripts (bash, PowerShell) to configure MCP server settings
- [x] Switch plugin manifest files to template-based versioning (plugin.json.tpl, marketplace.json.tpl)
- [x] Update all MCP server tests to pass new dependency parameters
This commit is contained in:
2026-01-07 00:26:20 +00:00
committed by GitHub
parent 92a99c7615
commit 7a061c85eb
85 changed files with 8445 additions and 8202 deletions
+11 -53
View File
@@ -2,9 +2,7 @@
package main
import (
"encoding/json"
"fmt"
"io"
"net/url"
"os"
"strings"
@@ -14,11 +12,8 @@ import (
// Input is the hook input from Claude Code.
type Input struct {
SessionID string `json:"session_id"`
CWD string `json:"cwd"`
PermissionMode string `json:"permission_mode"`
HookEventName string `json:"hook_event_name"`
Source string `json:"source"` // "startup", "resume", "clear", "compact"
hooks.BaseInput
Source string `json:"source"` // "startup", "resume", "clear", "compact"
}
// Observation represents an observation from the API.
@@ -32,53 +27,26 @@ type Observation struct {
}
func main() {
// Skip if this is an internal call (from SDK processor)
if os.Getenv("CLAUDE_MNEMONIC_INTERNAL") == "1" {
hooks.WriteResponse("SessionStart", true)
return
}
// Read input from stdin
inputData, err := io.ReadAll(os.Stdin)
if err != nil {
hooks.WriteError("SessionStart", err)
os.Exit(1)
}
var input Input
if err := json.Unmarshal(inputData, &input); err != nil {
hooks.WriteError("SessionStart", err)
os.Exit(1)
}
// Ensure worker is running
port, err := hooks.EnsureWorkerRunning()
if err != nil {
hooks.WriteError("SessionStart", err)
os.Exit(1)
}
// Generate unique project ID from CWD (dirname_hash format)
project := hooks.ProjectIDWithName(input.CWD)
hooks.RunHook("SessionStart", handleSessionStart)
}
func handleSessionStart(ctx *hooks.HookContext, input *Input) (string, error) {
// Fetch observations for context injection
endpoint := fmt.Sprintf("/api/context/inject?project=%s&cwd=%s",
url.QueryEscape(project),
url.QueryEscape(input.CWD))
url.QueryEscape(ctx.Project),
url.QueryEscape(ctx.CWD))
result, err := hooks.GET(port, endpoint)
result, err := hooks.GET(ctx.Port, endpoint)
if err != nil {
fmt.Fprintf(os.Stderr, "[claude-mnemonic] Warning: context fetch failed: %v\n", err)
hooks.WriteResponse("SessionStart", true)
return
return "", nil
}
// Parse observations from response
obsData, ok := result["observations"].([]interface{})
if !ok || len(obsData) == 0 {
// No observations - just continue normally
hooks.WriteResponse("SessionStart", true)
return
return "", nil
}
// Get full_count from response (how many observations get full detail)
@@ -136,17 +104,7 @@ func main() {
}
contextBuilder += "</claude-mnemonic-context>\n"
// Output context as JSON with additionalContext field
response := map[string]interface{}{
"continue": true,
"hookSpecificOutput": map[string]interface{}{
"hookEventName": "SessionStart",
"additionalContext": contextBuilder,
},
}
_ = json.NewEncoder(os.Stdout).Encode(response)
os.Exit(0)
return contextBuilder, nil
}
func getString(m map[string]interface{}, key string) string {
+46 -79
View File
@@ -5,7 +5,6 @@ package main
import (
"encoding/json"
"fmt"
"io"
"net/http"
"net/url"
"os"
@@ -72,18 +71,13 @@ const (
)
func main() {
// Read input from stdin
inputData, err := io.ReadAll(os.Stdin)
if err != nil {
// On error, output minimal status
fmt.Println(formatOffline())
return
}
hooks.RunStatuslineHook(handleStatusline)
}
var input StatusInput
if err := json.Unmarshal(inputData, &input); err != nil {
fmt.Println(formatOffline())
return
func handleStatusline(input *StatusInput, port int) string {
// Handle error cases (nil input)
if input == nil {
return formatOffline()
}
// Determine project directory
@@ -102,16 +96,14 @@ func main() {
}
// Get worker stats
stats := getWorkerStats(project)
stats := getWorkerStats(port, project)
// Format and output statusline
fmt.Println(formatStatusLine(stats, input))
// Format and return statusline
return formatStatusLine(stats, *input)
}
// getWorkerStats fetches stats from the worker service.
func getWorkerStats(project string) *WorkerStats {
port := hooks.GetWorkerPort()
func getWorkerStats(port int, project string) *WorkerStats {
// Build URL with optional project parameter
endpoint := fmt.Sprintf("http://127.0.0.1:%d/api/stats", port)
if project != "" {
@@ -187,54 +179,45 @@ func formatDefault(stats *WorkerStats, useColors bool) string {
}
// Build status parts with clear labels
parts := []string{}
parts := []string{
prefix,
indicator,
}
// Total memories served to Claude this session
parts = append(parts, fmt.Sprintf("served:%d", stats.Retrieval.ObservationsServed))
// Context injections (memories auto-loaded at session start)
// Add retrieval stats if available
if stats.Retrieval.ObservationsServed > 0 {
parts = append(parts, fmt.Sprintf("served:%d", stats.Retrieval.ObservationsServed))
}
if stats.Retrieval.ContextInjections > 0 {
parts = append(parts, fmt.Sprintf("injected:%d", stats.Retrieval.ContextInjections))
}
// Semantic searches performed
if stats.Retrieval.SearchRequests > 0 {
parts = append(parts, fmt.Sprintf("searches:%d", stats.Retrieval.SearchRequests))
}
// Project-specific memory count
// Add project-specific observation count if available
if stats.ProjectObservations > 0 {
if useColors {
parts = append(parts, fmt.Sprintf("%sproject:%d memories%s", colorYellow, stats.ProjectObservations, reset))
} else {
parts = append(parts, fmt.Sprintf("project:%d memories", stats.ProjectObservations))
}
parts = append(parts, fmt.Sprintf("project:%d memories", stats.ProjectObservations))
}
// Processing indicator
if stats.IsProcessing || stats.QueueDepth > 0 {
if useColors {
parts = append(parts, colorYellow+"processing..."+colorReset)
} else {
parts = append(parts, "processing...")
}
}
result := prefix + " " + indicator
for i, part := range parts {
if i == 0 {
result += " " + part
} else {
result += " | " + part
// Join with separators
result := parts[0] + " " + parts[1]
if len(parts) > 2 {
for i := 2; i < len(parts); i++ {
if useColors {
result += colorGray + " | " + reset + parts[i]
} else {
result += " | " + parts[i]
}
}
}
return result
}
// formatCompact returns a compact status line.
// formatCompact returns a compact status line format.
func formatCompact(stats *WorkerStats, useColors bool) string {
// [m] ● 42/5/3 (28)
// [m] ● 42/5/3
var prefix, indicator string
if useColors {
prefix = colorCyan + "[m]" + colorReset
@@ -244,31 +227,16 @@ func formatCompact(stats *WorkerStats, useColors bool) string {
indicator = "●"
}
result := fmt.Sprintf("%s %s %d/%d/%d",
return fmt.Sprintf("%s %s %d/%d/%d",
prefix, indicator,
stats.Retrieval.ObservationsServed,
stats.Retrieval.ContextInjections,
stats.Retrieval.SearchRequests,
)
if stats.ProjectObservations > 0 {
result += fmt.Sprintf(" (%d)", stats.ProjectObservations)
}
if stats.IsProcessing || stats.QueueDepth > 0 {
if useColors {
result += " " + colorYellow + "⚙" + colorReset
} else {
result += " ⚙"
}
}
return result
stats.Retrieval.SearchRequests)
}
// formatMinimal returns a minimal status line.
// formatMinimal returns a minimal status line format.
func formatMinimal(stats *WorkerStats, useColors bool) string {
// ● 42 obs
// ● 28 memories
var indicator string
if useColors {
indicator = colorGreen + "●" + colorReset
@@ -276,32 +244,31 @@ func formatMinimal(stats *WorkerStats, useColors bool) string {
indicator = "●"
}
result := fmt.Sprintf("%s %d", indicator, stats.Retrieval.ObservationsServed)
if stats.ProjectObservations > 0 {
result += fmt.Sprintf("/%d", stats.ProjectObservations)
return fmt.Sprintf("%s %d memories", indicator, stats.ProjectObservations)
}
return result
return fmt.Sprintf("%s mnemonic ready", indicator)
}
// formatOffline returns the offline status.
// formatOffline returns status for when worker is offline.
func formatOffline() string {
return formatOfflineColored(true)
useColors := os.Getenv("NO_COLOR") == "" && os.Getenv("TERM") != "dumb"
return formatOfflineColored(useColors)
}
// formatOfflineColored returns the offline status with optional colors.
// formatOfflineColored returns colored offline status.
func formatOfflineColored(useColors bool) string {
if useColors {
return colorCyan + "[mnemonic]" + colorReset + " " + colorGray + "○" + colorReset
return colorGray + "[mnemonic]" + colorReset + " " + colorGray + "○" + colorReset + " offline"
}
return "[mnemonic] ○"
return "[mnemonic] ○ offline"
}
// formatStartingColored returns the starting status with optional colors.
// formatStartingColored returns colored starting status.
func formatStartingColored(useColors bool) string {
if useColors {
return colorCyan + "[mnemonic]" + colorReset + " " + colorYellow + "" + colorReset + " starting"
return colorYellow + "[mnemonic]" + colorReset + " " + colorYellow + "" + colorReset + " starting..."
}
return "[mnemonic] starting"
return "[mnemonic] starting..."
}
+19 -62
View File
@@ -2,9 +2,7 @@
package main
import (
"encoding/json"
"fmt"
"io"
"net/url"
"os"
@@ -13,53 +11,25 @@ import (
// Input is the hook input from Claude Code.
type Input struct {
SessionID string `json:"session_id"`
CWD string `json:"cwd"`
PermissionMode string `json:"permission_mode"`
HookEventName string `json:"hook_event_name"`
Prompt string `json:"prompt"`
hooks.BaseInput
Prompt string `json:"prompt"`
}
func main() {
// Skip if this is an internal call (from SDK processor)
if os.Getenv("CLAUDE_MNEMONIC_INTERNAL") == "1" {
hooks.WriteResponse("UserPromptSubmit", true)
return
}
// Read input from stdin
inputData, err := io.ReadAll(os.Stdin)
if err != nil {
hooks.WriteError("UserPromptSubmit", err)
os.Exit(1)
}
var input Input
if err := json.Unmarshal(inputData, &input); err != nil {
hooks.WriteError("UserPromptSubmit", err)
os.Exit(1)
}
// Ensure worker is running
port, err := hooks.EnsureWorkerRunning()
if err != nil {
hooks.WriteError("UserPromptSubmit", err)
os.Exit(1)
}
// Generate unique project ID from CWD
project := hooks.ProjectIDWithName(input.CWD)
hooks.RunHook("UserPromptSubmit", handleUserPrompt)
}
func handleUserPrompt(ctx *hooks.HookContext, input *Input) (string, error) {
// Search for relevant observations based on the prompt
searchURL := fmt.Sprintf("/api/context/search?project=%s&query=%s&cwd=%s",
url.QueryEscape(project),
url.QueryEscape(ctx.Project),
url.QueryEscape(input.Prompt),
url.QueryEscape(input.CWD))
url.QueryEscape(ctx.CWD))
var contextToInject string
var observationCount int
searchResult, _ := hooks.GET(port, searchURL)
searchResult, _ := hooks.GET(ctx.Port, searchURL)
if observations, ok := searchResult["observations"].([]interface{}); ok && len(observations) > 0 {
// Results are already filtered by relevance threshold and capped by max_results
// from the server-side config (ContextRelevanceThreshold, ContextMaxPromptResults)
@@ -104,27 +74,24 @@ func main() {
}
contextBuilder += "</relevant-memory>\n"
contextToInject = contextBuilder
}
// Initialize session with matched observations count
result, err := hooks.POST(port, "/api/sessions/init", map[string]interface{}{
"claudeSessionId": input.SessionID,
"project": project,
result, err := hooks.POST(ctx.Port, "/api/sessions/init", map[string]interface{}{
"claudeSessionId": ctx.SessionID,
"project": ctx.Project,
"prompt": input.Prompt,
"matchedObservations": observationCount,
})
if err != nil {
hooks.WriteError("UserPromptSubmit", err)
os.Exit(1)
return "", err
}
// Check if skipped due to privacy
if skipped, ok := result["skipped"].(bool); ok && skipped {
fmt.Fprintf(os.Stderr, "[user-prompt] Session skipped (private)\n")
hooks.WriteResponse("UserPromptSubmit", true)
return
return "", nil
}
sessionID := int64(result["sessionDbId"].(float64))
@@ -133,30 +100,20 @@ func main() {
fmt.Fprintf(os.Stderr, "[user-prompt] Session %d, prompt #%d\n", sessionID, promptNumber)
// Start SDK agent
_, err = hooks.POST(port, fmt.Sprintf("/sessions/%d/init", sessionID), map[string]interface{}{
_, err = hooks.POST(ctx.Port, fmt.Sprintf("/sessions/%d/init", sessionID), map[string]interface{}{
"userPrompt": input.Prompt,
"promptNumber": promptNumber,
})
if err != nil {
hooks.WriteError("UserPromptSubmit", err)
os.Exit(1)
return "", err
}
// Output results - stdout with exit 0 adds context to Claude's prompt
// Return context if we found relevant observations
if observationCount > 0 {
// Show match count to user via stderr
fmt.Fprintf(os.Stderr, "[claude-mnemonic] Found %d relevant memories for this prompt\n", observationCount)
// Output context as JSON with additionalContext field
response := map[string]interface{}{
"continue": true,
"hookSpecificOutput": map[string]interface{}{
"hookEventName": "UserPromptSubmit",
"additionalContext": contextToInject,
},
}
_ = json.NewEncoder(os.Stdout).Encode(response)
os.Exit(0)
} else {
hooks.WriteResponse("UserPromptSubmit", true)
return contextToInject, nil
}
return "", nil
}
+34 -12
View File
@@ -10,12 +10,14 @@ import (
"time"
"github.com/lukaszraczylo/claude-mnemonic/internal/config"
"github.com/lukaszraczylo/claude-mnemonic/internal/db/sqlite"
"github.com/lukaszraczylo/claude-mnemonic/internal/db/gorm"
"github.com/lukaszraczylo/claude-mnemonic/internal/embedding"
"github.com/lukaszraczylo/claude-mnemonic/internal/mcp"
"github.com/lukaszraczylo/claude-mnemonic/internal/scoring"
"github.com/lukaszraczylo/claude-mnemonic/internal/search"
"github.com/lukaszraczylo/claude-mnemonic/internal/vector/sqlitevec"
"github.com/lukaszraczylo/claude-mnemonic/internal/watcher"
"github.com/lukaszraczylo/claude-mnemonic/pkg/models"
"github.com/rs/zerolog"
"github.com/rs/zerolog/log"
)
@@ -71,22 +73,25 @@ func main() {
cancel()
}()
// Initialize SQLite store (migrations run automatically)
storeCfg := sqlite.StoreConfig{
// Initialize database store (migrations run automatically)
storeCfg := gorm.Config{
Path: dbPath,
MaxConns: cfg.MaxConns,
WALMode: true,
// WALMode is enabled automatically by GORM
}
store, err := sqlite.NewStore(storeCfg)
store, err := gorm.NewStore(storeCfg)
if err != nil {
log.Fatal().Err(err).Msg("Failed to initialize SQLite store")
log.Fatal().Err(err).Msg("Failed to initialize database store")
}
defer store.Close()
// Initialize stores
observationStore := sqlite.NewObservationStore(store)
summaryStore := sqlite.NewSummaryStore(store)
promptStore := sqlite.NewPromptStore(store)
observationStore := gorm.NewObservationStore(store, nil, nil, nil)
summaryStore := gorm.NewSummaryStore(store)
promptStore := gorm.NewPromptStore(store, nil)
patternStore := gorm.NewPatternStore(store)
relationStore := gorm.NewRelationStore(store)
sessionStore := gorm.NewSessionStore(store)
// Initialize embedding service and vector client
var vectorClient *sqlitevec.Client
@@ -95,7 +100,7 @@ func main() {
log.Warn().Err(err).Msg("Embedding service unavailable, vector search disabled")
} else {
defer embedSvc.Close()
vectorClient, err = sqlitevec.NewClient(sqlitevec.Config{DB: store.DB()}, embedSvc)
vectorClient, err = sqlitevec.NewClient(sqlitevec.Config{DB: store.GetRawDB()}, embedSvc)
if err != nil {
log.Warn().Err(err).Msg("Vector client unavailable, vector search disabled")
} else {
@@ -103,14 +108,31 @@ func main() {
}
}
// Initialize scoring components
scoreConfig := models.DefaultScoringConfig()
scoreCalculator := scoring.NewCalculator(scoreConfig)
recalculator := scoring.NewRecalculator(observationStore, scoreCalculator, log.Logger)
go recalculator.Start(ctx)
defer recalculator.Stop()
// Initialize search manager
searchMgr := search.NewManager(observationStore, summaryStore, promptStore, vectorClient)
// Start file watchers
startWatchers(ctx, dbPath)
// Create and run MCP server
server := mcp.NewServer(searchMgr, Version)
// Create and run MCP server with all dependencies
server := mcp.NewServer(
searchMgr,
Version,
observationStore,
patternStore,
relationStore,
sessionStore,
vectorClient,
scoreCalculator,
recalculator,
)
log.Info().Str("project", *project).Str("version", Version).Msg("Starting MCP server")
if err := server.Run(ctx); err != nil {