mirror of
https://github.com/lukaszraczylo/claude-mnemonic.git
synced 2026-06-11 00:09:28 +00:00
Initial commit
This commit is contained in:
@@ -0,0 +1,51 @@
|
||||
// Package hooks provides hook utilities for claude-mnemonic.
|
||||
package hooks
|
||||
|
||||
import (
|
||||
"crypto/sha256"
|
||||
"encoding/hex"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"os"
|
||||
"path/filepath"
|
||||
)
|
||||
|
||||
// HookResponse is the response sent back to Claude Code.
|
||||
type HookResponse struct {
|
||||
Continue bool `json:"continue"`
|
||||
}
|
||||
|
||||
// ProjectIDWithName returns both the hash ID and the directory name for display.
|
||||
// Format: "dirname_abc123" (name + truncated hash for human-readability)
|
||||
func ProjectIDWithName(cwd string) string {
|
||||
absPath, err := filepath.Abs(cwd)
|
||||
if err != nil {
|
||||
absPath = cwd
|
||||
}
|
||||
|
||||
dirName := filepath.Base(absPath)
|
||||
hash := sha256.Sum256([]byte(absPath))
|
||||
shortHash := hex.EncodeToString(hash[:3]) // 6 chars
|
||||
|
||||
return fmt.Sprintf("%s_%s", dirName, shortHash)
|
||||
}
|
||||
|
||||
// Exit codes for Claude Code hooks
|
||||
const (
|
||||
ExitSuccess = 0
|
||||
ExitFailure = 1
|
||||
ExitUserMessageOnly = 3 // Display stderr as user message
|
||||
)
|
||||
|
||||
// WriteResponse writes a hook response to stdout.
|
||||
func WriteResponse(hookName string, success bool) {
|
||||
response := HookResponse{Continue: success}
|
||||
data, _ := json.Marshal(response)
|
||||
fmt.Println(string(data))
|
||||
}
|
||||
|
||||
// WriteError writes an error message to stderr and exits.
|
||||
func WriteError(hookName string, err error) {
|
||||
fmt.Fprintf(os.Stderr, "[%s] Error: %v\n", hookName, err)
|
||||
WriteResponse(hookName, false)
|
||||
}
|
||||
@@ -0,0 +1,277 @@
|
||||
// Package hooks provides hook utilities for claude-mnemonic.
|
||||
package hooks
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"net"
|
||||
"net/http"
|
||||
"os"
|
||||
"os/exec"
|
||||
"path/filepath"
|
||||
"strconv"
|
||||
"strings"
|
||||
"time"
|
||||
)
|
||||
|
||||
// Version is set at build time via ldflags
|
||||
var Version = "dev"
|
||||
|
||||
const (
|
||||
// DefaultWorkerPort is the default worker port.
|
||||
DefaultWorkerPort = 37777
|
||||
|
||||
// HealthCheckTimeout is the timeout for health checks (reduced from 5s for faster startup).
|
||||
HealthCheckTimeout = 1 * time.Second
|
||||
|
||||
// StartupTimeout is the timeout for worker startup.
|
||||
StartupTimeout = 30 * time.Second
|
||||
)
|
||||
|
||||
// GetWorkerPort returns the worker port from environment or default.
|
||||
func GetWorkerPort() int {
|
||||
if port := os.Getenv("CLAUDE_MNEMONIC_WORKER_PORT"); port != "" {
|
||||
if p, err := strconv.Atoi(port); err == nil && p > 0 {
|
||||
return p
|
||||
}
|
||||
}
|
||||
return DefaultWorkerPort
|
||||
}
|
||||
|
||||
// IsWorkerRunning checks if the worker is running and healthy.
|
||||
func IsWorkerRunning(port int) bool {
|
||||
client := &http.Client{Timeout: HealthCheckTimeout}
|
||||
resp, err := client.Get(fmt.Sprintf("http://127.0.0.1:%d/api/health", port))
|
||||
if err != nil {
|
||||
return false
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
return resp.StatusCode == http.StatusOK
|
||||
}
|
||||
|
||||
// EnsureWorkerRunning ensures the worker is running, starting it if necessary.
|
||||
// If a worker is already running and healthy with matching version, it reuses it.
|
||||
// If version mismatch or unhealthy, it kills the old worker and starts fresh.
|
||||
func EnsureWorkerRunning() (int, error) {
|
||||
port := GetWorkerPort()
|
||||
|
||||
// Check if already running and healthy
|
||||
if IsWorkerRunning(port) {
|
||||
// Check version - if mismatch, restart
|
||||
if runningVersion := GetWorkerVersion(port); runningVersion != "" {
|
||||
if runningVersion != Version {
|
||||
fmt.Fprintf(os.Stderr, "[claude-mnemonic] Worker version mismatch (running: %s, expected: %s), restarting...\n", runningVersion, Version)
|
||||
if err := KillProcessOnPort(port); err != nil {
|
||||
fmt.Fprintf(os.Stderr, "[claude-mnemonic] Warning: failed to kill old worker: %v\n", err)
|
||||
}
|
||||
time.Sleep(500 * time.Millisecond)
|
||||
} else {
|
||||
// Version matches, reuse existing worker
|
||||
return port, nil
|
||||
}
|
||||
} else {
|
||||
// Couldn't get version, assume it's fine
|
||||
return port, nil
|
||||
}
|
||||
}
|
||||
|
||||
// Check if port is in use but worker is unhealthy
|
||||
if IsPortInUse(port) {
|
||||
// Something is using the port but not responding to health checks
|
||||
// Try to kill it
|
||||
if err := KillProcessOnPort(port); err != nil {
|
||||
// Log but continue - maybe it will die on its own
|
||||
fmt.Fprintf(os.Stderr, "[claude-mnemonic] Warning: failed to kill unhealthy process on port %d: %v\n", port, err)
|
||||
}
|
||||
// Wait a moment for port to be released
|
||||
time.Sleep(500 * time.Millisecond)
|
||||
}
|
||||
|
||||
// Find worker binary
|
||||
workerPath := findWorkerBinary()
|
||||
if workerPath == "" {
|
||||
return 0, fmt.Errorf("worker binary not found")
|
||||
}
|
||||
|
||||
// Start worker
|
||||
cmd := exec.Command(workerPath) // #nosec G204 -- workerPath is from internal findWorkerBinary
|
||||
cmd.Stdout = os.Stderr
|
||||
cmd.Stderr = os.Stderr
|
||||
if err := cmd.Start(); err != nil {
|
||||
return 0, fmt.Errorf("failed to start worker: %w", err)
|
||||
}
|
||||
|
||||
// Wait for worker to be ready with exponential backoff
|
||||
deadline := time.Now().Add(StartupTimeout)
|
||||
backoff := 50 * time.Millisecond
|
||||
maxBackoff := 500 * time.Millisecond
|
||||
|
||||
for time.Now().Before(deadline) {
|
||||
if IsWorkerRunning(port) {
|
||||
return port, nil
|
||||
}
|
||||
time.Sleep(backoff)
|
||||
// Exponential backoff with cap
|
||||
backoff = backoff * 2
|
||||
if backoff > maxBackoff {
|
||||
backoff = maxBackoff
|
||||
}
|
||||
}
|
||||
|
||||
return 0, fmt.Errorf("worker failed to start within timeout")
|
||||
}
|
||||
|
||||
// GetWorkerVersion gets the version of the running worker.
|
||||
func GetWorkerVersion(port int) string {
|
||||
client := &http.Client{Timeout: HealthCheckTimeout}
|
||||
resp, err := client.Get(fmt.Sprintf("http://127.0.0.1:%d/api/version", port))
|
||||
if err != nil {
|
||||
return ""
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
if resp.StatusCode != http.StatusOK {
|
||||
return ""
|
||||
}
|
||||
|
||||
var result map[string]string
|
||||
if err := json.NewDecoder(resp.Body).Decode(&result); err != nil {
|
||||
return ""
|
||||
}
|
||||
|
||||
return result["version"]
|
||||
}
|
||||
|
||||
// IsPortInUse checks if the port is in use (regardless of health).
|
||||
func IsPortInUse(port int) bool {
|
||||
conn, err := net.DialTimeout("tcp", fmt.Sprintf("127.0.0.1:%d", port), 500*time.Millisecond)
|
||||
if err != nil {
|
||||
return false
|
||||
}
|
||||
_ = conn.Close()
|
||||
return true
|
||||
}
|
||||
|
||||
// KillProcessOnPort finds and kills the process using the given port.
|
||||
func KillProcessOnPort(port int) error {
|
||||
// Use lsof to find the process (works on macOS and Linux)
|
||||
cmd := exec.Command("lsof", "-t", "-i", fmt.Sprintf(":%d", port)) // #nosec G204 -- port is from internal config
|
||||
output, err := cmd.Output()
|
||||
if err != nil {
|
||||
// lsof returns exit code 1 when no process is found - that's fine
|
||||
if exitErr, ok := err.(*exec.ExitError); ok && exitErr.ExitCode() == 1 {
|
||||
return nil // No process found
|
||||
}
|
||||
return fmt.Errorf("failed to find process on port: %w", err)
|
||||
}
|
||||
|
||||
pidStr := strings.TrimSpace(string(output))
|
||||
if pidStr == "" {
|
||||
return nil // No process found
|
||||
}
|
||||
|
||||
// Handle multiple PIDs (one per line)
|
||||
pids := strings.Split(pidStr, "\n")
|
||||
for _, pid := range pids {
|
||||
pid = strings.TrimSpace(pid)
|
||||
if pid == "" {
|
||||
continue
|
||||
}
|
||||
|
||||
// Kill the process
|
||||
killCmd := exec.Command("kill", "-9", pid) // #nosec G204 -- pid is from lsof output
|
||||
if err := killCmd.Run(); err != nil {
|
||||
return fmt.Errorf("failed to kill process %s: %w", pid, err)
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// findWorkerBinary finds the worker binary path.
|
||||
func findWorkerBinary() string {
|
||||
// Check CLAUDE_PLUGIN_ROOT first (set by Claude Code when running hooks)
|
||||
if pluginRoot := os.Getenv("CLAUDE_PLUGIN_ROOT"); pluginRoot != "" {
|
||||
workerPath := filepath.Join(pluginRoot, "worker")
|
||||
if _, err := os.Stat(workerPath); err == nil {
|
||||
return workerPath
|
||||
}
|
||||
}
|
||||
|
||||
// Check common locations
|
||||
home := os.Getenv("HOME")
|
||||
locations := []string{
|
||||
"./worker",
|
||||
"./bin/worker",
|
||||
filepath.Join(home, ".claude/plugins/cache/claude-mnemonic/claude-mnemonic/1.0.0/worker"),
|
||||
filepath.Join(home, ".claude/plugins/marketplaces/claude-mnemonic/worker"),
|
||||
}
|
||||
|
||||
for _, loc := range locations {
|
||||
if _, err := os.Stat(loc); err == nil {
|
||||
return loc
|
||||
}
|
||||
}
|
||||
|
||||
// Try PATH
|
||||
if path, err := exec.LookPath("claude-mnemonic-worker"); err == nil {
|
||||
return path
|
||||
}
|
||||
|
||||
return ""
|
||||
}
|
||||
|
||||
// POST sends a POST request to the worker.
|
||||
func POST(port int, path string, body interface{}) (map[string]interface{}, error) {
|
||||
client := &http.Client{Timeout: 10 * time.Second}
|
||||
|
||||
jsonBody, err := json.Marshal(body)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
resp, err := client.Post(
|
||||
fmt.Sprintf("http://127.0.0.1:%d%s", port, path),
|
||||
"application/json",
|
||||
bytes.NewReader(jsonBody),
|
||||
)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
if resp.StatusCode >= 400 {
|
||||
return nil, fmt.Errorf("request failed: %s", resp.Status)
|
||||
}
|
||||
|
||||
var result map[string]interface{}
|
||||
if err := json.NewDecoder(resp.Body).Decode(&result); err != nil {
|
||||
// Not all endpoints return JSON
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
return result, nil
|
||||
}
|
||||
|
||||
// GET sends a GET request to the worker.
|
||||
func GET(port int, path string) (map[string]interface{}, error) {
|
||||
client := &http.Client{Timeout: 10 * time.Second}
|
||||
|
||||
resp, err := client.Get(fmt.Sprintf("http://127.0.0.1:%d%s", port, path))
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
if resp.StatusCode >= 400 {
|
||||
return nil, fmt.Errorf("request failed: %s", resp.Status)
|
||||
}
|
||||
|
||||
var result map[string]interface{}
|
||||
if err := json.NewDecoder(resp.Body).Decode(&result); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return result, nil
|
||||
}
|
||||
@@ -0,0 +1,192 @@
|
||||
// Package hooks provides hook utilities for claude-mnemonic.
|
||||
package hooks
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func TestGetWorkerPort(t *testing.T) {
|
||||
// Test default port
|
||||
port := GetWorkerPort()
|
||||
assert.Equal(t, DefaultWorkerPort, port)
|
||||
|
||||
// Test with environment variable
|
||||
t.Setenv("CLAUDE_MNEMONIC_WORKER_PORT", "12345")
|
||||
port = GetWorkerPort()
|
||||
assert.Equal(t, 12345, port)
|
||||
|
||||
// Test with invalid environment variable (should return default)
|
||||
t.Setenv("CLAUDE_MNEMONIC_WORKER_PORT", "invalid")
|
||||
port = GetWorkerPort()
|
||||
assert.Equal(t, DefaultWorkerPort, port)
|
||||
}
|
||||
|
||||
func TestIsWorkerRunning(t *testing.T) {
|
||||
// Create a test server that responds to health checks
|
||||
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
if r.URL.Path == "/api/health" {
|
||||
w.WriteHeader(http.StatusOK)
|
||||
json.NewEncoder(w).Encode(map[string]string{"status": "ready"})
|
||||
} else {
|
||||
w.WriteHeader(http.StatusNotFound)
|
||||
}
|
||||
}))
|
||||
defer server.Close()
|
||||
|
||||
// Extract port from test server URL
|
||||
// Note: In real tests we'd use the actual port, but test server uses random port
|
||||
// So we test with a non-existent port
|
||||
assert.False(t, IsWorkerRunning(99999)) // Non-existent port
|
||||
}
|
||||
|
||||
func TestIsPortInUse(t *testing.T) {
|
||||
// Create a test server to occupy a port
|
||||
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
w.WriteHeader(http.StatusOK)
|
||||
}))
|
||||
defer server.Close()
|
||||
|
||||
// Non-existent port should not be in use
|
||||
assert.False(t, IsPortInUse(99999))
|
||||
}
|
||||
|
||||
func TestGetWorkerVersion(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
serverResponse func(w http.ResponseWriter, r *http.Request)
|
||||
expectedResult string
|
||||
}{
|
||||
{
|
||||
name: "returns version from server",
|
||||
serverResponse: func(w http.ResponseWriter, r *http.Request) {
|
||||
if r.URL.Path == "/api/version" {
|
||||
json.NewEncoder(w).Encode(map[string]string{"version": "1.2.3"})
|
||||
}
|
||||
},
|
||||
expectedResult: "1.2.3",
|
||||
},
|
||||
{
|
||||
name: "returns empty on 404",
|
||||
serverResponse: func(w http.ResponseWriter, r *http.Request) {
|
||||
w.WriteHeader(http.StatusNotFound)
|
||||
},
|
||||
expectedResult: "",
|
||||
},
|
||||
{
|
||||
name: "returns empty on invalid JSON",
|
||||
serverResponse: func(w http.ResponseWriter, r *http.Request) {
|
||||
w.Write([]byte("not json"))
|
||||
},
|
||||
expectedResult: "",
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
server := httptest.NewServer(http.HandlerFunc(tt.serverResponse))
|
||||
defer server.Close()
|
||||
|
||||
// We can't easily test with the actual function since it uses a hardcoded localhost
|
||||
// But we can verify the logic works with the test server
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestProjectIDWithName(t *testing.T) {
|
||||
tests := []struct {
|
||||
cwd string
|
||||
expected string
|
||||
}{
|
||||
{
|
||||
cwd: "/Users/test/projects/my-project",
|
||||
expected: "my-project_", // Will have hash suffix
|
||||
},
|
||||
{
|
||||
cwd: "/tmp",
|
||||
expected: "tmp_",
|
||||
},
|
||||
{
|
||||
cwd: "/",
|
||||
expected: "", // Empty dirname
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.cwd, func(t *testing.T) {
|
||||
result := ProjectIDWithName(tt.cwd)
|
||||
if tt.expected != "" {
|
||||
assert.Contains(t, result, tt.expected[:len(tt.expected)-1]) // Check prefix before underscore
|
||||
assert.Contains(t, result, "_") // Should have underscore separator
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestVersionMatching(t *testing.T) {
|
||||
// Test that version matching logic works correctly
|
||||
tests := []struct {
|
||||
name string
|
||||
runningVersion string
|
||||
hookVersion string
|
||||
shouldRestart bool
|
||||
}{
|
||||
{
|
||||
name: "matching versions",
|
||||
runningVersion: "1.0.0",
|
||||
hookVersion: "1.0.0",
|
||||
shouldRestart: false,
|
||||
},
|
||||
{
|
||||
name: "mismatched versions",
|
||||
runningVersion: "1.0.0",
|
||||
hookVersion: "2.0.0",
|
||||
shouldRestart: true,
|
||||
},
|
||||
{
|
||||
name: "dirty vs clean",
|
||||
runningVersion: "1.0.0",
|
||||
hookVersion: "1.0.0-dirty",
|
||||
shouldRestart: true,
|
||||
},
|
||||
{
|
||||
name: "empty running version",
|
||||
runningVersion: "",
|
||||
hookVersion: "1.0.0",
|
||||
shouldRestart: false, // Can't determine, don't restart
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
// Simulate the version check logic
|
||||
shouldRestart := false
|
||||
if tt.runningVersion != "" && tt.runningVersion != tt.hookVersion {
|
||||
shouldRestart = true
|
||||
}
|
||||
assert.Equal(t, tt.shouldRestart, shouldRestart)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestKillProcessOnPort_NoProcess(t *testing.T) {
|
||||
// Test killing a process on a port that has no process
|
||||
// Should not error, just return nil
|
||||
err := KillProcessOnPort(99999) // Port unlikely to be in use
|
||||
// lsof will return empty/error, which is fine
|
||||
require.NoError(t, err)
|
||||
}
|
||||
|
||||
func TestFindWorkerBinary(t *testing.T) {
|
||||
// Test that findWorkerBinary returns empty string when binary not found
|
||||
// This is hard to test without mocking the filesystem
|
||||
// But we can verify it doesn't panic
|
||||
result := findWorkerBinary()
|
||||
// Result depends on whether worker is installed, so we just check it doesn't panic
|
||||
t.Logf("findWorkerBinary returned: %s", result)
|
||||
}
|
||||
@@ -0,0 +1,291 @@
|
||||
// Package models contains domain models for claude-mnemonic.
|
||||
package models
|
||||
|
||||
import (
|
||||
"database/sql"
|
||||
"database/sql/driver"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"time"
|
||||
)
|
||||
|
||||
// ObservationType represents the type of observation.
|
||||
type ObservationType string
|
||||
|
||||
const (
|
||||
ObsTypeDecision ObservationType = "decision"
|
||||
ObsTypeBugfix ObservationType = "bugfix"
|
||||
ObsTypeFeature ObservationType = "feature"
|
||||
ObsTypeRefactor ObservationType = "refactor"
|
||||
ObsTypeDiscovery ObservationType = "discovery"
|
||||
ObsTypeChange ObservationType = "change"
|
||||
)
|
||||
|
||||
// ObservationScope defines the visibility scope of an observation.
|
||||
type ObservationScope string
|
||||
|
||||
const (
|
||||
// ScopeProject means the observation is only visible within the same project.
|
||||
ScopeProject ObservationScope = "project"
|
||||
// ScopeGlobal means the observation is visible across all projects.
|
||||
// Used for best practices, advanced patterns, and generalizable knowledge.
|
||||
ScopeGlobal ObservationScope = "global"
|
||||
)
|
||||
|
||||
// GlobalizableConcepts are concept tags that indicate an observation
|
||||
// should be considered for global scope (best practices, patterns, etc.)
|
||||
var GlobalizableConcepts = []string{
|
||||
"best-practice",
|
||||
"pattern",
|
||||
"anti-pattern",
|
||||
"architecture",
|
||||
"security",
|
||||
"performance",
|
||||
"testing",
|
||||
"debugging",
|
||||
"workflow",
|
||||
"tooling",
|
||||
}
|
||||
|
||||
// JSONStringArray is a custom type for handling JSON string arrays in SQLite.
|
||||
type JSONStringArray []string
|
||||
|
||||
// Scan implements sql.Scanner for JSONStringArray.
|
||||
func (j *JSONStringArray) Scan(src interface{}) error {
|
||||
if src == nil {
|
||||
*j = nil
|
||||
return nil
|
||||
}
|
||||
|
||||
var data []byte
|
||||
switch v := src.(type) {
|
||||
case string:
|
||||
data = []byte(v)
|
||||
case []byte:
|
||||
data = v
|
||||
default:
|
||||
return fmt.Errorf("JSONStringArray: unsupported type %T", src)
|
||||
}
|
||||
|
||||
if len(data) == 0 {
|
||||
*j = nil
|
||||
return nil
|
||||
}
|
||||
|
||||
return json.Unmarshal(data, j)
|
||||
}
|
||||
|
||||
// Value implements driver.Valuer for JSONStringArray.
|
||||
func (j JSONStringArray) Value() (driver.Value, error) {
|
||||
if j == nil {
|
||||
return nil, nil
|
||||
}
|
||||
return json.Marshal(j)
|
||||
}
|
||||
|
||||
// JSONInt64Map is a custom type for handling JSON int64 maps in SQLite.
|
||||
type JSONInt64Map map[string]int64
|
||||
|
||||
// Scan implements sql.Scanner for JSONInt64Map.
|
||||
func (j *JSONInt64Map) Scan(src interface{}) error {
|
||||
if src == nil {
|
||||
*j = nil
|
||||
return nil
|
||||
}
|
||||
|
||||
var data []byte
|
||||
switch v := src.(type) {
|
||||
case string:
|
||||
data = []byte(v)
|
||||
case []byte:
|
||||
data = v
|
||||
default:
|
||||
return fmt.Errorf("JSONInt64Map: unsupported type %T", src)
|
||||
}
|
||||
|
||||
if len(data) == 0 {
|
||||
*j = nil
|
||||
return nil
|
||||
}
|
||||
|
||||
return json.Unmarshal(data, j)
|
||||
}
|
||||
|
||||
// Value implements driver.Valuer for JSONInt64Map.
|
||||
func (j JSONInt64Map) Value() (driver.Value, error) {
|
||||
if j == nil {
|
||||
return nil, nil
|
||||
}
|
||||
return json.Marshal(j)
|
||||
}
|
||||
|
||||
// Observation represents a learning extracted from a Claude Code session.
|
||||
type Observation struct {
|
||||
ID int64 `db:"id" json:"id"`
|
||||
SDKSessionID string `db:"sdk_session_id" json:"sdk_session_id"`
|
||||
Project string `db:"project" json:"project"`
|
||||
Scope ObservationScope `db:"scope" json:"scope"`
|
||||
Type ObservationType `db:"type" json:"type"`
|
||||
Title sql.NullString `db:"title" json:"title,omitempty"`
|
||||
Subtitle sql.NullString `db:"subtitle" json:"subtitle,omitempty"`
|
||||
Facts JSONStringArray `db:"facts" json:"facts,omitempty"`
|
||||
Narrative sql.NullString `db:"narrative" json:"narrative,omitempty"`
|
||||
Concepts JSONStringArray `db:"concepts" json:"concepts,omitempty"`
|
||||
FilesRead JSONStringArray `db:"files_read" json:"files_read,omitempty"`
|
||||
FilesModified JSONStringArray `db:"files_modified" json:"files_modified,omitempty"`
|
||||
FileMtimes JSONInt64Map `db:"file_mtimes" json:"file_mtimes,omitempty"`
|
||||
PromptNumber sql.NullInt64 `db:"prompt_number" json:"prompt_number,omitempty"`
|
||||
DiscoveryTokens int64 `db:"discovery_tokens" json:"discovery_tokens"`
|
||||
CreatedAt string `db:"created_at" json:"created_at"`
|
||||
CreatedAtEpoch int64 `db:"created_at_epoch" json:"created_at_epoch"`
|
||||
IsStale bool `db:"-" json:"is_stale,omitempty"`
|
||||
}
|
||||
|
||||
// ParsedObservation represents an observation parsed from SDK response XML.
|
||||
type ParsedObservation struct {
|
||||
Type ObservationType
|
||||
Title string
|
||||
Subtitle string
|
||||
Facts []string
|
||||
Narrative string
|
||||
Concepts []string
|
||||
FilesRead []string
|
||||
FilesModified []string
|
||||
FileMtimes map[string]int64 // File path -> mtime epoch ms
|
||||
Scope ObservationScope // Optional: if empty, will be auto-determined
|
||||
}
|
||||
|
||||
// ToStoredObservation converts a ParsedObservation to the stored Observation format.
|
||||
// Used for similarity comparison before storage.
|
||||
func (p *ParsedObservation) ToStoredObservation() *Observation {
|
||||
return &Observation{
|
||||
Type: p.Type,
|
||||
Title: sql.NullString{String: p.Title, Valid: p.Title != ""},
|
||||
Subtitle: sql.NullString{String: p.Subtitle, Valid: p.Subtitle != ""},
|
||||
Facts: p.Facts,
|
||||
Narrative: sql.NullString{String: p.Narrative, Valid: p.Narrative != ""},
|
||||
Concepts: p.Concepts,
|
||||
FilesRead: p.FilesRead,
|
||||
FilesModified: p.FilesModified,
|
||||
FileMtimes: p.FileMtimes,
|
||||
}
|
||||
}
|
||||
|
||||
// DetermineScope determines the appropriate scope based on observation concepts.
|
||||
// Returns ScopeGlobal if any concept matches globalizable patterns, else ScopeProject.
|
||||
func DetermineScope(concepts []string) ObservationScope {
|
||||
for _, concept := range concepts {
|
||||
for _, globalConcept := range GlobalizableConcepts {
|
||||
if concept == globalConcept {
|
||||
return ScopeGlobal
|
||||
}
|
||||
}
|
||||
}
|
||||
return ScopeProject
|
||||
}
|
||||
|
||||
// ObservationJSON is a JSON-friendly representation of Observation.
|
||||
// It converts sql.NullString to plain strings for clean JSON output.
|
||||
type ObservationJSON struct {
|
||||
ID int64 `json:"id"`
|
||||
SDKSessionID string `json:"sdk_session_id"`
|
||||
Project string `json:"project"`
|
||||
Scope ObservationScope `json:"scope"`
|
||||
Type ObservationType `json:"type"`
|
||||
Title string `json:"title,omitempty"`
|
||||
Subtitle string `json:"subtitle,omitempty"`
|
||||
Facts []string `json:"facts,omitempty"`
|
||||
Narrative string `json:"narrative,omitempty"`
|
||||
Concepts []string `json:"concepts,omitempty"`
|
||||
FilesRead []string `json:"files_read,omitempty"`
|
||||
FilesModified []string `json:"files_modified,omitempty"`
|
||||
FileMtimes map[string]int64 `json:"file_mtimes,omitempty"`
|
||||
PromptNumber int64 `json:"prompt_number,omitempty"`
|
||||
DiscoveryTokens int64 `json:"discovery_tokens"`
|
||||
CreatedAt string `json:"created_at"`
|
||||
CreatedAtEpoch int64 `json:"created_at_epoch"`
|
||||
IsStale bool `json:"is_stale,omitempty"`
|
||||
}
|
||||
|
||||
// MarshalJSON implements json.Marshaler for Observation.
|
||||
// Converts sql.NullString fields to plain strings.
|
||||
func (o *Observation) MarshalJSON() ([]byte, error) {
|
||||
j := ObservationJSON{
|
||||
ID: o.ID,
|
||||
SDKSessionID: o.SDKSessionID,
|
||||
Project: o.Project,
|
||||
Scope: o.Scope,
|
||||
Type: o.Type,
|
||||
Facts: o.Facts,
|
||||
Concepts: o.Concepts,
|
||||
FilesRead: o.FilesRead,
|
||||
FilesModified: o.FilesModified,
|
||||
FileMtimes: o.FileMtimes,
|
||||
DiscoveryTokens: o.DiscoveryTokens,
|
||||
CreatedAt: o.CreatedAt,
|
||||
CreatedAtEpoch: o.CreatedAtEpoch,
|
||||
IsStale: o.IsStale,
|
||||
}
|
||||
if o.Title.Valid {
|
||||
j.Title = o.Title.String
|
||||
}
|
||||
if o.Subtitle.Valid {
|
||||
j.Subtitle = o.Subtitle.String
|
||||
}
|
||||
if o.Narrative.Valid {
|
||||
j.Narrative = o.Narrative.String
|
||||
}
|
||||
if o.PromptNumber.Valid {
|
||||
j.PromptNumber = o.PromptNumber.Int64
|
||||
}
|
||||
return json.Marshal(j)
|
||||
}
|
||||
|
||||
// NewObservation creates a new observation from parsed data.
|
||||
func NewObservation(sdkSessionID, project string, parsed *ParsedObservation, promptNumber int, discoveryTokens int64) *Observation {
|
||||
now := time.Now()
|
||||
|
||||
// Determine scope: use parsed scope if set, otherwise auto-determine from concepts
|
||||
scope := parsed.Scope
|
||||
if scope == "" {
|
||||
scope = DetermineScope(parsed.Concepts)
|
||||
}
|
||||
|
||||
return &Observation{
|
||||
SDKSessionID: sdkSessionID,
|
||||
Project: project,
|
||||
Scope: scope,
|
||||
Type: parsed.Type,
|
||||
Title: sql.NullString{String: parsed.Title, Valid: parsed.Title != ""},
|
||||
Subtitle: sql.NullString{String: parsed.Subtitle, Valid: parsed.Subtitle != ""},
|
||||
Facts: parsed.Facts,
|
||||
Narrative: sql.NullString{String: parsed.Narrative, Valid: parsed.Narrative != ""},
|
||||
Concepts: parsed.Concepts,
|
||||
FilesRead: parsed.FilesRead,
|
||||
FilesModified: parsed.FilesModified,
|
||||
FileMtimes: parsed.FileMtimes,
|
||||
PromptNumber: sql.NullInt64{Int64: int64(promptNumber), Valid: promptNumber > 0},
|
||||
DiscoveryTokens: discoveryTokens,
|
||||
CreatedAt: now.Format(time.RFC3339),
|
||||
CreatedAtEpoch: now.UnixMilli(),
|
||||
}
|
||||
}
|
||||
|
||||
// CheckStaleness checks if an observation is stale based on current file mtimes.
|
||||
// Returns true if any tracked file has been modified since the observation was created.
|
||||
func (o *Observation) CheckStaleness(currentMtimes map[string]int64) bool {
|
||||
if len(o.FileMtimes) == 0 {
|
||||
return false // No file tracking, assume fresh
|
||||
}
|
||||
|
||||
for path, recordedMtime := range o.FileMtimes {
|
||||
if currentMtime, exists := currentMtimes[path]; exists {
|
||||
if currentMtime > recordedMtime {
|
||||
return true // File was modified since observation was created
|
||||
}
|
||||
}
|
||||
// If file doesn't exist in currentMtimes, it may have been deleted
|
||||
// We don't mark as stale for missing files - they might just not be checked
|
||||
}
|
||||
return false
|
||||
}
|
||||
@@ -0,0 +1,20 @@
|
||||
// Package models contains domain models for claude-mnemonic.
|
||||
package models
|
||||
|
||||
// UserPrompt represents a user prompt captured during a session.
|
||||
type UserPrompt struct {
|
||||
ID int64 `db:"id" json:"id"`
|
||||
ClaudeSessionID string `db:"claude_session_id" json:"claude_session_id"`
|
||||
PromptNumber int `db:"prompt_number" json:"prompt_number"`
|
||||
PromptText string `db:"prompt_text" json:"prompt_text"`
|
||||
MatchedObservations int `db:"matched_observations" json:"matched_observations"`
|
||||
CreatedAt string `db:"created_at" json:"created_at"`
|
||||
CreatedAtEpoch int64 `db:"created_at_epoch" json:"created_at_epoch"`
|
||||
}
|
||||
|
||||
// UserPromptWithSession includes session context for search results.
|
||||
type UserPromptWithSession struct {
|
||||
UserPrompt
|
||||
Project string `db:"project" json:"project"`
|
||||
SDKSessionID string `db:"sdk_session_id" json:"sdk_session_id"`
|
||||
}
|
||||
@@ -0,0 +1,45 @@
|
||||
// Package models contains domain models for claude-mnemonic.
|
||||
package models
|
||||
|
||||
import (
|
||||
"database/sql"
|
||||
"time"
|
||||
)
|
||||
|
||||
// SessionStatus represents the status of an SDK session.
|
||||
type SessionStatus string
|
||||
|
||||
const (
|
||||
SessionStatusActive SessionStatus = "active"
|
||||
SessionStatusCompleted SessionStatus = "completed"
|
||||
SessionStatusFailed SessionStatus = "failed"
|
||||
)
|
||||
|
||||
// SDKSession represents a Claude Code session tracked by the memory system.
|
||||
type SDKSession struct {
|
||||
ID int64 `db:"id" json:"id"`
|
||||
ClaudeSessionID string `db:"claude_session_id" json:"claude_session_id"`
|
||||
SDKSessionID sql.NullString `db:"sdk_session_id" json:"sdk_session_id,omitempty"`
|
||||
Project string `db:"project" json:"project"`
|
||||
UserPrompt sql.NullString `db:"user_prompt" json:"user_prompt,omitempty"`
|
||||
WorkerPort sql.NullInt64 `db:"worker_port" json:"worker_port,omitempty"`
|
||||
PromptCounter int64 `db:"prompt_counter" json:"prompt_counter"`
|
||||
Status SessionStatus `db:"status" json:"status"`
|
||||
StartedAt string `db:"started_at" json:"started_at"`
|
||||
StartedAtEpoch int64 `db:"started_at_epoch" json:"started_at_epoch"`
|
||||
CompletedAt sql.NullString `db:"completed_at" json:"completed_at,omitempty"`
|
||||
CompletedAtEpoch sql.NullInt64 `db:"completed_at_epoch" json:"completed_at_epoch,omitempty"`
|
||||
}
|
||||
|
||||
// ActiveSession represents an in-memory active session being processed.
|
||||
type ActiveSession struct {
|
||||
SessionDBID int64
|
||||
ClaudeSessionID string
|
||||
SDKSessionID string
|
||||
Project string
|
||||
UserPrompt string
|
||||
LastPromptNumber int
|
||||
StartTime time.Time
|
||||
CumulativeInputTokens int64
|
||||
CumulativeOutputTokens int64
|
||||
}
|
||||
@@ -0,0 +1,107 @@
|
||||
// Package models contains domain models for claude-mnemonic.
|
||||
package models
|
||||
|
||||
import (
|
||||
"database/sql"
|
||||
"encoding/json"
|
||||
"time"
|
||||
)
|
||||
|
||||
// SessionSummary represents a summary of a Claude Code session.
|
||||
type SessionSummary struct {
|
||||
ID int64 `db:"id" json:"id"`
|
||||
SDKSessionID string `db:"sdk_session_id" json:"sdk_session_id"`
|
||||
Project string `db:"project" json:"project"`
|
||||
Request sql.NullString `db:"request" json:"request,omitempty"`
|
||||
Investigated sql.NullString `db:"investigated" json:"investigated,omitempty"`
|
||||
Learned sql.NullString `db:"learned" json:"learned,omitempty"`
|
||||
Completed sql.NullString `db:"completed" json:"completed,omitempty"`
|
||||
NextSteps sql.NullString `db:"next_steps" json:"next_steps,omitempty"`
|
||||
Notes sql.NullString `db:"notes" json:"notes,omitempty"`
|
||||
PromptNumber sql.NullInt64 `db:"prompt_number" json:"prompt_number,omitempty"`
|
||||
DiscoveryTokens int64 `db:"discovery_tokens" json:"discovery_tokens"`
|
||||
CreatedAt string `db:"created_at" json:"created_at"`
|
||||
CreatedAtEpoch int64 `db:"created_at_epoch" json:"created_at_epoch"`
|
||||
}
|
||||
|
||||
// ParsedSummary represents a summary parsed from SDK response XML.
|
||||
type ParsedSummary struct {
|
||||
Request string
|
||||
Investigated string
|
||||
Learned string
|
||||
Completed string
|
||||
NextSteps string
|
||||
Notes string
|
||||
}
|
||||
|
||||
// NewSessionSummary creates a new session summary from parsed data.
|
||||
func NewSessionSummary(sdkSessionID, project string, parsed *ParsedSummary, promptNumber int, discoveryTokens int64) *SessionSummary {
|
||||
now := time.Now()
|
||||
return &SessionSummary{
|
||||
SDKSessionID: sdkSessionID,
|
||||
Project: project,
|
||||
Request: sql.NullString{String: parsed.Request, Valid: parsed.Request != ""},
|
||||
Investigated: sql.NullString{String: parsed.Investigated, Valid: parsed.Investigated != ""},
|
||||
Learned: sql.NullString{String: parsed.Learned, Valid: parsed.Learned != ""},
|
||||
Completed: sql.NullString{String: parsed.Completed, Valid: parsed.Completed != ""},
|
||||
NextSteps: sql.NullString{String: parsed.NextSteps, Valid: parsed.NextSteps != ""},
|
||||
Notes: sql.NullString{String: parsed.Notes, Valid: parsed.Notes != ""},
|
||||
PromptNumber: sql.NullInt64{Int64: int64(promptNumber), Valid: promptNumber > 0},
|
||||
DiscoveryTokens: discoveryTokens,
|
||||
CreatedAt: now.Format(time.RFC3339),
|
||||
CreatedAtEpoch: now.UnixMilli(),
|
||||
}
|
||||
}
|
||||
|
||||
// SessionSummaryJSON is a JSON-friendly representation of SessionSummary.
|
||||
// It converts sql.NullString to plain strings for clean JSON output.
|
||||
type SessionSummaryJSON struct {
|
||||
ID int64 `json:"id"`
|
||||
SDKSessionID string `json:"sdk_session_id"`
|
||||
Project string `json:"project"`
|
||||
Request string `json:"request,omitempty"`
|
||||
Investigated string `json:"investigated,omitempty"`
|
||||
Learned string `json:"learned,omitempty"`
|
||||
Completed string `json:"completed,omitempty"`
|
||||
NextSteps string `json:"next_steps,omitempty"`
|
||||
Notes string `json:"notes,omitempty"`
|
||||
PromptNumber int64 `json:"prompt_number,omitempty"`
|
||||
DiscoveryTokens int64 `json:"discovery_tokens"`
|
||||
CreatedAt string `json:"created_at"`
|
||||
CreatedAtEpoch int64 `json:"created_at_epoch"`
|
||||
}
|
||||
|
||||
// MarshalJSON implements json.Marshaler for SessionSummary.
|
||||
// Converts sql.NullString fields to plain strings.
|
||||
func (s *SessionSummary) MarshalJSON() ([]byte, error) {
|
||||
j := SessionSummaryJSON{
|
||||
ID: s.ID,
|
||||
SDKSessionID: s.SDKSessionID,
|
||||
Project: s.Project,
|
||||
DiscoveryTokens: s.DiscoveryTokens,
|
||||
CreatedAt: s.CreatedAt,
|
||||
CreatedAtEpoch: s.CreatedAtEpoch,
|
||||
}
|
||||
if s.Request.Valid {
|
||||
j.Request = s.Request.String
|
||||
}
|
||||
if s.Investigated.Valid {
|
||||
j.Investigated = s.Investigated.String
|
||||
}
|
||||
if s.Learned.Valid {
|
||||
j.Learned = s.Learned.String
|
||||
}
|
||||
if s.Completed.Valid {
|
||||
j.Completed = s.Completed.String
|
||||
}
|
||||
if s.NextSteps.Valid {
|
||||
j.NextSteps = s.NextSteps.String
|
||||
}
|
||||
if s.Notes.Valid {
|
||||
j.Notes = s.Notes.String
|
||||
}
|
||||
if s.PromptNumber.Valid {
|
||||
j.PromptNumber = s.PromptNumber.Int64
|
||||
}
|
||||
return json.Marshal(j)
|
||||
}
|
||||
@@ -0,0 +1,162 @@
|
||||
// Package similarity provides text similarity and clustering utilities.
|
||||
package similarity
|
||||
|
||||
import (
|
||||
"strings"
|
||||
|
||||
"github.com/lukaszraczylo/claude-mnemonic/pkg/models"
|
||||
)
|
||||
|
||||
// ClusterObservations groups similar observations and returns only one representative per cluster.
|
||||
// Uses Jaccard similarity on extracted terms from title, narrative, and facts.
|
||||
// Observations should be sorted by preference (e.g., recency) - first one in each cluster is kept.
|
||||
func ClusterObservations(observations []*models.Observation, similarityThreshold float64) []*models.Observation {
|
||||
if len(observations) <= 1 {
|
||||
return observations
|
||||
}
|
||||
|
||||
// Extract terms for each observation
|
||||
termSets := make([]map[string]bool, len(observations))
|
||||
for i, obs := range observations {
|
||||
termSets[i] = ExtractObservationTerms(obs)
|
||||
}
|
||||
|
||||
// Track which observations are already clustered
|
||||
clustered := make([]bool, len(observations))
|
||||
result := make([]*models.Observation, 0)
|
||||
|
||||
for i := 0; i < len(observations); i++ {
|
||||
if clustered[i] {
|
||||
continue
|
||||
}
|
||||
|
||||
// This observation becomes the representative of its cluster
|
||||
// (observations are already sorted by recency, so first one is newest)
|
||||
result = append(result, observations[i])
|
||||
clustered[i] = true
|
||||
|
||||
// Find all similar observations and mark them as clustered
|
||||
for j := i + 1; j < len(observations); j++ {
|
||||
if clustered[j] {
|
||||
continue
|
||||
}
|
||||
|
||||
similarity := JaccardSimilarity(termSets[i], termSets[j])
|
||||
if similarity >= similarityThreshold {
|
||||
clustered[j] = true
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return result
|
||||
}
|
||||
|
||||
// IsSimilarToAny checks if a new observation is similar to any existing observation.
|
||||
// Returns true if similarity to any existing observation exceeds the threshold.
|
||||
func IsSimilarToAny(newObs *models.Observation, existing []*models.Observation, similarityThreshold float64) bool {
|
||||
if len(existing) == 0 {
|
||||
return false
|
||||
}
|
||||
|
||||
newTerms := ExtractObservationTerms(newObs)
|
||||
if len(newTerms) == 0 {
|
||||
return false
|
||||
}
|
||||
|
||||
for _, obs := range existing {
|
||||
existingTerms := ExtractObservationTerms(obs)
|
||||
similarity := JaccardSimilarity(newTerms, existingTerms)
|
||||
if similarity >= similarityThreshold {
|
||||
return true
|
||||
}
|
||||
}
|
||||
|
||||
return false
|
||||
}
|
||||
|
||||
// ExtractObservationTerms extracts meaningful terms from an observation for similarity comparison.
|
||||
func ExtractObservationTerms(obs *models.Observation) map[string]bool {
|
||||
terms := make(map[string]bool)
|
||||
|
||||
// Add terms from title
|
||||
addTerms(terms, obs.Title.String)
|
||||
|
||||
// Add terms from narrative
|
||||
addTerms(terms, obs.Narrative.String)
|
||||
|
||||
// Add terms from facts
|
||||
for _, fact := range obs.Facts {
|
||||
addTerms(terms, fact)
|
||||
}
|
||||
|
||||
// Add file paths as terms (normalized)
|
||||
for _, file := range obs.FilesRead {
|
||||
// Use just the filename without path for matching
|
||||
parts := strings.Split(file, "/")
|
||||
if len(parts) > 0 {
|
||||
terms[strings.ToLower(parts[len(parts)-1])] = true
|
||||
}
|
||||
}
|
||||
|
||||
for _, file := range obs.FilesModified {
|
||||
parts := strings.Split(file, "/")
|
||||
if len(parts) > 0 {
|
||||
terms[strings.ToLower(parts[len(parts)-1])] = true
|
||||
}
|
||||
}
|
||||
|
||||
return terms
|
||||
}
|
||||
|
||||
// addTerms tokenizes text and adds meaningful terms to the set.
|
||||
func addTerms(terms map[string]bool, text string) {
|
||||
// Simple tokenization: split on non-alphanumeric, filter short words
|
||||
words := strings.FieldsFunc(strings.ToLower(text), func(r rune) bool {
|
||||
return !((r >= 'a' && r <= 'z') || (r >= '0' && r <= '9') || r == '_')
|
||||
})
|
||||
|
||||
stopWords := map[string]bool{
|
||||
"the": true, "a": true, "an": true, "is": true, "are": true,
|
||||
"was": true, "were": true, "be": true, "been": true, "being": true,
|
||||
"have": true, "has": true, "had": true, "do": true, "does": true,
|
||||
"did": true, "will": true, "would": true, "could": true, "should": true,
|
||||
"may": true, "might": true, "must": true, "shall": true,
|
||||
"this": true, "that": true, "these": true, "those": true,
|
||||
"and": true, "or": true, "but": true, "if": true, "then": true,
|
||||
"for": true, "from": true, "with": true, "about": true, "into": true,
|
||||
"to": true, "of": true, "in": true, "on": true, "at": true, "by": true,
|
||||
"it": true, "its": true, "which": true, "who": true, "what": true,
|
||||
"when": true, "where": true, "how": true, "why": true,
|
||||
}
|
||||
|
||||
for _, word := range words {
|
||||
if len(word) >= 3 && !stopWords[word] {
|
||||
terms[word] = true
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// JaccardSimilarity calculates the Jaccard similarity between two term sets.
|
||||
// Returns a value between 0 (no overlap) and 1 (identical).
|
||||
func JaccardSimilarity(set1, set2 map[string]bool) float64 {
|
||||
if len(set1) == 0 && len(set2) == 0 {
|
||||
return 1.0
|
||||
}
|
||||
if len(set1) == 0 || len(set2) == 0 {
|
||||
return 0.0
|
||||
}
|
||||
|
||||
intersection := 0
|
||||
for term := range set1 {
|
||||
if set2[term] {
|
||||
intersection++
|
||||
}
|
||||
}
|
||||
|
||||
union := len(set1) + len(set2) - intersection
|
||||
if union == 0 {
|
||||
return 0.0
|
||||
}
|
||||
|
||||
return float64(intersection) / float64(union)
|
||||
}
|
||||
@@ -0,0 +1,292 @@
|
||||
// Package similarity provides text similarity and clustering utilities.
|
||||
package similarity
|
||||
|
||||
import (
|
||||
"database/sql"
|
||||
"testing"
|
||||
|
||||
"github.com/lukaszraczylo/claude-mnemonic/pkg/models"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func TestJaccardSimilarity(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
set1 map[string]bool
|
||||
set2 map[string]bool
|
||||
expected float64
|
||||
}{
|
||||
{
|
||||
name: "identical sets",
|
||||
set1: map[string]bool{"a": true, "b": true, "c": true},
|
||||
set2: map[string]bool{"a": true, "b": true, "c": true},
|
||||
expected: 1.0,
|
||||
},
|
||||
{
|
||||
name: "no overlap",
|
||||
set1: map[string]bool{"a": true, "b": true},
|
||||
set2: map[string]bool{"c": true, "d": true},
|
||||
expected: 0.0,
|
||||
},
|
||||
{
|
||||
name: "partial overlap",
|
||||
set1: map[string]bool{"a": true, "b": true, "c": true},
|
||||
set2: map[string]bool{"b": true, "c": true, "d": true},
|
||||
expected: 0.5, // intersection=2, union=4
|
||||
},
|
||||
{
|
||||
name: "empty sets",
|
||||
set1: map[string]bool{},
|
||||
set2: map[string]bool{},
|
||||
expected: 1.0,
|
||||
},
|
||||
{
|
||||
name: "one empty set",
|
||||
set1: map[string]bool{"a": true},
|
||||
set2: map[string]bool{},
|
||||
expected: 0.0,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
result := JaccardSimilarity(tt.set1, tt.set2)
|
||||
assert.InDelta(t, tt.expected, result, 0.001)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestExtractObservationTerms(t *testing.T) {
|
||||
obs := &models.Observation{
|
||||
Title: sql.NullString{String: "Authentication flow implementation", Valid: true},
|
||||
Narrative: sql.NullString{String: "We implemented JWT-based authentication", Valid: true},
|
||||
Facts: models.JSONStringArray{"Users authenticate via API", "Tokens expire after 24 hours"},
|
||||
FilesRead: models.JSONStringArray{"/src/auth/handler.go", "/src/auth/jwt.go"},
|
||||
}
|
||||
|
||||
terms := ExtractObservationTerms(obs)
|
||||
|
||||
// Should contain terms from title
|
||||
assert.Contains(t, terms, "authentication")
|
||||
assert.Contains(t, terms, "flow")
|
||||
assert.Contains(t, terms, "implementation")
|
||||
|
||||
// Should contain terms from narrative
|
||||
assert.Contains(t, terms, "implemented")
|
||||
|
||||
// Should contain terms from facts
|
||||
assert.Contains(t, terms, "tokens")
|
||||
assert.Contains(t, terms, "expire")
|
||||
assert.Contains(t, terms, "hours")
|
||||
|
||||
// Should contain filenames (without path)
|
||||
assert.Contains(t, terms, "handler.go")
|
||||
assert.Contains(t, terms, "jwt.go")
|
||||
|
||||
// Should NOT contain stop words
|
||||
assert.NotContains(t, terms, "the")
|
||||
assert.NotContains(t, terms, "and")
|
||||
assert.NotContains(t, terms, "we")
|
||||
}
|
||||
|
||||
func TestClusterObservations(t *testing.T) {
|
||||
// Create similar observations
|
||||
obs1 := &models.Observation{
|
||||
ID: 1,
|
||||
Title: sql.NullString{String: "Authentication flow implementation", Valid: true},
|
||||
Narrative: sql.NullString{String: "JWT-based authentication for API", Valid: true},
|
||||
}
|
||||
obs2 := &models.Observation{
|
||||
ID: 2,
|
||||
Title: sql.NullString{String: "Authentication flow update", Valid: true},
|
||||
Narrative: sql.NullString{String: "Updated JWT authentication logic", Valid: true},
|
||||
}
|
||||
obs3 := &models.Observation{
|
||||
ID: 3,
|
||||
Title: sql.NullString{String: "Database migration guide", Valid: true},
|
||||
Narrative: sql.NullString{String: "How to run database migrations", Valid: true},
|
||||
}
|
||||
obs4 := &models.Observation{
|
||||
ID: 4,
|
||||
Title: sql.NullString{String: "Database schema changes", Valid: true},
|
||||
Narrative: sql.NullString{String: "Updated database schema for users", Valid: true},
|
||||
}
|
||||
|
||||
observations := []*models.Observation{obs1, obs2, obs3, obs4}
|
||||
|
||||
// Cluster with 0.4 threshold
|
||||
clustered := ClusterObservations(observations, 0.4)
|
||||
|
||||
// obs1 and obs2 should be clustered (similar authentication content)
|
||||
// obs3 and obs4 should be clustered (similar database content)
|
||||
t.Logf("Clustered %d observations down to %d", len(observations), len(clustered))
|
||||
assert.LessOrEqual(t, len(clustered), 4)
|
||||
assert.GreaterOrEqual(t, len(clustered), 1)
|
||||
|
||||
// First observation in each cluster should be kept (obs1 for auth, obs3 for db)
|
||||
ids := make(map[int64]bool)
|
||||
for _, obs := range clustered {
|
||||
ids[obs.ID] = true
|
||||
}
|
||||
|
||||
// Depending on threshold, obs1 should be kept (first in auth cluster)
|
||||
if len(clustered) <= 3 {
|
||||
assert.True(t, ids[1], "First observation (ID=1) should be kept as cluster representative")
|
||||
}
|
||||
}
|
||||
|
||||
func TestClusterObservations_SingleObservation(t *testing.T) {
|
||||
obs := &models.Observation{
|
||||
ID: 1,
|
||||
Title: sql.NullString{String: "Single observation", Valid: true},
|
||||
}
|
||||
|
||||
clustered := ClusterObservations([]*models.Observation{obs}, 0.4)
|
||||
|
||||
assert.Len(t, clustered, 1)
|
||||
assert.Equal(t, int64(1), clustered[0].ID)
|
||||
}
|
||||
|
||||
func TestClusterObservations_EmptyList(t *testing.T) {
|
||||
clustered := ClusterObservations([]*models.Observation{}, 0.4)
|
||||
assert.Len(t, clustered, 0)
|
||||
}
|
||||
|
||||
func TestClusterObservations_NoDuplicates(t *testing.T) {
|
||||
// Create observations with completely different content
|
||||
observations := []*models.Observation{
|
||||
{
|
||||
ID: 1,
|
||||
Title: sql.NullString{String: "Authentication system", Valid: true},
|
||||
Narrative: sql.NullString{String: "JWT tokens for user auth", Valid: true},
|
||||
},
|
||||
{
|
||||
ID: 2,
|
||||
Title: sql.NullString{String: "Database configuration", Valid: true},
|
||||
Narrative: sql.NullString{String: "PostgreSQL setup and migrations", Valid: true},
|
||||
},
|
||||
{
|
||||
ID: 3,
|
||||
Title: sql.NullString{String: "Caching layer", Valid: true},
|
||||
Narrative: sql.NullString{String: "Redis caching implementation", Valid: true},
|
||||
},
|
||||
{
|
||||
ID: 4,
|
||||
Title: sql.NullString{String: "Logging setup", Valid: true},
|
||||
Narrative: sql.NullString{String: "Structured logging with zerolog", Valid: true},
|
||||
},
|
||||
{
|
||||
ID: 5,
|
||||
Title: sql.NullString{String: "API endpoints", Valid: true},
|
||||
Narrative: sql.NullString{String: "REST API implementation", Valid: true},
|
||||
},
|
||||
}
|
||||
|
||||
clustered := ClusterObservations(observations, 0.4)
|
||||
|
||||
// With completely different content, all should be kept
|
||||
assert.Len(t, clustered, 5, "All unique observations should be kept")
|
||||
}
|
||||
|
||||
func TestIsSimilarToAny(t *testing.T) {
|
||||
existing := []*models.Observation{
|
||||
{
|
||||
ID: 1,
|
||||
Title: sql.NullString{String: "Authentication implementation", Valid: true},
|
||||
Narrative: sql.NullString{String: "JWT authentication flow", Valid: true},
|
||||
},
|
||||
{
|
||||
ID: 2,
|
||||
Title: sql.NullString{String: "Database setup", Valid: true},
|
||||
Narrative: sql.NullString{String: "PostgreSQL configuration", Valid: true},
|
||||
},
|
||||
}
|
||||
|
||||
// New observation similar to existing
|
||||
similar := &models.Observation{
|
||||
ID: 3,
|
||||
Title: sql.NullString{String: "Authentication update", Valid: true},
|
||||
Narrative: sql.NullString{String: "JWT authentication changes", Valid: true},
|
||||
}
|
||||
|
||||
// New observation not similar to any existing
|
||||
different := &models.Observation{
|
||||
ID: 4,
|
||||
Title: sql.NullString{String: "Caching layer", Valid: true},
|
||||
Narrative: sql.NullString{String: "Redis caching implementation", Valid: true},
|
||||
}
|
||||
|
||||
assert.True(t, IsSimilarToAny(similar, existing, 0.3), "Similar observation should be detected")
|
||||
assert.False(t, IsSimilarToAny(different, existing, 0.3), "Different observation should not match")
|
||||
}
|
||||
|
||||
func TestIsSimilarToAny_EmptyExisting(t *testing.T) {
|
||||
newObs := &models.Observation{
|
||||
ID: 1,
|
||||
Title: sql.NullString{String: "New observation", Valid: true},
|
||||
}
|
||||
|
||||
assert.False(t, IsSimilarToAny(newObs, []*models.Observation{}, 0.4))
|
||||
assert.False(t, IsSimilarToAny(newObs, nil, 0.4))
|
||||
}
|
||||
|
||||
func TestAddTerms(t *testing.T) {
|
||||
terms := make(map[string]bool)
|
||||
|
||||
addTerms(terms, "The quick brown fox jumps over the lazy dog")
|
||||
|
||||
// Should contain words >= 3 chars that aren't stop words
|
||||
assert.Contains(t, terms, "quick")
|
||||
assert.Contains(t, terms, "brown")
|
||||
assert.Contains(t, terms, "fox")
|
||||
assert.Contains(t, terms, "jumps")
|
||||
assert.Contains(t, terms, "over")
|
||||
assert.Contains(t, terms, "lazy")
|
||||
assert.Contains(t, terms, "dog")
|
||||
|
||||
// Should NOT contain stop words
|
||||
assert.NotContains(t, terms, "the")
|
||||
|
||||
// Should NOT contain short words
|
||||
// (all words in the sentence are >= 3 chars after stop word removal)
|
||||
}
|
||||
|
||||
func TestClusterObservations_MoreThanOldLimit(t *testing.T) {
|
||||
// This test verifies that we can now return more than 5 observations
|
||||
// after removing the hardcoded limit
|
||||
|
||||
// Create 10 completely unique observations with very different content
|
||||
observations := []*models.Observation{
|
||||
{ID: 1, Title: sql.NullString{String: "JWT tokens expire daily", Valid: true}},
|
||||
{ID: 2, Title: sql.NullString{String: "PostgreSQL indexes optimize", Valid: true}},
|
||||
{ID: 3, Title: sql.NullString{String: "Redis caching TTL values", Valid: true}},
|
||||
{ID: 4, Title: sql.NullString{String: "Zerolog structured logging", Valid: true}},
|
||||
{ID: 5, Title: sql.NullString{String: "Pytest fixtures setup", Valid: true}},
|
||||
{ID: 6, Title: sql.NullString{String: "Docker containers orchestration", Valid: true}},
|
||||
{ID: 7, Title: sql.NullString{String: "Prometheus metrics collection", Valid: true}},
|
||||
{ID: 8, Title: sql.NullString{String: "OWASP vulnerability scanning", Valid: true}},
|
||||
{ID: 9, Title: sql.NullString{String: "Goroutines parallel execution", Valid: true}},
|
||||
{ID: 10, Title: sql.NullString{String: "Kubernetes horizontal scaling", Valid: true}},
|
||||
}
|
||||
|
||||
clustered := ClusterObservations(observations, 0.4)
|
||||
|
||||
// With unique content, all 10 should be kept (previously would have been capped at 5)
|
||||
assert.Len(t, clustered, 10, "Should return all 10 unique observations, not limited to 5")
|
||||
}
|
||||
|
||||
func TestClusterObservations_PreservesOrder(t *testing.T) {
|
||||
// The first observation in each cluster should be kept
|
||||
observations := []*models.Observation{
|
||||
{ID: 1, Title: sql.NullString{String: "First auth observation", Valid: true}},
|
||||
{ID: 2, Title: sql.NullString{String: "Second auth observation", Valid: true}},
|
||||
{ID: 3, Title: sql.NullString{String: "Database observation", Valid: true}},
|
||||
}
|
||||
|
||||
clustered := ClusterObservations(observations, 0.4)
|
||||
|
||||
// First observation should always be first in result
|
||||
require.NotEmpty(t, clustered)
|
||||
assert.Equal(t, int64(1), clustered[0].ID, "First observation should be kept as first result")
|
||||
}
|
||||
Reference in New Issue
Block a user