From 5215ee861722bde40e6a92f4b8a0d58fea2c3788 Mon Sep 17 00:00:00 2001 From: Lukasz Raczylo Date: Sun, 11 Jan 2026 12:25:58 +0000 Subject: [PATCH] test: add comprehensive test coverage across multiple packages - [x] Add 298 tests for Python chunker functionality - [x] Add 213 tests for chunking types and constants - [x] Add 398 tests for TypeScript/JavaScript chunker - [x] Add 954 tests for MCP server handlers and validation - [x] Add 563 tests for pattern detector and analysis - [x] Add 1149 tests for vector client cache and operations - [x] Add 663 tests for SDK processor, circuit breaker, and deduplication - [x] Add 731 tests for session manager lifecycle and concurrency - [x] Add 331 tests for similarity clustering and term extraction --- internal/chunking/python/chunker_test.go | 298 +++ internal/chunking/types_test.go | 213 ++ internal/chunking/typescript/chunker_test.go | 398 ++++ internal/mcp/server_test.go | 2141 ++++++++++++++++-- internal/pattern/detector_test.go | 563 +++++ internal/vector/sqlitevec/client_test.go | 1149 ++++++++++ internal/worker/sdk/processor_test.go | 663 ++++++ internal/worker/session/manager_test.go | 731 ++++++ pkg/similarity/clustering_test.go | 331 +++ 9 files changed, 6269 insertions(+), 218 deletions(-) create mode 100644 internal/chunking/python/chunker_test.go create mode 100644 internal/chunking/types_test.go create mode 100644 internal/chunking/typescript/chunker_test.go diff --git a/internal/chunking/python/chunker_test.go b/internal/chunking/python/chunker_test.go new file mode 100644 index 0000000..b68e465 --- /dev/null +++ b/internal/chunking/python/chunker_test.go @@ -0,0 +1,298 @@ +package python + +import ( + "context" + "os" + "path/filepath" + "testing" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + + "github.com/lukaszraczylo/claude-mnemonic/internal/chunking" +) + +// ============================================================================= +// TEST HELPERS +// ============================================================================= + +func createTempPythonFile(t *testing.T, content string) string { + t.Helper() + + tmpDir := t.TempDir() + filePath := filepath.Join(tmpDir, "test.py") + + err := os.WriteFile(filePath, []byte(content), 0600) + require.NoError(t, err) + + return filePath +} + +// ============================================================================= +// TESTS FOR Chunker +// ============================================================================= + +func TestNewChunker(t *testing.T) { + t.Parallel() + + opts := chunking.DefaultChunkOptions() + c := NewChunker(opts) + + assert.NotNil(t, c) + assert.NotNil(t, c.parser) +} + +func TestChunker_Language(t *testing.T) { + t.Parallel() + + c := NewChunker(chunking.DefaultChunkOptions()) + + assert.Equal(t, chunking.LanguagePython, c.Language()) +} + +func TestChunker_SupportedExtensions(t *testing.T) { + t.Parallel() + + c := NewChunker(chunking.DefaultChunkOptions()) + exts := c.SupportedExtensions() + + assert.Contains(t, exts, ".py") +} + +func TestChunker_Chunk_SimpleFunction(t *testing.T) { + t.Parallel() + + code := `def greet(name): + """Greets a person by name.""" + return f"Hello, {name}!" +` + + filePath := createTempPythonFile(t, code) + c := NewChunker(chunking.DefaultChunkOptions()) + + chunks, err := c.Chunk(context.Background(), filePath) + require.NoError(t, err) + require.NotEmpty(t, chunks) + + // Should find the greet function + var foundGreet bool + for _, chunk := range chunks { + if chunk.Name == "greet" { + foundGreet = true + assert.Equal(t, chunking.ChunkTypeFunction, chunk.Type) + assert.Equal(t, chunking.LanguagePython, chunk.Language) + assert.Contains(t, chunk.Content, "def greet") + } + } + assert.True(t, foundGreet, "Should find 'greet' function") +} + +func TestChunker_Chunk_ClassWithMethods(t *testing.T) { + t.Parallel() + + code := `class Calculator: + """A simple calculator class.""" + + def add(self, a, b): + """Adds two numbers.""" + return a + b + + def multiply(self, a, b): + """Multiplies two numbers.""" + return a * b +` + + filePath := createTempPythonFile(t, code) + c := NewChunker(chunking.DefaultChunkOptions()) + + chunks, err := c.Chunk(context.Background(), filePath) + require.NoError(t, err) + require.NotEmpty(t, chunks) + + // Should find the Calculator class and its methods + var foundClass, foundAdd, foundMultiply bool + for _, chunk := range chunks { + switch chunk.Name { + case "Calculator": + foundClass = true + assert.Equal(t, chunking.ChunkTypeClass, chunk.Type) + case "add": + foundAdd = true + assert.Equal(t, chunking.ChunkTypeMethod, chunk.Type) + assert.Equal(t, "Calculator", chunk.ParentName) + case "multiply": + foundMultiply = true + assert.Equal(t, chunking.ChunkTypeMethod, chunk.Type) + assert.Equal(t, "Calculator", chunk.ParentName) + } + } + + assert.True(t, foundClass, "Should find 'Calculator' class") + assert.True(t, foundAdd, "Should find 'add' method") + assert.True(t, foundMultiply, "Should find 'multiply' method") +} + +func TestChunker_Chunk_MultipleFunctions(t *testing.T) { + t.Parallel() + + code := `def first_function(): + pass + +def second_function(x, y): + return x + y + +def third_function(): + """Has a docstring.""" + return 42 +` + + filePath := createTempPythonFile(t, code) + c := NewChunker(chunking.DefaultChunkOptions()) + + chunks, err := c.Chunk(context.Background(), filePath) + require.NoError(t, err) + + // Should find all three functions + functionNames := make(map[string]bool) + for _, chunk := range chunks { + if chunk.Type == chunking.ChunkTypeFunction { + functionNames[chunk.Name] = true + } + } + + assert.True(t, functionNames["first_function"]) + assert.True(t, functionNames["second_function"]) + assert.True(t, functionNames["third_function"]) +} + +func TestChunker_Chunk_FileNotFound(t *testing.T) { + t.Parallel() + + c := NewChunker(chunking.DefaultChunkOptions()) + + _, err := c.Chunk(context.Background(), "/nonexistent/path/file.py") + require.Error(t, err) + assert.Contains(t, err.Error(), "read file") +} + +func TestChunker_Chunk_EmptyFile(t *testing.T) { + t.Parallel() + + filePath := createTempPythonFile(t, "") + c := NewChunker(chunking.DefaultChunkOptions()) + + chunks, err := c.Chunk(context.Background(), filePath) + require.NoError(t, err) + assert.Empty(t, chunks) +} + +func TestChunker_Chunk_OnlyComments(t *testing.T) { + t.Parallel() + + code := `# This is a comment +# Another comment +""" +This is a module docstring +""" +` + + filePath := createTempPythonFile(t, code) + c := NewChunker(chunking.DefaultChunkOptions()) + + chunks, err := c.Chunk(context.Background(), filePath) + require.NoError(t, err) + // Comments and docstrings without code should not produce chunks + assert.Empty(t, chunks) +} + +func TestChunker_Chunk_NestedClass(t *testing.T) { + t.Parallel() + + code := `class Outer: + class Inner: + def inner_method(self): + pass + + def outer_method(self): + pass +` + + filePath := createTempPythonFile(t, code) + c := NewChunker(chunking.DefaultChunkOptions()) + + chunks, err := c.Chunk(context.Background(), filePath) + require.NoError(t, err) + require.NotEmpty(t, chunks) + + // Should find the Outer class at minimum + var foundOuter bool + for _, chunk := range chunks { + if chunk.Name == "Outer" { + foundOuter = true + } + } + assert.True(t, foundOuter, "Should find 'Outer' class") +} + +func TestChunker_Chunk_Decorators(t *testing.T) { + t.Parallel() + + code := `@staticmethod +def static_func(): + pass + +@classmethod +def class_func(cls): + pass + +@property +def my_property(self): + return self._value +` + + filePath := createTempPythonFile(t, code) + c := NewChunker(chunking.DefaultChunkOptions()) + + chunks, err := c.Chunk(context.Background(), filePath) + require.NoError(t, err) + require.NotEmpty(t, chunks) + + // Should find decorated functions + functionNames := make(map[string]bool) + for _, chunk := range chunks { + functionNames[chunk.Name] = true + } + + assert.True(t, functionNames["static_func"]) + assert.True(t, functionNames["class_func"]) + assert.True(t, functionNames["my_property"]) +} + +func TestChunker_Chunk_AsyncFunction(t *testing.T) { + t.Parallel() + + code := `async def fetch_data(url): + """Fetches data from URL asynchronously.""" + pass + +async def process_items(items): + for item in items: + await process(item) +` + + filePath := createTempPythonFile(t, code) + c := NewChunker(chunking.DefaultChunkOptions()) + + chunks, err := c.Chunk(context.Background(), filePath) + require.NoError(t, err) + require.NotEmpty(t, chunks) + + // Should find async functions + functionNames := make(map[string]bool) + for _, chunk := range chunks { + functionNames[chunk.Name] = true + } + + assert.True(t, functionNames["fetch_data"]) + assert.True(t, functionNames["process_items"]) +} diff --git a/internal/chunking/types_test.go b/internal/chunking/types_test.go new file mode 100644 index 0000000..81cff09 --- /dev/null +++ b/internal/chunking/types_test.go @@ -0,0 +1,213 @@ +package chunking + +import ( + "testing" + + "github.com/stretchr/testify/assert" +) + +// ============================================================================= +// TESTS FOR Chunk METHODS +// ============================================================================= + +func TestChunk_Identifier(t *testing.T) { + tests := []struct { + name string + expected string + chunk Chunk + }{ + // ===== GOOD CASES ===== + { + name: "top-level function", + chunk: Chunk{ + Name: "MyFunction", + ParentName: "", + }, + expected: "MyFunction", + }, + { + name: "method with parent", + chunk: Chunk{ + Name: "Process", + ParentName: "Handler", + }, + expected: "Handler.Process", + }, + { + name: "nested method", + chunk: Chunk{ + Name: "Validate", + ParentName: "UserService", + }, + expected: "UserService.Validate", + }, + + // ===== EDGE CASES ===== + { + name: "empty name", + chunk: Chunk{ + Name: "", + ParentName: "", + }, + expected: "", + }, + { + name: "parent but no name", + chunk: Chunk{ + Name: "", + ParentName: "Parent", + }, + expected: "Parent.", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + result := tt.chunk.Identifier() + assert.Equal(t, tt.expected, result) + }) + } +} + +func TestChunk_LineRange(t *testing.T) { + tests := []struct { + name string + expected string + chunk Chunk + }{ + // ===== GOOD CASES ===== + { + name: "single line", + chunk: Chunk{ + StartLine: 10, + EndLine: 10, + }, + expected: "L10-L10", + }, + { + name: "multi-line", + chunk: Chunk{ + StartLine: 25, + EndLine: 50, + }, + expected: "L25-L50", + }, + + // ===== EDGE CASES ===== + { + name: "line 1", + chunk: Chunk{ + StartLine: 1, + EndLine: 5, + }, + expected: "L1-L5", + }, + { + name: "large line numbers", + chunk: Chunk{ + StartLine: 1000, + EndLine: 2500, + }, + expected: "L1000-L2500", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + result := tt.chunk.LineRange() + assert.Equal(t, tt.expected, result) + }) + } +} + +func TestChunk_SearchableContent(t *testing.T) { + tests := []struct { + name string + contains []string + chunk Chunk + }{ + // ===== GOOD CASES ===== + { + name: "full chunk with all fields", + chunk: Chunk{ + Signature: "func ProcessData(input []byte) error", + DocComment: "// ProcessData handles incoming data", + Content: "func ProcessData(input []byte) error {\n\treturn nil\n}", + }, + contains: []string{ + "func ProcessData(input []byte) error", + "ProcessData handles incoming data", + "return nil", + }, + }, + { + name: "only signature", + chunk: Chunk{ + Signature: "func Hello()", + }, + contains: []string{"func Hello()"}, + }, + { + name: "only content", + chunk: Chunk{ + Content: "some code here", + }, + contains: []string{"some code here"}, + }, + + // ===== EDGE CASES ===== + { + name: "empty chunk", + chunk: Chunk{}, + contains: []string{}, + }, + { + name: "only doc comment", + chunk: Chunk{ + DocComment: "// Important documentation", + }, + contains: []string{"Important documentation"}, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + result := tt.chunk.SearchableContent() + for _, expected := range tt.contains { + assert.Contains(t, result, expected) + } + }) + } +} + +func TestDefaultChunkOptions(t *testing.T) { + opts := DefaultChunkOptions() + + assert.Greater(t, opts.MaxChunkSize, 0, "MaxChunkSize should be positive") + assert.True(t, opts.IncludeDocComments, "IncludeDocComments should be true by default") + assert.True(t, opts.IncludePrivate, "IncludePrivate should be true by default") + assert.Equal(t, 0, opts.MinLines, "MinLines should be 0 by default") +} + +// ============================================================================= +// TESTS FOR ChunkType AND Language CONSTANTS +// ============================================================================= + +func TestChunkType_Values(t *testing.T) { + // Ensure all chunk types have expected values + assert.Equal(t, ChunkType("function"), ChunkTypeFunction) + assert.Equal(t, ChunkType("method"), ChunkTypeMethod) + assert.Equal(t, ChunkType("class"), ChunkTypeClass) + assert.Equal(t, ChunkType("interface"), ChunkTypeInterface) + assert.Equal(t, ChunkType("type"), ChunkTypeType) + assert.Equal(t, ChunkType("const"), ChunkTypeConst) + assert.Equal(t, ChunkType("var"), ChunkTypeVar) +} + +func TestLanguage_Values(t *testing.T) { + // Ensure all language types have expected values + assert.Equal(t, Language("go"), LanguageGo) + assert.Equal(t, Language("python"), LanguagePython) + assert.Equal(t, Language("typescript"), LanguageTypeScript) + assert.Equal(t, Language("javascript"), LanguageJavaScript) +} diff --git a/internal/chunking/typescript/chunker_test.go b/internal/chunking/typescript/chunker_test.go new file mode 100644 index 0000000..108f76a --- /dev/null +++ b/internal/chunking/typescript/chunker_test.go @@ -0,0 +1,398 @@ +package typescript + +import ( + "context" + "os" + "path/filepath" + "testing" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + + "github.com/lukaszraczylo/claude-mnemonic/internal/chunking" +) + +// ============================================================================= +// TEST HELPERS +// ============================================================================= + +func createTempTSFile(t *testing.T, content string, ext string) string { + t.Helper() + + tmpDir := t.TempDir() + filePath := filepath.Join(tmpDir, "test"+ext) + + err := os.WriteFile(filePath, []byte(content), 0600) + require.NoError(t, err) + + return filePath +} + +// ============================================================================= +// TESTS FOR Chunker +// ============================================================================= + +func TestNewChunker(t *testing.T) { + t.Parallel() + + opts := chunking.DefaultChunkOptions() + c := NewChunker(opts) + + assert.NotNil(t, c) + assert.NotNil(t, c.parser) +} + +func TestChunker_Language(t *testing.T) { + t.Parallel() + + c := NewChunker(chunking.DefaultChunkOptions()) + + assert.Equal(t, chunking.LanguageTypeScript, c.Language()) +} + +func TestChunker_SupportedExtensions(t *testing.T) { + t.Parallel() + + c := NewChunker(chunking.DefaultChunkOptions()) + exts := c.SupportedExtensions() + + assert.Contains(t, exts, ".ts") + assert.Contains(t, exts, ".tsx") + assert.Contains(t, exts, ".js") + assert.Contains(t, exts, ".jsx") +} + +func TestChunker_Chunk_SimpleFunction(t *testing.T) { + t.Parallel() + + code := `function greet(name: string): string { + return "Hello, " + name + "!"; +} +` + + filePath := createTempTSFile(t, code, ".ts") + c := NewChunker(chunking.DefaultChunkOptions()) + + chunks, err := c.Chunk(context.Background(), filePath) + require.NoError(t, err) + require.NotEmpty(t, chunks) + + // Should find the greet function + var foundGreet bool + for _, chunk := range chunks { + if chunk.Name == "greet" { + foundGreet = true + assert.Equal(t, chunking.ChunkTypeFunction, chunk.Type) + assert.Equal(t, chunking.LanguageTypeScript, chunk.Language) + assert.Contains(t, chunk.Content, "function greet") + } + } + assert.True(t, foundGreet, "Should find 'greet' function") +} + +func TestChunker_Chunk_ClassWithMethods(t *testing.T) { + t.Parallel() + + code := `class Calculator { + add(a: number, b: number): number { + return a + b; + } + + multiply(a: number, b: number): number { + return a * b; + } +} +` + + filePath := createTempTSFile(t, code, ".ts") + c := NewChunker(chunking.DefaultChunkOptions()) + + chunks, err := c.Chunk(context.Background(), filePath) + require.NoError(t, err) + require.NotEmpty(t, chunks) + + // Should find the Calculator class and its methods + var foundClass, foundAdd, foundMultiply bool + for _, chunk := range chunks { + switch chunk.Name { + case "Calculator": + foundClass = true + assert.Equal(t, chunking.ChunkTypeClass, chunk.Type) + case "add": + foundAdd = true + assert.Equal(t, chunking.ChunkTypeMethod, chunk.Type) + assert.Equal(t, "Calculator", chunk.ParentName) + case "multiply": + foundMultiply = true + assert.Equal(t, chunking.ChunkTypeMethod, chunk.Type) + assert.Equal(t, "Calculator", chunk.ParentName) + } + } + + assert.True(t, foundClass, "Should find 'Calculator' class") + assert.True(t, foundAdd, "Should find 'add' method") + assert.True(t, foundMultiply, "Should find 'multiply' method") +} + +func TestChunker_Chunk_Interface(t *testing.T) { + t.Parallel() + + code := `interface User { + id: number; + name: string; + email: string; +} + +interface Authenticator { + login(username: string, password: string): boolean; + logout(): void; +} +` + + filePath := createTempTSFile(t, code, ".ts") + c := NewChunker(chunking.DefaultChunkOptions()) + + chunks, err := c.Chunk(context.Background(), filePath) + require.NoError(t, err) + require.NotEmpty(t, chunks) + + // Should find interfaces + interfaceNames := make(map[string]bool) + for _, chunk := range chunks { + if chunk.Type == chunking.ChunkTypeInterface { + interfaceNames[chunk.Name] = true + } + } + + assert.True(t, interfaceNames["User"]) + assert.True(t, interfaceNames["Authenticator"]) +} + +func TestChunker_Chunk_TypeAlias(t *testing.T) { + t.Parallel() + + code := `type UserID = string; + +type Handler = (event: Event) => void; + +type Result = { success: true; data: T } | { success: false; error: Error }; +` + + filePath := createTempTSFile(t, code, ".ts") + c := NewChunker(chunking.DefaultChunkOptions()) + + chunks, err := c.Chunk(context.Background(), filePath) + require.NoError(t, err) + require.NotEmpty(t, chunks) + + // Should find type aliases + typeNames := make(map[string]bool) + for _, chunk := range chunks { + if chunk.Type == chunking.ChunkTypeType { + typeNames[chunk.Name] = true + } + } + + assert.True(t, typeNames["UserID"]) + assert.True(t, typeNames["Handler"]) + assert.True(t, typeNames["Result"]) +} + +func TestChunker_Chunk_ArrowFunction(t *testing.T) { + t.Parallel() + + code := `const add = (a: number, b: number): number => a + b; + +const greet = (name: string): string => { + return "Hello, " + name; +}; +` + + filePath := createTempTSFile(t, code, ".ts") + c := NewChunker(chunking.DefaultChunkOptions()) + + _, err := c.Chunk(context.Background(), filePath) + require.NoError(t, err) + // Arrow functions may or may not be captured depending on AST structure + // At minimum, no error should occur +} + +func TestChunker_Chunk_FileNotFound(t *testing.T) { + t.Parallel() + + c := NewChunker(chunking.DefaultChunkOptions()) + + _, err := c.Chunk(context.Background(), "/nonexistent/path/file.ts") + require.Error(t, err) + assert.Contains(t, err.Error(), "read file") +} + +func TestChunker_Chunk_EmptyFile(t *testing.T) { + t.Parallel() + + filePath := createTempTSFile(t, "", ".ts") + c := NewChunker(chunking.DefaultChunkOptions()) + + chunks, err := c.Chunk(context.Background(), filePath) + require.NoError(t, err) + assert.Empty(t, chunks) +} + +func TestChunker_Chunk_OnlyComments(t *testing.T) { + t.Parallel() + + code := `// This is a comment +/* Another comment */ +/** + * JSDoc comment + */ +` + + filePath := createTempTSFile(t, code, ".ts") + c := NewChunker(chunking.DefaultChunkOptions()) + + chunks, err := c.Chunk(context.Background(), filePath) + require.NoError(t, err) + // Comments without code should not produce chunks + assert.Empty(t, chunks) +} + +func TestChunker_Chunk_AsyncFunction(t *testing.T) { + t.Parallel() + + code := `async function fetchData(url: string): Promise { + const response = await fetch(url); + return response.json(); +} + +async function processItems(items: string[]): Promise { + for (const item of items) { + await process(item); + } +} +` + + filePath := createTempTSFile(t, code, ".ts") + c := NewChunker(chunking.DefaultChunkOptions()) + + chunks, err := c.Chunk(context.Background(), filePath) + require.NoError(t, err) + require.NotEmpty(t, chunks) + + // Should find async functions + functionNames := make(map[string]bool) + for _, chunk := range chunks { + if chunk.Type == chunking.ChunkTypeFunction { + functionNames[chunk.Name] = true + } + } + + assert.True(t, functionNames["fetchData"]) + assert.True(t, functionNames["processItems"]) +} + +func TestChunker_Chunk_ExportedFunction(t *testing.T) { + t.Parallel() + + code := `export function publicFunction(): void { + console.log("public"); +} + +export default function defaultExport(): void { + console.log("default"); +} +` + + filePath := createTempTSFile(t, code, ".ts") + c := NewChunker(chunking.DefaultChunkOptions()) + + chunks, err := c.Chunk(context.Background(), filePath) + require.NoError(t, err) + require.NotEmpty(t, chunks) + + // Should find exported functions + functionNames := make(map[string]bool) + for _, chunk := range chunks { + if chunk.Type == chunking.ChunkTypeFunction { + functionNames[chunk.Name] = true + } + } + + assert.True(t, functionNames["publicFunction"]) + assert.True(t, functionNames["defaultExport"]) +} + +func TestChunker_Chunk_JSXFile(t *testing.T) { + t.Parallel() + + code := `function Button({ label }: { label: string }) { + return ; +} + +function App() { + return ( +
+
+ ); +} +` + + filePath := createTempTSFile(t, code, ".tsx") + c := NewChunker(chunking.DefaultChunkOptions()) + + chunks, err := c.Chunk(context.Background(), filePath) + require.NoError(t, err) + require.NotEmpty(t, chunks) + + // Should find JSX components as functions + functionNames := make(map[string]bool) + for _, chunk := range chunks { + if chunk.Type == chunking.ChunkTypeFunction { + functionNames[chunk.Name] = true + } + } + + assert.True(t, functionNames["Button"]) + assert.True(t, functionNames["App"]) +} + +func TestChunker_Chunk_JavaScript(t *testing.T) { + t.Parallel() + + code := `function simpleFunc() { + return 42; +} + +class MyClass { + constructor() { + this.value = 0; + } + + getValue() { + return this.value; + } +} +` + + filePath := createTempTSFile(t, code, ".js") + c := NewChunker(chunking.DefaultChunkOptions()) + + chunks, err := c.Chunk(context.Background(), filePath) + require.NoError(t, err) + require.NotEmpty(t, chunks) + + // Should find JavaScript functions and classes + var foundFunc, foundClass bool + for _, chunk := range chunks { + if chunk.Name == "simpleFunc" { + foundFunc = true + } + if chunk.Name == "MyClass" { + foundClass = true + } + } + + assert.True(t, foundFunc, "Should find 'simpleFunc' function") + assert.True(t, foundClass, "Should find 'MyClass' class") +} diff --git a/internal/mcp/server_test.go b/internal/mcp/server_test.go index da0335f..dc898e0 100644 --- a/internal/mcp/server_test.go +++ b/internal/mcp/server_test.go @@ -13,6 +13,10 @@ import ( "github.com/stretchr/testify/suite" ) +// ============================================================================= +// TEST SUITE +// ============================================================================= + // ServerSuite is a test suite for MCP Server operations. type ServerSuite struct { suite.Suite @@ -30,13 +34,20 @@ func (s *ServerSuite) TestNewServer() { s.Equal("1.0.0", server.version) } +// ============================================================================= +// TESTS FOR Request/Response Structs +// ============================================================================= + // TestRequest tests Request struct JSON marshaling. func TestRequest(t *testing.T) { + t.Parallel() + tests := []struct { name string expected string req Request }{ + // ===== GOOD CASES ===== { name: "initialize request", req: Request{ @@ -65,10 +76,30 @@ func TestRequest(t *testing.T) { }, expected: `{"jsonrpc":"2.0","id":2,"method":"tools/call","params":{"name":"search","arguments":{}}}`, }, + // ===== EDGE CASES ===== + { + name: "request with nil ID", + req: Request{ + JSONRPC: "2.0", + ID: nil, + Method: "initialize", + }, + expected: `{"jsonrpc":"2.0","id":null,"method":"initialize"}`, + }, + { + name: "request with float ID", + req: Request{ + JSONRPC: "2.0", + ID: 1.5, + Method: "test", + }, + expected: `{"jsonrpc":"2.0","id":1.5,"method":"test"}`, + }, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { + t.Parallel() data, err := json.Marshal(tt.req) require.NoError(t, err) assert.JSONEq(t, tt.expected, string(data)) @@ -85,11 +116,14 @@ func TestRequest(t *testing.T) { // TestResponse tests Response struct JSON marshaling. func TestResponse(t *testing.T) { + t.Parallel() + tests := []struct { name string resp Response expected string }{ + // ===== GOOD CASES ===== { name: "success response", resp: Response{ @@ -124,10 +158,21 @@ func TestResponse(t *testing.T) { }, expected: `{"jsonrpc":"2.0","id":3,"error":{"code":-32602,"message":"Invalid params","data":"missing field"}}`, }, + // ===== EDGE CASES ===== + { + name: "response with nil ID", + resp: Response{ + JSONRPC: "2.0", + ID: nil, + Result: "ok", + }, + expected: `{"jsonrpc":"2.0","id":null,"result":"ok"}`, + }, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { + t.Parallel() data, err := json.Marshal(tt.resp) require.NoError(t, err) assert.JSONEq(t, tt.expected, string(data)) @@ -137,6 +182,8 @@ func TestResponse(t *testing.T) { // TestError tests Error struct. func TestError(t *testing.T) { + t.Parallel() + tests := []struct { expected string name string @@ -171,6 +218,7 @@ func TestError(t *testing.T) { for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { + t.Parallel() data, err := json.Marshal(tt.err) require.NoError(t, err) assert.JSONEq(t, tt.expected, string(data)) @@ -180,6 +228,8 @@ func TestError(t *testing.T) { // TestToolCallParams tests ToolCallParams struct. func TestToolCallParams(t *testing.T) { + t.Parallel() + tests := []struct { name string input string @@ -205,6 +255,7 @@ func TestToolCallParams(t *testing.T) { for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { + t.Parallel() var params ToolCallParams err := json.Unmarshal([]byte(tt.input), ¶ms) require.NoError(t, err) @@ -215,6 +266,8 @@ func TestToolCallParams(t *testing.T) { // TestTool tests Tool struct. func TestTool(t *testing.T) { + t.Parallel() + tool := Tool{ Name: "search", Description: "Search observations", @@ -238,6 +291,8 @@ func TestTool(t *testing.T) { // TestTimelineParams tests TimelineParams struct. func TestTimelineParams(t *testing.T) { + t.Parallel() + tests := []struct { name string input string @@ -281,6 +336,7 @@ func TestTimelineParams(t *testing.T) { for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { + t.Parallel() var params TimelineParams err := json.Unmarshal([]byte(tt.input), ¶ms) require.NoError(t, err) @@ -291,8 +347,14 @@ func TestTimelineParams(t *testing.T) { } } +// ============================================================================= +// TESTS FOR Server Handlers +// ============================================================================= + // TestHandleInitialize tests the initialize handler. func TestHandleInitialize(t *testing.T) { + t.Parallel() + server := NewServer(nil, "1.2.3", nil, nil, nil, nil, nil, nil, nil, nil) req := &Request{ @@ -320,6 +382,8 @@ func TestHandleInitialize(t *testing.T) { // TestHandleToolsList tests the tools/list handler. func TestHandleToolsList(t *testing.T) { + t.Parallel() + server := NewServer(nil, "1.0.0", nil, nil, nil, nil, nil, nil, nil, nil) req := &Request{ @@ -361,6 +425,8 @@ func TestHandleToolsList(t *testing.T) { // TestHandleRequest tests request routing. func TestHandleRequest(t *testing.T) { + t.Parallel() + server := NewServer(nil, "1.0.0", nil, nil, nil, nil, nil, nil, nil, nil) ctx := context.Background() @@ -404,6 +470,7 @@ func TestHandleRequest(t *testing.T) { for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { + t.Parallel() resp := server.handleRequest(ctx, tt.req) assert.Equal(t, "2.0", resp.JSONRPC) @@ -423,6 +490,8 @@ func TestHandleRequest(t *testing.T) { // TestHandleToolsCall_InvalidParams tests tools/call with invalid params. func TestHandleToolsCall_InvalidParams(t *testing.T) { + t.Parallel() + server := NewServer(nil, "1.0.0", nil, nil, nil, nil, nil, nil, nil, nil) ctx := context.Background() @@ -442,6 +511,8 @@ func TestHandleToolsCall_InvalidParams(t *testing.T) { // TestCallTool_UnknownTool tests callTool with unknown tool name. func TestCallTool_UnknownTool(t *testing.T) { + t.Parallel() + server := NewServer(nil, "1.0.0", nil, nil, nil, nil, nil, nil, nil, nil) ctx := context.Background() @@ -452,6 +523,8 @@ func TestCallTool_UnknownTool(t *testing.T) { // TestCallTool_InvalidArgs tests callTool with invalid arguments. func TestCallTool_InvalidArgs(t *testing.T) { + t.Parallel() + server := NewServer(nil, "1.0.0", nil, nil, nil, nil, nil, nil, nil, nil) ctx := context.Background() @@ -460,8 +533,14 @@ func TestCallTool_InvalidArgs(t *testing.T) { assert.Contains(t, err.Error(), "invalid arguments") } +// ============================================================================= +// TESTS FOR Server I/O +// ============================================================================= + // TestSendResponse tests response sending. func TestSendResponse(t *testing.T) { + t.Parallel() + var buf bytes.Buffer server := &Server{ stdout: &buf, @@ -483,6 +562,8 @@ func TestSendResponse(t *testing.T) { // TestSendError tests error response sending. func TestSendError(t *testing.T) { + t.Parallel() + var buf bytes.Buffer server := &Server{ stdout: &buf, @@ -498,6 +579,8 @@ func TestSendError(t *testing.T) { // TestRun_ParseError tests Run with invalid JSON input. func TestRun_ParseError(t *testing.T) { + t.Parallel() + var stdout bytes.Buffer stdin := strings.NewReader("invalid json\n") @@ -517,6 +600,8 @@ func TestRun_ParseError(t *testing.T) { // TestRun_EmptyLine tests Run skips empty lines. func TestRun_EmptyLine(t *testing.T) { + t.Parallel() + var stdout bytes.Buffer stdin := strings.NewReader("\n\n") @@ -534,6 +619,8 @@ func TestRun_EmptyLine(t *testing.T) { // TestRun_ValidRequest tests Run with a valid request. func TestRun_ValidRequest(t *testing.T) { + t.Parallel() + var stdout bytes.Buffer req := `{"jsonrpc":"2.0","id":1,"method":"initialize"}` stdin := strings.NewReader(req + "\n") @@ -553,8 +640,954 @@ func TestRun_ValidRequest(t *testing.T) { assert.Contains(t, output, `"protocolVersion"`) } +// TestRun_MultipleRequests tests Run with multiple sequential requests. +func TestRun_MultipleRequests(t *testing.T) { + t.Parallel() + + var stdout bytes.Buffer + req1 := `{"jsonrpc":"2.0","id":1,"method":"initialize"}` + req2 := `{"jsonrpc":"2.0","id":2,"method":"tools/list"}` + stdin := strings.NewReader(req1 + "\n" + req2 + "\n") + + server := &Server{ + stdin: stdin, + stdout: &stdout, + version: "1.0.0", + } + + err := server.Run(context.Background()) + require.NoError(t, err) + + output := stdout.String() + // Should contain responses for both requests + assert.Contains(t, output, `"id":1`) + assert.Contains(t, output, `"id":2`) +} + +// TestRunMixedRequests tests Run with mixed valid and invalid requests. +func TestRunMixedRequests(t *testing.T) { + t.Parallel() + + var stdout bytes.Buffer + req1 := `{"jsonrpc":"2.0","id":1,"method":"initialize"}` + req2 := `invalid json` + req3 := `{"jsonrpc":"2.0","id":3,"method":"tools/list"}` + stdin := strings.NewReader(req1 + "\n" + req2 + "\n" + req3 + "\n") + + server := &Server{ + stdin: stdin, + stdout: &stdout, + version: "1.0.0", + } + + err := server.Run(context.Background()) + require.NoError(t, err) + + output := stdout.String() + // Should have responses for all three requests + assert.Contains(t, output, `"id":1`) + assert.Contains(t, output, `"error"`) // Parse error for invalid json + assert.Contains(t, output, `"id":3`) +} + +// ============================================================================= +// TESTS FOR Handler Parameter Validation +// ============================================================================= + +// TestHandleFindRelatedObservations_Validation tests parameter validation. +func TestHandleFindRelatedObservations_Validation(t *testing.T) { + t.Parallel() + + server := NewServer(nil, "1.0.0", nil, nil, nil, nil, nil, nil, nil, nil) + ctx := context.Background() + + tests := []struct { + name string + args string + errContains string + wantErr bool + }{ + { + name: "missing id", + args: `{}`, + wantErr: true, + errContains: "id is required", + }, + { + name: "invalid json", + args: `{invalid`, + wantErr: true, + errContains: "invalid arguments", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + t.Parallel() + _, err := server.handleFindRelatedObservations(ctx, json.RawMessage(tt.args)) + if tt.wantErr { + require.Error(t, err) + if tt.errContains != "" { + assert.Contains(t, err.Error(), tt.errContains) + } + } else { + require.NoError(t, err) + } + }) + } +} + +// TestHandleFindSimilarObservations_Validation tests parameter validation. +func TestHandleFindSimilarObservations_Validation(t *testing.T) { + t.Parallel() + + server := NewServer(nil, "1.0.0", nil, nil, nil, nil, nil, nil, nil, nil) + ctx := context.Background() + + tests := []struct { + name string + args string + errContains string + wantErr bool + }{ + { + name: "missing query", + args: `{}`, + wantErr: true, + errContains: "query is required", + }, + { + name: "invalid json", + args: `{invalid`, + wantErr: true, + errContains: "invalid arguments", + }, + { + name: "nil vector client", + args: `{"query": "test"}`, + wantErr: true, + errContains: "vector search not available", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + t.Parallel() + _, err := server.handleFindSimilarObservations(ctx, json.RawMessage(tt.args)) + if tt.wantErr { + require.Error(t, err) + if tt.errContains != "" { + assert.Contains(t, err.Error(), tt.errContains) + } + } else { + require.NoError(t, err) + } + }) + } +} + +// TestHandleGetPatterns_Validation tests parameter validation. +func TestHandleGetPatterns_Validation(t *testing.T) { + t.Parallel() + + server := NewServer(nil, "1.0.0", nil, nil, nil, nil, nil, nil, nil, nil) + ctx := context.Background() + + tests := []struct { + name string + args string + errContains string + wantErr bool + }{ + { + name: "invalid json", + args: `{invalid`, + wantErr: true, + errContains: "invalid arguments", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + t.Parallel() + _, err := server.handleGetPatterns(ctx, json.RawMessage(tt.args)) + if tt.wantErr { + require.Error(t, err) + if tt.errContains != "" { + assert.Contains(t, err.Error(), tt.errContains) + } + } else { + require.NoError(t, err) + } + }) + } +} + +// TestHandleBulkDeleteObservations_Validation tests parameter validation. +func TestHandleBulkDeleteObservations_Validation(t *testing.T) { + t.Parallel() + + server := NewServer(nil, "1.0.0", nil, nil, nil, nil, nil, nil, nil, nil) + ctx := context.Background() + + tests := []struct { + name string + args string + errContains string + wantErr bool + }{ + { + name: "missing ids", + args: `{}`, + wantErr: true, + errContains: "ids is required", + }, + { + name: "empty ids array", + args: `{"ids": []}`, + wantErr: true, + errContains: "ids is required", + }, + { + name: "too many ids", + args: `{"ids": [` + strings.Repeat("1,", 1001) + `1]}`, + wantErr: true, + errContains: "maximum 1000 IDs", + }, + { + name: "invalid json", + args: `{invalid`, + wantErr: true, + errContains: "invalid arguments", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + t.Parallel() + _, err := server.handleBulkDeleteObservations(ctx, json.RawMessage(tt.args)) + if tt.wantErr { + require.Error(t, err) + if tt.errContains != "" { + assert.Contains(t, err.Error(), tt.errContains) + } + } else { + require.NoError(t, err) + } + }) + } +} + +// TestHandleBulkMarkSuperseded_Validation tests parameter validation. +func TestHandleBulkMarkSuperseded_Validation(t *testing.T) { + t.Parallel() + + server := NewServer(nil, "1.0.0", nil, nil, nil, nil, nil, nil, nil, nil) + ctx := context.Background() + + tests := []struct { + name string + args string + errContains string + wantErr bool + }{ + { + name: "missing ids", + args: `{}`, + wantErr: true, + errContains: "ids is required", + }, + { + name: "empty ids array", + args: `{"ids": []}`, + wantErr: true, + errContains: "ids is required", + }, + { + name: "too many ids", + args: `{"ids": [` + strings.Repeat("1,", 1001) + `1]}`, + wantErr: true, + errContains: "maximum 1000 IDs", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + t.Parallel() + _, err := server.handleBulkMarkSuperseded(ctx, json.RawMessage(tt.args)) + if tt.wantErr { + require.Error(t, err) + if tt.errContains != "" { + assert.Contains(t, err.Error(), tt.errContains) + } + } else { + require.NoError(t, err) + } + }) + } +} + +// TestHandleBulkBoostObservations_Validation tests parameter validation. +func TestHandleBulkBoostObservations_Validation(t *testing.T) { + t.Parallel() + + server := NewServer(nil, "1.0.0", nil, nil, nil, nil, nil, nil, nil, nil) + ctx := context.Background() + + tests := []struct { + name string + args string + errContains string + wantErr bool + }{ + { + name: "missing ids", + args: `{"boost": 0.1}`, + wantErr: true, + errContains: "ids is required", + }, + { + name: "boost out of range low", + args: `{"ids": [1], "boost": -1.5}`, + wantErr: true, + errContains: "boost must be between", + }, + { + name: "boost out of range high", + args: `{"ids": [1], "boost": 1.5}`, + wantErr: true, + errContains: "boost must be between", + }, + { + name: "too many ids", + args: `{"ids": [` + strings.Repeat("1,", 1001) + `1], "boost": 0.1}`, + wantErr: true, + errContains: "maximum 1000 IDs", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + t.Parallel() + _, err := server.handleBulkBoostObservations(ctx, json.RawMessage(tt.args)) + if tt.wantErr { + require.Error(t, err) + if tt.errContains != "" { + assert.Contains(t, err.Error(), tt.errContains) + } + } else { + require.NoError(t, err) + } + }) + } +} + +// TestHandleTriggerMaintenance_Validation tests that nil service returns error. +func TestHandleTriggerMaintenance_Validation(t *testing.T) { + t.Parallel() + + server := NewServer(nil, "1.0.0", nil, nil, nil, nil, nil, nil, nil, nil) + ctx := context.Background() + + _, err := server.handleTriggerMaintenance(ctx) + require.Error(t, err) + assert.Contains(t, err.Error(), "maintenance service not available") +} + +// TestHandleGetMaintenanceStats_Validation tests that nil service returns error. +func TestHandleGetMaintenanceStats_Validation(t *testing.T) { + t.Parallel() + + server := NewServer(nil, "1.0.0", nil, nil, nil, nil, nil, nil, nil, nil) + ctx := context.Background() + + _, err := server.handleGetMaintenanceStats(ctx) + require.Error(t, err) + assert.Contains(t, err.Error(), "maintenance service not available") +} + +// TestHandleMergeObservations_Validation tests parameter validation. +func TestHandleMergeObservations_Validation(t *testing.T) { + t.Parallel() + + server := NewServer(nil, "1.0.0", nil, nil, nil, nil, nil, nil, nil, nil) + ctx := context.Background() + + tests := []struct { + name string + args string + errContains string + wantErr bool + }{ + { + name: "missing source_id", + args: `{"target_id": 2}`, + wantErr: true, + errContains: "source_id and target_id are required", + }, + { + name: "missing target_id", + args: `{"source_id": 1}`, + wantErr: true, + errContains: "source_id and target_id are required", + }, + { + name: "same source and target", + args: `{"source_id": 1, "target_id": 1}`, + wantErr: true, + errContains: "source_id and target_id cannot be the same", + }, + { + name: "boost out of range", + args: `{"source_id": 1, "target_id": 2, "boost": 0.6}`, + wantErr: true, + errContains: "boost must be between 0 and 0.5", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + t.Parallel() + _, err := server.handleMergeObservations(ctx, json.RawMessage(tt.args)) + if tt.wantErr { + require.Error(t, err) + if tt.errContains != "" { + assert.Contains(t, err.Error(), tt.errContains) + } + } else { + require.NoError(t, err) + } + }) + } +} + +// TestHandleGetObservation_Validation tests parameter validation. +func TestHandleGetObservation_Validation(t *testing.T) { + t.Parallel() + + server := NewServer(nil, "1.0.0", nil, nil, nil, nil, nil, nil, nil, nil) + ctx := context.Background() + + tests := []struct { + name string + args string + errContains string + wantErr bool + }{ + { + name: "missing id", + args: `{}`, + wantErr: true, + errContains: "id is required", + }, + { + name: "invalid json", + args: `{invalid`, + wantErr: true, + errContains: "invalid arguments", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + t.Parallel() + _, err := server.handleGetObservation(ctx, json.RawMessage(tt.args)) + if tt.wantErr { + require.Error(t, err) + if tt.errContains != "" { + assert.Contains(t, err.Error(), tt.errContains) + } + } else { + require.NoError(t, err) + } + }) + } +} + +// TestHandleEditObservation_Validation tests parameter validation. +func TestHandleEditObservation_Validation(t *testing.T) { + t.Parallel() + + server := NewServer(nil, "1.0.0", nil, nil, nil, nil, nil, nil, nil, nil) + ctx := context.Background() + + tests := []struct { + name string + args string + errContains string + wantErr bool + }{ + { + name: "missing id", + args: `{"title": "new title"}`, + wantErr: true, + errContains: "id is required", + }, + { + name: "invalid scope", + args: `{"id": 1, "scope": "invalid"}`, + wantErr: true, + errContains: "scope must be 'project' or 'global'", + }, + { + name: "invalid json", + args: `{invalid`, + wantErr: true, + errContains: "invalid arguments", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + t.Parallel() + _, err := server.handleEditObservation(ctx, json.RawMessage(tt.args)) + if tt.wantErr { + require.Error(t, err) + if tt.errContains != "" { + assert.Contains(t, err.Error(), tt.errContains) + } + } else { + require.NoError(t, err) + } + }) + } +} + +// TestHandleGetObservationQuality_Validation tests parameter validation. +func TestHandleGetObservationQuality_Validation(t *testing.T) { + t.Parallel() + + server := NewServer(nil, "1.0.0", nil, nil, nil, nil, nil, nil, nil, nil) + ctx := context.Background() + + tests := []struct { + name string + args string + errContains string + wantErr bool + }{ + { + name: "missing id", + args: `{}`, + wantErr: true, + errContains: "id is required", + }, + { + name: "invalid json", + args: `{invalid`, + wantErr: true, + errContains: "invalid arguments", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + t.Parallel() + _, err := server.handleGetObservationQuality(ctx, json.RawMessage(tt.args)) + if tt.wantErr { + require.Error(t, err) + if tt.errContains != "" { + assert.Contains(t, err.Error(), tt.errContains) + } + } else { + require.NoError(t, err) + } + }) + } +} + +// TestHandleSuggestConsolidations_Validation tests parameter validation. +func TestHandleSuggestConsolidations_Validation(t *testing.T) { + t.Parallel() + + server := NewServer(nil, "1.0.0", nil, nil, nil, nil, nil, nil, nil, nil) + ctx := context.Background() + + tests := []struct { + name string + args string + errContains string + wantErr bool + }{ + { + name: "min_similarity too low", + args: `{"min_similarity": 0.3}`, + wantErr: true, + errContains: "min_similarity must be between 0.5 and 1.0", + }, + { + name: "min_similarity too high", + args: `{"min_similarity": 1.5}`, + wantErr: true, + errContains: "min_similarity must be between 0.5 and 1.0", + }, + { + name: "invalid json", + args: `{invalid`, + wantErr: true, + errContains: "invalid arguments", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + t.Parallel() + _, err := server.handleSuggestConsolidations(ctx, json.RawMessage(tt.args)) + if tt.wantErr { + require.Error(t, err) + if tt.errContains != "" { + assert.Contains(t, err.Error(), tt.errContains) + } + } else { + require.NoError(t, err) + } + }) + } +} + +// TestHandleTagObservation_Validation tests parameter validation. +func TestHandleTagObservation_Validation(t *testing.T) { + t.Parallel() + + server := NewServer(nil, "1.0.0", nil, nil, nil, nil, nil, nil, nil, nil) + ctx := context.Background() + + tests := []struct { + name string + args string + errContains string + wantErr bool + }{ + { + name: "missing id", + args: `{"tags": ["tag1"]}`, + wantErr: true, + errContains: "id is required", + }, + { + name: "missing tags", + args: `{"id": 1}`, + wantErr: true, + errContains: "tags is required", + }, + { + name: "invalid mode", + args: `{"id": 1, "tags": ["tag1"], "mode": "invalid"}`, + wantErr: true, + errContains: "mode must be 'add', 'remove', or 'set'", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + t.Parallel() + _, err := server.handleTagObservation(ctx, json.RawMessage(tt.args)) + if tt.wantErr { + require.Error(t, err) + if tt.errContains != "" { + assert.Contains(t, err.Error(), tt.errContains) + } + } else { + require.NoError(t, err) + } + }) + } +} + +// TestHandleGetObservationsByTag_Validation tests parameter validation. +func TestHandleGetObservationsByTag_Validation(t *testing.T) { + t.Parallel() + + server := NewServer(nil, "1.0.0", nil, nil, nil, nil, nil, nil, nil, nil) + ctx := context.Background() + + tests := []struct { + name string + args string + errContains string + wantErr bool + }{ + { + name: "missing tag", + args: `{}`, + wantErr: true, + errContains: "tag is required", + }, + { + name: "invalid json", + args: `{invalid`, + wantErr: true, + errContains: "invalid arguments", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + t.Parallel() + _, err := server.handleGetObservationsByTag(ctx, json.RawMessage(tt.args)) + if tt.wantErr { + require.Error(t, err) + if tt.errContains != "" { + assert.Contains(t, err.Error(), tt.errContains) + } + } else { + require.NoError(t, err) + } + }) + } +} + +// TestHandleBatchTagByPattern_Validation tests parameter validation. +func TestHandleBatchTagByPattern_Validation(t *testing.T) { + t.Parallel() + + server := NewServer(nil, "1.0.0", nil, nil, nil, nil, nil, nil, nil, nil) + ctx := context.Background() + + tests := []struct { + name string + args string + errContains string + wantErr bool + }{ + { + name: "missing pattern", + args: `{"tags": ["tag1"]}`, + wantErr: true, + errContains: "pattern is required", + }, + { + name: "missing tags", + args: `{"pattern": "test"}`, + wantErr: true, + errContains: "tags is required", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + t.Parallel() + _, err := server.handleBatchTagByPattern(ctx, json.RawMessage(tt.args)) + if tt.wantErr { + require.Error(t, err) + if tt.errContains != "" { + assert.Contains(t, err.Error(), tt.errContains) + } + } else { + require.NoError(t, err) + } + }) + } +} + +// TestHandleExplainSearchRanking_Validation tests parameter validation. +func TestHandleExplainSearchRanking_Validation(t *testing.T) { + t.Parallel() + + server := NewServer(nil, "1.0.0", nil, nil, nil, nil, nil, nil, nil, nil) + ctx := context.Background() + + tests := []struct { + name string + args string + errContains string + wantErr bool + }{ + { + name: "missing query", + args: `{"top_n": 5}`, + wantErr: true, + errContains: "query is required", + }, + { + name: "invalid json", + args: `{invalid`, + wantErr: true, + errContains: "invalid arguments", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + t.Parallel() + _, err := server.handleExplainSearchRanking(ctx, json.RawMessage(tt.args)) + if tt.wantErr { + require.Error(t, err) + if tt.errContains != "" { + assert.Contains(t, err.Error(), tt.errContains) + } + } else { + require.NoError(t, err) + } + }) + } +} + +// TestHandleGetObservationRelationships_Validation tests parameter validation. +func TestHandleGetObservationRelationships_Validation(t *testing.T) { + t.Parallel() + + server := NewServer(nil, "1.0.0", nil, nil, nil, nil, nil, nil, nil, nil) + ctx := context.Background() + + tests := []struct { + name string + args string + errContains string + wantErr bool + }{ + { + name: "missing id", + args: `{}`, + wantErr: true, + errContains: "id is required", + }, + { + name: "negative id", + args: `{"id": -1}`, + wantErr: true, + errContains: "id is required and must be positive", + }, + { + name: "nil relation store", + args: `{"id": 1}`, + wantErr: true, + errContains: "relation store not available", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + t.Parallel() + _, err := server.handleGetObservationRelationships(ctx, json.RawMessage(tt.args)) + if tt.wantErr { + require.Error(t, err) + if tt.errContains != "" { + assert.Contains(t, err.Error(), tt.errContains) + } + } else { + require.NoError(t, err) + } + }) + } +} + +// TestHandleGetObservationScoringBreakdown_Validation tests parameter validation. +func TestHandleGetObservationScoringBreakdown_Validation(t *testing.T) { + t.Parallel() + + server := NewServer(nil, "1.0.0", nil, nil, nil, nil, nil, nil, nil, nil) + ctx := context.Background() + + tests := []struct { + name string + args string + errContains string + wantErr bool + }{ + { + name: "missing id", + args: `{}`, + wantErr: true, + errContains: "id is required", + }, + { + name: "negative id", + args: `{"id": -1}`, + wantErr: true, + errContains: "id is required and must be positive", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + t.Parallel() + _, err := server.handleGetObservationScoringBreakdown(ctx, json.RawMessage(tt.args)) + if tt.wantErr { + require.Error(t, err) + if tt.errContains != "" { + assert.Contains(t, err.Error(), tt.errContains) + } + } else { + require.NoError(t, err) + } + }) + } +} + +// TestHandleTimeline_InvalidJSON tests timeline with invalid JSON. +func TestHandleTimeline_InvalidJSON(t *testing.T) { + t.Parallel() + + server := NewServer(nil, "1.0.0", nil, nil, nil, nil, nil, nil, nil, nil) + ctx := context.Background() + + _, err := server.handleTimeline(ctx, json.RawMessage(`{invalid`)) + require.Error(t, err) + assert.Contains(t, err.Error(), "invalid timeline params") +} + +// TestHandleTimelineByQuery_EmptyQuery tests timeline by query with empty query. +func TestHandleTimelineByQuery_EmptyQuery(t *testing.T) { + t.Parallel() + + server := NewServer(nil, "1.0.0", nil, nil, nil, nil, nil, nil, nil, nil) + ctx := context.Background() + + // Empty query should error + _, err := server.handleTimelineByQuery(ctx, json.RawMessage(`{}`)) + require.Error(t, err) + assert.Contains(t, err.Error(), "query is required") +} + +// TestHandleTimelineByQuery_InvalidJSON tests timeline by query with invalid JSON. +func TestHandleTimelineByQuery_InvalidJSON(t *testing.T) { + t.Parallel() + + server := NewServer(nil, "1.0.0", nil, nil, nil, nil, nil, nil, nil, nil) + ctx := context.Background() + + _, err := server.handleTimelineByQuery(ctx, json.RawMessage(`{invalid`)) + require.Error(t, err) + assert.Contains(t, err.Error(), "invalid timeline params") +} + +// TestHandleTimeline_NoAnchorNoQuery tests timeline with no anchor and no query. +func TestHandleTimeline_NoAnchorNoQuery(t *testing.T) { + t.Parallel() + + server := NewServer(nil, "1.0.0", nil, nil, nil, nil, nil, nil, nil, nil) + ctx := context.Background() + + // No anchor_id and no query should return empty result + result, err := server.handleTimeline(ctx, json.RawMessage(`{}`)) + require.NoError(t, err) + assert.NotNil(t, result) + assert.Empty(t, result.Results) +} + +// TestHandleTimeline_WithDefaults tests timeline default values are applied. +func TestHandleTimeline_WithDefaults(t *testing.T) { + t.Parallel() + + server := NewServer(nil, "1.0.0", nil, nil, nil, nil, nil, nil, nil, nil) + ctx := context.Background() + + // With anchor_id = 0, should return empty result + result, err := server.handleTimeline(ctx, json.RawMessage(`{"anchor_id": 0}`)) + require.NoError(t, err) + assert.NotNil(t, result) + assert.Empty(t, result.Results) +} + +// ============================================================================= +// TESTS FOR Additional Tool Operations +// ============================================================================= + // TestJSONRPCErrorCodes tests standard JSON-RPC error codes. func TestJSONRPCErrorCodes(t *testing.T) { + t.Parallel() + errorCodes := map[string]int{ "Parse error": -32700, "Invalid Request": -32600, @@ -565,6 +1598,7 @@ func TestJSONRPCErrorCodes(t *testing.T) { for msg, code := range errorCodes { t.Run(msg, func(t *testing.T) { + t.Parallel() err := Error{Code: code, Message: msg} assert.Equal(t, code, err.Code) assert.Equal(t, msg, err.Message) @@ -574,6 +1608,8 @@ func TestJSONRPCErrorCodes(t *testing.T) { // TestToolListContainsExpectedSchemas tests that tool schemas are valid. func TestToolListContainsExpectedSchemas(t *testing.T) { + t.Parallel() + server := NewServer(nil, "1.0.0", nil, nil, nil, nil, nil, nil, nil, nil) req := &Request{ @@ -600,6 +1636,8 @@ func TestToolListContainsExpectedSchemas(t *testing.T) { // TestHandleToolsCall_UnknownTool tests tools/call with unknown tool name. func TestHandleToolsCall_UnknownTool(t *testing.T) { + t.Parallel() + server := NewServer(nil, "1.0.0", nil, nil, nil, nil, nil, nil, nil, nil) ctx := context.Background() @@ -616,10 +1654,10 @@ func TestHandleToolsCall_UnknownTool(t *testing.T) { assert.Contains(t, resp.Error.Data, "unknown tool") } -// TestCallTool_ToolNameRecognition tests that valid tool names are recognized (not "unknown tool"). +// TestCallTool_ToolNameRecognition tests that valid tool names are recognized. func TestCallTool_ToolNameRecognition(t *testing.T) { - // Note: This test verifies tool routing logic, not execution (which requires searchMgr) - // All valid tool names should be in the handleToolsList response + t.Parallel() + server := NewServer(nil, "1.0.0", nil, nil, nil, nil, nil, nil, nil, nil) req := &Request{ @@ -634,17 +1672,43 @@ func TestCallTool_ToolNameRecognition(t *testing.T) { // Verify all expected tools are registered expectedTools := map[string]bool{ - "search": true, - "timeline": true, - "decisions": true, - "changes": true, - "how_it_works": true, - "find_by_concept": true, - "find_by_file": true, - "find_by_type": true, - "get_recent_context": true, - "get_context_timeline": true, - "get_timeline_by_query": true, + "search": true, + "timeline": true, + "decisions": true, + "changes": true, + "how_it_works": true, + "find_by_concept": true, + "find_by_file": true, + "find_by_type": true, + "get_recent_context": true, + "get_context_timeline": true, + "get_timeline_by_query": true, + "find_related_observations": true, + "find_similar_observations": true, + "get_patterns": true, + "get_memory_stats": true, + "bulk_delete_observations": true, + "bulk_mark_superseded": true, + "bulk_boost_observations": true, + "trigger_maintenance": true, + "get_maintenance_stats": true, + "merge_observations": true, + "get_observation": true, + "edit_observation": true, + "get_observation_quality": true, + "suggest_consolidations": true, + "tag_observation": true, + "get_observations_by_tag": true, + "get_temporal_trends": true, + "get_data_quality_report": true, + "batch_tag_by_pattern": true, + "explain_search_ranking": true, + "export_observations": true, + "check_system_health": true, + "analyze_search_patterns": true, + "get_observation_relationships": true, + "get_observation_scoring_breakdown": true, + "analyze_observation_importance": true, } foundTools := make(map[string]bool) @@ -657,52 +1721,10 @@ func TestCallTool_ToolNameRecognition(t *testing.T) { } } -// TestRun_MultipleRequests tests Run with multiple sequential requests. -func TestRun_MultipleRequests(t *testing.T) { - var stdout bytes.Buffer - req1 := `{"jsonrpc":"2.0","id":1,"method":"initialize"}` - req2 := `{"jsonrpc":"2.0","id":2,"method":"tools/list"}` - stdin := strings.NewReader(req1 + "\n" + req2 + "\n") - - server := &Server{ - stdin: stdin, - stdout: &stdout, - version: "1.0.0", - } - - err := server.Run(context.Background()) - require.NoError(t, err) - - output := stdout.String() - // Should contain responses for both requests - assert.Contains(t, output, `"id":1`) - assert.Contains(t, output, `"id":2`) -} - -// TestHandleTimeline_Defaults tests timeline default values. -func TestHandleTimeline_Defaults(t *testing.T) { - // Test that handleTimeline sets default before/after values - params := TimelineParams{ - AnchorID: 0, - Query: "", - Before: 0, - After: 0, - } - - // Simulate the default value assignment from handleTimeline - if params.Before <= 0 { - params.Before = 10 - } - if params.After <= 0 { - params.After = 10 - } - - assert.Equal(t, 10, params.Before) - assert.Equal(t, 10, params.After) -} - // TestTimelineParams_Complete tests complete TimelineParams parsing. func TestTimelineParams_Complete(t *testing.T) { + t.Parallel() + input := `{ "anchor_id": 100, "query": "test query", @@ -736,6 +1758,8 @@ func TestTimelineParams_Complete(t *testing.T) { // TestServerStdinStdoutConfig tests that server stdin/stdout can be configured. func TestServerStdinStdoutConfig(t *testing.T) { + t.Parallel() + var stdout bytes.Buffer var stdin bytes.Buffer @@ -752,6 +1776,8 @@ func TestServerStdinStdoutConfig(t *testing.T) { // TestResponseIDTypes tests that response IDs can be various types. func TestResponseIDTypes(t *testing.T) { + t.Parallel() + tests := []struct { id any name string @@ -764,6 +1790,7 @@ func TestResponseIDTypes(t *testing.T) { for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { + t.Parallel() var buf bytes.Buffer server := &Server{stdout: &buf} @@ -780,65 +1807,10 @@ func TestResponseIDTypes(t *testing.T) { } } -// TestHandleTimelineByQuery_EmptyQuery tests timeline by query with empty query. -func TestHandleTimelineByQuery_EmptyQuery(t *testing.T) { - server := NewServer(nil, "1.0.0", nil, nil, nil, nil, nil, nil, nil, nil) - ctx := context.Background() - - // Empty query should error - _, err := server.handleTimelineByQuery(ctx, json.RawMessage(`{}`)) - require.Error(t, err) - assert.Contains(t, err.Error(), "query is required") -} - -// TestHandleTimeline_InvalidJSON tests timeline with invalid JSON. -func TestHandleTimeline_InvalidJSON(t *testing.T) { - server := NewServer(nil, "1.0.0", nil, nil, nil, nil, nil, nil, nil, nil) - ctx := context.Background() - - _, err := server.handleTimeline(ctx, json.RawMessage(`{invalid`)) - require.Error(t, err) - assert.Contains(t, err.Error(), "invalid timeline params") -} - -// TestHandleTimelineByQuery_InvalidJSON tests timeline by query with invalid JSON. -func TestHandleTimelineByQuery_InvalidJSON(t *testing.T) { - server := NewServer(nil, "1.0.0", nil, nil, nil, nil, nil, nil, nil, nil) - ctx := context.Background() - - _, err := server.handleTimelineByQuery(ctx, json.RawMessage(`{invalid`)) - require.Error(t, err) - assert.Contains(t, err.Error(), "invalid timeline params") -} - -// TestHandleTimeline_NoAnchorNoQuery tests timeline with no anchor and no query. -func TestHandleTimeline_NoAnchorNoQuery(t *testing.T) { - server := NewServer(nil, "1.0.0", nil, nil, nil, nil, nil, nil, nil, nil) - ctx := context.Background() - - // No anchor_id and no query should return empty result - result, err := server.handleTimeline(ctx, json.RawMessage(`{}`)) - require.NoError(t, err) - assert.NotNil(t, result) - assert.Empty(t, result.Results) -} - -// TestHandleTimeline_WithDefaults tests timeline default values are applied. -func TestHandleTimeline_WithDefaults(t *testing.T) { - server := NewServer(nil, "1.0.0", nil, nil, nil, nil, nil, nil, nil, nil) - ctx := context.Background() - - // With anchor_id but no before/after, defaults should be applied - // However, since searchMgr is nil, this will fail after defaults are applied - result, err := server.handleTimeline(ctx, json.RawMessage(`{"anchor_id": 0}`)) - // Should return empty result since anchor_id is 0 - require.NoError(t, err) - assert.NotNil(t, result) - assert.Empty(t, result.Results) -} - // TestServerFields tests Server struct fields. func TestServerFields(t *testing.T) { + t.Parallel() + server := NewServer(nil, "2.0.0", nil, nil, nil, nil, nil, nil, nil, nil) assert.Equal(t, "2.0.0", server.version) @@ -849,6 +1821,8 @@ func TestServerFields(t *testing.T) { // TestRequestUnmarshalWithNullID tests Request unmarshaling with null ID. func TestRequestUnmarshalWithNullID(t *testing.T) { + t.Parallel() + input := `{"jsonrpc":"2.0","id":null,"method":"initialize"}` var req Request @@ -861,6 +1835,8 @@ func TestRequestUnmarshalWithNullID(t *testing.T) { // TestResponseWithNullError tests Response without error. func TestResponseWithNullError(t *testing.T) { + t.Parallel() + resp := Response{ JSONRPC: "2.0", ID: 1, @@ -876,6 +1852,8 @@ func TestResponseWithNullError(t *testing.T) { // TestErrorWithNilData tests Error without data. func TestErrorWithNilData(t *testing.T) { + t.Parallel() + err := Error{ Code: -32600, Message: "Invalid Request", @@ -891,6 +1869,8 @@ func TestErrorWithNilData(t *testing.T) { // TestToolInputSchema tests that tool input schemas have required fields. func TestToolInputSchema(t *testing.T) { + t.Parallel() + server := NewServer(nil, "1.0.0", nil, nil, nil, nil, nil, nil, nil, nil) req := &Request{ @@ -915,32 +1895,124 @@ func TestToolInputSchema(t *testing.T) { } } -// TestRunMixedRequests tests Run with mixed valid and invalid requests. -func TestRunMixedRequests(t *testing.T) { - var stdout bytes.Buffer - req1 := `{"jsonrpc":"2.0","id":1,"method":"initialize"}` - req2 := `invalid json` - req3 := `{"jsonrpc":"2.0","id":3,"method":"tools/list"}` - stdin := strings.NewReader(req1 + "\n" + req2 + "\n" + req3 + "\n") +// TestCallTool_UnknownToolName tests callTool with various unknown tool names. +func TestCallTool_UnknownToolName(t *testing.T) { + t.Parallel() - server := &Server{ - stdin: stdin, - stdout: &stdout, - version: "1.0.0", + server := NewServer(nil, "1.0.0", nil, nil, nil, nil, nil, nil, nil, nil) + ctx := context.Background() + + unknownTools := []string{ + "invalid_tool", + "nonexistent", + "search_v2", + "timeline_special", } - err := server.Run(context.Background()) - require.NoError(t, err) + for _, name := range unknownTools { + t.Run(name, func(t *testing.T) { + t.Parallel() + result, err := server.callTool(ctx, name, json.RawMessage(`{}`)) + assert.Error(t, err) + assert.Empty(t, result) + assert.Contains(t, err.Error(), "unknown tool") + }) + } +} - output := stdout.String() - // Should have responses for all three requests - assert.Contains(t, output, `"id":1`) - assert.Contains(t, output, `"error"`) // Parse error for invalid json - assert.Contains(t, output, `"id":3`) +// TestTimelineParams_Validation tests TimelineParams struct field validation. +func TestTimelineParams_Validation(t *testing.T) { + t.Parallel() + + tests := []struct { + name string + json string + wantOK bool + }{ + {"valid with anchor_id", `{"anchor_id":123,"before":5,"after":5}`, true}, + {"valid with query only", `{"query":"test query"}`, true}, + {"empty params", `{}`, true}, + {"with all fields", `{"anchor_id":1,"query":"test","before":10,"after":10,"project":"proj","obs_type":"bugfix","format":"full"}`, true}, + {"invalid json", `{invalid`, false}, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + t.Parallel() + var params TimelineParams + err := json.Unmarshal([]byte(tt.json), ¶ms) + if tt.wantOK { + assert.NoError(t, err) + } else { + assert.Error(t, err) + } + }) + } +} + +// TestHandleToolsCall_EmptyParams tests tools/call with empty params. +func TestHandleToolsCall_EmptyParams(t *testing.T) { + t.Parallel() + + server := NewServer(nil, "1.0.0", nil, nil, nil, nil, nil, nil, nil, nil) + ctx := context.Background() + + req := &Request{ + JSONRPC: "2.0", + ID: 1, + Method: "tools/call", + Params: json.RawMessage(`{}`), + } + + resp := server.handleToolsCall(ctx, req) + + // Should error due to missing name + require.NotNil(t, resp.Error) +} + +// TestSendResponse_WithError tests sendResponse with an error response. +func TestSendResponse_WithError(t *testing.T) { + t.Parallel() + + var buf bytes.Buffer + server := &Server{stdout: &buf} + + resp := &Response{ + JSONRPC: "2.0", + ID: 1, + Error: &Error{Code: -32600, Message: "Invalid Request"}, + } + + server.sendResponse(resp) + + output := buf.String() + assert.Contains(t, output, `"error"`) + assert.Contains(t, output, `-32600`) +} + +// TestSendResponse_NilID tests sendResponse with nil ID. +func TestSendResponse_NilID(t *testing.T) { + t.Parallel() + + var buf bytes.Buffer + server := &Server{stdout: &buf} + + resp := &Response{ + JSONRPC: "2.0", + ID: nil, + Result: "notification response", + } + + server.sendResponse(resp) + + output := buf.String() + assert.Contains(t, output, `"id":null`) } // TestToolCallParamsWithComplexArgs tests ToolCallParams with complex arguments. func TestToolCallParamsWithComplexArgs(t *testing.T) { + t.Parallel() + input := `{ "name": "search", "arguments": { @@ -958,57 +2030,10 @@ func TestToolCallParamsWithComplexArgs(t *testing.T) { assert.NotEmpty(t, params.Arguments) } -// TestCallTool_UnknownToolName tests callTool with various unknown tool names. -func TestCallTool_UnknownToolName(t *testing.T) { - server := NewServer(nil, "1.0.0", nil, nil, nil, nil, nil, nil, nil, nil) - ctx := context.Background() - - unknownTools := []string{ - "invalid_tool", - "nonexistent", - "search_v2", - "timeline_special", - } - - for _, name := range unknownTools { - t.Run(name, func(t *testing.T) { - result, err := server.callTool(ctx, name, json.RawMessage(`{}`)) - assert.Error(t, err) - assert.Empty(t, result) - assert.Contains(t, err.Error(), "unknown tool") - }) - } -} - -// TestTimelineParams_Validation tests TimelineParams struct field validation. -func TestTimelineParams_Validation(t *testing.T) { - tests := []struct { - name string - json string - wantOK bool - }{ - {"valid with anchor_id", `{"anchor_id":123,"before":5,"after":5}`, true}, - {"valid with query only", `{"query":"test query"}`, true}, - {"empty params", `{}`, true}, - {"with all fields", `{"anchor_id":1,"query":"test","before":10,"after":10,"project":"proj","obs_type":"bugfix","format":"full"}`, true}, - {"invalid json", `{invalid`, false}, - } - - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - var params TimelineParams - err := json.Unmarshal([]byte(tt.json), ¶ms) - if tt.wantOK { - assert.NoError(t, err) - } else { - assert.Error(t, err) - } - }) - } -} - // TestHandleToolsCall_UnknownToolNameError tests tools/call with unknown tool returns error. func TestHandleToolsCall_UnknownToolNameError(t *testing.T) { + t.Parallel() + server := NewServer(nil, "1.0.0", nil, nil, nil, nil, nil, nil, nil, nil) ctx := context.Background() @@ -1029,55 +2054,735 @@ func TestHandleToolsCall_UnknownToolNameError(t *testing.T) { assert.True(t, resp.Error.Code != 0) } -// TestHandleToolsCall_EmptyParams tests tools/call with empty params. -func TestHandleToolsCall_EmptyParams(t *testing.T) { +// ============================================================================= +// TESTS FOR Handler Defaults +// ============================================================================= + +// TestHandleTimeline_Defaults tests timeline default values. +func TestHandleTimeline_Defaults(t *testing.T) { + t.Parallel() + + // Test that handleTimeline sets default before/after values + params := TimelineParams{ + AnchorID: 0, + Query: "", + Before: 0, + After: 0, + } + + // Simulate the default value assignment from handleTimeline + if params.Before <= 0 { + params.Before = 10 + } + if params.After <= 0 { + params.After = 10 + } + + assert.Equal(t, 10, params.Before) + assert.Equal(t, 10, params.After) +} + +// TestHandleGetTemporalTrends_Validation tests parameter validation. +func TestHandleGetTemporalTrends_Validation(t *testing.T) { + t.Parallel() + server := NewServer(nil, "1.0.0", nil, nil, nil, nil, nil, nil, nil, nil) ctx := context.Background() - req := &Request{ - JSONRPC: "2.0", - ID: 1, - Method: "tools/call", - Params: json.RawMessage(`{}`), + tests := []struct { + name string + args string + errContains string + wantErr bool + }{ + { + name: "invalid json", + args: `{invalid`, + wantErr: true, + errContains: "invalid arguments", + }, } - resp := server.handleToolsCall(ctx, req) - - // Should error due to missing name - require.NotNil(t, resp.Error) + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + t.Parallel() + _, err := server.handleGetTemporalTrends(ctx, json.RawMessage(tt.args)) + if tt.wantErr { + require.Error(t, err) + if tt.errContains != "" { + assert.Contains(t, err.Error(), tt.errContains) + } + } + }) + } } -// TestSendResponse_WithError tests sendResponse with an error response. -func TestSendResponse_WithError(t *testing.T) { - var buf bytes.Buffer - server := &Server{stdout: &buf} +// TestHandleGetDataQualityReport_Validation tests parameter validation. +func TestHandleGetDataQualityReport_Validation(t *testing.T) { + t.Parallel() - resp := &Response{ - JSONRPC: "2.0", - ID: 1, - Error: &Error{Code: -32600, Message: "Invalid Request"}, + server := NewServer(nil, "1.0.0", nil, nil, nil, nil, nil, nil, nil, nil) + ctx := context.Background() + + tests := []struct { + name string + args string + errContains string + wantErr bool + }{ + { + name: "invalid json", + args: `{invalid`, + wantErr: true, + errContains: "invalid arguments", + }, } - server.sendResponse(resp) - - output := buf.String() - assert.Contains(t, output, `"error"`) - assert.Contains(t, output, `-32600`) + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + t.Parallel() + _, err := server.handleGetDataQualityReport(ctx, json.RawMessage(tt.args)) + if tt.wantErr { + require.Error(t, err) + if tt.errContains != "" { + assert.Contains(t, err.Error(), tt.errContains) + } + } + }) + } } -// TestSendResponse_NilID tests sendResponse with nil ID. -func TestSendResponse_NilID(t *testing.T) { - var buf bytes.Buffer - server := &Server{stdout: &buf} +// TestHandleExportObservations_Validation tests parameter validation. +func TestHandleExportObservations_Validation(t *testing.T) { + t.Parallel() - resp := &Response{ - JSONRPC: "2.0", - ID: nil, - Result: "notification response", + server := NewServer(nil, "1.0.0", nil, nil, nil, nil, nil, nil, nil, nil) + ctx := context.Background() + + tests := []struct { + name string + args string + errContains string + wantErr bool + }{ + { + name: "invalid json", + args: `{invalid`, + wantErr: true, + errContains: "invalid arguments", + }, } - server.sendResponse(resp) - - output := buf.String() - assert.Contains(t, output, `"id":null`) + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + t.Parallel() + _, err := server.handleExportObservations(ctx, json.RawMessage(tt.args)) + if tt.wantErr { + require.Error(t, err) + if tt.errContains != "" { + assert.Contains(t, err.Error(), tt.errContains) + } + } + }) + } +} + +// TestHandleAnalyzeSearchPatterns_Validation tests parameter validation. +func TestHandleAnalyzeSearchPatterns_Validation(t *testing.T) { + t.Parallel() + + server := NewServer(nil, "1.0.0", nil, nil, nil, nil, nil, nil, nil, nil) + ctx := context.Background() + + tests := []struct { + name string + args string + errContains string + wantErr bool + }{ + { + name: "invalid json", + args: `{invalid`, + wantErr: true, + errContains: "invalid params", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + t.Parallel() + _, err := server.handleAnalyzeSearchPatterns(ctx, json.RawMessage(tt.args)) + if tt.wantErr { + require.Error(t, err) + if tt.errContains != "" { + assert.Contains(t, err.Error(), tt.errContains) + } + } + }) + } +} + +// TestHandleAnalyzeObservationImportance_Validation tests parameter validation. +func TestHandleAnalyzeObservationImportance_Validation(t *testing.T) { + t.Parallel() + + server := NewServer(nil, "1.0.0", nil, nil, nil, nil, nil, nil, nil, nil) + ctx := context.Background() + + tests := []struct { + name string + args string + errContains string + wantErr bool + }{ + { + name: "invalid json", + args: `{invalid`, + wantErr: true, + errContains: "invalid arguments", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + t.Parallel() + _, err := server.handleAnalyzeObservationImportance(ctx, json.RawMessage(tt.args)) + if tt.wantErr { + require.Error(t, err) + if tt.errContains != "" { + assert.Contains(t, err.Error(), tt.errContains) + } + } + }) + } +} + +// TestHandleGetMemoryStats_NilStores tests GetMemoryStats with nil stores. +func TestHandleGetMemoryStats_NilStores(t *testing.T) { + t.Parallel() + + server := NewServer(nil, "1.0.0", nil, nil, nil, nil, nil, nil, nil, nil) + ctx := context.Background() + + // Should not panic with nil stores + result, err := server.handleGetMemoryStats(ctx) + require.NoError(t, err) + assert.NotEmpty(t, result) + + // Should be valid JSON + var stats map[string]any + err = json.Unmarshal([]byte(result), &stats) + require.NoError(t, err) +} + +// TestHandleCheckSystemHealth_NilStores tests CheckSystemHealth with nil stores. +func TestHandleCheckSystemHealth_NilStores(t *testing.T) { + t.Parallel() + + server := NewServer(nil, "1.0.0", nil, nil, nil, nil, nil, nil, nil, nil) + ctx := context.Background() + + // Should not panic with nil stores + result, err := server.handleCheckSystemHealth(ctx) + require.NoError(t, err) + assert.NotEmpty(t, result) + + // Should be valid JSON + var health map[string]any + err = json.Unmarshal([]byte(result), &health) + require.NoError(t, err) + + // Should have subsystems and overall status + assert.Contains(t, health, "overall_status") + assert.Contains(t, health, "subsystems") +} + +// ============================================================================= +// COMPREHENSIVE callTool TESTS +// ============================================================================= + +// TestCallTool_AllSpecialTools tests all special tool cases in callTool switch. +func TestCallTool_AllSpecialTools(t *testing.T) { + t.Parallel() + + server := NewServer(nil, "1.0.0", nil, nil, nil, nil, nil, nil, nil, nil) + ctx := context.Background() + + // Tests for tools that can work without stores or have nil guards + tests := []struct { + name string + toolName string + args string + wantErr bool + checkPanic bool + }{ + // Tools that work with nil stores + { + name: "get_memory_stats", + toolName: "get_memory_stats", + args: `{}`, + wantErr: false, + }, + { + name: "check_system_health", + toolName: "check_system_health", + args: `{}`, + wantErr: false, + }, + // Tools that need stores but have parameter validation first + { + name: "find_related_observations - invalid json", + toolName: "find_related_observations", + args: `{invalid`, + wantErr: true, + }, + { + name: "find_related_observations - missing id", + toolName: "find_related_observations", + args: `{}`, + wantErr: true, + }, + { + name: "find_similar_observations - invalid json", + toolName: "find_similar_observations", + args: `{invalid`, + wantErr: true, + }, + { + name: "find_similar_observations - missing query", + toolName: "find_similar_observations", + args: `{}`, + wantErr: true, + }, + { + name: "get_patterns - invalid json", + toolName: "get_patterns", + args: `{invalid`, + wantErr: true, + }, + { + name: "bulk_delete_observations - invalid json", + toolName: "bulk_delete_observations", + args: `{invalid`, + wantErr: true, + }, + { + name: "bulk_delete_observations - missing ids", + toolName: "bulk_delete_observations", + args: `{}`, + wantErr: true, + }, + { + name: "bulk_mark_superseded - invalid json", + toolName: "bulk_mark_superseded", + args: `{invalid`, + wantErr: true, + }, + { + name: "bulk_mark_superseded - missing ids", + toolName: "bulk_mark_superseded", + args: `{}`, + wantErr: true, + }, + { + name: "bulk_boost_observations - invalid json", + toolName: "bulk_boost_observations", + args: `{invalid`, + wantErr: true, + }, + { + name: "bulk_boost_observations - missing ids", + toolName: "bulk_boost_observations", + args: `{}`, + wantErr: true, + }, + { + name: "merge_observations - invalid json", + toolName: "merge_observations", + args: `{invalid`, + wantErr: true, + }, + { + name: "merge_observations - missing source_ids", + toolName: "merge_observations", + args: `{}`, + wantErr: true, + }, + { + name: "get_observation - invalid json", + toolName: "get_observation", + args: `{invalid`, + wantErr: true, + }, + { + name: "get_observation - missing id", + toolName: "get_observation", + args: `{}`, + wantErr: true, + }, + { + name: "edit_observation - invalid json", + toolName: "edit_observation", + args: `{invalid`, + wantErr: true, + }, + { + name: "edit_observation - missing id", + toolName: "edit_observation", + args: `{}`, + wantErr: true, + }, + { + name: "get_observation_quality - invalid json", + toolName: "get_observation_quality", + args: `{invalid`, + wantErr: true, + }, + { + name: "get_observation_quality - missing id", + toolName: "get_observation_quality", + args: `{}`, + wantErr: true, + }, + { + name: "suggest_consolidations - invalid json", + toolName: "suggest_consolidations", + args: `{invalid`, + wantErr: true, + }, + { + name: "tag_observation - invalid json", + toolName: "tag_observation", + args: `{invalid`, + wantErr: true, + }, + { + name: "tag_observation - missing id", + toolName: "tag_observation", + args: `{}`, + wantErr: true, + }, + { + name: "get_observations_by_tag - invalid json", + toolName: "get_observations_by_tag", + args: `{invalid`, + wantErr: true, + }, + { + name: "get_observations_by_tag - missing tag", + toolName: "get_observations_by_tag", + args: `{}`, + wantErr: true, + }, + { + name: "get_temporal_trends - invalid json", + toolName: "get_temporal_trends", + args: `{invalid`, + wantErr: true, + }, + { + name: "get_data_quality_report - invalid json", + toolName: "get_data_quality_report", + args: `{invalid`, + wantErr: true, + }, + { + name: "batch_tag_by_pattern - invalid json", + toolName: "batch_tag_by_pattern", + args: `{invalid`, + wantErr: true, + }, + { + name: "batch_tag_by_pattern - missing pattern", + toolName: "batch_tag_by_pattern", + args: `{}`, + wantErr: true, + }, + { + name: "explain_search_ranking - invalid json", + toolName: "explain_search_ranking", + args: `{invalid`, + wantErr: true, + }, + { + name: "explain_search_ranking - missing query", + toolName: "explain_search_ranking", + args: `{}`, + wantErr: true, + }, + { + name: "export_observations - invalid json", + toolName: "export_observations", + args: `{invalid`, + wantErr: true, + }, + { + name: "analyze_search_patterns - invalid json", + toolName: "analyze_search_patterns", + args: `{invalid`, + wantErr: true, + }, + { + name: "get_observation_relationships - invalid json", + toolName: "get_observation_relationships", + args: `{invalid`, + wantErr: true, + }, + { + name: "get_observation_relationships - missing id", + toolName: "get_observation_relationships", + args: `{}`, + wantErr: true, + }, + { + name: "get_observation_scoring_breakdown - invalid json", + toolName: "get_observation_scoring_breakdown", + args: `{invalid`, + wantErr: true, + }, + { + name: "get_observation_scoring_breakdown - missing id", + toolName: "get_observation_scoring_breakdown", + args: `{}`, + wantErr: true, + }, + { + name: "analyze_observation_importance - invalid json", + toolName: "analyze_observation_importance", + args: `{invalid`, + wantErr: true, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + t.Parallel() + result, err := server.callTool(ctx, tt.toolName, json.RawMessage(tt.args)) + if tt.wantErr { + require.Error(t, err) + } else { + require.NoError(t, err) + assert.NotEmpty(t, result) + } + }) + } +} + +// TestCallTool_SearchTools tests search-based tools in callTool. +func TestCallTool_SearchTools(t *testing.T) { + t.Parallel() + + server := NewServer(nil, "1.0.0", nil, nil, nil, nil, nil, nil, nil, nil) + ctx := context.Background() + + // All search tools should fail with invalid JSON or when searchMgr is nil + searchTools := []string{ + "search", + "timeline", + "decisions", + "changes", + "how_it_works", + "find_by_concept", + "find_by_file", + "find_by_type", + "get_recent_context", + "get_context_timeline", + "get_timeline_by_query", + } + + for _, toolName := range searchTools { + t.Run(toolName+"_invalid_json", func(t *testing.T) { + t.Parallel() + _, err := server.callTool(ctx, toolName, json.RawMessage(`{invalid`)) + require.Error(t, err) + assert.Contains(t, err.Error(), "invalid") + }) + } +} + +// TestHandleTriggerMaintenance_NilService tests trigger_maintenance with nil service. +func TestHandleTriggerMaintenance_NilService(t *testing.T) { + t.Parallel() + + server := NewServer(nil, "1.0.0", nil, nil, nil, nil, nil, nil, nil, nil) + ctx := context.Background() + + // Should return error when maintenanceService is nil + _, err := server.handleTriggerMaintenance(ctx) + require.Error(t, err) + assert.Contains(t, err.Error(), "maintenance service not available") +} + +// TestHandleGetMaintenanceStats_NilService tests get_maintenance_stats with nil service. +func TestHandleGetMaintenanceStats_NilService(t *testing.T) { + t.Parallel() + + server := NewServer(nil, "1.0.0", nil, nil, nil, nil, nil, nil, nil, nil) + ctx := context.Background() + + // Should return error when maintenanceService is nil + _, err := server.handleGetMaintenanceStats(ctx) + require.Error(t, err) + assert.Contains(t, err.Error(), "maintenance service not available") +} + +// TestHandleTimeline_ParameterDefaultsNew tests timeline parameter defaults. +func TestHandleTimeline_ParameterDefaultsNew(t *testing.T) { + t.Parallel() + + server := NewServer(nil, "1.0.0", nil, nil, nil, nil, nil, nil, nil, nil) + ctx := context.Background() + + // Invalid JSON should fail + _, err := server.handleTimeline(ctx, json.RawMessage(`{invalid`)) + require.Error(t, err) + assert.Contains(t, err.Error(), "invalid timeline params") +} + +// TestHandleTimelineByQuery_ValidationExtended tests timeline_by_query validation. +func TestHandleTimelineByQuery_ValidationExtended(t *testing.T) { + t.Parallel() + + server := NewServer(nil, "1.0.0", nil, nil, nil, nil, nil, nil, nil, nil) + ctx := context.Background() + + tests := []struct { + name string + args string + errContains string + wantErr bool + }{ + { + name: "invalid json", + args: `{invalid`, + wantErr: true, + errContains: "invalid timeline params", + }, + { + name: "missing query", + args: `{}`, + wantErr: true, + errContains: "query is required", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + t.Parallel() + _, err := server.handleTimelineByQuery(ctx, json.RawMessage(tt.args)) + require.Error(t, err) + assert.Contains(t, err.Error(), tt.errContains) + }) + } +} + +// TestHandleSuggestConsolidations_ValidationExtended tests suggest_consolidations validation. +func TestHandleSuggestConsolidations_ValidationExtended(t *testing.T) { + t.Parallel() + + server := NewServer(nil, "1.0.0", nil, nil, nil, nil, nil, nil, nil, nil) + ctx := context.Background() + + tests := []struct { + name string + args string + errContains string + wantErr bool + }{ + { + name: "invalid json", + args: `{invalid`, + wantErr: true, + errContains: "invalid arguments", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + t.Parallel() + _, err := server.handleSuggestConsolidations(ctx, json.RawMessage(tt.args)) + require.Error(t, err) + assert.Contains(t, err.Error(), tt.errContains) + }) + } +} + +// ============================================================================= +// NIL GUARD HANDLER TESTS +// ============================================================================= + +// TestHandleFindSimilarObservations_NilVectorClient tests nil vector client handling. +func TestHandleFindSimilarObservations_NilVectorClient(t *testing.T) { + t.Parallel() + + server := NewServer(nil, "1.0.0", nil, nil, nil, nil, nil, nil, nil, nil) + ctx := context.Background() + + // Should return error when vectorClient is nil with valid query + _, err := server.handleFindSimilarObservations(ctx, json.RawMessage(`{"query": "test query"}`)) + require.Error(t, err) + assert.Contains(t, err.Error(), "vector search not available") +} + +// TestHandleGetObservationRelationships_NilRelationStore tests nil relation store handling. +func TestHandleGetObservationRelationships_NilRelationStore(t *testing.T) { + t.Parallel() + + server := NewServer(nil, "1.0.0", nil, nil, nil, nil, nil, nil, nil, nil) + ctx := context.Background() + + // Should return error when relationStore is nil with valid params + _, err := server.handleGetObservationRelationships(ctx, json.RawMessage(`{"id": 123}`)) + require.Error(t, err) + assert.Contains(t, err.Error(), "relation store not available") +} + +// ============================================================================= +// MORE PARAM LIMIT TESTS +// ============================================================================= + +// TestHandleBulkBoostObservations_TooManyIDs tests the max IDs limit. +func TestHandleBulkBoostObservations_TooManyIDs(t *testing.T) { + t.Parallel() + + server := NewServer(nil, "1.0.0", nil, nil, nil, nil, nil, nil, nil, nil) + ctx := context.Background() + + // Create array with 1001 IDs + ids := make([]int, 1001) + for i := range ids { + ids[i] = i + 1 + } + idsJSON, _ := json.Marshal(ids) + argsJSON := `{"ids": ` + string(idsJSON) + `, "amount": 1}` + + _, err := server.handleBulkBoostObservations(ctx, json.RawMessage(argsJSON)) + require.Error(t, err) + assert.Contains(t, err.Error(), "maximum 1000 IDs") +} + +// TestHandleMergeObservations_SameID tests merge with same source and target. +func TestHandleMergeObservations_SameID(t *testing.T) { + t.Parallel() + + server := NewServer(nil, "1.0.0", nil, nil, nil, nil, nil, nil, nil, nil) + ctx := context.Background() + + // source_id and target_id cannot be the same + _, err := server.handleMergeObservations(ctx, json.RawMessage(`{"source_id": 123, "target_id": 123}`)) + require.Error(t, err) + assert.Contains(t, err.Error(), "cannot be the same") +} + +// TestHandleMergeObservations_InvalidBoost tests merge with invalid boost. +func TestHandleMergeObservations_InvalidBoost(t *testing.T) { + t.Parallel() + + server := NewServer(nil, "1.0.0", nil, nil, nil, nil, nil, nil, nil, nil) + ctx := context.Background() + + // boost must be between 0 and 0.5 + _, err := server.handleMergeObservations(ctx, json.RawMessage(`{"source_id": 1, "target_id": 2, "boost": 0.6}`)) + require.Error(t, err) + assert.Contains(t, err.Error(), "boost must be between") } diff --git a/internal/pattern/detector_test.go b/internal/pattern/detector_test.go index 9b9dc8b..ff3f4c0 100644 --- a/internal/pattern/detector_test.go +++ b/internal/pattern/detector_test.go @@ -328,6 +328,26 @@ func TestDefaultConfig(t *testing.T) { } } +func TestDefaultConfig_AllFieldsValid(t *testing.T) { + config := DefaultConfig() + + if config.MinMatchScore != 0.3 { + t.Errorf("MinMatchScore = %f, want 0.3", config.MinMatchScore) + } + if config.MinFrequencyForPattern != 2 { + t.Errorf("MinFrequencyForPattern = %d, want 2", config.MinFrequencyForPattern) + } + if config.AnalysisInterval != 5*time.Minute { + t.Errorf("AnalysisInterval = %v, want 5m", config.AnalysisInterval) + } + if config.MaxPatternsToTrack != 1000 { + t.Errorf("MaxPatternsToTrack = %d, want 1000", config.MaxPatternsToTrack) + } + if config.MaxCandidates != 500 { + t.Errorf("MaxCandidates = %d, want 500", config.MaxCandidates) + } +} + func TestGeneratePatternName(t *testing.T) { tests := []struct { patternType models.PatternType @@ -352,6 +372,85 @@ func TestGeneratePatternName(t *testing.T) { } } +func TestGeneratePatternName_EdgeCases(t *testing.T) { + tests := []struct { + name string + ptype models.PatternType + title string + want string + signature []string + }{ + { + name: "with title uses title directly", + ptype: models.PatternTypeBug, + signature: []string{"ignored"}, + title: "Custom Title", + want: "Custom Title", + }, + { + name: "long title generates from signature", + ptype: models.PatternTypeBug, + signature: []string{"sig1", "sig2"}, + title: "This is a very long title that exceeds sixty characters and should be ignored", + want: "Bug Pattern: sig1, sig2", + }, + { + name: "empty signature returns Unnamed", + ptype: models.PatternTypeBug, + signature: []string{}, + title: "", + want: "Bug Pattern: Unnamed", + }, + { + name: "single signature element", + ptype: models.PatternTypeRefactor, + signature: []string{"single"}, + title: "", + want: "Refactor Pattern: single", + }, + { + name: "more than 3 signature elements truncates", + ptype: models.PatternTypeBestPractice, + signature: []string{"a", "b", "c", "d", "e"}, + title: "", + want: "Best Practice: a, b, c", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got := generatePatternName(tt.ptype, tt.signature, tt.title) + if got != tt.want { + t.Errorf("generatePatternName() = %q, want %q", got, tt.want) + } + }) + } +} + +func TestGeneratePatternName_AllTypes(t *testing.T) { + tests := []struct { + ptype models.PatternType + wantPrefix string + }{ + {models.PatternTypeBug, "Bug Pattern:"}, + {models.PatternTypeRefactor, "Refactor Pattern:"}, + {models.PatternTypeArchitecture, "Architecture Pattern:"}, + {models.PatternTypeAntiPattern, "Anti-Pattern:"}, + {models.PatternTypeBestPractice, "Best Practice:"}, + {models.PatternType("unknown"), "test"}, // Unknown type has empty prefix, starts with first signature element + } + + for _, tt := range tests { + t.Run(string(tt.ptype), func(t *testing.T) { + name := generatePatternName(tt.ptype, []string{"test", "sig"}, "") + if !hasPrefix(name, tt.wantPrefix) { + t.Errorf("Expected prefix %q for type %s, got: %s", + tt.wantPrefix, tt.ptype, name) + } + }) + } +} + func TestFormatPatternInsight(t *testing.T) { // Pattern without recommendation pattern1 := &models.Pattern{ @@ -386,6 +485,470 @@ func TestFormatPatternInsight(t *testing.T) { } } +func TestFormatPatternInsight_AllTypes(t *testing.T) { + types := []struct { + ptype models.PatternType + contains string + }{ + {models.PatternTypeBug, "bug pattern"}, + {models.PatternTypeRefactor, "recognized pattern"}, // Falls to default case + {models.PatternTypeArchitecture, "recognized pattern"}, // Falls to default case + {models.PatternTypeAntiPattern, "anti-pattern"}, + {models.PatternTypeBestPractice, "best practice"}, + {models.PatternType("unknown"), "recognized pattern"}, // Falls to default case + } + + for _, tt := range types { + t.Run(string(tt.ptype), func(t *testing.T) { + pattern := &models.Pattern{ + Type: tt.ptype, + Frequency: 3, + Projects: []string{"proj1"}, + } + insight := formatPatternInsight(pattern) + if !containsString(insight, tt.contains) { + t.Errorf("Expected insight to contain %q for type %s, got: %s", + tt.contains, tt.ptype, insight) + } + }) + } +} + +func TestFormatPatternInsight_MultiProject(t *testing.T) { + pattern := &models.Pattern{ + Type: models.PatternTypeBug, + Frequency: 10, + Projects: []string{"proj1", "proj2", "proj3"}, + } + + insight := formatPatternInsight(pattern) + + if !containsString(insight, "10 times") { + t.Error("Expected frequency in insight") + } + if !containsString(insight, "3 projects") { + t.Error("Expected project count in insight") + } +} + +func TestFormatPatternInsight_SingleProject(t *testing.T) { + pattern := &models.Pattern{ + Type: models.PatternTypeBestPractice, + Frequency: 5, + Projects: []string{"only-one"}, + } + + insight := formatPatternInsight(pattern) + + if !containsString(insight, "5 times") { + t.Error("Expected frequency in insight") + } + // Single project should NOT mention "projects" + if containsString(insight, "projects") { + t.Error("Single project should not mention 'projects'") + } +} + +func TestDetector_SetSyncFunc(t *testing.T) { + store := setupTestStore(t) + defer store.Close() + + patternStore := gorm.NewPatternStore(store) + observationStore := gorm.NewObservationStore(store, nil, nil, nil) + config := DefaultConfig() + + detector := NewDetector(patternStore, observationStore, config) + + // Initially nil + if detector.syncFunc != nil { + t.Error("Expected syncFunc to be nil initially") + } + + // Set sync func + var syncCalled bool + detector.SetSyncFunc(func(p *models.Pattern) { + syncCalled = true + }) + + if detector.syncFunc == nil { + t.Error("Expected syncFunc to be set") + } + + // Verify it can be called + detector.syncFunc(&models.Pattern{}) + if !syncCalled { + t.Error("Expected sync function to be called") + } +} + +func TestDetector_CandidateCount(t *testing.T) { + store := setupTestStore(t) + defer store.Close() + + patternStore := gorm.NewPatternStore(store) + observationStore := gorm.NewObservationStore(store, nil, nil, nil) + config := DefaultConfig() + + detector := NewDetector(patternStore, observationStore, config) + + // Initially zero + if count := detector.CandidateCount(); count != 0 { + t.Errorf("Expected 0 candidates, got %d", count) + } + + // Add some candidates + detector.candidates["key1"] = &candidatePattern{} + detector.candidates["key2"] = &candidatePattern{} + + if count := detector.CandidateCount(); count != 2 { + t.Errorf("Expected 2 candidates, got %d", count) + } +} + +func TestDetector_AnalyzeRecentObservations(t *testing.T) { + store := setupTestStore(t) + defer store.Close() + + patternStore := gorm.NewPatternStore(store) + observationStore := gorm.NewObservationStore(store, nil, nil, nil) + config := DefaultConfig() + + detector := NewDetector(patternStore, observationStore, config) + ctx := context.Background() + + // Should not error even with no observations + err := detector.AnalyzeRecentObservations(ctx) + if err != nil { + t.Fatalf("AnalyzeRecentObservations() error = %v", err) + } +} + +func TestGenerateCandidateKey(t *testing.T) { + tests := []struct { + name string + want string + signature []string + }{ + { + name: "single element", + signature: []string{"error"}, + want: "error|", + }, + { + name: "multiple elements", + signature: []string{"error", "handling", "nil"}, + want: "error|handling|nil|", + }, + { + name: "empty signature", + signature: []string{}, + want: "", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got := generateCandidateKey(tt.signature) + if got != tt.want { + t.Errorf("generateCandidateKey() = %q, want %q", got, tt.want) + } + }) + } +} + +func TestGenerateCandidateKey_EdgeCases(t *testing.T) { + tests := []struct { + name string + want string + signature []string + }{ + { + name: "nil signature", + signature: nil, + want: "", + }, + { + name: "empty strings in signature", + signature: []string{"", ""}, + want: "||", + }, + { + name: "special characters", + signature: []string{"error|handling", "nil"}, + want: "error|handling|nil|", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got := generateCandidateKey(tt.signature) + if got != tt.want { + t.Errorf("generateCandidateKey() = %q, want %q", got, tt.want) + } + }) + } +} + +func TestItoa(t *testing.T) { + tests := []struct { + want string + input int + }{ + {"0", 0}, + {"1", 1}, + {"10", 10}, + {"123", 123}, + {"-1", -1}, + {"-123", -123}, + {"1000000", 1000000}, + } + + for _, tt := range tests { + t.Run(tt.want, func(t *testing.T) { + got := itoa(tt.input) + if got != tt.want { + t.Errorf("itoa(%d) = %q, want %q", tt.input, got, tt.want) + } + }) + } +} + +func TestItoa_EdgeCases(t *testing.T) { + tests := []struct { + want string + input int + }{ + {"0", 0}, + {"0", -0}, + {"1", 1}, + {"-1", -1}, + {"9", 9}, + {"10", 10}, + {"99", 99}, + {"100", 100}, + {"999", 999}, + {"1000", 1000}, + {"-999", -999}, + {"-1000", -1000}, + {"2147483647", 2147483647}, // Max int32 + {"-2147483647", -2147483647}, // Min int32 + 1 + } + + for _, tt := range tests { + t.Run(tt.want, func(t *testing.T) { + got := itoa(tt.input) + if got != tt.want { + t.Errorf("itoa(%d) = %q, want %q", tt.input, got, tt.want) + } + }) + } +} + +func TestDetectionResult_ZeroValue(t *testing.T) { + result := &DetectionResult{} + + if result.MatchedPattern != nil { + t.Error("Zero value should have nil MatchedPattern") + } + if result.MatchScore != 0 { + t.Error("Zero value should have 0 MatchScore") + } + if result.IsNewPattern { + t.Error("Zero value should have false IsNewPattern") + } +} + +func TestCandidatePattern_Fields(t *testing.T) { + candidate := &candidatePattern{ + patternType: models.PatternTypeBug, + title: "Test Title", + signature: []string{"sig1", "sig2"}, + observationIDs: []int64{1, 2, 3}, + projects: []string{"proj1", "proj2"}, + lastSeenEpoch: time.Now().UnixMilli(), + } + + if candidate.patternType != models.PatternTypeBug { + t.Error("Wrong pattern type") + } + if candidate.title != "Test Title" { + t.Error("Wrong title") + } + if len(candidate.signature) != 2 { + t.Error("Wrong signature length") + } + if len(candidate.observationIDs) != 3 { + t.Error("Wrong observationIDs length") + } + if len(candidate.projects) != 2 { + t.Error("Wrong projects length") + } +} + +func TestDetector_AnalyzeObservation_EmptySignature(t *testing.T) { + store := setupTestStore(t) + defer store.Close() + + patternStore := gorm.NewPatternStore(store) + observationStore := gorm.NewObservationStore(store, nil, nil, nil) + config := DefaultConfig() + + detector := NewDetector(patternStore, observationStore, config) + ctx := context.Background() + + // Create observation with empty concepts/title/narrative + obs := &models.Observation{ + ID: 1, + SDKSessionID: "test-session", + Project: "test-project", + Scope: models.ScopeProject, + Type: models.ObsTypeBugfix, + // All fields that would create signature are empty + } + + result, err := detector.AnalyzeObservation(ctx, obs) + if err != nil { + t.Fatalf("AnalyzeObservation() error = %v", err) + } + + // Should return empty result for empty signature + if result.MatchedPattern != nil { + t.Error("Expected nil pattern for empty signature") + } +} + +func TestDetector_AnalyzeObservation_CandidateEviction(t *testing.T) { + store := setupTestStore(t) + defer store.Close() + + patternStore := gorm.NewPatternStore(store) + observationStore := gorm.NewObservationStore(store, nil, nil, nil) + config := DefaultConfig() + config.MaxCandidates = 2 // Very small for testing + config.MinFrequencyForPattern = 10 // High so nothing gets promoted + + detector := NewDetector(patternStore, observationStore, config) + ctx := context.Background() + + // Add observations with different signatures until we exceed MaxCandidates + obs1 := createTestObservation(1, "First", []string{"first", "unique"}) + obs2 := createTestObservation(2, "Second", []string{"second", "unique"}) + obs3 := createTestObservation(3, "Third", []string{"third", "unique"}) + + // Analyze all observations + _, _ = detector.AnalyzeObservation(ctx, obs1) + time.Sleep(10 * time.Millisecond) // Small delay so timestamps differ + _, _ = detector.AnalyzeObservation(ctx, obs2) + time.Sleep(10 * time.Millisecond) + _, _ = detector.AnalyzeObservation(ctx, obs3) + + // Should have at most MaxCandidates + if count := detector.CandidateCount(); count > config.MaxCandidates { + t.Errorf("Expected at most %d candidates, got %d", config.MaxCandidates, count) + } +} + +func TestDetector_PromoteCandidateWithSyncFunc(t *testing.T) { + store := setupTestStore(t) + defer store.Close() + + patternStore := gorm.NewPatternStore(store) + observationStore := gorm.NewObservationStore(store, nil, nil, nil) + config := DefaultConfig() + config.MinFrequencyForPattern = 2 + + detector := NewDetector(patternStore, observationStore, config) + ctx := context.Background() + + // Set up sync function to track calls + var syncedPattern *models.Pattern + detector.SetSyncFunc(func(p *models.Pattern) { + syncedPattern = p + }) + + // Create two similar observations to trigger pattern promotion + obs1 := createTestObservation(1, "Sync Test", []string{"sync", "test"}) + obs2 := createTestObservation(2, "Sync Test", []string{"sync", "test"}) + + _, _ = detector.AnalyzeObservation(ctx, obs1) + result, _ := detector.AnalyzeObservation(ctx, obs2) + + if result.MatchedPattern == nil { + t.Fatal("Expected pattern to be created") + } + + if syncedPattern == nil { + t.Error("Expected sync function to be called") + } + + if syncedPattern != nil && syncedPattern.Name != result.MatchedPattern.Name { + t.Errorf("Synced pattern name mismatch: got %s, want %s", + syncedPattern.Name, result.MatchedPattern.Name) + } +} + +func TestDetector_AnalyzeObservation_UpdateExistingCandidate(t *testing.T) { + store := setupTestStore(t) + defer store.Close() + + patternStore := gorm.NewPatternStore(store) + observationStore := gorm.NewObservationStore(store, nil, nil, nil) + config := DefaultConfig() + config.MinFrequencyForPattern = 5 // High enough that we don't promote + + detector := NewDetector(patternStore, observationStore, config) + ctx := context.Background() + + // Create observations with same signature + obs1 := createTestObservation(1, "Update Test", []string{"update", "test"}) + obs2 := createTestObservation(2, "Update Test", []string{"update", "test"}) + obs2.Project = "different-project" + + // Analyze first observation + _, _ = detector.AnalyzeObservation(ctx, obs1) + + // Check candidate count + if count := detector.CandidateCount(); count != 1 { + t.Errorf("Expected 1 candidate after first obs, got %d", count) + } + + // Analyze second observation + _, _ = detector.AnalyzeObservation(ctx, obs2) + + // Still 1 candidate (same signature) + if count := detector.CandidateCount(); count != 1 { + t.Errorf("Expected 1 candidate after second obs, got %d", count) + } + + // Check that candidate has both projects + key := generateCandidateKey([]string{"update", "test"}) + candidate := detector.candidates[key] + if candidate == nil { + t.Fatal("Expected candidate to exist") + } + if len(candidate.projects) != 2 { + t.Errorf("Expected 2 projects, got %d", len(candidate.projects)) + } +} + +func TestDetector_GetPatternInsight_NotFound(t *testing.T) { + store := setupTestStore(t) + defer store.Close() + + patternStore := gorm.NewPatternStore(store) + observationStore := gorm.NewObservationStore(store, nil, nil, nil) + config := DefaultConfig() + + detector := NewDetector(patternStore, observationStore, config) + ctx := context.Background() + + // Try to get insight for non-existent pattern + _, err := detector.GetPatternInsight(ctx, 99999) + if err == nil { + t.Error("Expected error for non-existent pattern") + } +} + // Helper functions func setupTestStore(t *testing.T) *gorm.Store { diff --git a/internal/vector/sqlitevec/client_test.go b/internal/vector/sqlitevec/client_test.go index b2dbf3a..5e9be1a 100644 --- a/internal/vector/sqlitevec/client_test.go +++ b/internal/vector/sqlitevec/client_test.go @@ -752,3 +752,1152 @@ func TestClient_DeleteVectorsByDocIDs_NonExistent(t *testing.T) { err = client.DeleteVectorsByDocIDs(context.Background(), []string{"non-existent-1", "non-existent-2"}) require.NoError(t, err) } + +// ============================================================================= +// TESTS FOR CacheStats +// ============================================================================= + +func TestCacheStatsSnapshot_HitRate_NoOperations(t *testing.T) { + snapshot := CacheStatsSnapshot{} + assert.Equal(t, float64(0), snapshot.HitRate()) +} + +func TestCacheStatsSnapshot_HitRate_WithOperations(t *testing.T) { + tests := []struct { + name string + stats CacheStatsSnapshot + expected float64 + }{ + { + name: "all_hits", + stats: CacheStatsSnapshot{ + EmbeddingHits: 50, + ResultHits: 50, + }, + expected: 100.0, + }, + { + name: "no_hits", + stats: CacheStatsSnapshot{ + EmbeddingMisses: 50, + ResultMisses: 50, + }, + expected: 0.0, + }, + { + name: "50_percent_hits", + stats: CacheStatsSnapshot{ + EmbeddingHits: 25, + EmbeddingMisses: 25, + ResultHits: 25, + ResultMisses: 25, + }, + expected: 50.0, + }, + { + name: "75_percent_hits", + stats: CacheStatsSnapshot{ + EmbeddingHits: 30, + EmbeddingMisses: 10, + ResultHits: 30, + ResultMisses: 10, + }, + expected: 75.0, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + result := tt.stats.HitRate() + assert.InDelta(t, tt.expected, result, 0.01) + }) + } +} + +func TestCacheStats_HitRate_NoOperations(t *testing.T) { + stats := &CacheStats{} + assert.Equal(t, float64(0), stats.HitRate()) +} + +func TestCacheStats_HitRate_WithOperations(t *testing.T) { + stats := &CacheStats{} + stats.embeddingHits.Add(10) + stats.embeddingMisses.Add(10) + stats.resultHits.Add(10) + stats.resultMisses.Add(10) + + // 20 hits / 40 total = 50% + assert.InDelta(t, 50.0, stats.HitRate(), 0.01) +} + +func TestCacheStats_Snapshot(t *testing.T) { + stats := &CacheStats{} + stats.embeddingHits.Add(10) + stats.embeddingMisses.Add(5) + stats.resultHits.Add(20) + stats.resultMisses.Add(15) + stats.embeddingEvictions.Add(2) + stats.resultEvictions.Add(3) + + snapshot := stats.Snapshot() + + assert.Equal(t, int64(10), snapshot.EmbeddingHits) + assert.Equal(t, int64(5), snapshot.EmbeddingMisses) + assert.Equal(t, int64(20), snapshot.ResultHits) + assert.Equal(t, int64(15), snapshot.ResultMisses) + assert.Equal(t, int64(2), snapshot.EmbeddingEvictions) + assert.Equal(t, int64(3), snapshot.ResultEvictions) +} + +// ============================================================================= +// TESTS FOR Cache Methods +// ============================================================================= + +func TestClient_ClearCache(t *testing.T) { + db, dbCleanup := testDB(t) + defer dbCleanup() + + embedSvc, embedCleanup := testEmbeddingService(t) + defer embedCleanup() + + client, err := NewClient(Config{DB: db}, embedSvc) + require.NoError(t, err) + + // Add a document and query to populate cache + docs := []Document{ + {ID: "doc-1", Content: "test content for caching"}, + } + err = client.AddDocuments(context.Background(), docs) + require.NoError(t, err) + + // Query to populate cache + _, err = client.Query(context.Background(), "test content", 5, nil) + require.NoError(t, err) + + // Verify cache has entries + initialSize := client.EmbeddingCacheSize() + assert.Greater(t, initialSize, 0) + + // Clear cache + client.ClearCache() + + // Verify cache is empty + assert.Equal(t, 0, client.EmbeddingCacheSize()) +} + +func TestClient_GetCacheStats(t *testing.T) { + db, dbCleanup := testDB(t) + defer dbCleanup() + + embedSvc, embedCleanup := testEmbeddingService(t) + defer embedCleanup() + + client, err := NewClient(Config{DB: db}, embedSvc) + require.NoError(t, err) + + // Get stats before any operations + stats := client.GetCacheStats() + assert.Equal(t, int64(0), stats.EmbeddingHits) + assert.Equal(t, int64(0), stats.EmbeddingMisses) + + // Add a document and query to generate cache activity + docs := []Document{ + {ID: "doc-1", Content: "test content for caching"}, + } + err = client.AddDocuments(context.Background(), docs) + require.NoError(t, err) + + // Query - should be a miss first time + _, err = client.Query(context.Background(), "test content", 5, nil) + require.NoError(t, err) + + // Query again - should be a hit + _, err = client.Query(context.Background(), "test content", 5, nil) + require.NoError(t, err) + + // Get stats after operations + stats = client.GetCacheStats() + assert.Greater(t, stats.EmbeddingMisses+stats.EmbeddingHits, int64(0)) +} + +func TestClient_CacheStats(t *testing.T) { + db, dbCleanup := testDB(t) + defer dbCleanup() + + embedSvc, embedCleanup := testEmbeddingService(t) + defer embedCleanup() + + client, err := NewClient(Config{DB: db}, embedSvc) + require.NoError(t, err) + + // Get initial stats + size, maxSize := client.CacheStats() + assert.Equal(t, 0, size) + assert.Greater(t, maxSize, 0) + + // Add a document and query to populate cache + docs := []Document{ + {ID: "doc-1", Content: "test content"}, + } + err = client.AddDocuments(context.Background(), docs) + require.NoError(t, err) + + _, err = client.Query(context.Background(), "test content", 5, nil) + require.NoError(t, err) + + // Check stats after operations + size, _ = client.CacheStats() + assert.Greater(t, size, 0) +} + +func TestClient_EmbeddingCacheSize(t *testing.T) { + db, dbCleanup := testDB(t) + defer dbCleanup() + + embedSvc, embedCleanup := testEmbeddingService(t) + defer embedCleanup() + + client, err := NewClient(Config{DB: db}, embedSvc) + require.NoError(t, err) + + // Initially empty + assert.Equal(t, 0, client.EmbeddingCacheSize()) + + // Add a document and query + docs := []Document{ + {ID: "doc-1", Content: "test content"}, + } + err = client.AddDocuments(context.Background(), docs) + require.NoError(t, err) + + _, err = client.Query(context.Background(), "unique query", 5, nil) + require.NoError(t, err) + + // Should have at least one entry + assert.Greater(t, client.EmbeddingCacheSize(), 0) +} + +func TestClient_ResultCacheSize(t *testing.T) { + db, dbCleanup := testDB(t) + defer dbCleanup() + + embedSvc, embedCleanup := testEmbeddingService(t) + defer embedCleanup() + + client, err := NewClient(Config{DB: db}, embedSvc) + require.NoError(t, err) + + // Initially empty + assert.Equal(t, 0, client.ResultCacheSize()) +} + +// ============================================================================= +// TESTS FOR QueryBatch +// ============================================================================= + +func TestClient_QueryBatch_Empty(t *testing.T) { + db, dbCleanup := testDB(t) + defer dbCleanup() + + embedSvc, embedCleanup := testEmbeddingService(t) + defer embedCleanup() + + client, err := NewClient(Config{DB: db}, embedSvc) + require.NoError(t, err) + + results := client.QueryBatch(context.Background(), []string{}, 10, nil) + assert.Nil(t, results) +} + +func TestClient_QueryBatch_Single(t *testing.T) { + db, dbCleanup := testDB(t) + defer dbCleanup() + + embedSvc, embedCleanup := testEmbeddingService(t) + defer embedCleanup() + + client, err := NewClient(Config{DB: db}, embedSvc) + require.NoError(t, err) + + // Add some documents + docs := []Document{ + { + ID: "obs-1", + Content: "Authentication and security implementation.", + Metadata: map[string]any{"doc_type": "observation"}, + }, + } + err = client.AddDocuments(context.Background(), docs) + require.NoError(t, err) + + // Query batch with single query + results := client.QueryBatch(context.Background(), []string{"authentication"}, 10, nil) + + assert.Len(t, results, 1) + assert.NoError(t, results[0].Error) + assert.Equal(t, "authentication", results[0].Query) +} + +func TestClient_QueryBatch_Multiple(t *testing.T) { + db, dbCleanup := testDB(t) + defer dbCleanup() + + embedSvc, embedCleanup := testEmbeddingService(t) + defer embedCleanup() + + client, err := NewClient(Config{DB: db}, embedSvc) + require.NoError(t, err) + + // Add some documents + docs := []Document{ + {ID: "obs-1", Content: "Authentication and security implementation."}, + {ID: "obs-2", Content: "Database optimization and indexing."}, + {ID: "obs-3", Content: "API rate limiting and throttling."}, + } + err = client.AddDocuments(context.Background(), docs) + require.NoError(t, err) + + // Query batch with multiple queries + queries := []string{"authentication", "database", "API"} + results := client.QueryBatch(context.Background(), queries, 10, nil) + + assert.Len(t, results, 3) + for i, r := range results { + assert.NoError(t, r.Error) + assert.Equal(t, queries[i], r.Query) + } +} + +func TestClient_QueryBatch_WithContextCancellation(t *testing.T) { + db, dbCleanup := testDB(t) + defer dbCleanup() + + embedSvc, embedCleanup := testEmbeddingService(t) + defer embedCleanup() + + client, err := NewClient(Config{DB: db}, embedSvc) + require.NoError(t, err) + + // Cancel context immediately + ctx, cancel := context.WithCancel(context.Background()) + cancel() + + // Queries should fail due to cancelled context + queries := []string{"query1", "query2", "query3"} + results := client.QueryBatch(ctx, queries, 10, nil) + + assert.Len(t, results, 3) + // At least some should have context cancellation error + hasError := false + for _, r := range results { + if r.Error != nil { + hasError = true + } + } + assert.True(t, hasError, "Should have at least one error due to cancelled context") +} + +// ============================================================================= +// TESTS FOR QueryMultiField +// ============================================================================= + +func TestClient_QueryMultiField_Basic(t *testing.T) { + t.Skip("QueryMultiField SQL query needs 'k' parameter fix for sqlite-vec") + + db, dbCleanup := testDB(t) + defer dbCleanup() + + embedSvc, embedCleanup := testEmbeddingService(t) + defer embedCleanup() + + client, err := NewClient(Config{DB: db}, embedSvc) + require.NoError(t, err) + + // Add documents with different field types for same sqlite_id + docs := []Document{ + { + ID: "obs-1-title", + Content: "Authentication implementation", + Metadata: map[string]any{ + "sqlite_id": int64(1), + "doc_type": "observation", + "field_type": "title", + "project": "test-project", + "scope": "project", + }, + }, + { + ID: "obs-1-narrative", + Content: "We implemented JWT-based authentication for the API.", + Metadata: map[string]any{ + "sqlite_id": int64(1), + "doc_type": "observation", + "field_type": "narrative", + "project": "test-project", + "scope": "project", + }, + }, + { + ID: "obs-2-title", + Content: "Database optimization", + Metadata: map[string]any{ + "sqlite_id": int64(2), + "doc_type": "observation", + "field_type": "title", + "project": "test-project", + "scope": "project", + }, + }, + } + err = client.AddDocuments(context.Background(), docs) + require.NoError(t, err) + + // Query multi-field + results, err := client.QueryMultiField(context.Background(), "authentication JWT", 10, "observation", "test-project") + require.NoError(t, err) + + // Should return deduplicated results (one per sqlite_id) + assert.NotEmpty(t, results) + // Each result should have unique sqlite_id + seenIDs := make(map[float64]bool) + for _, r := range results { + sqliteID, ok := r.Metadata["sqlite_id"].(float64) + if ok { + assert.False(t, seenIDs[sqliteID], "Should not have duplicate sqlite_ids") + seenIDs[sqliteID] = true + } + } +} + +func TestClient_QueryMultiField_WithGlobalScope(t *testing.T) { + t.Skip("QueryMultiField SQL query needs 'k' parameter fix for sqlite-vec") + + db, dbCleanup := testDB(t) + defer dbCleanup() + + embedSvc, embedCleanup := testEmbeddingService(t) + defer embedCleanup() + + client, err := NewClient(Config{DB: db}, embedSvc) + require.NoError(t, err) + + // Add documents: one project-scoped, one global + docs := []Document{ + { + ID: "obs-1-title", + Content: "Security best practices", + Metadata: map[string]any{ + "sqlite_id": int64(1), + "doc_type": "observation", + "field_type": "title", + "project": "project-a", + "scope": "project", + }, + }, + { + ID: "obs-2-title", + Content: "Security patterns for all projects", + Metadata: map[string]any{ + "sqlite_id": int64(2), + "doc_type": "observation", + "field_type": "title", + "project": "project-b", + "scope": "global", + }, + }, + } + err = client.AddDocuments(context.Background(), docs) + require.NoError(t, err) + + // Query from project-a - should get project-a doc and global doc + results, err := client.QueryMultiField(context.Background(), "security", 10, "observation", "project-a") + require.NoError(t, err) + + // Should include both project-scoped (matching project) and global + assert.NotEmpty(t, results) +} + +// ============================================================================= +// TESTS FOR GetHealthStats +// ============================================================================= + +func TestClient_GetHealthStats_Empty(t *testing.T) { + db, dbCleanup := testDB(t) + defer dbCleanup() + + embedSvc, embedCleanup := testEmbeddingService(t) + defer embedCleanup() + + client, err := NewClient(Config{DB: db}, embedSvc) + require.NoError(t, err) + + stats, err := client.GetHealthStats(context.Background()) + require.NoError(t, err) + + assert.NotNil(t, stats) + assert.Equal(t, int64(0), stats.TotalVectors) + assert.Equal(t, int64(0), stats.StaleVectors) + assert.Equal(t, embedSvc.Version(), stats.CurrentModel) + assert.True(t, stats.NeedsRebuild) + assert.Equal(t, "empty", stats.RebuildReason) +} + +func TestClient_GetHealthStats_WithData(t *testing.T) { + db, dbCleanup := testDB(t) + defer dbCleanup() + + embedSvc, embedCleanup := testEmbeddingService(t) + defer embedCleanup() + + client, err := NewClient(Config{DB: db}, embedSvc) + require.NoError(t, err) + + // Add some documents + docs := []Document{ + { + ID: "obs-1", + Content: "Test content 1", + Metadata: map[string]any{ + "sqlite_id": int64(1), + "doc_type": "observation", + "project": "project-a", + }, + }, + { + ID: "obs-2", + Content: "Test content 2", + Metadata: map[string]any{ + "sqlite_id": int64(2), + "doc_type": "observation", + "project": "project-a", + }, + }, + { + ID: "sum-1", + Content: "Summary content", + Metadata: map[string]any{ + "sqlite_id": int64(10), + "doc_type": "session_summary", + "project": "project-b", + }, + }, + } + err = client.AddDocuments(context.Background(), docs) + require.NoError(t, err) + + stats, err := client.GetHealthStats(context.Background()) + require.NoError(t, err) + + assert.NotNil(t, stats) + assert.Equal(t, int64(3), stats.TotalVectors) + assert.Equal(t, int64(0), stats.StaleVectors) // All fresh + assert.False(t, stats.NeedsRebuild) + + // Coverage by type + assert.Equal(t, int64(2), stats.CoverageByType["observation"]) + assert.Equal(t, int64(1), stats.CoverageByType["session_summary"]) + + // Model versions + assert.Equal(t, int64(3), stats.ModelVersions[embedSvc.Version()]) + + // Project counts + assert.Equal(t, int64(2), stats.ProjectCounts["project-a"]) + assert.Equal(t, int64(1), stats.ProjectCounts["project-b"]) +} + +func TestClient_GetHealthStats_WithStaleVectors(t *testing.T) { + db, dbCleanup := testDB(t) + defer dbCleanup() + + embedSvc, embedCleanup := testEmbeddingService(t) + defer embedCleanup() + + client, err := NewClient(Config{DB: db}, embedSvc) + require.NoError(t, err) + + // Add a document with current model + docs := []Document{ + {ID: "doc-1", Content: "Fresh content"}, + } + err = client.AddDocuments(context.Background(), docs) + require.NoError(t, err) + + // Insert a stale vector directly + embedding := make([]float32, 384) + embeddingBytes, err := sqlite_vec.SerializeFloat32(embedding) + require.NoError(t, err) + + _, err = db.Exec(` + INSERT INTO vectors (doc_id, embedding, model_version, sqlite_id, doc_type, field_type, project, scope) + VALUES (?, ?, ?, ?, ?, ?, ?, ?) + `, "stale-doc", embeddingBytes, "old-model", 999, "observation", "content", "test-project", "project") + require.NoError(t, err) + + stats, err := client.GetHealthStats(context.Background()) + require.NoError(t, err) + + assert.Equal(t, int64(2), stats.TotalVectors) + assert.Equal(t, int64(1), stats.StaleVectors) + assert.True(t, stats.NeedsRebuild) + assert.Contains(t, stats.RebuildReason, "model_mismatch") +} + +// ============================================================================= +// TESTS FOR DeleteByObservationID +// ============================================================================= + +func TestClient_DeleteByObservationID_NoMatches(t *testing.T) { + db, dbCleanup := testDB(t) + defer dbCleanup() + + embedSvc, embedCleanup := testEmbeddingService(t) + defer embedCleanup() + + client, err := NewClient(Config{DB: db}, embedSvc) + require.NoError(t, err) + + // Delete non-existent observation - should not error + err = client.DeleteByObservationID(context.Background(), 999) + require.NoError(t, err) +} + +func TestClient_DeleteByObservationID_WithMatches(t *testing.T) { + db, dbCleanup := testDB(t) + defer dbCleanup() + + embedSvc, embedCleanup := testEmbeddingService(t) + defer embedCleanup() + + client, err := NewClient(Config{DB: db}, embedSvc) + require.NoError(t, err) + + // Add documents with observation IDs in doc_id + docs := []Document{ + {ID: "obs_123_narrative", Content: "Narrative for observation 123"}, + {ID: "obs_123_facts_0", Content: "Fact 0 for observation 123"}, + {ID: "obs_123_facts_1", Content: "Fact 1 for observation 123"}, + {ID: "obs_456_narrative", Content: "Narrative for observation 456"}, + } + err = client.AddDocuments(context.Background(), docs) + require.NoError(t, err) + + // Verify 4 documents exist + count, err := client.Count(context.Background()) + require.NoError(t, err) + assert.Equal(t, int64(4), count) + + // Delete observation 123 + err = client.DeleteByObservationID(context.Background(), 123) + require.NoError(t, err) + + // Should have 1 document remaining (obs_456) + count, err = client.Count(context.Background()) + require.NoError(t, err) + assert.Equal(t, int64(1), count) + + // Verify obs_456 still exists + var exists int + err = db.QueryRow("SELECT COUNT(*) FROM vectors WHERE doc_id LIKE 'obs_456_%'").Scan(&exists) + require.NoError(t, err) + assert.Equal(t, 1, exists) +} + +// ============================================================================= +// TESTS FOR cacheCleanupLoop and cleanupExpiredCaches +// ============================================================================= + +func TestClient_CleanupExpiredCaches(t *testing.T) { + db, dbCleanup := testDB(t) + defer dbCleanup() + + embedSvc, embedCleanup := testEmbeddingService(t) + defer embedCleanup() + + client, err := NewClient(Config{DB: db}, embedSvc) + require.NoError(t, err) + + // Add a document and query to populate cache + docs := []Document{ + {ID: "doc-1", Content: "test content"}, + } + err = client.AddDocuments(context.Background(), docs) + require.NoError(t, err) + + _, err = client.Query(context.Background(), "test", 5, nil) + require.NoError(t, err) + + // Verify cache has entries + assert.Greater(t, client.EmbeddingCacheSize(), 0) + + // Call cleanup (will only clean expired entries) + client.cleanupExpiredCaches() + + // Fresh cache entries should still exist + assert.Greater(t, client.EmbeddingCacheSize(), 0) +} + +func TestClient_CacheCleanupLoop_StopsOnClose(t *testing.T) { + db, dbCleanup := testDB(t) + defer dbCleanup() + + embedSvc, embedCleanup := testEmbeddingService(t) + defer embedCleanup() + + client, err := NewClient(Config{DB: db}, embedSvc) + require.NoError(t, err) + + // Close should stop the cleanup loop + err = client.Close() + require.NoError(t, err) +} + +// ============================================================================= +// TESTS FOR EMBEDDING CACHE BEHAVIOR +// ============================================================================= + +func TestClient_EmbeddingCache_HitAfterMiss(t *testing.T) { + db, dbCleanup := testDB(t) + defer dbCleanup() + + embedSvc, embedCleanup := testEmbeddingService(t) + defer embedCleanup() + + client, err := NewClient(Config{DB: db}, embedSvc) + require.NoError(t, err) + + // Add a document so we can query + docs := []Document{ + {ID: "test-1", Content: "Hello world test content"}, + } + err = client.AddDocuments(context.Background(), docs) + require.NoError(t, err) + + // First query - cache miss + _, err = client.Query(context.Background(), "hello world", 10, nil) + require.NoError(t, err) + + stats1 := client.GetCacheStats() + assert.Equal(t, int64(1), stats1.EmbeddingMisses) + + // Invalidate result cache to force embedding cache usage on second query + client.InvalidateResultCache() + + // Second query with same text - should be embedding cache hit (result cache miss) + _, err = client.Query(context.Background(), "hello world", 10, nil) + require.NoError(t, err) + + stats2 := client.GetCacheStats() + assert.Equal(t, int64(1), stats2.EmbeddingMisses) // Same miss count + assert.Equal(t, int64(1), stats2.EmbeddingHits) // One hit +} + +func TestClient_ResultCache_HitAfterMiss(t *testing.T) { + db, dbCleanup := testDB(t) + defer dbCleanup() + + embedSvc, embedCleanup := testEmbeddingService(t) + defer embedCleanup() + + client, err := NewClient(Config{DB: db}, embedSvc) + require.NoError(t, err) + + // Add a document + docs := []Document{ + { + ID: "test-1", + Content: "Testing result cache behavior", + }, + } + err = client.AddDocuments(context.Background(), docs) + require.NoError(t, err) + + // First query - result cache miss + _, err = client.Query(context.Background(), "testing cache", 10, nil) + require.NoError(t, err) + + stats1 := client.GetCacheStats() + assert.Equal(t, int64(1), stats1.ResultMisses) + + // Second identical query - should be result cache hit + _, err = client.Query(context.Background(), "testing cache", 10, nil) + require.NoError(t, err) + + stats2 := client.GetCacheStats() + assert.Equal(t, int64(1), stats2.ResultMisses) // Same miss count + assert.Equal(t, int64(1), stats2.ResultHits) // One hit +} + +func TestClient_Query_WithContextCancel(t *testing.T) { + db, dbCleanup := testDB(t) + defer dbCleanup() + + embedSvc, embedCleanup := testEmbeddingService(t) + defer embedCleanup() + + client, err := NewClient(Config{DB: db}, embedSvc) + require.NoError(t, err) + + // Create cancelled context + ctx, cancel := context.WithCancel(context.Background()) + cancel() + + // Query with cancelled context + _, err = client.Query(ctx, "test query", 10, nil) + require.Error(t, err) + assert.Contains(t, err.Error(), "context canceled") +} + +func TestClient_AddDocuments_WithContextCancel(t *testing.T) { + db, dbCleanup := testDB(t) + defer dbCleanup() + + embedSvc, embedCleanup := testEmbeddingService(t) + defer embedCleanup() + + client, err := NewClient(Config{DB: db}, embedSvc) + require.NoError(t, err) + + // Create cancelled context + ctx, cancel := context.WithCancel(context.Background()) + cancel() + + docs := []Document{{ID: "test", Content: "test content"}} + err = client.AddDocuments(ctx, docs) + require.Error(t, err) + assert.Contains(t, err.Error(), "context canceled") +} + +func TestClient_InvalidateResultCache(t *testing.T) { + db, dbCleanup := testDB(t) + defer dbCleanup() + + embedSvc, embedCleanup := testEmbeddingService(t) + defer embedCleanup() + + client, err := NewClient(Config{DB: db}, embedSvc) + require.NoError(t, err) + + // Add a document + docs := []Document{ + {ID: "test-1", Content: "Test invalidation"}, + } + err = client.AddDocuments(context.Background(), docs) + require.NoError(t, err) + + // Query to populate result cache + _, err = client.Query(context.Background(), "invalidation", 10, nil) + require.NoError(t, err) + + assert.Greater(t, client.ResultCacheSize(), 0) + + // Invalidate the result cache + client.InvalidateResultCache() + + assert.Equal(t, 0, client.ResultCacheSize()) +} + +func TestClient_Count_WithError(t *testing.T) { + db, dbCleanup := testDB(t) + defer dbCleanup() + + embedSvc, embedCleanup := testEmbeddingService(t) + defer embedCleanup() + + client, err := NewClient(Config{DB: db}, embedSvc) + require.NoError(t, err) + + // Close DB to cause error + db.Close() + + _, err = client.Count(context.Background()) + require.Error(t, err) +} + +func TestClient_NeedsRebuild_ReturnsReason(t *testing.T) { + db, dbCleanup := testDB(t) + defer dbCleanup() + + embedSvc, embedCleanup := testEmbeddingService(t) + defer embedCleanup() + + client, err := NewClient(Config{DB: db}, embedSvc) + require.NoError(t, err) + + // Empty database should need rebuild + needsRebuild, reason := client.NeedsRebuild(context.Background()) + assert.True(t, needsRebuild) + assert.NotEmpty(t, reason) +} + +func TestClient_GetStaleVectors_DBError(t *testing.T) { + db, dbCleanup := testDB(t) + defer dbCleanup() + + embedSvc, embedCleanup := testEmbeddingService(t) + defer embedCleanup() + + client, err := NewClient(Config{DB: db}, embedSvc) + require.NoError(t, err) + + // Close DB to cause error + db.Close() + + _, err = client.GetStaleVectors(context.Background()) + require.Error(t, err) +} + +func TestClient_DeleteVectorsByDocIDs_DBError(t *testing.T) { + db, dbCleanup := testDB(t) + defer dbCleanup() + + embedSvc, embedCleanup := testEmbeddingService(t) + defer embedCleanup() + + client, err := NewClient(Config{DB: db}, embedSvc) + require.NoError(t, err) + + // Close DB to cause error + db.Close() + + err = client.DeleteVectorsByDocIDs(context.Background(), []string{"doc-1"}) + require.Error(t, err) +} + +func TestClient_DeleteByObservationID_DBError(t *testing.T) { + db, dbCleanup := testDB(t) + defer dbCleanup() + + embedSvc, embedCleanup := testEmbeddingService(t) + defer embedCleanup() + + client, err := NewClient(Config{DB: db}, embedSvc) + require.NoError(t, err) + + // Close DB to cause error + db.Close() + + err = client.DeleteByObservationID(context.Background(), 123) + require.Error(t, err) +} + +func TestClient_Query_DBError(t *testing.T) { + db, dbCleanup := testDB(t) + defer dbCleanup() + + embedSvc, embedCleanup := testEmbeddingService(t) + defer embedCleanup() + + client, err := NewClient(Config{DB: db}, embedSvc) + require.NoError(t, err) + + // Add document first + docs := []Document{{ID: "test", Content: "test content"}} + err = client.AddDocuments(context.Background(), docs) + require.NoError(t, err) + + // Close DB to cause error on query + db.Close() + + // Clear the cache so it has to hit the DB + client.InvalidateResultCache() + client.ClearCache() + + _, err = client.Query(context.Background(), "test", 10, nil) + require.Error(t, err) +} + +func TestClient_AddDocuments_DBError(t *testing.T) { + db, dbCleanup := testDB(t) + defer dbCleanup() + + embedSvc, embedCleanup := testEmbeddingService(t) + defer embedCleanup() + + client, err := NewClient(Config{DB: db}, embedSvc) + require.NoError(t, err) + + // Close DB to cause error + db.Close() + + docs := []Document{{ID: "test", Content: "test content"}} + err = client.AddDocuments(context.Background(), docs) + require.Error(t, err) +} + +func TestClient_GetHealthStats_DBError(t *testing.T) { + db, dbCleanup := testDB(t) + defer dbCleanup() + + embedSvc, embedCleanup := testEmbeddingService(t) + defer embedCleanup() + + client, err := NewClient(Config{DB: db}, embedSvc) + require.NoError(t, err) + + // Close DB to cause error + db.Close() + + _, err = client.GetHealthStats(context.Background()) + require.Error(t, err) +} + +func TestClient_QueryBatch_DBError(t *testing.T) { + db, dbCleanup := testDB(t) + defer dbCleanup() + + embedSvc, embedCleanup := testEmbeddingService(t) + defer embedCleanup() + + client, err := NewClient(Config{DB: db}, embedSvc) + require.NoError(t, err) + + // Close DB to cause error + db.Close() + + results := client.QueryBatch(context.Background(), []string{"test1", "test2"}, 10, nil) + require.Len(t, results, 2) + assert.Error(t, results[0].Error) + assert.Error(t, results[1].Error) +} + +func TestClient_DeleteDocuments_DBError(t *testing.T) { + db, dbCleanup := testDB(t) + defer dbCleanup() + + embedSvc, embedCleanup := testEmbeddingService(t) + defer embedCleanup() + + client, err := NewClient(Config{DB: db}, embedSvc) + require.NoError(t, err) + + // Close DB to cause error + db.Close() + + err = client.DeleteDocuments(context.Background(), []string{"doc-1"}) + require.Error(t, err) +} + +func TestClient_Query_WithEmptyResults(t *testing.T) { + db, dbCleanup := testDB(t) + defer dbCleanup() + + embedSvc, embedCleanup := testEmbeddingService(t) + defer embedCleanup() + + client, err := NewClient(Config{DB: db}, embedSvc) + require.NoError(t, err) + + // Query with no documents - should return empty results + results, err := client.Query(context.Background(), "nonexistent query", 10, nil) + require.NoError(t, err) + assert.Empty(t, results) +} + +func TestClient_QueryBatch_AllSucceed(t *testing.T) { + db, dbCleanup := testDB(t) + defer dbCleanup() + + embedSvc, embedCleanup := testEmbeddingService(t) + defer embedCleanup() + + client, err := NewClient(Config{DB: db}, embedSvc) + require.NoError(t, err) + + // Add some documents + docs := []Document{ + {ID: "doc-1", Content: "Test content for batch query one"}, + {ID: "doc-2", Content: "Test content for batch query two"}, + } + err = client.AddDocuments(context.Background(), docs) + require.NoError(t, err) + + // Run batch query with multiple queries + results := client.QueryBatch(context.Background(), []string{"batch one", "batch two", "batch three"}, 10, nil) + + // All queries should succeed + require.Len(t, results, 3) + for i, r := range results { + assert.NoError(t, r.Error, "Query %d should not fail", i) + } +} + +// ============================================================================= +// TESTS FOR HELPER FUNCTIONS EDGE CASES +// ============================================================================= + +func TestExtractObservationIDs_Int64Metadata(t *testing.T) { + // Test the int64 fallback path for sqlite_id metadata + results := []QueryResult{ + { + ID: "obs-1", + Similarity: 0.9, + Metadata: map[string]any{ + "sqlite_id": int64(123), // int64 instead of float64 + "doc_type": "observation", + "project": "test-project", + }, + }, + } + + ids := ExtractObservationIDs(results, "test-project") + assert.Len(t, ids, 1) + assert.Equal(t, int64(123), ids[0]) +} + +func TestExtractSummaryIDs_Int64Metadata(t *testing.T) { + // Test the int64 fallback path for sqlite_id metadata + results := []QueryResult{ + { + ID: "sum-1", + Similarity: 0.9, + Metadata: map[string]any{ + "sqlite_id": int64(456), // int64 instead of float64 + "doc_type": "session_summary", + "project": "test-project", + }, + }, + } + + ids := ExtractSummaryIDs(results, "test-project") + assert.Len(t, ids, 1) + assert.Equal(t, int64(456), ids[0]) +} + +func TestExtractPromptIDs_Int64Metadata(t *testing.T) { + // Test the int64 fallback path for sqlite_id metadata + results := []QueryResult{ + { + ID: "prompt-1", + Similarity: 0.9, + Metadata: map[string]any{ + "sqlite_id": int64(789), // int64 instead of float64 + "doc_type": "user_prompt", + "project": "test-project", + }, + }, + } + + ids := ExtractPromptIDs(results, "test-project") + assert.Len(t, ids, 1) + assert.Equal(t, int64(789), ids[0]) +} + +func TestExtractObservationIDs_GlobalScope(t *testing.T) { + // Test that global scope observations are included for any project + results := []QueryResult{ + { + ID: "obs-1", + Similarity: 0.9, + Metadata: map[string]any{ + "sqlite_id": float64(123), + "doc_type": "observation", + "project": "other-project", + "scope": "global", // Global scope should be included + }, + }, + } + + ids := ExtractObservationIDs(results, "test-project") + assert.Len(t, ids, 1) + assert.Equal(t, int64(123), ids[0]) +} diff --git a/internal/worker/sdk/processor_test.go b/internal/worker/sdk/processor_test.go index 04532c4..6df9a78 100644 --- a/internal/worker/sdk/processor_test.go +++ b/internal/worker/sdk/processor_test.go @@ -1,9 +1,12 @@ package sdk import ( + "context" "os" "path/filepath" + "sync" "testing" + "time" "github.com/lukaszraczylo/claude-mnemonic/pkg/models" "github.com/stretchr/testify/assert" @@ -1178,3 +1181,663 @@ func TestSafeResolvePath(t *testing.T) { }) } } + +// ============================================================================= +// TESTS FOR CircuitBreaker +// ============================================================================= + +func TestNewCircuitBreaker(t *testing.T) { + cb := NewCircuitBreaker(5, 60) + + assert.NotNil(t, cb) + assert.Equal(t, int64(5), cb.threshold) + assert.Equal(t, int64(60), cb.resetTimeout) + assert.Equal(t, "closed", cb.State()) +} + +func TestCircuitBreaker_Allow_Closed(t *testing.T) { + cb := NewCircuitBreaker(5, 60) + + // Closed state should allow requests + assert.True(t, cb.Allow()) + assert.True(t, cb.Allow()) +} + +func TestCircuitBreaker_Allow_Open(t *testing.T) { + cb := NewCircuitBreaker(2, 60) // Low threshold for testing + + // Record enough failures to open the circuit + cb.RecordFailure() + cb.RecordFailure() + + // Open state should block requests + assert.False(t, cb.Allow()) + assert.Equal(t, "open", cb.State()) +} + +func TestCircuitBreaker_RecordSuccess(t *testing.T) { + cb := NewCircuitBreaker(2, 60) + + // Record a failure + cb.RecordFailure() + assert.Equal(t, int64(1), cb.Metrics().Failures) + + // Record success resets failures + cb.RecordSuccess() + assert.Equal(t, int64(0), cb.Metrics().Failures) + assert.Equal(t, "closed", cb.State()) +} + +func TestCircuitBreaker_RecordFailure_OpensCircuit(t *testing.T) { + cb := NewCircuitBreaker(3, 60) + + // Record failures below threshold + cb.RecordFailure() + assert.Equal(t, "closed", cb.State()) + + cb.RecordFailure() + assert.Equal(t, "closed", cb.State()) + + // Third failure should open circuit + cb.RecordFailure() + assert.Equal(t, "open", cb.State()) +} + +func TestCircuitBreaker_State(t *testing.T) { + cb := NewCircuitBreaker(1, 60) + + // Initially closed + assert.Equal(t, "closed", cb.State()) + + // After failure, open + cb.RecordFailure() + assert.Equal(t, "open", cb.State()) + + // After success, closed + cb.RecordSuccess() + assert.Equal(t, "closed", cb.State()) +} + +func TestCircuitBreaker_Metrics(t *testing.T) { + cb := NewCircuitBreaker(5, 120) + + metrics := cb.Metrics() + assert.Equal(t, "closed", metrics.State) + assert.Equal(t, int64(0), metrics.Failures) + assert.Equal(t, int64(5), metrics.Threshold) + assert.Equal(t, int64(120), metrics.ResetTimeoutSecs) + assert.Equal(t, int64(0), metrics.LastFailureUnix) + + // After failure + cb.RecordFailure() + metrics = cb.Metrics() + assert.Equal(t, int64(1), metrics.Failures) + assert.Greater(t, metrics.LastFailureUnix, int64(0)) +} + +func TestCircuitBreaker_Metrics_OpenWithReset(t *testing.T) { + cb := NewCircuitBreaker(1, 60) + + cb.RecordFailure() + assert.Equal(t, "open", cb.State()) + + metrics := cb.Metrics() + assert.Equal(t, "open", metrics.State) + assert.Greater(t, metrics.SecondsUntilReset, int64(0)) + assert.LessOrEqual(t, metrics.SecondsUntilReset, int64(60)) +} + +// ============================================================================= +// TESTS FOR RequestDeduplicator +// ============================================================================= + +func TestNewRequestDeduplicator(t *testing.T) { + d := NewRequestDeduplicator(300, 1000) + + assert.NotNil(t, d) + assert.NotNil(t, d.seen) + assert.Equal(t, int64(300), d.ttlSecs) + assert.Equal(t, 1000, d.maxSize) +} + +func TestRequestDeduplicator_IsDuplicate_NotSeen(t *testing.T) { + d := NewRequestDeduplicator(300, 1000) + + // New hash is not a duplicate + assert.False(t, d.IsDuplicate("newhash")) +} + +func TestRequestDeduplicator_IsDuplicate_AfterRecord(t *testing.T) { + d := NewRequestDeduplicator(300, 1000) + + hash := "testhash" + + // Record the hash + d.Record(hash) + + // Now it should be a duplicate + assert.True(t, d.IsDuplicate(hash)) +} + +func TestRequestDeduplicator_Record(t *testing.T) { + d := NewRequestDeduplicator(300, 1000) + + hash := "recordtest" + d.Record(hash) + + // Check it was recorded + d.mu.RLock() + _, exists := d.seen[hash] + d.mu.RUnlock() + + assert.True(t, exists) +} + +func TestRequestDeduplicator_Record_Eviction(t *testing.T) { + // Small maxSize for testing eviction + d := NewRequestDeduplicator(0, 2) // TTL of 0 means everything is "old" + + // Record until capacity + d.Record("hash1") + d.Record("hash2") + + // Recording a third should trigger eviction (since TTL is 0) + d.Record("hash3") + + // Should have cleaned up old entries + d.mu.RLock() + size := len(d.seen) + d.mu.RUnlock() + + // Size should be limited (eviction occurred) + assert.LessOrEqual(t, size, 3) +} + +func TestHashRequest(t *testing.T) { + tests := []struct { + name string + toolName string + input string + output string + compareWith []string + wantLen int + wantSame bool + }{ + { + name: "basic hash", + toolName: "Read", + input: "file.txt", + output: "content", + wantLen: 16, + }, + { + name: "consistent hashing", + toolName: "Edit", + input: "same input", + output: "same output", + wantLen: 16, + }, + { + name: "long output truncation", + toolName: "Bash", + input: "command", + output: string(make([]byte, 5000)), // Very long output + wantLen: 16, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + hash := hashRequest(tt.toolName, tt.input, tt.output) + assert.Len(t, hash, tt.wantLen) + + // Same inputs should produce same hash + hash2 := hashRequest(tt.toolName, tt.input, tt.output) + assert.Equal(t, hash, hash2) + }) + } +} + +func TestHashRequest_DifferentInputs(t *testing.T) { + // Different inputs should produce different hashes + hash1 := hashRequest("Read", "file1.txt", "content1") + hash2 := hashRequest("Read", "file2.txt", "content2") + + assert.NotEqual(t, hash1, hash2) +} + +func TestHashRequest_OutputTruncation(t *testing.T) { + // Hash should be the same for outputs that differ only after 1000 chars + longOutput1 := string(make([]byte, 1500)) + longOutput2 := longOutput1[:1000] + "different suffix here" + + hash1 := hashRequest("Read", "input", longOutput1) + hash2 := hashRequest("Read", "input", longOutput2) + + // Since we only hash first 1000 chars, these should be the same + assert.Equal(t, hash1, hash2) +} + +// ============================================================================= +// TESTS FOR Processor methods +// ============================================================================= + +func TestProcessor_CircuitBreakerState(t *testing.T) { + p := &Processor{ + circuitBreaker: NewCircuitBreaker(2, 60), + } + + // Initially closed + assert.Equal(t, "closed", p.CircuitBreakerState()) + + // After enough failures, open + p.circuitBreaker.RecordFailure() + p.circuitBreaker.RecordFailure() + assert.Equal(t, "open", p.CircuitBreakerState()) + + // After success, closed + p.circuitBreaker.RecordSuccess() + assert.Equal(t, "closed", p.CircuitBreakerState()) +} + +func TestProcessor_CircuitBreakerMetrics(t *testing.T) { + p := &Processor{ + circuitBreaker: NewCircuitBreaker(5, 120), + } + + metrics := p.CircuitBreakerMetrics() + assert.Equal(t, "closed", metrics.State) + assert.Equal(t, int64(0), metrics.Failures) + assert.Equal(t, int64(5), metrics.Threshold) + assert.Equal(t, int64(120), metrics.ResetTimeoutSecs) + + // Record a failure and check metrics update + p.circuitBreaker.RecordFailure() + metrics = p.CircuitBreakerMetrics() + assert.Equal(t, int64(1), metrics.Failures) + assert.Greater(t, metrics.LastFailureUnix, int64(0)) +} + +// ============================================================================= +// TESTS FOR Vector Sync Workers +// ============================================================================= + +func TestProcessor_StartAndStopVectorSyncWorkers(t *testing.T) { + var syncedObservations []*models.Observation + var mu sync.Mutex + + p := &Processor{ + vectorSyncChan: make(chan *models.Observation, MaxVectorSyncWorkers*2), + vectorSyncDone: make(chan struct{}), + syncObservationFunc: func(obs *models.Observation) { + mu.Lock() + syncedObservations = append(syncedObservations, obs) + mu.Unlock() + }, + } + + // Start workers + p.StartVectorSyncWorkers() + + // Send some observations + obs1 := &models.Observation{SDKSessionID: "test1"} + obs2 := &models.Observation{SDKSessionID: "test2"} + p.vectorSyncChan <- obs1 + p.vectorSyncChan <- obs2 + + // Give workers time to process + time.Sleep(50 * time.Millisecond) + + // Stop workers + p.StopVectorSyncWorkers() + + // Verify observations were synced + mu.Lock() + assert.Len(t, syncedObservations, 2) + mu.Unlock() +} + +func TestProcessor_VectorSyncWorker_DrainOnShutdown(t *testing.T) { + var syncedCount int + var mu sync.Mutex + + p := &Processor{ + vectorSyncChan: make(chan *models.Observation, 10), + vectorSyncDone: make(chan struct{}), + syncObservationFunc: func(obs *models.Observation) { + mu.Lock() + syncedCount++ + mu.Unlock() + }, + } + + // Queue observations before starting workers + for i := 0; i < 5; i++ { + p.vectorSyncChan <- &models.Observation{SDKSessionID: "pre-queued"} + } + + // Start workers + p.StartVectorSyncWorkers() + + // Stop immediately - workers should drain the queue + p.StopVectorSyncWorkers() + + // All pre-queued items should have been processed + mu.Lock() + assert.Equal(t, 5, syncedCount) + mu.Unlock() +} + +func TestProcessor_VectorSyncWorker_NilSyncFunc(t *testing.T) { + p := &Processor{ + vectorSyncChan: make(chan *models.Observation, 10), + vectorSyncDone: make(chan struct{}), + syncObservationFunc: nil, // No sync function set + } + + // Start workers + p.StartVectorSyncWorkers() + + // Send observation - should not panic even with nil sync func + p.vectorSyncChan <- &models.Observation{SDKSessionID: "test"} + + // Give it time to process + time.Sleep(50 * time.Millisecond) + + // Stop workers - should not panic + p.StopVectorSyncWorkers() +} + +// ============================================================================= +// TESTS FOR CircuitBreaker Additional Behaviors +// ============================================================================= + +func TestCircuitBreaker_Allow_OpenBlocksRequests(t *testing.T) { + cb := NewCircuitBreaker(1, 60) + + // Open the circuit + cb.RecordFailure() + assert.Equal(t, "open", cb.State()) + + // All requests should be blocked + assert.False(t, cb.Allow()) + assert.False(t, cb.Allow()) + assert.False(t, cb.Allow()) +} + +func TestCircuitBreaker_MultipleFailures(t *testing.T) { + cb := NewCircuitBreaker(3, 60) // Higher threshold + + // Record failures below threshold + cb.RecordFailure() + assert.Equal(t, "closed", cb.State()) + assert.Equal(t, int64(1), cb.Metrics().Failures) + + cb.RecordFailure() + assert.Equal(t, "closed", cb.State()) + assert.Equal(t, int64(2), cb.Metrics().Failures) + + // Third failure opens circuit + cb.RecordFailure() + assert.Equal(t, "open", cb.State()) + assert.Equal(t, int64(3), cb.Metrics().Failures) +} + +func TestCircuitBreaker_SuccessResetsFailures(t *testing.T) { + cb := NewCircuitBreaker(5, 60) + + // Record some failures + cb.RecordFailure() + cb.RecordFailure() + assert.Equal(t, int64(2), cb.Metrics().Failures) + + // Success resets failures + cb.RecordSuccess() + assert.Equal(t, int64(0), cb.Metrics().Failures) + assert.Equal(t, "closed", cb.State()) +} + +func TestCircuitBreaker_Metrics_Comprehensive(t *testing.T) { + cb := NewCircuitBreaker(5, 120) + + // Initial state + metrics := cb.Metrics() + assert.Equal(t, "closed", metrics.State) + assert.Equal(t, int64(0), metrics.Failures) + assert.Equal(t, int64(5), metrics.Threshold) + assert.Equal(t, int64(120), metrics.ResetTimeoutSecs) + assert.Equal(t, int64(0), metrics.LastFailureUnix) + assert.Equal(t, int64(0), metrics.SecondsUntilReset) + + // After failures that open circuit + for i := 0; i < 5; i++ { + cb.RecordFailure() + } + metrics = cb.Metrics() + assert.Equal(t, "open", metrics.State) + assert.Equal(t, int64(5), metrics.Failures) + assert.Greater(t, metrics.LastFailureUnix, int64(0)) + assert.Greater(t, metrics.SecondsUntilReset, int64(0)) +} + +// ============================================================================= +// TESTS FOR MaxVectorSyncWorkers constant +// ============================================================================= + +func TestMaxVectorSyncWorkers(t *testing.T) { + assert.Equal(t, 8, MaxVectorSyncWorkers) +} + +// ============================================================================= +// ADDITIONAL EDGE CASE TESTS +// ============================================================================= + +func TestRequestDeduplicator_IsDuplicate_ExpiredEntry(t *testing.T) { + if testing.Short() { + t.Skip("Skipping time-dependent test in short mode") + } + // Use a 1-second TTL with enough margin + d := NewRequestDeduplicator(1, 100) + + hash := "expiretest" + d.Record(hash) + + // Initially duplicate + assert.True(t, d.IsDuplicate(hash)) + + // Wait for TTL to expire (2.5 seconds to ensure crossing second boundaries) + time.Sleep(2500 * time.Millisecond) + + // Should no longer be considered duplicate + assert.False(t, d.IsDuplicate(hash)) +} + +// ============================================================================= +// TESTS FOR ProcessObservation Early Returns +// ============================================================================= + +func TestProcessObservation_SkipTool(t *testing.T) { + p := &Processor{ + circuitBreaker: NewCircuitBreaker(5, 60), + deduplicator: NewRequestDeduplicator(300, 1000), + } + + ctx := context.Background() + + // TodoWrite should be skipped + err := p.ProcessObservation(ctx, "session-1", "project-1", "TodoWrite", + map[string]string{"content": "test"}, "success", 1, "/test/cwd") + assert.NoError(t, err) + + // Glob should be skipped + err = p.ProcessObservation(ctx, "session-1", "project-1", "Glob", + map[string]string{"pattern": "*.go"}, []string{"main.go", "test.go"}, 1, "/test/cwd") + assert.NoError(t, err) + + // AskUserQuestion should be skipped + err = p.ProcessObservation(ctx, "session-1", "project-1", "AskUserQuestion", + "question", "answer", 1, "/test/cwd") + assert.NoError(t, err) +} + +func TestProcessObservation_SkipTrivial(t *testing.T) { + p := &Processor{ + circuitBreaker: NewCircuitBreaker(5, 60), + deduplicator: NewRequestDeduplicator(300, 1000), + } + + ctx := context.Background() + + // Short output should be skipped + err := p.ProcessObservation(ctx, "session-1", "project-1", "Read", + map[string]string{"file_path": "/test.go"}, "short", 1, "/test/cwd") + assert.NoError(t, err) + + // "No matches found" should be skipped + err = p.ProcessObservation(ctx, "session-1", "project-1", "Grep", + map[string]string{"pattern": "test"}, "No matches found in the repository", 1, "/test/cwd") + assert.NoError(t, err) +} + +func TestProcessObservation_SkipDuplicate(t *testing.T) { + p := &Processor{ + circuitBreaker: NewCircuitBreaker(5, 60), + deduplicator: NewRequestDeduplicator(300, 1000), + sem: make(chan struct{}, 4), + claudePath: "/nonexistent/path", // Will fail at CLI call stage + } + + ctx := context.Background() + + // Valid input that would be processed + input := map[string]string{"file_path": "/project/main.go"} + output := "package main\n\nimport \"fmt\"\n\nfunc main() {\n\tfmt.Println(\"Hello World\")\n}" + + // First call should try to process (will fail because claudePath doesn't exist) + err := p.ProcessObservation(ctx, "session-1", "project-1", "Read", input, output, 1, "/test/cwd") + // Expect error because claudePath doesn't exist + assert.Error(t, err) + + // Second call with same input should be skipped as duplicate + err = p.ProcessObservation(ctx, "session-1", "project-1", "Read", input, output, 1, "/test/cwd") + assert.NoError(t, err) // No error because it was skipped as duplicate +} + +func TestProcessObservation_CircuitBreakerOpen(t *testing.T) { + cb := NewCircuitBreaker(1, 60) // Threshold of 1 + cb.RecordFailure() // Open the circuit breaker + + p := &Processor{ + circuitBreaker: cb, + deduplicator: NewRequestDeduplicator(300, 1000), + } + + ctx := context.Background() + + // Valid input that would be processed + input := map[string]string{"file_path": "/project/main.go"} + output := "package main\n\nimport \"fmt\"\n\nfunc main() {\n\tfmt.Println(\"Hello World\")\n}" + + err := p.ProcessObservation(ctx, "session-1", "project-1", "Read", input, output, 1, "/test/cwd") + assert.Error(t, err) + assert.Contains(t, err.Error(), "circuit breaker open") +} + +func TestProcessObservation_ContextCancel(t *testing.T) { + p := &Processor{ + circuitBreaker: NewCircuitBreaker(5, 60), + deduplicator: NewRequestDeduplicator(300, 1000), + sem: make(chan struct{}, 1), // Small semaphore + claudePath: "/fake/claude", + } + + // Fill the semaphore + p.sem <- struct{}{} + + // Create a cancelled context + ctx, cancel := context.WithCancel(context.Background()) + cancel() // Cancel immediately + + // Valid input that would be processed + input := map[string]string{"file_path": "/project/main.go"} + output := "package main\n\nimport \"fmt\"\n\nfunc main() {\n\tfmt.Println(\"Hello World\")\n}" + + err := p.ProcessObservation(ctx, "session-1", "project-1", "Read", input, output, 1, "/test/cwd") + assert.Error(t, err) + assert.ErrorIs(t, err, context.Canceled) +} + +// ============================================================================= +// TESTS FOR ProcessSummary Early Returns +// ============================================================================= + +func TestProcessSummary_SkipEmptyRequest(t *testing.T) { + p := &Processor{ + circuitBreaker: NewCircuitBreaker(5, 60), + deduplicator: NewRequestDeduplicator(300, 1000), + } + + ctx := context.Background() + + // Empty request should be skipped (sessionDBID, sdkSessionID, project, userPrompt, lastUserMsg, lastAssistantMsg) + err := p.ProcessSummary(ctx, 1, "session-1", "project-1", "", "", "") + assert.NoError(t, err) +} + +func TestProcessSummary_CircuitBreakerOpen(t *testing.T) { + cb := NewCircuitBreaker(1, 60) + cb.RecordFailure() // Open the circuit breaker + + p := &Processor{ + circuitBreaker: cb, + deduplicator: NewRequestDeduplicator(300, 1000), + sem: make(chan struct{}, 4), + claudePath: "/nonexistent/path", + } + + ctx := context.Background() + + // Meaningful assistant message (> 200 chars, contains code discussion) + assistantMsg := `I've updated the handler.go file to fix the authentication bug. +The function validateToken() was not checking token expiry correctly. +I've added a check for the exp claim and implemented proper error handling. +The changes have been tested and the build passes successfully. +Here's the implementation details and code review.` + + // Valid request but circuit breaker is open + err := p.ProcessSummary(ctx, 1, "session-1", "project-1", + "Implement authentication", "User message", assistantMsg) + assert.Error(t, err) + assert.Contains(t, err.Error(), "claude CLI failed") +} + +// ============================================================================= +// TESTS FOR callClaudeCLI Error Paths +// ============================================================================= + +func TestCallClaudeCLI_PromptTooLarge(t *testing.T) { + p := &Processor{ + claudePath: "/fake/claude", + } + + ctx := context.Background() + + // Create a prompt that exceeds MaxPromptSize + largePrompt := string(make([]byte, MaxPromptSize+1)) + + _, err := p.callClaudeCLI(ctx, largePrompt) + assert.Error(t, err) + assert.Contains(t, err.Error(), "prompt exceeds maximum size") +} + +func TestCallClaudeCLI_BinaryNotFound(t *testing.T) { + p := &Processor{ + claudePath: "/nonexistent/path/to/claude", + } + + ctx := context.Background() + + _, err := p.callClaudeCLI(ctx, "test prompt") + assert.Error(t, err) + assert.Contains(t, err.Error(), "claude CLI failed") +} diff --git a/internal/worker/session/manager_test.go b/internal/worker/session/manager_test.go index 1a181e4..52d419e 100644 --- a/internal/worker/session/manager_test.go +++ b/internal/worker/session/manager_test.go @@ -693,3 +693,734 @@ func TestToolInputResponse(t *testing.T) { }) } } + +// ============================================================================= +// TESTS FOR NewManager AND CLEANUP +// ============================================================================= + +// TestNewManager tests the NewManager function. +func TestNewManager(t *testing.T) { + t.Parallel() + + // Test with nil session store (valid for testing) + manager := NewManager(nil) + + assert.NotNil(t, manager) + assert.NotNil(t, manager.sessions) + assert.NotNil(t, manager.ProcessNotify) + assert.NotNil(t, manager.ctx) + assert.NotNil(t, manager.cancel) + assert.Equal(t, 0, manager.GetActiveSessionCount()) + + // Clean up - cancel context to stop cleanup goroutine + manager.cancel() +} + +// TestNewManager_CleanupGoroutineStops tests that cleanup goroutine stops on cancel. +func TestNewManager_CleanupGoroutineStops(t *testing.T) { + t.Parallel() + + manager := NewManager(nil) + + // Give goroutine time to start + time.Sleep(10 * time.Millisecond) + + // Cancel should stop the cleanup goroutine + manager.cancel() + + // Context should be done + select { + case <-manager.ctx.Done(): + // Expected + case <-time.After(100 * time.Millisecond): + t.Error("Context should be done after cancel") + } +} + +// TestCleanupStaleSessions_NoSessions tests cleanup with no sessions. +func TestCleanupStaleSessions_NoSessions(t *testing.T) { + t.Parallel() + + manager := &Manager{ + sessions: make(map[int64]*ActiveSession), + ProcessNotify: make(chan struct{}, 1), + } + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + manager.ctx = ctx + manager.cancel = cancel + + // Should not panic with empty sessions + manager.cleanupStaleSessions() + assert.Equal(t, 0, manager.GetActiveSessionCount()) +} + +// TestCleanupStaleSessions_FreshSession tests that fresh sessions are not cleaned. +func TestCleanupStaleSessions_FreshSession(t *testing.T) { + t.Parallel() + + manager := &Manager{ + sessions: make(map[int64]*ActiveSession), + ProcessNotify: make(chan struct{}, 1), + } + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + manager.ctx = ctx + manager.cancel = cancel + + // Add a fresh session + sessionCtx, sessionCancel := context.WithCancel(context.Background()) + manager.sessions[1] = &ActiveSession{ + SessionDBID: 1, + StartTime: time.Now(), // Fresh + pendingMessages: []PendingMessage{}, + ctx: sessionCtx, + cancel: sessionCancel, + } + + manager.cleanupStaleSessions() + + // Session should still exist (not stale) + assert.Equal(t, 1, manager.GetActiveSessionCount()) + sessionCancel() +} + +// TestCleanupStaleSessions_StaleSession tests that stale sessions are cleaned. +func TestCleanupStaleSessions_StaleSession(t *testing.T) { + t.Parallel() + + manager := &Manager{ + sessions: make(map[int64]*ActiveSession), + ProcessNotify: make(chan struct{}, 1), + } + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + manager.ctx = ctx + manager.cancel = cancel + + // Add a stale session + sessionCtx, sessionCancel := context.WithCancel(context.Background()) + manager.sessions[1] = &ActiveSession{ + SessionDBID: 1, + StartTime: time.Now().Add(-SessionTimeout - time.Minute), // Stale + pendingMessages: []PendingMessage{}, + ctx: sessionCtx, + cancel: sessionCancel, + } + + manager.cleanupStaleSessions() + + // Session should be deleted + assert.Equal(t, 0, manager.GetActiveSessionCount()) +} + +// TestCleanupStaleSessions_StaleWithPending tests stale sessions with pending messages are not cleaned. +func TestCleanupStaleSessions_StaleWithPending(t *testing.T) { + t.Parallel() + + manager := &Manager{ + sessions: make(map[int64]*ActiveSession), + ProcessNotify: make(chan struct{}, 1), + } + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + manager.ctx = ctx + manager.cancel = cancel + + // Add a stale session with pending messages + sessionCtx, sessionCancel := context.WithCancel(context.Background()) + defer sessionCancel() + manager.sessions[1] = &ActiveSession{ + SessionDBID: 1, + StartTime: time.Now().Add(-SessionTimeout - time.Minute), // Stale + pendingMessages: []PendingMessage{{Type: MessageTypeObservation}}, + ctx: sessionCtx, + cancel: sessionCancel, + } + + manager.cleanupStaleSessions() + + // Session should NOT be deleted (has pending messages) + assert.Equal(t, 1, manager.GetActiveSessionCount()) +} + +// TestCleanupStaleSessions_StaleWithActiveGenerator tests stale sessions with active generator are not cleaned. +func TestCleanupStaleSessions_StaleWithActiveGenerator(t *testing.T) { + t.Parallel() + + manager := &Manager{ + sessions: make(map[int64]*ActiveSession), + ProcessNotify: make(chan struct{}, 1), + } + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + manager.ctx = ctx + manager.cancel = cancel + + // Add a stale session with active generator + sessionCtx, sessionCancel := context.WithCancel(context.Background()) + defer sessionCancel() + session := &ActiveSession{ + SessionDBID: 1, + StartTime: time.Now().Add(-SessionTimeout - time.Minute), // Stale + pendingMessages: []PendingMessage{}, + ctx: sessionCtx, + cancel: sessionCancel, + } + session.generatorActive.Store(true) + manager.sessions[1] = session + + manager.cleanupStaleSessions() + + // Session should NOT be deleted (generator is active) + assert.Equal(t, 1, manager.GetActiveSessionCount()) +} + +// TestCleanupStaleSessions_MixedSessions tests cleanup with mixed fresh and stale sessions. +func TestCleanupStaleSessions_MixedSessions(t *testing.T) { + t.Parallel() + + manager := &Manager{ + sessions: make(map[int64]*ActiveSession), + ProcessNotify: make(chan struct{}, 1), + } + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + manager.ctx = ctx + manager.cancel = cancel + + // Fresh session + ctx1, cancel1 := context.WithCancel(context.Background()) + defer cancel1() + manager.sessions[1] = &ActiveSession{ + SessionDBID: 1, + StartTime: time.Now(), + pendingMessages: []PendingMessage{}, + ctx: ctx1, + cancel: cancel1, + } + + // Stale session (should be deleted) + ctx2, cancel2 := context.WithCancel(context.Background()) + manager.sessions[2] = &ActiveSession{ + SessionDBID: 2, + StartTime: time.Now().Add(-SessionTimeout - time.Minute), + pendingMessages: []PendingMessage{}, + ctx: ctx2, + cancel: cancel2, + } + + // Stale session with pending (should NOT be deleted) + ctx3, cancel3 := context.WithCancel(context.Background()) + defer cancel3() + manager.sessions[3] = &ActiveSession{ + SessionDBID: 3, + StartTime: time.Now().Add(-SessionTimeout - time.Minute), + pendingMessages: []PendingMessage{{Type: MessageTypeObservation}}, + ctx: ctx3, + cancel: cancel3, + } + + manager.cleanupStaleSessions() + + // Should have 2 sessions left (1 fresh, 1 stale with pending) + assert.Equal(t, 2, manager.GetActiveSessionCount()) + + // Verify which sessions remain + manager.mu.RLock() + _, has1 := manager.sessions[1] + _, has2 := manager.sessions[2] + _, has3 := manager.sessions[3] + manager.mu.RUnlock() + + assert.True(t, has1, "Fresh session should remain") + assert.False(t, has2, "Stale session should be deleted") + assert.True(t, has3, "Stale session with pending should remain") +} + +// TestCleanupLoop_ExitsOnCancel tests that cleanup loop exits when context is cancelled. +func TestCleanupLoop_ExitsOnCancel(t *testing.T) { + t.Parallel() + + manager := &Manager{ + sessions: make(map[int64]*ActiveSession), + ProcessNotify: make(chan struct{}, 1), + } + ctx, cancel := context.WithCancel(context.Background()) + manager.ctx = ctx + manager.cancel = cancel + + // Start cleanup loop in goroutine + done := make(chan struct{}) + go func() { + manager.cleanupLoop() + close(done) + }() + + // Cancel immediately + cancel() + + // Should exit quickly + select { + case <-done: + // Success - loop exited + case <-time.After(100 * time.Millisecond): + t.Error("Cleanup loop should exit when context is cancelled") + } +} + +// ============================================================================= +// TESTS FOR InitializeSession (without DB) +// ============================================================================= + +// TestInitializeSession_AlreadyActive tests reusing an already active session. +func TestInitializeSession_AlreadyActive(t *testing.T) { + t.Parallel() + + manager := &Manager{ + sessions: make(map[int64]*ActiveSession), + ProcessNotify: make(chan struct{}, 1), + } + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + manager.ctx = ctx + manager.cancel = cancel + + // Pre-add an active session + existingSession := &ActiveSession{ + SessionDBID: 42, + ClaudeSessionID: "claude-existing", + Project: "test-project", + UserPrompt: "original prompt", + LastPromptNumber: 1, + StartTime: time.Now(), + pendingMessages: make([]PendingMessage, 0), + } + manager.sessions[42] = existingSession + + // Initialize same session - should reuse + session, err := manager.InitializeSession(context.Background(), 42, "new prompt", 5) + + assert.NoError(t, err) + assert.NotNil(t, session) + assert.Same(t, existingSession, session) + assert.Equal(t, "new prompt", session.UserPrompt) + assert.Equal(t, 5, session.LastPromptNumber) +} + +// TestInitializeSession_AlreadyActive_EmptyPrompt tests reusing session with empty prompt. +func TestInitializeSession_AlreadyActive_EmptyPrompt(t *testing.T) { + t.Parallel() + + manager := &Manager{ + sessions: make(map[int64]*ActiveSession), + ProcessNotify: make(chan struct{}, 1), + } + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + manager.ctx = ctx + manager.cancel = cancel + + // Pre-add an active session + existingSession := &ActiveSession{ + SessionDBID: 42, + UserPrompt: "original prompt", + LastPromptNumber: 1, + } + manager.sessions[42] = existingSession + + // Initialize with empty prompt - should NOT update + session, err := manager.InitializeSession(context.Background(), 42, "", 0) + + assert.NoError(t, err) + assert.NotNil(t, session) + assert.Equal(t, "original prompt", session.UserPrompt) // Unchanged + assert.Equal(t, 1, session.LastPromptNumber) // Unchanged +} + +// TestInitializeSession_NoStore tests initialization without session store. +func TestInitializeSession_NoStore(t *testing.T) { + t.Parallel() + + manager := &Manager{ + sessionStore: nil, // No store + sessions: make(map[int64]*ActiveSession), + ProcessNotify: make(chan struct{}, 1), + } + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + manager.ctx = ctx + manager.cancel = cancel + + // Should fail gracefully with nil store (panic recovery not expected) + // This tests the guard against nil sessionStore + defer func() { + if r := recover(); r != nil { + _ = r // Expected panic when calling nil store - intentionally ignored + } + }() + + _, _ = manager.InitializeSession(context.Background(), 999, "prompt", 1) +} + +// TestInitializeSession_CallbackTriggered tests that created callback is triggered. +func TestInitializeSession_CallbackTriggered(t *testing.T) { + t.Parallel() + + manager := &Manager{ + sessions: make(map[int64]*ActiveSession), + ProcessNotify: make(chan struct{}, 1), + } + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + manager.ctx = ctx + manager.cancel = cancel + + var calledWithID int64 + manager.SetOnSessionCreated(func(id int64) { + calledWithID = id + }) + + // Add session directly (simulating what would happen after DB fetch) + sessionCtx, sessionCancel := context.WithCancel(context.Background()) + defer sessionCancel() + session := &ActiveSession{ + SessionDBID: 100, + ClaudeSessionID: "test", + Project: "project", + StartTime: time.Now(), + pendingMessages: make([]PendingMessage, 0), + notify: make(chan struct{}, 1), + ctx: sessionCtx, + cancel: sessionCancel, + } + + manager.mu.Lock() + manager.sessions[100] = session + onCreated := manager.onCreated + manager.mu.Unlock() + + // Trigger callback + if onCreated != nil { + onCreated(100) + } + + assert.Equal(t, int64(100), calledWithID) +} + +// ============================================================================= +// TESTS FOR QueueObservation AND QueueSummarize (without DB) +// ============================================================================= + +// TestQueueObservation_ToExistingSession tests queuing to an existing session. +func TestQueueObservation_ToExistingSession(t *testing.T) { + t.Parallel() + + manager := &Manager{ + sessions: make(map[int64]*ActiveSession), + ProcessNotify: make(chan struct{}, 1), + } + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + manager.ctx = ctx + manager.cancel = cancel + + // Pre-add session + session := &ActiveSession{ + SessionDBID: 1, + pendingMessages: make([]PendingMessage, 0), + notify: make(chan struct{}, 1), + } + manager.sessions[1] = session + + // Queue observation + err := manager.QueueObservation(context.Background(), 1, ObservationData{ + ToolName: "Read", + ToolInput: map[string]string{"path": "/test"}, + ToolResponse: "content", + PromptNumber: 1, + CWD: "/project", + }) + + assert.NoError(t, err) + assert.Equal(t, 1, manager.GetTotalQueueDepth()) + + // Verify message + messages := manager.DrainMessages(1) + assert.Len(t, messages, 1) + assert.Equal(t, MessageTypeObservation, messages[0].Type) + assert.Equal(t, "Read", messages[0].Observation.ToolName) + assert.Equal(t, "/project", messages[0].Observation.CWD) +} + +// TestQueueObservation_NotifiesSession tests that notification is sent to session. +func TestQueueObservation_NotifiesSession(t *testing.T) { + t.Parallel() + + manager := &Manager{ + sessions: make(map[int64]*ActiveSession), + ProcessNotify: make(chan struct{}, 1), + } + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + manager.ctx = ctx + manager.cancel = cancel + + // Pre-add session with notify channel + session := &ActiveSession{ + SessionDBID: 1, + pendingMessages: make([]PendingMessage, 0), + notify: make(chan struct{}, 1), + } + manager.sessions[1] = session + + // Queue observation + err := manager.QueueObservation(context.Background(), 1, ObservationData{ToolName: "Test"}) + assert.NoError(t, err) + + // Should receive notification on session channel + select { + case <-session.notify: + // Success + default: + t.Error("Session should receive notification") + } + + // Should receive notification on process channel + select { + case <-manager.ProcessNotify: + // Success + default: + t.Error("Manager ProcessNotify should receive notification") + } +} + +// TestQueueSummarize_ToExistingSession tests queuing summarize to an existing session. +func TestQueueSummarize_ToExistingSession(t *testing.T) { + t.Parallel() + + manager := &Manager{ + sessions: make(map[int64]*ActiveSession), + ProcessNotify: make(chan struct{}, 1), + } + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + manager.ctx = ctx + manager.cancel = cancel + + // Pre-add session + session := &ActiveSession{ + SessionDBID: 1, + pendingMessages: make([]PendingMessage, 0), + notify: make(chan struct{}, 1), + } + manager.sessions[1] = session + + // Queue summarize + err := manager.QueueSummarize(context.Background(), 1, "User asked question", "Assistant answered") + assert.NoError(t, err) + assert.Equal(t, 1, manager.GetTotalQueueDepth()) + + // Verify message + messages := manager.DrainMessages(1) + assert.Len(t, messages, 1) + assert.Equal(t, MessageTypeSummarize, messages[0].Type) + assert.Equal(t, "User asked question", messages[0].Summarize.LastUserMessage) + assert.Equal(t, "Assistant answered", messages[0].Summarize.LastAssistantMessage) +} + +// TestQueueSummarize_NotifiesSession tests that notification is sent to session. +func TestQueueSummarize_NotifiesSession(t *testing.T) { + t.Parallel() + + manager := &Manager{ + sessions: make(map[int64]*ActiveSession), + ProcessNotify: make(chan struct{}, 1), + } + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + manager.ctx = ctx + manager.cancel = cancel + + // Pre-add session with notify channel + session := &ActiveSession{ + SessionDBID: 1, + pendingMessages: make([]PendingMessage, 0), + notify: make(chan struct{}, 1), + } + manager.sessions[1] = session + + // Queue summarize + err := manager.QueueSummarize(context.Background(), 1, "user", "assistant") + assert.NoError(t, err) + + // Should receive notification on session channel + select { + case <-session.notify: + // Success + default: + t.Error("Session should receive notification") + } + + // Should receive notification on process channel + select { + case <-manager.ProcessNotify: + // Success + default: + t.Error("Manager ProcessNotify should receive notification") + } +} + +// TestQueueOperations_MultipleMessages tests queuing multiple messages. +func TestQueueOperations_MultipleMessages(t *testing.T) { + t.Parallel() + + manager := &Manager{ + sessions: make(map[int64]*ActiveSession), + ProcessNotify: make(chan struct{}, 1), + } + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + manager.ctx = ctx + manager.cancel = cancel + + // Pre-add session + session := &ActiveSession{ + SessionDBID: 1, + pendingMessages: make([]PendingMessage, 0), + notify: make(chan struct{}, 1), + } + manager.sessions[1] = session + + // Queue multiple messages + for i := 0; i < 10; i++ { + if i%2 == 0 { + err := manager.QueueObservation(context.Background(), 1, ObservationData{ + ToolName: "Tool" + string(rune('A'+i)), + }) + assert.NoError(t, err) + } else { + err := manager.QueueSummarize(context.Background(), 1, "user", "assistant") + assert.NoError(t, err) + } + } + + assert.Equal(t, 10, manager.GetTotalQueueDepth()) + + // Drain and verify + messages := manager.DrainMessages(1) + assert.Len(t, messages, 10) +} + +// TestQueueOperations_NonBlockingNotification tests non-blocking notification behavior. +func TestQueueOperations_NonBlockingNotification(t *testing.T) { + t.Parallel() + + manager := &Manager{ + sessions: make(map[int64]*ActiveSession), + ProcessNotify: make(chan struct{}, 1), + } + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + manager.ctx = ctx + manager.cancel = cancel + + // Pre-add session with full notify channel + session := &ActiveSession{ + SessionDBID: 1, + pendingMessages: make([]PendingMessage, 0), + notify: make(chan struct{}, 1), + } + // Fill the notify channel + session.notify <- struct{}{} + manager.sessions[1] = session + + // Fill ProcessNotify channel + manager.ProcessNotify <- struct{}{} + + // Queue should NOT block even with full channels + done := make(chan bool) + go func() { + err := manager.QueueObservation(context.Background(), 1, ObservationData{ToolName: "Test"}) + assert.NoError(t, err) + done <- true + }() + + select { + case <-done: + // Success - didn't block + case <-time.After(100 * time.Millisecond): + t.Error("Queue operation should not block even with full notification channels") + } +} + +// TestConcurrentQueueAndCleanup tests concurrent queue operations and cleanup. +func TestConcurrentQueueAndCleanup(t *testing.T) { + t.Parallel() + + manager := &Manager{ + sessions: make(map[int64]*ActiveSession), + ProcessNotify: make(chan struct{}, 1), + } + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + manager.ctx = ctx + manager.cancel = cancel + + // Pre-add multiple sessions + for i := int64(1); i <= 5; i++ { + sessionCtx, sessionCancel := context.WithCancel(context.Background()) + manager.sessions[i] = &ActiveSession{ + SessionDBID: i, + StartTime: time.Now(), + pendingMessages: make([]PendingMessage, 0), + notify: make(chan struct{}, 1), + ctx: sessionCtx, + cancel: sessionCancel, + } + } + + var wg sync.WaitGroup + + // Concurrent queue operations + for i := 0; i < 50; i++ { + wg.Add(1) + go func(idx int) { + defer wg.Done() + sessionID := int64((idx % 5) + 1) + if idx%2 == 0 { + _ = manager.QueueObservation(context.Background(), sessionID, ObservationData{ToolName: "Test"}) + } else { + _ = manager.QueueSummarize(context.Background(), sessionID, "user", "assistant") + } + }(i) + } + + // Concurrent cleanup + for i := 0; i < 10; i++ { + wg.Add(1) + go func() { + defer wg.Done() + manager.cleanupStaleSessions() + }() + } + + // Concurrent reads + for i := 0; i < 20; i++ { + wg.Add(1) + go func() { + defer wg.Done() + _ = manager.GetActiveSessionCount() + _ = manager.GetTotalQueueDepth() + _ = manager.IsAnySessionProcessing() + _ = manager.GetAllSessions() + }() + } + + wg.Wait() + + // Should have all sessions (none are stale) + assert.Equal(t, 5, manager.GetActiveSessionCount()) + // Should have 50 messages total + assert.Equal(t, 50, manager.GetTotalQueueDepth()) +} diff --git a/pkg/similarity/clustering_test.go b/pkg/similarity/clustering_test.go index 6aa7f06..e781599 100644 --- a/pkg/similarity/clustering_test.go +++ b/pkg/similarity/clustering_test.go @@ -290,3 +290,334 @@ func TestClusterObservations_PreservesOrder(t *testing.T) { require.NotEmpty(t, clustered) assert.Equal(t, int64(1), clustered[0].ID, "First observation should be kept as first result") } + +// ============================================================================= +// TESTS FOR OPTIMIZED CLUSTERING (triggered when len(observations) > 50) +// ============================================================================= + +func TestClusterObservationsOptimized_LargeSet(t *testing.T) { + t.Parallel() + + // Create 60 observations to trigger optimized path (threshold is 50) + observations := make([]*models.Observation, 60) + + // Create 30 pairs of similar observations + topics := []string{ + "authentication", "authorization", "database", "caching", "logging", + "monitoring", "testing", "deployment", "scaling", "security", + "networking", "storage", "messaging", "scheduling", "configuration", + "validation", "serialization", "encryption", "compression", "indexing", + "backup", "recovery", "migration", "versioning", "documentation", + "profiling", "debugging", "tracing", "alerting", "reporting", + } + + for i := 0; i < 30; i++ { + // First observation of pair + observations[i*2] = &models.Observation{ + ID: int64(i*2 + 1), + Title: sql.NullString{String: topics[i] + " implementation", Valid: true}, + Narrative: sql.NullString{String: "Detailed " + topics[i] + " system design", Valid: true}, + } + // Second observation of pair (similar to first) + observations[i*2+1] = &models.Observation{ + ID: int64(i*2 + 2), + Title: sql.NullString{String: topics[i] + " update", Valid: true}, + Narrative: sql.NullString{String: "Updated " + topics[i] + " logic", Valid: true}, + } + } + + clustered := ClusterObservations(observations, 0.4) + + // With similar pairs, we should get roughly 30 clusters (one per topic) + t.Logf("Clustered %d observations down to %d", len(observations), len(clustered)) + assert.Less(t, len(clustered), 60, "Similar observations should be clustered together") + assert.GreaterOrEqual(t, len(clustered), 1, "Should have at least one cluster") +} + +func TestClusterObservationsOptimized_AllUnique(t *testing.T) { + t.Parallel() + + // Create 55 completely unique observations with NO shared terms + // Each observation has only its unique term (no common words like "topic" or "content") + uniqueTerms := []string{ + "aardvark", "butterfly", "caterpillar", "dragonfly", "elephant", + "flamingo", "giraffe", "hippopotamus", "iguana", "jellyfish", + "kangaroo", "leopard", "mongoose", "nightingale", "octopus", + "penguin", "quail", "rhinoceros", "salamander", "toucan", + "umbrella", "vulture", "walrus", "xylophone", "yakking", + "zebra123", "astronomy99", "biology88", "chemistry77", "dynamics66", + "economics55", "forensics44", "genetics33", "hydraulics22", "immunology11", + "jurisprudence", "kinetics", "linguistics", "metallurgy", "neurology", + "oceanography", "pharmacology", "quantumphysics", "robotics", "sociology", + "thermodynamics", "ultrasound", "virology", "wavelength", "xenobiology", + "yeastculture", "zoology123", "algebra456", "botany789", "calculus012", + } + + observations := make([]*models.Observation, 55) + for i := 0; i < 55; i++ { + // Each observation has ONLY its unique term - no shared words + observations[i] = &models.Observation{ + ID: int64(i + 1), + Title: sql.NullString{String: uniqueTerms[i], Valid: true}, + Narrative: sql.NullString{String: uniqueTerms[i], Valid: true}, + } + } + + clustered := ClusterObservations(observations, 0.4) + + // All unique content should remain unclustered + assert.Len(t, clustered, 55, "All unique observations should be kept") +} + +func TestClusterObservationsOptimized_SignaturePrefiltering(t *testing.T) { + t.Parallel() + + // Test that signature prefiltering works correctly + // Create observations where some have very different signatures + observations := make([]*models.Observation, 60) + + // First half: all identical (about "authentication") - should cluster to 1 + for i := 0; i < 30; i++ { + observations[i] = &models.Observation{ + ID: int64(i + 1), + Title: sql.NullString{String: "authentication security login", Valid: true}, + Narrative: sql.NullString{String: "JWT tokens OAuth authentication", Valid: true}, + } + } + + // Second half: each completely unique with NO shared terms + diffTerms := []string{ + "quantumphysics", "photosynthesis", "archaeologydig", "linguisticstudy", "astronomystar", + "paleontologyfossil", "oceanographywave", "entomologybug", "mycologyfungi", "herpetologysnake", + "ornithologybird", "ichthyologyfish", "seismologyquake", "volcanologylava", "meteorologyrain", + "cartographymap", "ethnographyculture", "philologyword", "numismaticscoin", "heraldryshield", + "genealogytree", "chronologytime", "typographyfont", "calligraphyink", "epigraphystone", + "papyrologytext", "codicologybook", "diplomaticseal", "sigillographywax", "sphragisticsring", + } + for i := 30; i < 60; i++ { + term := diffTerms[i-30] + // Each has ONLY its unique term - no shared words + observations[i] = &models.Observation{ + ID: int64(i + 1), + Title: sql.NullString{String: term, Valid: true}, + Narrative: sql.NullString{String: term, Valid: true}, + } + } + + clustered := ClusterObservations(observations, 0.5) + + // Should have 31 clusters: 1 for all auth topics + 30 unique topics + t.Logf("Clustered %d observations down to %d", len(observations), len(clustered)) + assert.Equal(t, 31, len(clustered), "Should have 31 clusters (1 auth + 30 unique)") +} + +// ============================================================================= +// TESTS FOR HELPER FUNCTIONS +// ============================================================================= + +func TestComputeTermSignature(t *testing.T) { + tests := []struct { + terms map[string]bool + compareTo map[string]bool + name string + expectZero bool + expectSame bool + }{ + // ===== GOOD CASES ===== + { + name: "single term", + terms: map[string]bool{"hello": true}, + expectZero: false, + }, + { + name: "multiple terms", + terms: map[string]bool{"hello": true, "world": true}, + expectZero: false, + }, + { + name: "identical terms produce same signature", + terms: map[string]bool{"alpha": true, "beta": true}, + expectSame: true, + compareTo: map[string]bool{"alpha": true, "beta": true}, + }, + + // ===== EDGE CASES ===== + { + name: "empty set", + terms: map[string]bool{}, + expectZero: true, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + sig := computeTermSignature(tt.terms) + + if tt.expectZero { + assert.Equal(t, uint64(0), sig, "Empty set should produce zero signature") + } else { + assert.NotEqual(t, uint64(0), sig, "Non-empty set should produce non-zero signature") + } + + if tt.expectSame && tt.compareTo != nil { + sig2 := computeTermSignature(tt.compareTo) + assert.Equal(t, sig, sig2, "Identical term sets should produce identical signatures") + } + }) + } +} + +func TestComputeTermSignature_DifferentSets(t *testing.T) { + t.Parallel() + + // Different term sets should usually produce different signatures + set1 := map[string]bool{"authentication": true, "security": true} + set2 := map[string]bool{"database": true, "migration": true} + + sig1 := computeTermSignature(set1) + sig2 := computeTermSignature(set2) + + // While hash collisions are possible, they should be rare + assert.NotEqual(t, sig1, sig2, "Different term sets should usually produce different signatures") +} + +func TestPopCount64(t *testing.T) { + tests := []struct { + name string + input uint64 + expected int + }{ + // ===== GOOD CASES ===== + {name: "zero", input: 0, expected: 0}, + {name: "one", input: 1, expected: 1}, + {name: "powers of two", input: 8, expected: 1}, + {name: "all ones in byte", input: 0xFF, expected: 8}, + {name: "alternating bits", input: 0xAAAAAAAAAAAAAAAA, expected: 32}, + {name: "max uint64", input: 0xFFFFFFFFFFFFFFFF, expected: 64}, + + // ===== EDGE CASES ===== + {name: "single high bit", input: 1 << 63, expected: 1}, + {name: "sparse bits", input: 0x8000000000000001, expected: 2}, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + result := popCount64(tt.input) + assert.Equal(t, tt.expected, result) + }) + } +} + +func TestIsSimilarToAny_EmptyTerms(t *testing.T) { + t.Parallel() + + // Observation with no extractable terms + emptyObs := &models.Observation{ + ID: 1, + Title: sql.NullString{String: "", Valid: false}, + Narrative: sql.NullString{String: "", Valid: false}, + } + + existing := []*models.Observation{ + { + ID: 2, + Title: sql.NullString{String: "Some content here", Valid: true}, + Narrative: sql.NullString{String: "More content", Valid: true}, + }, + } + + // Should return false when new observation has no terms + assert.False(t, IsSimilarToAny(emptyObs, existing, 0.3)) +} + +func TestExtractObservationTerms_FilesModified(t *testing.T) { + t.Parallel() + + obs := &models.Observation{ + ID: 1, + Title: sql.NullString{String: "Code changes", Valid: true}, + FilesModified: models.JSONStringArray{"/src/handler.go", "/pkg/models/user.go"}, + } + + terms := ExtractObservationTerms(obs) + + // Should contain filenames from FilesModified + assert.Contains(t, terms, "handler.go") + assert.Contains(t, terms, "user.go") +} + +func TestAddTerms_ShortWords(t *testing.T) { + t.Parallel() + + terms := make(map[string]bool) + + addTerms(terms, "I am a go developer") + + // Short words (< 3 chars) should be excluded + assert.NotContains(t, terms, "i") + assert.NotContains(t, terms, "am") + assert.NotContains(t, terms, "a") + assert.NotContains(t, terms, "go") // Only 2 chars + + // "developer" should be included + assert.Contains(t, terms, "developer") +} + +func TestAddTerms_SpecialCharacters(t *testing.T) { + t.Parallel() + + terms := make(map[string]bool) + + addTerms(terms, "user_id authentication-flow JWT_token") + + // Hyphens split words, but underscores are kept as part of the word + // (underscore is included in the tokenization regex) + assert.Contains(t, terms, "user_id") + assert.Contains(t, terms, "authentication") + assert.Contains(t, terms, "flow") + assert.Contains(t, terms, "jwt_token") +} + +func TestJaccardSimilarity_SubsetSuperset(t *testing.T) { + t.Parallel() + + subset := map[string]bool{"a": true, "b": true} + superset := map[string]bool{"a": true, "b": true, "c": true, "d": true} + + // Subset similarity should be intersection/union = 2/4 = 0.5 + result := JaccardSimilarity(subset, superset) + assert.InDelta(t, 0.5, result, 0.001) +} + +func TestClusterObservations_HighThreshold(t *testing.T) { + t.Parallel() + + // With a very high threshold, almost nothing should be clustered + observations := []*models.Observation{ + {ID: 1, Title: sql.NullString{String: "authentication implementation", Valid: true}}, + {ID: 2, Title: sql.NullString{String: "authentication update", Valid: true}}, + {ID: 3, Title: sql.NullString{String: "authentication refactor", Valid: true}}, + } + + // With threshold of 0.9, even similar observations shouldn't cluster + clustered := ClusterObservations(observations, 0.9) + + assert.Len(t, clustered, 3, "High threshold should prevent clustering") +} + +func TestClusterObservations_LowThreshold(t *testing.T) { + t.Parallel() + + // With a very low threshold, more things should be clustered + observations := []*models.Observation{ + {ID: 1, Title: sql.NullString{String: "authentication implementation details", Valid: true}}, + {ID: 2, Title: sql.NullString{String: "authentication security update", Valid: true}}, + {ID: 3, Title: sql.NullString{String: "something completely different topic", Valid: true}}, + } + + // With threshold of 0.1, partial overlap should cluster + clustered := ClusterObservations(observations, 0.1) + + // First two share "authentication", should likely cluster + assert.LessOrEqual(t, len(clustered), 3) +}