Files
claude-mnemonic/internal/mcp/server_test.go
T
lukaszraczylo a81482d06a fix: address 15 additional hang vectors found during deep audit (#45)
MCP server (5 fixes):
- Move semaphore acquisition inside goroutine so main loop stays
  responsive when all slots are taken
- Add 10s write timeout to sendResponse to prevent pipe deadlock
  when Claude Code pauses reading stdout
- Send fallback JSON-RPC error when json.Marshal fails instead of
  silently swallowing the error and leaving caller waiting forever
- Silence unknown notification methods (req.ID == nil) instead of
  sending unsolicited error responses that may desync the host
- Return MCP isError content for tool failures instead of top-level
  JSON-RPC error, matching the MCP specification

Vector/embedding (3 fixes):
- Move EmbedBatchWithContext call before writeMu.Lock in AddDocuments
  so ONNX inference runs outside the write lock
- Replace singleflight.Do with DoChan + ctx select in both
  getOrComputeEmbedding and UnifiedSearch so callers can bail out
  independently when their context expires
- Add activeQueries atomic counter; skip cache warming when user
  queries are in-flight; reduce warming timeout from 5s to 2s

Hooks (4 fixes):
- Cap EnsureWorkerRunning to 15s hard deadline with context; reduce
  StartupTimeout from 30s to 10s; reduce port-in-use retries
- Fix nil dereference panic in user-prompt hook when initResult is
  nil (non-JSON worker response); use comma-ok assertions
- Use package-level hookClient/healthClient with DisableKeepAlives
  to prevent FD leaks in short-lived hook processes
- Set SysProcAttr{Setpgid: true} to detach worker from hook process
  group, preventing kill-cascade from Claude Code

Worker/DB (3 fixes):
- Replace os.Exit(0) in MCP config watcher with context cancellation
  for clean protocol shutdown
- Add 60s context.WithTimeout around ProcessObservation calls in
  processAllSessions to prevent hung CLI subprocesses from blocking
  the queue processor forever
- Set explicit PRAGMA wal_autocheckpoint=1000 and add PASSIVE WAL
  checkpoint to Optimize() to prevent checkpoint stalls

Adds 20+ regression tests across all fix areas.
2026-05-26 14:29:34 +01:00

3421 lines
86 KiB
Go
Raw Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
// Package mcp provides the MCP (Model Context Protocol) server for claude-mnemonic.
package mcp
import (
"bufio"
"bytes"
"context"
"encoding/json"
"io"
"net/http"
"net/http/httptest"
"strings"
"sync"
"testing"
"time"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"github.com/stretchr/testify/suite"
)
// timelineParams is used in tests for timeline request parsing.
type timelineParams struct {
Query string `json:"query"`
Project string `json:"project"`
ObsType string `json:"obs_type"`
Concepts string `json:"concepts"`
Files string `json:"files"`
Format string `json:"format"`
AnchorID int64 `json:"anchor_id"`
Before int `json:"before"`
After int `json:"after"`
DateStart int64 `json:"dateStart"`
DateEnd int64 `json:"dateEnd"`
}
// =============================================================================
// TEST SUITE
// =============================================================================
// ServerSuite is a test suite for MCP Server operations.
type ServerSuite struct {
suite.Suite
}
func TestServerSuite(t *testing.T) {
suite.Run(t, new(ServerSuite))
}
// TestNewServer tests server creation.
func (s *ServerSuite) TestNewServer() {
server := NewServer(nil, "http://localhost:37777", "test-project", "1.0.0")
s.NotNil(server)
s.Nil(server.client)
s.Equal("1.0.0", server.version)
s.Equal("http://localhost:37777", server.workerURL)
s.Equal("test-project", server.project)
}
// =============================================================================
// TESTS FOR Request/Response Structs
// =============================================================================
// TestRequest tests Request struct JSON marshaling.
func TestRequest(t *testing.T) {
t.Parallel()
tests := []struct {
name string
expected string
req Request
}{
// ===== GOOD CASES =====
{
name: "initialize request",
req: Request{
JSONRPC: "2.0",
ID: 1,
Method: "initialize",
},
expected: `{"jsonrpc":"2.0","id":1,"method":"initialize"}`,
},
{
name: "tools/list request",
req: Request{
JSONRPC: "2.0",
ID: "abc",
Method: "tools/list",
},
expected: `{"jsonrpc":"2.0","id":"abc","method":"tools/list"}`,
},
{
name: "tools/call with params",
req: Request{
JSONRPC: "2.0",
ID: 2,
Method: "tools/call",
Params: json.RawMessage(`{"name":"search","arguments":{}}`),
},
expected: `{"jsonrpc":"2.0","id":2,"method":"tools/call","params":{"name":"search","arguments":{}}}`,
},
// ===== EDGE CASES =====
{
name: "request with nil ID",
req: Request{
JSONRPC: "2.0",
ID: nil,
Method: "initialize",
},
expected: `{"jsonrpc":"2.0","id":null,"method":"initialize"}`,
},
{
name: "request with float ID",
req: Request{
JSONRPC: "2.0",
ID: 1.5,
Method: "test",
},
expected: `{"jsonrpc":"2.0","id":1.5,"method":"test"}`,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
t.Parallel()
data, err := json.Marshal(tt.req)
require.NoError(t, err)
assert.JSONEq(t, tt.expected, string(data))
// Test unmarshaling
var parsed Request
err = json.Unmarshal(data, &parsed)
require.NoError(t, err)
assert.Equal(t, tt.req.JSONRPC, parsed.JSONRPC)
assert.Equal(t, tt.req.Method, parsed.Method)
})
}
}
// TestResponse tests Response struct JSON marshaling.
func TestResponse(t *testing.T) {
t.Parallel()
tests := []struct {
name string
resp Response
expected string
}{
// ===== GOOD CASES =====
{
name: "success response",
resp: Response{
JSONRPC: "2.0",
ID: 1,
Result: map[string]string{"status": "ok"},
},
expected: `{"jsonrpc":"2.0","id":1,"result":{"status":"ok"}}`,
},
{
name: "error response",
resp: Response{
JSONRPC: "2.0",
ID: 2,
Error: &Error{
Code: -32600,
Message: "Invalid Request",
},
},
expected: `{"jsonrpc":"2.0","id":2,"error":{"code":-32600,"message":"Invalid Request"}}`,
},
{
name: "error with data",
resp: Response{
JSONRPC: "2.0",
ID: 3,
Error: &Error{
Code: -32602,
Message: "Invalid params",
Data: "missing field",
},
},
expected: `{"jsonrpc":"2.0","id":3,"error":{"code":-32602,"message":"Invalid params","data":"missing field"}}`,
},
// ===== EDGE CASES =====
{
name: "response with nil ID",
resp: Response{
JSONRPC: "2.0",
ID: nil,
Result: "ok",
},
expected: `{"jsonrpc":"2.0","id":null,"result":"ok"}`,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
t.Parallel()
data, err := json.Marshal(tt.resp)
require.NoError(t, err)
assert.JSONEq(t, tt.expected, string(data))
})
}
}
// TestError tests Error struct.
func TestError(t *testing.T) {
t.Parallel()
tests := []struct {
expected string
name string
err Error
}{
{
name: "parse error",
err: Error{
Code: -32700,
Message: "Parse error",
},
expected: `{"code":-32700,"message":"Parse error"}`,
},
{
name: "method not found",
err: Error{
Code: -32601,
Message: "Method not found",
},
expected: `{"code":-32601,"message":"Method not found"}`,
},
{
name: "invalid params",
err: Error{
Code: -32602,
Message: "Invalid params",
Data: "details here",
},
expected: `{"code":-32602,"message":"Invalid params","data":"details here"}`,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
t.Parallel()
data, err := json.Marshal(tt.err)
require.NoError(t, err)
assert.JSONEq(t, tt.expected, string(data))
})
}
}
// TestToolCallParams tests ToolCallParams struct.
func TestToolCallParams(t *testing.T) {
t.Parallel()
tests := []struct {
name string
input string
expected ToolCallParams
}{
{
name: "search tool call",
input: `{"name":"search","arguments":{"query":"test"}}`,
expected: ToolCallParams{
Name: "search",
Arguments: json.RawMessage(`{"query":"test"}`),
},
},
{
name: "decisions tool call",
input: `{"name":"decisions","arguments":{"query":"auth"}}`,
expected: ToolCallParams{
Name: "decisions",
Arguments: json.RawMessage(`{"query":"auth"}`),
},
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
t.Parallel()
var params ToolCallParams
err := json.Unmarshal([]byte(tt.input), &params)
require.NoError(t, err)
assert.Equal(t, tt.expected.Name, params.Name)
})
}
}
// TestTool tests Tool struct.
func TestTool(t *testing.T) {
t.Parallel()
tool := Tool{
Name: "search",
Description: "Search observations",
InputSchema: map[string]any{
"type": "object",
"properties": map[string]any{
"query": map[string]any{"type": "string"},
},
},
}
data, err := json.Marshal(tool)
require.NoError(t, err)
var parsed Tool
err = json.Unmarshal(data, &parsed)
require.NoError(t, err)
assert.Equal(t, "search", parsed.Name)
assert.Equal(t, "Search observations", parsed.Description)
}
// TestTimelineParamsStruct tests timelineParams struct.
func TestTimelineParamsStruct(t *testing.T) {
t.Parallel()
tests := []struct {
name string
input string
expected timelineParams
}{
{
name: "with anchor_id",
input: `{"anchor_id":123,"before":5,"after":5}`,
expected: timelineParams{
AnchorID: 123,
Before: 5,
After: 5,
},
},
{
name: "with query",
input: `{"query":"test query","project":"my-project"}`,
expected: timelineParams{
Query: "test query",
Project: "my-project",
},
},
{
name: "full params",
input: `{"anchor_id":100,"query":"search","before":10,"after":20,"project":"proj","obs_type":"bugfix","concepts":"security","files":"main.go","dateStart":1234567890,"dateEnd":9876543210,"format":"full"}`,
expected: timelineParams{
AnchorID: 100,
Query: "search",
Before: 10,
After: 20,
Project: "proj",
ObsType: "bugfix",
Concepts: "security",
Files: "main.go",
DateStart: 1234567890,
DateEnd: 9876543210,
Format: "full",
},
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
t.Parallel()
var params timelineParams
err := json.Unmarshal([]byte(tt.input), &params)
require.NoError(t, err)
assert.Equal(t, tt.expected.AnchorID, params.AnchorID)
assert.Equal(t, tt.expected.Query, params.Query)
assert.Equal(t, tt.expected.Project, params.Project)
})
}
}
// =============================================================================
// TESTS FOR Server Handlers
// =============================================================================
// TestHandleInitialize tests the initialize handler.
func TestHandleInitialize(t *testing.T) {
t.Parallel()
server := NewServer(nil, "", "", "1.2.3")
req := &Request{
JSONRPC: "2.0",
ID: 1,
Method: "initialize",
}
resp := server.handleInitialize(req)
assert.Equal(t, "2.0", resp.JSONRPC)
assert.Equal(t, 1, resp.ID)
assert.Nil(t, resp.Error)
assert.NotNil(t, resp.Result)
result, ok := resp.Result.(map[string]any)
require.True(t, ok)
assert.Equal(t, "2024-11-05", result["protocolVersion"])
serverInfo, ok := result["serverInfo"].(map[string]any)
require.True(t, ok)
assert.Equal(t, "claude-mnemonic", serverInfo["name"])
assert.Equal(t, "1.2.3", serverInfo["version"])
}
// TestHandleToolsList tests the tools/list handler.
func TestHandleToolsList(t *testing.T) {
t.Parallel()
server := NewServer(nil, "", "", "1.0.0")
req := &Request{
JSONRPC: "2.0",
ID: 1,
Method: "tools/list",
}
resp := server.handleToolsList(req)
assert.Equal(t, "2.0", resp.JSONRPC)
assert.Equal(t, 1, resp.ID)
assert.Nil(t, resp.Error)
result, ok := resp.Result.(map[string]any)
require.True(t, ok)
tools, ok := result["tools"].([]Tool)
require.True(t, ok)
assert.NotEmpty(t, tools)
// Verify expected tools are present
toolNames := make(map[string]bool)
for _, tool := range tools {
toolNames[tool.Name] = true
}
expectedTools := []string{
"search", "timeline", "decisions", "changes",
"how_it_works", "find_by_concept", "find_by_file",
"find_by_type", "get_recent_context", "get_context_timeline",
"get_timeline_by_query",
}
for _, name := range expectedTools {
assert.True(t, toolNames[name], "expected tool %s to be present", name)
}
}
// TestHandleRequest tests request routing.
func TestHandleRequest(t *testing.T) {
t.Parallel()
server := NewServer(nil, "", "", "1.0.0")
ctx := context.Background()
tests := []struct {
req *Request
name string
errorMessage string
errorCode int
expectError bool
}{
{
name: "initialize method",
req: &Request{
JSONRPC: "2.0",
ID: 1,
Method: "initialize",
},
expectError: false,
},
{
name: "tools/list method",
req: &Request{
JSONRPC: "2.0",
ID: 2,
Method: "tools/list",
},
expectError: false,
},
{
name: "unknown method",
req: &Request{
JSONRPC: "2.0",
ID: 3,
Method: "unknown_method",
},
expectError: true,
errorCode: -32601,
errorMessage: "Method not found",
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
t.Parallel()
resp := server.handleRequest(ctx, tt.req)
assert.Equal(t, "2.0", resp.JSONRPC)
assert.Equal(t, tt.req.ID, resp.ID)
if tt.expectError {
require.NotNil(t, resp.Error)
assert.Equal(t, tt.errorCode, resp.Error.Code)
assert.Equal(t, tt.errorMessage, resp.Error.Message)
} else {
assert.Nil(t, resp.Error)
assert.NotNil(t, resp.Result)
}
})
}
}
// TestHandleToolsCall_InvalidParams tests tools/call with invalid params.
func TestHandleToolsCall_InvalidParams(t *testing.T) {
t.Parallel()
server := NewServer(nil, "", "", "1.0.0")
ctx := context.Background()
req := &Request{
JSONRPC: "2.0",
ID: 1,
Method: "tools/call",
Params: json.RawMessage(`invalid json`),
}
resp := server.handleToolsCall(ctx, req)
require.NotNil(t, resp.Error)
assert.Equal(t, -32602, resp.Error.Code)
assert.Equal(t, "Invalid params", resp.Error.Message)
}
// TestCallTool_UnknownTool tests callTool with unknown tool name.
func TestCallTool_UnknownTool(t *testing.T) {
t.Parallel()
server := NewServer(nil, "", "", "1.0.0")
ctx := context.Background()
_, err := server.callTool(ctx, "nonexistent_tool", json.RawMessage(`{}`))
require.Error(t, err)
assert.Contains(t, err.Error(), "unknown tool")
}
// TestCallTool_InvalidArgs tests callTool with invalid arguments.
func TestCallTool_InvalidArgs(t *testing.T) {
t.Parallel()
server := NewServer(nil, "", "", "1.0.0")
ctx := context.Background()
// Invalid JSON is best-effort parsed; the call fails because worker is unavailable
_, err := server.callTool(ctx, "search", json.RawMessage(`invalid json`))
require.Error(t, err)
assert.Contains(t, err.Error(), "worker unavailable")
}
// =============================================================================
// TESTS FOR Server I/O
// =============================================================================
// TestSendResponse tests response sending.
func TestSendResponse(t *testing.T) {
t.Parallel()
var buf bytes.Buffer
server := &Server{
stdout: &buf,
}
resp := &Response{
JSONRPC: "2.0",
ID: 1,
Result: map[string]string{"status": "ok"},
}
server.sendResponse(resp)
output := buf.String()
assert.Contains(t, output, `"jsonrpc":"2.0"`)
assert.Contains(t, output, `"id":1`)
assert.Contains(t, output, `"result"`)
}
// TestSendError tests error response sending.
func TestSendError(t *testing.T) {
t.Parallel()
var buf bytes.Buffer
server := &Server{
stdout: &buf,
}
server.sendError(1, -32700, "Parse error", "details")
output := buf.String()
assert.Contains(t, output, `"error"`)
assert.Contains(t, output, `-32700`)
assert.Contains(t, output, `"Parse error"`)
}
// TestRun_ParseError tests Run with invalid JSON input.
func TestRun_ParseError(t *testing.T) {
t.Parallel()
var stdout bytes.Buffer
stdin := strings.NewReader("invalid json\n")
server := &Server{
stdin: stdin,
stdout: &stdout,
}
err := server.Run(context.Background())
require.NoError(t, err)
output := stdout.String()
assert.Contains(t, output, `"error"`)
assert.Contains(t, output, `-32700`)
assert.Contains(t, output, `"Parse error"`)
}
// TestRun_EmptyLine tests Run skips empty lines.
func TestRun_EmptyLine(t *testing.T) {
t.Parallel()
var stdout bytes.Buffer
stdin := strings.NewReader("\n\n")
server := &Server{
stdin: stdin,
stdout: &stdout,
}
err := server.Run(context.Background())
require.NoError(t, err)
// Should be empty - no responses for empty lines
assert.Empty(t, stdout.String())
}
// TestRun_ValidRequest tests Run with a valid request.
func TestRun_ValidRequest(t *testing.T) {
t.Parallel()
var stdout bytes.Buffer
req := `{"jsonrpc":"2.0","id":1,"method":"initialize"}`
stdin := strings.NewReader(req + "\n")
server := &Server{
stdin: stdin,
stdout: &stdout,
version: "1.0.0",
}
err := server.Run(context.Background())
require.NoError(t, err)
output := stdout.String()
assert.Contains(t, output, `"jsonrpc":"2.0"`)
assert.Contains(t, output, `"result"`)
assert.Contains(t, output, `"protocolVersion"`)
}
// TestRun_MultipleRequests tests Run with multiple sequential requests.
func TestRun_MultipleRequests(t *testing.T) {
t.Parallel()
var stdout bytes.Buffer
req1 := `{"jsonrpc":"2.0","id":1,"method":"initialize"}`
req2 := `{"jsonrpc":"2.0","id":2,"method":"tools/list"}`
stdin := strings.NewReader(req1 + "\n" + req2 + "\n")
server := &Server{
stdin: stdin,
stdout: &stdout,
version: "1.0.0",
}
err := server.Run(context.Background())
require.NoError(t, err)
output := stdout.String()
// Should contain responses for both requests
assert.Contains(t, output, `"id":1`)
assert.Contains(t, output, `"id":2`)
}
// TestRunMixedRequests tests Run with mixed valid and invalid requests.
func TestRunMixedRequests(t *testing.T) {
t.Parallel()
var stdout bytes.Buffer
req1 := `{"jsonrpc":"2.0","id":1,"method":"initialize"}`
req2 := `invalid json`
req3 := `{"jsonrpc":"2.0","id":3,"method":"tools/list"}`
stdin := strings.NewReader(req1 + "\n" + req2 + "\n" + req3 + "\n")
server := &Server{
stdin: stdin,
stdout: &stdout,
version: "1.0.0",
}
err := server.Run(context.Background())
require.NoError(t, err)
output := stdout.String()
// Should have responses for all three requests
assert.Contains(t, output, `"id":1`)
assert.Contains(t, output, `"error"`) // Parse error for invalid json
assert.Contains(t, output, `"id":3`)
}
// =============================================================================
// TESTS FOR Handler Parameter Validation
// =============================================================================
// TestHandleFindRelatedObservations_Validation tests parameter validation.
func TestHandleFindRelatedObservations_Validation(t *testing.T) {
t.Parallel()
server := NewServer(nil, "", "", "1.0.0")
ctx := context.Background()
tests := []struct {
name string
args string
errContains string
wantErr bool
}{
{
name: "worker unavailable",
args: `{"id": 1}`,
wantErr: true,
errContains: "worker unavailable",
},
{
name: "missing id",
args: `{}`,
wantErr: true,
errContains: "id is required",
},
{
name: "invalid json",
args: `{invalid`,
wantErr: true,
errContains: "invalid arguments",
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
t.Parallel()
_, err := server.handleFindRelatedProxy(ctx, json.RawMessage(tt.args))
if tt.wantErr {
require.Error(t, err)
if tt.errContains != "" {
assert.Contains(t, err.Error(), tt.errContains)
}
} else {
require.NoError(t, err)
}
})
}
}
// TestHandleFindSimilarObservations_Validation tests parameter validation.
func TestHandleFindSimilarObservations_Validation(t *testing.T) {
t.Parallel()
server := NewServer(nil, "", "", "1.0.0")
ctx := context.Background()
tests := []struct {
name string
args string
errContains string
wantErr bool
}{
{
name: "missing query",
args: `{}`,
wantErr: true,
errContains: "query is required",
},
{
name: "invalid json",
args: `{invalid`,
wantErr: true,
errContains: "invalid arguments",
},
{
name: "nil vector client",
args: `{"query": "test"}`,
wantErr: true,
errContains: "worker unavailable",
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
t.Parallel()
_, err := server.handleFindSimilarProxy(ctx, json.RawMessage(tt.args))
if tt.wantErr {
require.Error(t, err)
if tt.errContains != "" {
assert.Contains(t, err.Error(), tt.errContains)
}
} else {
require.NoError(t, err)
}
})
}
}
// TestHandleGetPatterns_Validation tests parameter validation.
func TestHandleGetPatterns_Validation(t *testing.T) {
t.Parallel()
server := NewServer(nil, "", "", "1.0.0")
ctx := context.Background()
tests := []struct {
name string
args string
errContains string
wantErr bool
}{
{
name: "invalid json",
args: `{invalid`,
wantErr: true,
errContains: "invalid arguments",
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
t.Parallel()
_, err := server.handleGetPatternsProxy(ctx, json.RawMessage(tt.args))
if tt.wantErr {
require.Error(t, err)
if tt.errContains != "" {
assert.Contains(t, err.Error(), tt.errContains)
}
} else {
require.NoError(t, err)
}
})
}
}
// TestHandleBulkDeleteObservations_Validation tests parameter validation.
func TestHandleBulkDeleteObservations_Validation(t *testing.T) {
t.Parallel()
server := NewServer(nil, "", "", "1.0.0")
ctx := context.Background()
tests := []struct {
name string
args string
errContains string
wantErr bool
}{
{
name: "missing ids",
args: `{}`,
wantErr: true,
errContains: "ids is required",
},
{
name: "empty ids array",
args: `{"ids": []}`,
wantErr: true,
errContains: "ids is required",
},
{
name: "too many ids",
args: `{"ids": [` + strings.Repeat("1,", 1001) + `1]}`,
wantErr: true,
errContains: "worker unavailable",
},
{
name: "invalid json",
args: `{invalid`,
wantErr: true,
errContains: "invalid arguments",
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
t.Parallel()
_, err := server.handleBulkStatusProxy(ctx, json.RawMessage(tt.args), "delete")
if tt.wantErr {
require.Error(t, err)
if tt.errContains != "" {
assert.Contains(t, err.Error(), tt.errContains)
}
} else {
require.NoError(t, err)
}
})
}
}
// TestHandleBulkMarkSuperseded_Validation tests parameter validation.
func TestHandleBulkMarkSuperseded_Validation(t *testing.T) {
t.Parallel()
server := NewServer(nil, "", "", "1.0.0")
ctx := context.Background()
tests := []struct {
name string
args string
errContains string
wantErr bool
}{
{
name: "missing ids",
args: `{}`,
wantErr: true,
errContains: "ids is required",
},
{
name: "empty ids array",
args: `{"ids": []}`,
wantErr: true,
errContains: "ids is required",
},
{
name: "too many ids",
args: `{"ids": [` + strings.Repeat("1,", 1001) + `1]}`,
wantErr: true,
errContains: "worker unavailable",
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
t.Parallel()
_, err := server.handleBulkStatusProxy(ctx, json.RawMessage(tt.args), "supersede")
if tt.wantErr {
require.Error(t, err)
if tt.errContains != "" {
assert.Contains(t, err.Error(), tt.errContains)
}
} else {
require.NoError(t, err)
}
})
}
}
// TestHandleBulkBoostObservations_Validation tests parameter validation.
func TestHandleBulkBoostObservations_Validation(t *testing.T) {
t.Parallel()
server := NewServer(nil, "", "", "1.0.0")
ctx := context.Background()
tests := []struct {
name string
args string
errContains string
wantErr bool
}{
{
name: "missing ids",
args: `{"boost": 0.1}`,
wantErr: true,
errContains: "ids is required",
},
{
name: "boost out of range low",
args: `{"ids": [1], "boost": -1.5}`,
wantErr: true,
errContains: "worker unavailable",
},
{
name: "boost out of range high",
args: `{"ids": [1], "boost": 1.5}`,
wantErr: true,
errContains: "worker unavailable",
},
{
name: "too many ids",
args: `{"ids": [` + strings.Repeat("1,", 1001) + `1], "boost": 0.1}`,
wantErr: true,
errContains: "worker unavailable",
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
t.Parallel()
_, err := server.handleBulkStatusProxy(ctx, json.RawMessage(tt.args), "boost")
if tt.wantErr {
require.Error(t, err)
if tt.errContains != "" {
assert.Contains(t, err.Error(), tt.errContains)
}
} else {
require.NoError(t, err)
}
})
}
}
// TestHandleTriggerMaintenance_Validation tests that nil service returns error.
func TestHandleTriggerMaintenance_Validation(t *testing.T) {
t.Parallel()
server := NewServer(nil, "", "", "1.0.0")
ctx := context.Background()
_, err := server.proxyPostRaw(ctx, "/api/scoring/recalculate", nil)
require.Error(t, err)
assert.Contains(t, err.Error(), "worker unavailable")
}
// TestHandleGetMaintenanceStats_Validation tests that nil service returns error.
func TestHandleGetMaintenanceStats_Validation(t *testing.T) {
t.Parallel()
server := NewServer(nil, "", "", "1.0.0")
ctx := context.Background()
_, err := server.proxyGetRaw(ctx, "/api/stats", map[string]string{"project": ""})
require.Error(t, err)
assert.Contains(t, err.Error(), "worker unavailable")
}
// TestHandleMergeObservations_Validation tests parameter validation.
func TestHandleMergeObservations_Validation(t *testing.T) {
t.Parallel()
server := NewServer(nil, "", "", "1.0.0")
ctx := context.Background()
tests := []struct {
name string
args string
errContains string
wantErr bool
}{
{
name: "missing source_id",
args: `{"target_id": 2}`,
wantErr: true,
errContains: "source_id and target_id are required",
},
{
name: "missing target_id",
args: `{"source_id": 1}`,
wantErr: true,
errContains: "source_id and target_id are required",
},
{
name: "same source and target",
args: `{"source_id": 1, "target_id": 1}`,
wantErr: true,
errContains: "worker unavailable",
},
{
name: "worker unavailable with boost",
args: `{"source_id": 1, "target_id": 2, "boost": 0.6}`,
wantErr: true,
errContains: "worker unavailable",
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
t.Parallel()
_, err := server.handleMergeProxy(ctx, json.RawMessage(tt.args))
if tt.wantErr {
require.Error(t, err)
if tt.errContains != "" {
assert.Contains(t, err.Error(), tt.errContains)
}
} else {
require.NoError(t, err)
}
})
}
}
// TestHandleGetObservation_Validation tests parameter validation.
func TestHandleGetObservation_Validation(t *testing.T) {
t.Parallel()
server := NewServer(nil, "", "", "1.0.0")
ctx := context.Background()
tests := []struct {
name string
args string
errContains string
wantErr bool
}{
{
name: "missing id",
args: `{}`,
wantErr: true,
errContains: "id is required",
},
{
name: "invalid json",
args: `{invalid`,
wantErr: true,
errContains: "invalid arguments",
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
t.Parallel()
_, err := server.handleGetObservationProxy(ctx, json.RawMessage(tt.args))
if tt.wantErr {
require.Error(t, err)
if tt.errContains != "" {
assert.Contains(t, err.Error(), tt.errContains)
}
} else {
require.NoError(t, err)
}
})
}
}
// TestHandleEditObservation_Validation tests parameter validation.
func TestHandleEditObservation_Validation(t *testing.T) {
t.Parallel()
server := NewServer(nil, "", "", "1.0.0")
ctx := context.Background()
tests := []struct {
name string
args string
errContains string
wantErr bool
}{
{
name: "missing id",
args: `{"title": "new title"}`,
wantErr: true,
errContains: "id is required",
},
{
name: "invalid scope",
args: `{"id": 1, "scope": "invalid"}`,
wantErr: true,
errContains: "worker unavailable",
},
{
name: "invalid json",
args: `{invalid`,
wantErr: true,
errContains: "invalid arguments",
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
t.Parallel()
_, err := server.handleEditObservationProxy(ctx, json.RawMessage(tt.args))
if tt.wantErr {
require.Error(t, err)
if tt.errContains != "" {
assert.Contains(t, err.Error(), tt.errContains)
}
} else {
require.NoError(t, err)
}
})
}
}
// TestHandleGetObservationQuality_Validation tests parameter validation.
func TestHandleGetObservationQuality_Validation(t *testing.T) {
t.Parallel()
server := NewServer(nil, "", "", "1.0.0")
ctx := context.Background()
tests := []struct {
name string
args string
errContains string
wantErr bool
}{
{
name: "missing id",
args: `{}`,
wantErr: true,
errContains: "id is required",
},
{
name: "invalid json",
args: `{invalid`,
wantErr: true,
errContains: "invalid arguments",
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
t.Parallel()
_, err := server.handleGetObservationQualityProxy(ctx, json.RawMessage(tt.args))
if tt.wantErr {
require.Error(t, err)
if tt.errContains != "" {
assert.Contains(t, err.Error(), tt.errContains)
}
} else {
require.NoError(t, err)
}
})
}
}
// TestHandleSuggestConsolidations_Validation tests parameter validation.
func TestHandleSuggestConsolidations_Validation(t *testing.T) {
t.Parallel()
server := NewServer(nil, "", "", "1.0.0")
ctx := context.Background()
tests := []struct {
name string
args string
errContains string
wantErr bool
}{
{
name: "min_similarity too low",
args: `{"min_similarity": 0.3}`,
wantErr: true,
errContains: "worker unavailable",
},
{
name: "min_similarity too high",
args: `{"min_similarity": 1.5}`,
wantErr: true,
errContains: "worker unavailable",
},
{
name: "invalid json",
args: `{invalid`,
wantErr: true,
errContains: "invalid arguments",
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
t.Parallel()
_, err := server.handleSuggestConsolidationsProxy(ctx, json.RawMessage(tt.args))
if tt.wantErr {
require.Error(t, err)
if tt.errContains != "" {
assert.Contains(t, err.Error(), tt.errContains)
}
} else {
require.NoError(t, err)
}
})
}
}
// TestHandleTagObservation_Validation tests parameter validation.
func TestHandleTagObservation_Validation(t *testing.T) {
t.Parallel()
server := NewServer(nil, "", "", "1.0.0")
ctx := context.Background()
tests := []struct {
name string
args string
errContains string
wantErr bool
}{
{
name: "missing id",
args: `{"tags": ["tag1"]}`,
wantErr: true,
errContains: "id is required",
},
{
name: "missing tags",
args: `{"id": 1}`,
wantErr: true,
errContains: "tags is required",
},
{
name: "invalid mode",
args: `{"id": 1, "tags": ["tag1"], "mode": "invalid"}`,
wantErr: true,
errContains: "mode must be 'add', 'remove', or 'set'",
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
t.Parallel()
_, err := server.handleTagObservationProxy(ctx, json.RawMessage(tt.args))
if tt.wantErr {
require.Error(t, err)
if tt.errContains != "" {
assert.Contains(t, err.Error(), tt.errContains)
}
} else {
require.NoError(t, err)
}
})
}
}
// TestHandleGetObservationsByTag_Validation tests parameter validation.
func TestHandleGetObservationsByTag_Validation(t *testing.T) {
t.Parallel()
server := NewServer(nil, "", "", "1.0.0")
ctx := context.Background()
tests := []struct {
name string
args string
errContains string
wantErr bool
}{
{
name: "missing tag",
args: `{}`,
wantErr: true,
errContains: "tag is required",
},
{
name: "invalid json",
args: `{invalid`,
wantErr: true,
errContains: "invalid arguments",
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
t.Parallel()
_, err := server.handleGetObservationsByTagProxy(ctx, json.RawMessage(tt.args))
if tt.wantErr {
require.Error(t, err)
if tt.errContains != "" {
assert.Contains(t, err.Error(), tt.errContains)
}
} else {
require.NoError(t, err)
}
})
}
}
// TestHandleBatchTagByPattern_Validation tests parameter validation.
func TestHandleBatchTagByPattern_Validation(t *testing.T) {
t.Parallel()
server := NewServer(nil, "", "", "1.0.0")
ctx := context.Background()
tests := []struct {
name string
args string
errContains string
wantErr bool
}{
{
name: "missing pattern",
args: `{"tags": ["tag1"]}`,
wantErr: true,
errContains: "pattern is required",
},
{
name: "missing tags",
args: `{"pattern": "test"}`,
wantErr: true,
errContains: "tags is required",
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
t.Parallel()
_, err := server.handleBatchTagProxy(ctx, json.RawMessage(tt.args))
if tt.wantErr {
require.Error(t, err)
if tt.errContains != "" {
assert.Contains(t, err.Error(), tt.errContains)
}
} else {
require.NoError(t, err)
}
})
}
}
// TestHandleExplainSearchRanking_Validation tests parameter validation.
func TestHandleExplainSearchRanking_Validation(t *testing.T) {
t.Parallel()
server := NewServer(nil, "", "", "1.0.0")
ctx := context.Background()
tests := []struct {
name string
args string
errContains string
wantErr bool
}{
{
name: "missing query",
args: `{"top_n": 5}`,
wantErr: true,
errContains: "query is required",
},
{
name: "invalid json",
args: `{invalid`,
wantErr: true,
errContains: "invalid arguments",
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
t.Parallel()
_, err := server.handleExplainSearchProxy(ctx, json.RawMessage(tt.args))
if tt.wantErr {
require.Error(t, err)
if tt.errContains != "" {
assert.Contains(t, err.Error(), tt.errContains)
}
} else {
require.NoError(t, err)
}
})
}
}
// TestHandleGetObservationRelationships_Validation tests parameter validation.
func TestHandleGetObservationRelationships_Validation(t *testing.T) {
t.Parallel()
server := NewServer(nil, "", "", "1.0.0")
ctx := context.Background()
tests := []struct {
name string
args string
errContains string
wantErr bool
}{
{
name: "missing id",
args: `{}`,
wantErr: true,
errContains: "id is required",
},
{
name: "negative id",
args: `{"id": -1}`,
wantErr: true,
errContains: "id is required and must be positive",
},
{
name: "nil relation store",
args: `{"id": 1}`,
wantErr: true,
errContains: "worker unavailable",
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
t.Parallel()
_, err := server.handleGetRelationshipsProxy(ctx, json.RawMessage(tt.args))
if tt.wantErr {
require.Error(t, err)
if tt.errContains != "" {
assert.Contains(t, err.Error(), tt.errContains)
}
} else {
require.NoError(t, err)
}
})
}
}
// TestHandleGetObservationScoringBreakdown_Validation tests parameter validation.
func TestHandleGetObservationScoringBreakdown_Validation(t *testing.T) {
t.Parallel()
server := NewServer(nil, "", "", "1.0.0")
ctx := context.Background()
tests := []struct {
name string
args string
errContains string
wantErr bool
}{
{
name: "missing id",
args: `{}`,
wantErr: true,
errContains: "id is required",
},
{
name: "negative id",
args: `{"id": -1}`,
wantErr: true,
errContains: "id is required and must be positive",
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
t.Parallel()
_, err := server.handleGetScoringBreakdownProxy(ctx, json.RawMessage(tt.args))
if tt.wantErr {
require.Error(t, err)
if tt.errContains != "" {
assert.Contains(t, err.Error(), tt.errContains)
}
} else {
require.NoError(t, err)
}
})
}
}
// TestHandleTimeline_InvalidJSON tests timeline with invalid JSON.
func TestHandleTimeline_InvalidJSON(t *testing.T) {
t.Parallel()
server := NewServer(nil, "", "", "1.0.0")
ctx := context.Background()
_, err := server.handleTimelineProxy(ctx, json.RawMessage(`{invalid`))
require.Error(t, err)
assert.Contains(t, err.Error(), "invalid timeline params")
}
// TestHandleTimelineByQuery_EmptyQuery tests timeline by query with empty query.
func TestHandleTimelineByQuery_EmptyQuery(t *testing.T) {
t.Parallel()
server := NewServer(nil, "", "", "1.0.0")
ctx := context.Background()
// Empty query should return empty results (no anchor found)
result, err := server.handleTimelineProxy(ctx, json.RawMessage(`{}`))
require.NoError(t, err)
assert.Contains(t, result, `"observations":[]`)
}
// TestHandleTimelineByQuery_InvalidJSON tests timeline by query with invalid JSON.
func TestHandleTimelineByQuery_InvalidJSON(t *testing.T) {
t.Parallel()
server := NewServer(nil, "", "", "1.0.0")
ctx := context.Background()
_, err := server.handleTimelineProxy(ctx, json.RawMessage(`{invalid`))
require.Error(t, err)
assert.Contains(t, err.Error(), "invalid timeline params")
}
// TestHandleTimeline_NoAnchorNoQuery tests timeline with no anchor and no query.
func TestHandleTimeline_NoAnchorNoQuery(t *testing.T) {
t.Parallel()
server := NewServer(nil, "", "", "1.0.0")
ctx := context.Background()
// No anchor_id and no query should return empty result JSON
result, err := server.handleTimelineProxy(ctx, json.RawMessage(`{}`))
require.NoError(t, err)
assert.NotEmpty(t, result)
assert.Contains(t, result, `"observations":[]`)
}
// TestHandleTimeline_WithDefaults tests timeline default values are applied.
func TestHandleTimeline_WithDefaults(t *testing.T) {
t.Parallel()
server := NewServer(nil, "", "", "1.0.0")
ctx := context.Background()
// With anchor_id = 0, should return empty result JSON
result, err := server.handleTimelineProxy(ctx, json.RawMessage(`{"anchor_id": 0}`))
require.NoError(t, err)
assert.NotEmpty(t, result)
assert.Contains(t, result, `"observations":[]`)
}
// =============================================================================
// TESTS FOR Additional Tool Operations
// =============================================================================
// TestJSONRPCErrorCodes tests standard JSON-RPC error codes.
func TestJSONRPCErrorCodes(t *testing.T) {
t.Parallel()
errorCodes := map[string]int{
"Parse error": -32700,
"Invalid Request": -32600,
"Method not found": -32601,
"Invalid params": -32602,
"Internal error": -32603,
}
for msg, code := range errorCodes {
t.Run(msg, func(t *testing.T) {
t.Parallel()
err := Error{Code: code, Message: msg}
assert.Equal(t, code, err.Code)
assert.Equal(t, msg, err.Message)
})
}
}
// TestToolListContainsExpectedSchemas tests that tool schemas are valid.
func TestToolListContainsExpectedSchemas(t *testing.T) {
t.Parallel()
server := NewServer(nil, "", "", "1.0.0")
req := &Request{
JSONRPC: "2.0",
ID: 1,
Method: "tools/list",
}
resp := server.handleToolsList(req)
result := resp.Result.(map[string]any)
tools := result["tools"].([]Tool)
for _, tool := range tools {
assert.NotEmpty(t, tool.Name)
assert.NotEmpty(t, tool.Description)
assert.NotNil(t, tool.InputSchema)
// Check schema has type
schema := tool.InputSchema
_, hasType := schema["type"]
assert.True(t, hasType, "tool %s schema should have type", tool.Name)
}
}
// TestHandleToolsCall_UnknownTool tests tools/call with unknown tool name.
// After Fix #5: tool errors use Result with isError, not top-level Error.
func TestHandleToolsCall_UnknownTool(t *testing.T) {
t.Parallel()
server := NewServer(nil, "", "", "1.0.0")
ctx := context.Background()
req := &Request{
JSONRPC: "2.0",
ID: 1,
Method: "tools/call",
Params: json.RawMessage(`{"name":"unknown_tool","arguments":{}}`),
}
resp := server.handleToolsCall(ctx, req)
assert.Nil(t, resp.Error, "tool errors must not use top-level Error (MCP spec)")
require.NotNil(t, resp.Result)
result, ok := resp.Result.(map[string]any)
require.True(t, ok)
assert.Equal(t, true, result["isError"])
content := result["content"].([]map[string]any)
assert.Contains(t, content[0]["text"], "unknown tool")
}
// TestCallTool_ToolNameRecognition tests that valid tool names are recognized.
func TestCallTool_ToolNameRecognition(t *testing.T) {
t.Parallel()
server := NewServer(nil, "", "", "1.0.0")
req := &Request{
JSONRPC: "2.0",
ID: 1,
Method: "tools/list",
}
resp := server.handleToolsList(req)
result := resp.Result.(map[string]any)
tools := result["tools"].([]Tool)
// Verify all expected tools are registered
expectedTools := map[string]bool{
"search": true,
"timeline": true,
"decisions": true,
"changes": true,
"how_it_works": true,
"find_by_concept": true,
"find_by_file": true,
"find_by_type": true,
"get_recent_context": true,
"get_context_timeline": true,
"get_timeline_by_query": true,
"find_related_observations": true,
"find_similar_observations": true,
"get_patterns": true,
"get_memory_stats": true,
"bulk_delete_observations": true,
"bulk_mark_superseded": true,
"bulk_boost_observations": true,
"trigger_maintenance": true,
"get_maintenance_stats": true,
"merge_observations": true,
"get_observation": true,
"edit_observation": true,
"get_observation_quality": true,
"suggest_consolidations": true,
"tag_observation": true,
"get_observations_by_tag": true,
"get_temporal_trends": true,
"get_data_quality_report": true,
"batch_tag_by_pattern": true,
"explain_search_ranking": true,
"export_observations": true,
"check_system_health": true,
"analyze_search_patterns": true,
"get_observation_relationships": true,
"get_observation_scoring_breakdown": true,
"analyze_observation_importance": true,
}
foundTools := make(map[string]bool)
for _, tool := range tools {
foundTools[tool.Name] = true
}
for name := range expectedTools {
assert.True(t, foundTools[name], "tool %s should be registered", name)
}
}
// TestTimelineParamsStruct_Complete tests complete timelineParams parsing.
func TestTimelineParamsStruct_Complete(t *testing.T) {
t.Parallel()
input := `{
"anchor_id": 100,
"query": "test query",
"before": 5,
"after": 15,
"project": "my-project",
"obs_type": "bugfix",
"concepts": "security,auth",
"files": "main.go,handler.go",
"dateStart": 1700000000000,
"dateEnd": 1700100000000,
"format": "full"
}`
var params timelineParams
err := json.Unmarshal([]byte(input), &params)
require.NoError(t, err)
assert.Equal(t, int64(100), params.AnchorID)
assert.Equal(t, "test query", params.Query)
assert.Equal(t, 5, params.Before)
assert.Equal(t, 15, params.After)
assert.Equal(t, "my-project", params.Project)
assert.Equal(t, "bugfix", params.ObsType)
assert.Equal(t, "security,auth", params.Concepts)
assert.Equal(t, "main.go,handler.go", params.Files)
assert.Equal(t, int64(1700000000000), params.DateStart)
assert.Equal(t, int64(1700100000000), params.DateEnd)
assert.Equal(t, "full", params.Format)
}
// TestServerStdinStdoutConfig tests that server stdin/stdout can be configured.
func TestServerStdinStdoutConfig(t *testing.T) {
t.Parallel()
var stdout bytes.Buffer
var stdin bytes.Buffer
server := &Server{
stdin: &stdin,
stdout: &stdout,
version: "test-version",
}
assert.Equal(t, &stdin, server.stdin)
assert.Equal(t, &stdout, server.stdout)
assert.Equal(t, "test-version", server.version)
}
// TestResponseIDTypes tests that response IDs can be various types.
func TestResponseIDTypes(t *testing.T) {
t.Parallel()
tests := []struct {
id any
name string
}{
{name: "integer id", id: 1},
{name: "string id", id: "abc-123"},
{name: "float id", id: 1.5},
{name: "null id", id: nil},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
t.Parallel()
var buf bytes.Buffer
server := &Server{stdout: &buf}
resp := &Response{
JSONRPC: "2.0",
ID: tt.id,
Result: "ok",
}
server.sendResponse(resp)
output := buf.String()
assert.Contains(t, output, `"jsonrpc":"2.0"`)
})
}
}
// TestServerFields tests Server struct fields.
func TestServerFields(t *testing.T) {
t.Parallel()
server := NewServer(nil, "http://localhost:37777", "test", "2.0.0")
assert.Equal(t, "2.0.0", server.version)
assert.Nil(t, server.client)
assert.Equal(t, "http://localhost:37777", server.workerURL)
assert.NotNil(t, server.stdin)
assert.NotNil(t, server.stdout)
}
// TestRequestUnmarshalWithNullID tests Request unmarshaling with null ID.
func TestRequestUnmarshalWithNullID(t *testing.T) {
t.Parallel()
input := `{"jsonrpc":"2.0","id":null,"method":"initialize"}`
var req Request
err := json.Unmarshal([]byte(input), &req)
require.NoError(t, err)
assert.Equal(t, "2.0", req.JSONRPC)
assert.Nil(t, req.ID)
assert.Equal(t, "initialize", req.Method)
}
// TestResponseWithNullError tests Response without error.
func TestResponseWithNullError(t *testing.T) {
t.Parallel()
resp := Response{
JSONRPC: "2.0",
ID: 1,
Result: "success",
Error: nil,
}
data, err := json.Marshal(resp)
require.NoError(t, err)
assert.Contains(t, string(data), `"result":"success"`)
assert.NotContains(t, string(data), `"error"`)
}
// TestErrorWithNilData tests Error without data.
func TestErrorWithNilData(t *testing.T) {
t.Parallel()
err := Error{
Code: -32600,
Message: "Invalid Request",
Data: nil,
}
data, errMarshal := json.Marshal(err)
require.NoError(t, errMarshal)
assert.Contains(t, string(data), `"code":-32600`)
assert.Contains(t, string(data), `"message":"Invalid Request"`)
assert.NotContains(t, string(data), `"data"`)
}
// TestToolInputSchema tests that tool input schemas have required fields.
func TestToolInputSchema(t *testing.T) {
t.Parallel()
server := NewServer(nil, "", "", "1.0.0")
req := &Request{
JSONRPC: "2.0",
ID: 1,
Method: "tools/list",
}
resp := server.handleToolsList(req)
result := resp.Result.(map[string]any)
tools := result["tools"].([]Tool)
for _, tool := range tools {
schema := tool.InputSchema
schemaType, ok := schema["type"]
assert.True(t, ok, "tool %s schema should have type", tool.Name)
assert.Equal(t, "object", schemaType, "tool %s schema type should be object", tool.Name)
// All tools should have properties
_, hasProperties := schema["properties"]
assert.True(t, hasProperties, "tool %s should have properties", tool.Name)
}
}
// TestCallTool_UnknownToolName tests callTool with various unknown tool names.
func TestCallTool_UnknownToolName(t *testing.T) {
t.Parallel()
server := NewServer(nil, "", "", "1.0.0")
ctx := context.Background()
unknownTools := []string{
"invalid_tool",
"nonexistent",
"search_v2",
"timeline_special",
}
for _, name := range unknownTools {
t.Run(name, func(t *testing.T) {
t.Parallel()
result, err := server.callTool(ctx, name, json.RawMessage(`{}`))
assert.Error(t, err)
assert.Empty(t, result)
assert.Contains(t, err.Error(), "unknown tool")
})
}
}
// TestTimelineParamsStruct_Validation tests timelineParams struct field validation.
func TestTimelineParamsStruct_Validation(t *testing.T) {
t.Parallel()
tests := []struct {
name string
json string
wantOK bool
}{
{"valid with anchor_id", `{"anchor_id":123,"before":5,"after":5}`, true},
{"valid with query only", `{"query":"test query"}`, true},
{"empty params", `{}`, true},
{"with all fields", `{"anchor_id":1,"query":"test","before":10,"after":10,"project":"proj","obs_type":"bugfix","format":"full"}`, true},
{"invalid json", `{invalid`, false},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
t.Parallel()
var params timelineParams
err := json.Unmarshal([]byte(tt.json), &params)
if tt.wantOK {
assert.NoError(t, err)
} else {
assert.Error(t, err)
}
})
}
}
// TestHandleToolsCall_EmptyParams tests tools/call with empty params.
// After Fix #5: tool errors use Result with isError, not top-level Error.
func TestHandleToolsCall_EmptyParams(t *testing.T) {
t.Parallel()
server := NewServer(nil, "", "", "1.0.0")
ctx := context.Background()
req := &Request{
JSONRPC: "2.0",
ID: 1,
Method: "tools/call",
Params: json.RawMessage(`{}`),
}
resp := server.handleToolsCall(ctx, req)
// Empty name goes through callTool default branch -> isError
assert.Nil(t, resp.Error, "tool errors must use isError in Result")
require.NotNil(t, resp.Result)
result, ok := resp.Result.(map[string]any)
require.True(t, ok)
assert.Equal(t, true, result["isError"])
}
// TestSendResponse_WithError tests sendResponse with an error response.
func TestSendResponse_WithError(t *testing.T) {
t.Parallel()
var buf bytes.Buffer
server := &Server{stdout: &buf}
resp := &Response{
JSONRPC: "2.0",
ID: 1,
Error: &Error{Code: -32600, Message: "Invalid Request"},
}
server.sendResponse(resp)
output := buf.String()
assert.Contains(t, output, `"error"`)
assert.Contains(t, output, `-32600`)
}
// TestSendResponse_NilID tests sendResponse with nil ID.
func TestSendResponse_NilID(t *testing.T) {
t.Parallel()
var buf bytes.Buffer
server := &Server{stdout: &buf}
resp := &Response{
JSONRPC: "2.0",
ID: nil,
Result: "notification response",
}
server.sendResponse(resp)
output := buf.String()
assert.Contains(t, output, `"id":null`)
}
// TestToolCallParamsWithComplexArgs tests ToolCallParams with complex arguments.
func TestToolCallParamsWithComplexArgs(t *testing.T) {
t.Parallel()
input := `{
"name": "search",
"arguments": {
"query": "authentication bug",
"project": "my-project",
"limit": 10,
"type": "observations"
}
}`
var params ToolCallParams
err := json.Unmarshal([]byte(input), &params)
require.NoError(t, err)
assert.Equal(t, "search", params.Name)
assert.NotEmpty(t, params.Arguments)
}
// TestHandleToolsCall_UnknownToolNameError tests tools/call with unknown tool returns error.
// After Fix #5: tool errors use Result with isError, not top-level Error.
func TestHandleToolsCall_UnknownToolNameError(t *testing.T) {
t.Parallel()
server := NewServer(nil, "", "", "1.0.0")
ctx := context.Background()
req := &Request{
JSONRPC: "2.0",
ID: 1,
Method: "tools/call",
Params: json.RawMessage(`{"name":"very_unknown_tool_name","arguments":{}}`),
}
resp := server.handleToolsCall(ctx, req)
assert.Equal(t, "2.0", resp.JSONRPC)
assert.Equal(t, 1, resp.ID)
assert.Nil(t, resp.Error, "tool errors must use isError in Result, not top-level Error")
require.NotNil(t, resp.Result)
result, ok := resp.Result.(map[string]any)
require.True(t, ok)
assert.Equal(t, true, result["isError"])
}
// =============================================================================
// TESTS FOR Handler Defaults
// =============================================================================
// TestHandleTimeline_Defaults tests timeline default values.
func TestHandleTimeline_Defaults(t *testing.T) {
t.Parallel()
// Test that handleTimeline sets default before/after values
params := timelineParams{
AnchorID: 0,
Query: "",
Before: 0,
After: 0,
}
// Simulate the default value assignment from handleTimeline
if params.Before <= 0 {
params.Before = 10
}
if params.After <= 0 {
params.After = 10
}
assert.Equal(t, 10, params.Before)
assert.Equal(t, 10, params.After)
}
// TestHandleGetTemporalTrends_Validation tests parameter validation.
func TestHandleGetTemporalTrends_Validation(t *testing.T) {
t.Parallel()
server := NewServer(nil, "", "", "1.0.0")
ctx := context.Background()
tests := []struct {
name string
args string
errContains string
wantErr bool
}{
{
name: "invalid json",
args: `{invalid`,
wantErr: true,
errContains: "invalid arguments",
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
t.Parallel()
_, err := server.handleGetTemporalTrendsProxy(ctx, json.RawMessage(tt.args))
if tt.wantErr {
require.Error(t, err)
if tt.errContains != "" {
assert.Contains(t, err.Error(), tt.errContains)
}
}
})
}
}
// TestHandleGetDataQualityReport_Validation tests parameter validation.
func TestHandleGetDataQualityReport_Validation(t *testing.T) {
t.Parallel()
server := NewServer(nil, "", "", "1.0.0")
ctx := context.Background()
tests := []struct {
name string
args string
errContains string
wantErr bool
}{
{
name: "invalid json",
args: `{invalid`,
wantErr: true,
errContains: "invalid arguments",
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
t.Parallel()
_, err := server.handleGetDataQualityProxy(ctx, json.RawMessage(tt.args))
if tt.wantErr {
require.Error(t, err)
if tt.errContains != "" {
assert.Contains(t, err.Error(), tt.errContains)
}
}
})
}
}
// TestHandleExportObservations_Validation tests parameter validation.
func TestHandleExportObservations_Validation(t *testing.T) {
t.Parallel()
server := NewServer(nil, "", "", "1.0.0")
ctx := context.Background()
tests := []struct {
name string
args string
errContains string
wantErr bool
}{
{
name: "invalid json",
args: `{invalid`,
wantErr: true,
errContains: "invalid arguments",
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
t.Parallel()
_, err := server.handleExportProxy(ctx, json.RawMessage(tt.args))
if tt.wantErr {
require.Error(t, err)
if tt.errContains != "" {
assert.Contains(t, err.Error(), tt.errContains)
}
}
})
}
}
// TestHandleAnalyzeSearchPatterns_Validation tests parameter validation.
func TestHandleAnalyzeSearchPatterns_Validation(t *testing.T) {
t.Parallel()
server := NewServer(nil, "", "", "1.0.0")
ctx := context.Background()
tests := []struct {
name string
args string
errContains string
wantErr bool
}{
{
name: "worker unavailable",
args: `{}`,
wantErr: true,
errContains: "worker unavailable",
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
t.Parallel()
_, err := server.proxyGetRaw(ctx, "/api/search/analytics", nil)
if tt.wantErr {
require.Error(t, err)
if tt.errContains != "" {
assert.Contains(t, err.Error(), tt.errContains)
}
}
})
}
}
// TestHandleAnalyzeObservationImportance_Validation tests parameter validation.
func TestHandleAnalyzeObservationImportance_Validation(t *testing.T) {
t.Parallel()
server := NewServer(nil, "", "", "1.0.0")
ctx := context.Background()
tests := []struct {
name string
args string
errContains string
wantErr bool
}{
{
name: "invalid json",
args: `{invalid`,
wantErr: true,
errContains: "invalid arguments",
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
t.Parallel()
_, err := server.handleAnalyzeImportanceProxy(ctx, json.RawMessage(tt.args))
if tt.wantErr {
require.Error(t, err)
if tt.errContains != "" {
assert.Contains(t, err.Error(), tt.errContains)
}
}
})
}
}
// TestHandleGetMemoryStats_NilStores tests GetMemoryStats with nil stores.
func TestHandleGetMemoryStats_NilStores(t *testing.T) {
t.Parallel()
server := NewServer(nil, "", "", "1.0.0")
ctx := context.Background()
// Should return error when worker is unavailable (nil client)
_, err := server.proxyGetRaw(ctx, "/api/stats", map[string]string{"project": ""})
require.Error(t, err)
assert.Contains(t, err.Error(), "worker unavailable")
}
// TestHandleCheckSystemHealth_NilStores tests CheckSystemHealth with nil stores.
func TestHandleCheckSystemHealth_NilStores(t *testing.T) {
t.Parallel()
server := NewServer(nil, "", "", "1.0.0")
ctx := context.Background()
// Should return error when worker is unavailable (nil client)
_, err := server.proxyGetRaw(ctx, "/api/selfcheck", nil)
require.Error(t, err)
assert.Contains(t, err.Error(), "worker unavailable")
}
// =============================================================================
// COMPREHENSIVE callTool TESTS
// =============================================================================
// TestCallTool_AllSpecialTools tests all special tool cases in callTool switch.
func TestCallTool_AllSpecialTools(t *testing.T) {
t.Parallel()
server := NewServer(nil, "", "", "1.0.0")
ctx := context.Background()
// Tests for tools that can work without stores or have nil guards
tests := []struct {
name string
toolName string
args string
wantErr bool
checkPanic bool
}{
// Tools that work with nil stores
{
name: "get_memory_stats",
toolName: "get_memory_stats",
args: `{}`,
wantErr: true, // worker unavailable with nil client
},
{
name: "check_system_health",
toolName: "check_system_health",
args: `{}`,
wantErr: true, // worker unavailable with nil client
},
// Tools that need stores but have parameter validation first
{
name: "find_related_observations - invalid json",
toolName: "find_related_observations",
args: `{invalid`,
wantErr: true,
},
{
name: "find_related_observations - missing id",
toolName: "find_related_observations",
args: `{}`,
wantErr: true,
},
{
name: "find_similar_observations - invalid json",
toolName: "find_similar_observations",
args: `{invalid`,
wantErr: true,
},
{
name: "find_similar_observations - missing query",
toolName: "find_similar_observations",
args: `{}`,
wantErr: true,
},
{
name: "get_patterns - invalid json",
toolName: "get_patterns",
args: `{invalid`,
wantErr: true,
},
{
name: "bulk_delete_observations - invalid json",
toolName: "bulk_delete_observations",
args: `{invalid`,
wantErr: true,
},
{
name: "bulk_delete_observations - missing ids",
toolName: "bulk_delete_observations",
args: `{}`,
wantErr: true,
},
{
name: "bulk_mark_superseded - invalid json",
toolName: "bulk_mark_superseded",
args: `{invalid`,
wantErr: true,
},
{
name: "bulk_mark_superseded - missing ids",
toolName: "bulk_mark_superseded",
args: `{}`,
wantErr: true,
},
{
name: "bulk_boost_observations - invalid json",
toolName: "bulk_boost_observations",
args: `{invalid`,
wantErr: true,
},
{
name: "bulk_boost_observations - missing ids",
toolName: "bulk_boost_observations",
args: `{}`,
wantErr: true,
},
{
name: "merge_observations - invalid json",
toolName: "merge_observations",
args: `{invalid`,
wantErr: true,
},
{
name: "merge_observations - missing source_ids",
toolName: "merge_observations",
args: `{}`,
wantErr: true,
},
{
name: "get_observation - invalid json",
toolName: "get_observation",
args: `{invalid`,
wantErr: true,
},
{
name: "get_observation - missing id",
toolName: "get_observation",
args: `{}`,
wantErr: true,
},
{
name: "edit_observation - invalid json",
toolName: "edit_observation",
args: `{invalid`,
wantErr: true,
},
{
name: "edit_observation - missing id",
toolName: "edit_observation",
args: `{}`,
wantErr: true,
},
{
name: "get_observation_quality - invalid json",
toolName: "get_observation_quality",
args: `{invalid`,
wantErr: true,
},
{
name: "get_observation_quality - missing id",
toolName: "get_observation_quality",
args: `{}`,
wantErr: true,
},
{
name: "suggest_consolidations - invalid json",
toolName: "suggest_consolidations",
args: `{invalid`,
wantErr: true,
},
{
name: "tag_observation - invalid json",
toolName: "tag_observation",
args: `{invalid`,
wantErr: true,
},
{
name: "tag_observation - missing id",
toolName: "tag_observation",
args: `{}`,
wantErr: true,
},
{
name: "get_observations_by_tag - invalid json",
toolName: "get_observations_by_tag",
args: `{invalid`,
wantErr: true,
},
{
name: "get_observations_by_tag - missing tag",
toolName: "get_observations_by_tag",
args: `{}`,
wantErr: true,
},
{
name: "get_temporal_trends - invalid json",
toolName: "get_temporal_trends",
args: `{invalid`,
wantErr: true,
},
{
name: "get_data_quality_report - invalid json",
toolName: "get_data_quality_report",
args: `{invalid`,
wantErr: true,
},
{
name: "batch_tag_by_pattern - invalid json",
toolName: "batch_tag_by_pattern",
args: `{invalid`,
wantErr: true,
},
{
name: "batch_tag_by_pattern - missing pattern",
toolName: "batch_tag_by_pattern",
args: `{}`,
wantErr: true,
},
{
name: "explain_search_ranking - invalid json",
toolName: "explain_search_ranking",
args: `{invalid`,
wantErr: true,
},
{
name: "explain_search_ranking - missing query",
toolName: "explain_search_ranking",
args: `{}`,
wantErr: true,
},
{
name: "export_observations - invalid json",
toolName: "export_observations",
args: `{invalid`,
wantErr: true,
},
{
name: "analyze_search_patterns - invalid json",
toolName: "analyze_search_patterns",
args: `{invalid`,
wantErr: true,
},
{
name: "get_observation_relationships - invalid json",
toolName: "get_observation_relationships",
args: `{invalid`,
wantErr: true,
},
{
name: "get_observation_relationships - missing id",
toolName: "get_observation_relationships",
args: `{}`,
wantErr: true,
},
{
name: "get_observation_scoring_breakdown - invalid json",
toolName: "get_observation_scoring_breakdown",
args: `{invalid`,
wantErr: true,
},
{
name: "get_observation_scoring_breakdown - missing id",
toolName: "get_observation_scoring_breakdown",
args: `{}`,
wantErr: true,
},
{
name: "analyze_observation_importance - invalid json",
toolName: "analyze_observation_importance",
args: `{invalid`,
wantErr: true,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
t.Parallel()
result, err := server.callTool(ctx, tt.toolName, json.RawMessage(tt.args))
if tt.wantErr {
require.Error(t, err)
} else {
require.NoError(t, err)
assert.NotEmpty(t, result)
}
})
}
}
// TestCallTool_SearchTools tests search-based tools in callTool.
func TestCallTool_SearchTools(t *testing.T) {
t.Parallel()
server := NewServer(nil, "", "", "1.0.0")
ctx := context.Background()
// All search tools should fail when worker is unavailable (nil client)
searchTools := []string{
"search",
"decisions",
"changes",
"how_it_works",
"find_by_concept",
"find_by_file",
"find_by_type",
"get_recent_context",
}
for _, toolName := range searchTools {
t.Run(toolName+"_worker_unavailable", func(t *testing.T) {
t.Parallel()
_, err := server.callTool(ctx, toolName, json.RawMessage(`{"query":"test"}`))
require.Error(t, err)
assert.Contains(t, err.Error(), "worker unavailable")
})
}
// Timeline tools should handle invalid JSON with a parse error
timelineTools := []string{
"timeline",
"get_context_timeline",
"get_timeline_by_query",
}
for _, toolName := range timelineTools {
t.Run(toolName+"_invalid_json", func(t *testing.T) {
t.Parallel()
_, err := server.callTool(ctx, toolName, json.RawMessage(`{invalid`))
require.Error(t, err)
assert.Contains(t, err.Error(), "invalid")
})
}
}
// TestHandleTriggerMaintenance_NilService tests trigger_maintenance with nil service.
func TestHandleTriggerMaintenance_NilService(t *testing.T) {
t.Parallel()
server := NewServer(nil, "", "", "1.0.0")
ctx := context.Background()
// Should return error when maintenanceService is nil
_, err := server.proxyPostRaw(ctx, "/api/scoring/recalculate", nil)
require.Error(t, err)
assert.Contains(t, err.Error(), "worker unavailable")
}
// TestHandleGetMaintenanceStats_NilService tests get_maintenance_stats with nil service.
func TestHandleGetMaintenanceStats_NilService(t *testing.T) {
t.Parallel()
server := NewServer(nil, "", "", "1.0.0")
ctx := context.Background()
// Should return error when maintenanceService is nil
_, err := server.proxyGetRaw(ctx, "/api/stats", map[string]string{"project": ""})
require.Error(t, err)
assert.Contains(t, err.Error(), "worker unavailable")
}
// TestHandleTimeline_ParameterDefaultsNew tests timeline parameter defaults.
func TestHandleTimeline_ParameterDefaultsNew(t *testing.T) {
t.Parallel()
server := NewServer(nil, "", "", "1.0.0")
ctx := context.Background()
// Invalid JSON should fail
_, err := server.handleTimelineProxy(ctx, json.RawMessage(`{invalid`))
require.Error(t, err)
assert.Contains(t, err.Error(), "invalid timeline params")
}
// TestHandleTimelineByQuery_ValidationExtended tests timeline_by_query validation.
func TestHandleTimelineByQuery_ValidationExtended(t *testing.T) {
t.Parallel()
server := NewServer(nil, "", "", "1.0.0")
ctx := context.Background()
tests := []struct {
name string
args string
errContains string
wantErr bool
}{
{
name: "invalid json",
args: `{invalid`,
wantErr: true,
errContains: "invalid timeline params",
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
t.Parallel()
_, err := server.handleTimelineProxy(ctx, json.RawMessage(tt.args))
if tt.wantErr {
require.Error(t, err)
assert.Contains(t, err.Error(), tt.errContains)
}
})
}
}
// TestHandleSuggestConsolidations_ValidationExtended tests suggest_consolidations validation.
func TestHandleSuggestConsolidations_ValidationExtended(t *testing.T) {
t.Parallel()
server := NewServer(nil, "", "", "1.0.0")
ctx := context.Background()
tests := []struct {
name string
args string
errContains string
wantErr bool
}{
{
name: "invalid json",
args: `{invalid`,
wantErr: true,
errContains: "invalid arguments",
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
t.Parallel()
_, err := server.handleSuggestConsolidationsProxy(ctx, json.RawMessage(tt.args))
require.Error(t, err)
assert.Contains(t, err.Error(), tt.errContains)
})
}
}
// =============================================================================
// NIL GUARD HANDLER TESTS
// =============================================================================
// TestHandleFindSimilarObservations_NilVectorClient tests nil vector client handling.
func TestHandleFindSimilarObservations_NilVectorClient(t *testing.T) {
t.Parallel()
server := NewServer(nil, "", "", "1.0.0")
ctx := context.Background()
// Should return error when client is nil (worker unavailable)
_, err := server.handleFindSimilarProxy(ctx, json.RawMessage(`{"query": "test query"}`))
require.Error(t, err)
assert.Contains(t, err.Error(), "worker unavailable")
}
// TestHandleGetObservationRelationships_NilRelationStore tests nil relation store handling.
func TestHandleGetObservationRelationships_NilRelationStore(t *testing.T) {
t.Parallel()
server := NewServer(nil, "", "", "1.0.0")
ctx := context.Background()
// Should return error when worker is unavailable
_, err := server.handleGetRelationshipsProxy(ctx, json.RawMessage(`{"id": 123}`))
require.Error(t, err)
assert.Contains(t, err.Error(), "worker unavailable")
}
// =============================================================================
// MORE PARAM LIMIT TESTS
// =============================================================================
// TestHandleBulkBoostObservations_EmptyIDs tests the empty IDs validation.
func TestHandleBulkBoostObservations_EmptyIDs(t *testing.T) {
t.Parallel()
server := NewServer(nil, "", "", "1.0.0")
ctx := context.Background()
// Empty IDs should return error
_, err := server.handleBulkStatusProxy(ctx, json.RawMessage(`{"ids": []}`), "boost")
require.Error(t, err)
assert.Contains(t, err.Error(), "ids is required")
}
// TestHandleMergeObservations_MissingIDs tests merge with missing IDs.
func TestHandleMergeObservations_MissingIDs(t *testing.T) {
t.Parallel()
server := NewServer(nil, "", "", "1.0.0")
ctx := context.Background()
// source_id and target_id are required
_, err := server.handleMergeProxy(ctx, json.RawMessage(`{"source_id": 0, "target_id": 0}`))
require.Error(t, err)
assert.Contains(t, err.Error(), "source_id and target_id are required")
}
// TestHandleMergeObservations_WorkerUnavailable tests merge when worker is down.
func TestHandleMergeObservations_WorkerUnavailable(t *testing.T) {
t.Parallel()
server := NewServer(nil, "", "", "1.0.0")
ctx := context.Background()
// Should fail when worker is unavailable (nil client)
_, err := server.handleMergeProxy(ctx, json.RawMessage(`{"source_id": 1, "target_id": 2}`))
require.Error(t, err)
}
// =============================================================================
// REGRESSION TESTS — Fix #45: Concurrent request dispatching
// =============================================================================
// collectResponses reads newline-delimited JSON responses from r until it has
// collected n responses or the context is done. Returns collected responses.
func collectResponses(t *testing.T, r io.Reader, n int) []map[string]any {
t.Helper()
results := make([]map[string]any, 0, n)
scanner := bufio.NewScanner(r)
for scanner.Scan() {
line := scanner.Text()
if line == "" {
continue
}
var resp map[string]any
if err := json.Unmarshal([]byte(line), &resp); err != nil {
t.Logf("collectResponses: bad JSON line: %s", line)
continue
}
results = append(results, resp)
if len(results) >= n {
break
}
}
return results
}
// writeRequests writes newline-delimited JSON requests to w and closes it when done.
func writeRequests(t *testing.T, w io.WriteCloser, reqs []string) {
t.Helper()
for _, r := range reqs {
_, err := io.WriteString(w, r+"\n")
require.NoError(t, err)
}
_ = w.Close()
}
// TestRun_ConcurrentRequests verifies multiple requests are processed concurrently
// and not serially. If they ran serially, 5 × 100ms = 500ms. Concurrent should be
// well under 400ms on any reasonable machine.
func TestRun_ConcurrentRequests(t *testing.T) {
t.Parallel()
const delay = 100 * time.Millisecond
const numRequests = 5
// Mock worker: every request sleeps delay then returns "{}"
ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
time.Sleep(delay)
w.Header().Set("Content-Type", "application/json")
_, _ = io.WriteString(w, `{}`)
}))
defer ts.Close()
stdinR, stdinW := io.Pipe()
stdoutR, stdoutW := io.Pipe()
server := &Server{
client: ts.Client(),
workerURL: ts.URL,
project: "test",
version: "1.0.0",
stdin: stdinR,
stdout: stdoutW,
}
// Build requests — use get_memory_stats which goes to GET /api/stats
reqs := make([]string, numRequests)
for i := range reqs {
req := Request{JSONRPC: "2.0", ID: i + 1, Method: "tools/call",
Params: json.RawMessage(`{"name":"get_memory_stats","arguments":{}}`)}
data, err := json.Marshal(req)
require.NoError(t, err)
reqs[i] = string(data)
}
// Collect responses in background
var responses []map[string]any
var wg sync.WaitGroup
wg.Go(func() {
responses = collectResponses(t, stdoutR, numRequests)
_ = stdoutR.Close()
})
start := time.Now()
// Write all requests then close stdin (triggers Run to drain and return)
go writeRequests(t, stdinW, reqs)
err := server.Run(context.Background())
require.NoError(t, err)
_ = stdoutW.Close()
elapsed := time.Since(start)
wg.Wait()
// All responses received
assert.Len(t, responses, numRequests, "expected %d responses", numRequests)
// Concurrent execution: should be much less than numRequests × delay
serialTime := time.Duration(numRequests) * delay
assert.Less(t, elapsed, serialTime*4/5,
"elapsed %v not significantly less than serial %v — requests may be sequential", elapsed, serialTime)
// Each response has correct jsonrpc field
for _, resp := range responses {
assert.Equal(t, "2.0", resp["jsonrpc"])
}
}
// TestRun_SlowRequestDoesNotBlockOthers is the core regression for #45.
// A slow search request must not block a fast stats request from being answered first.
func TestRun_SlowRequestDoesNotBlockOthers(t *testing.T) {
t.Parallel()
const slowDelay = 300 * time.Millisecond
// responseOrder records which request IDs responded, in arrival order
var mu sync.Mutex
var responseOrder []any
// Mock worker: /api/context/search is slow, everything else is fast
ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
if strings.Contains(r.URL.Path, "/context/search") {
time.Sleep(slowDelay)
}
w.Header().Set("Content-Type", "application/json")
_, _ = io.WriteString(w, `{}`)
}))
defer ts.Close()
stdinR, stdinW := io.Pipe()
stdoutR, stdoutW := io.Pipe()
// Intercept stdout to record response order before passing through
pr, pw := io.Pipe()
go func() {
scanner := bufio.NewScanner(pr)
for scanner.Scan() {
line := scanner.Text()
var resp map[string]any
if err := json.Unmarshal([]byte(line), &resp); err == nil {
mu.Lock()
responseOrder = append(responseOrder, resp["id"])
mu.Unlock()
}
_, _ = io.WriteString(stdoutW, line+"\n")
}
_ = stdoutW.Close()
}()
server := &Server{
client: ts.Client(),
workerURL: ts.URL,
project: "test",
version: "1.0.0",
stdin: stdinR,
stdout: pw,
}
// Request 1: slow search (id=1)
slowReq := Request{JSONRPC: "2.0", ID: 1, Method: "tools/call",
Params: json.RawMessage(`{"name":"search","arguments":{"query":"anything"}}`)}
slowData, err := json.Marshal(slowReq)
require.NoError(t, err)
// Request 2: fast stats (id=2)
fastReq := Request{JSONRPC: "2.0", ID: 2, Method: "tools/call",
Params: json.RawMessage(`{"name":"get_memory_stats","arguments":{}}`)}
fastData, err := json.Marshal(fastReq)
require.NoError(t, err)
// Collect 2 responses
var responses []map[string]any
var wg sync.WaitGroup
wg.Go(func() {
responses = collectResponses(t, stdoutR, 2)
_ = stdoutR.Close()
})
// Write slow then fast, then close
go func() {
_, _ = io.WriteString(stdinW, string(slowData)+"\n")
// Small pause to ensure slow request goroutine is dispatched first
time.Sleep(10 * time.Millisecond)
_, _ = io.WriteString(stdinW, string(fastData)+"\n")
_ = stdinW.Close()
}()
runErr := server.Run(context.Background())
require.NoError(t, runErr)
wg.Wait()
require.Len(t, responses, 2, "expected 2 responses")
mu.Lock()
order := responseOrder
mu.Unlock()
require.Len(t, order, 2, "expected 2 recorded response IDs")
// The fast request (id=2) must arrive before the slow one (id=1)
assert.Equal(t, float64(2), order[0],
"fast request (id=2) should respond before slow request (id=1); got order %v", order)
assert.Equal(t, float64(1), order[1],
"slow request (id=1) should respond second; got order %v", order)
}
// TestRun_SemaphoreLimitsConcurrency verifies the semaphore cap (10) does not deadlock
// when more than 10 requests are sent and all eventually complete.
func TestRun_SemaphoreLimitsConcurrency(t *testing.T) {
t.Parallel()
const blockDelay = 200 * time.Millisecond
const numRequests = 15 // exceeds semaphore cap of 10
ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
time.Sleep(blockDelay)
w.Header().Set("Content-Type", "application/json")
_, _ = io.WriteString(w, `{}`)
}))
defer ts.Close()
stdinR, stdinW := io.Pipe()
stdoutR, stdoutW := io.Pipe()
server := &Server{
client: ts.Client(),
workerURL: ts.URL,
project: "test",
version: "1.0.0",
stdin: stdinR,
stdout: stdoutW,
}
reqs := make([]string, numRequests)
for i := range reqs {
req := Request{JSONRPC: "2.0", ID: i + 1, Method: "tools/call",
Params: json.RawMessage(`{"name":"get_memory_stats","arguments":{}}`)}
data, err := json.Marshal(req)
require.NoError(t, err)
reqs[i] = string(data)
}
var responses []map[string]any
var wg sync.WaitGroup
wg.Go(func() {
responses = collectResponses(t, stdoutR, numRequests)
_ = stdoutR.Close()
})
start := time.Now()
go writeRequests(t, stdinW, reqs)
err := server.Run(context.Background())
require.NoError(t, err)
_ = stdoutW.Close()
elapsed := time.Since(start)
wg.Wait()
// All 15 responses received — no deadlock
assert.Len(t, responses, numRequests, "all %d requests must complete", numRequests)
// With semaphore=10 and 15 requests at 200ms each, we need at least 2 batches.
// Should complete in ~2×blockDelay not 15×blockDelay.
// Upper bound: 3×blockDelay gives comfortable headroom for scheduling.
upperBound := 3 * blockDelay * 2 // generous: 3 batches + 2× overhead factor
assert.Less(t, elapsed, upperBound,
"elapsed %v suggests sequential processing (15×%v = %v)", elapsed, blockDelay, time.Duration(numRequests)*blockDelay)
}
// TestRun_GracefulDrainOnCancel verifies that cancelling the context causes Run to
// drain in-flight requests (wg.Wait) before returning ctx.Canceled.
func TestRun_GracefulDrainOnCancel(t *testing.T) {
t.Parallel()
const reqDelay = 200 * time.Millisecond
const numRequests = 3
ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
// Use a fixed sleep; the request-level context will be cancelled but the
// HTTP handler runs to completion independently (server-side).
time.Sleep(reqDelay)
w.Header().Set("Content-Type", "application/json")
_, _ = io.WriteString(w, `{}`)
}))
defer ts.Close()
stdinR, stdinW := io.Pipe()
stdoutR, stdoutW := io.Pipe()
server := &Server{
client: ts.Client(),
workerURL: ts.URL,
project: "test",
version: "1.0.0",
stdin: stdinR,
stdout: stdoutW,
}
ctx, cancel := context.WithCancel(context.Background())
// Build requests
reqs := make([]string, numRequests)
for i := range reqs {
req := Request{JSONRPC: "2.0", ID: i + 1, Method: "tools/call",
Params: json.RawMessage(`{"name":"get_memory_stats","arguments":{}}`)}
data, err := json.Marshal(req)
require.NoError(t, err)
reqs[i] = string(data)
}
// Write all requests before cancelling so they're all dispatched as goroutines
go func() {
for _, r := range reqs {
_, _ = io.WriteString(stdinW, r+"\n")
}
// Cancel after requests are dispatched but while they're still in-flight
time.Sleep(50 * time.Millisecond)
cancel()
// Leave stdin open — Run should return from the ctx.Done branch
}()
// Drain responses in background; we don't know exactly how many will complete
// because goroutines may get context cancelled on their HTTP calls too.
var responseMu sync.Mutex
var collectedResponses []map[string]any
var collectWg sync.WaitGroup
collectWg.Go(func() {
scanner := bufio.NewScanner(stdoutR)
for scanner.Scan() {
line := scanner.Text()
if line == "" {
continue
}
var resp map[string]any
if err := json.Unmarshal([]byte(line), &resp); err == nil {
responseMu.Lock()
collectedResponses = append(collectedResponses, resp)
responseMu.Unlock()
}
}
})
runErr := server.Run(ctx)
_ = stdoutW.Close()
_ = stdinW.Close()
collectWg.Wait()
// Core assertion: Run returned the context cancellation error
assert.ErrorIs(t, runErr, context.Canceled,
"Run must return context.Canceled when context is cancelled")
// Run returned only after wg.Wait() drained goroutines.
// The goroutines may have returned errors (ctx cancelled HTTP calls) but
// the key invariant is Run itself did not panic and returned cleanly.
// Any responses that did complete should be valid JSON-RPC.
responseMu.Lock()
defer responseMu.Unlock()
for _, resp := range collectedResponses {
assert.Equal(t, "2.0", resp["jsonrpc"], "any completed response must be valid JSON-RPC 2.0")
}
}
// =============================================================================
// REGRESSION TESTS — Fix #1-#5
// =============================================================================
// blockingWriter is an io.Writer that blocks forever on Write.
type blockingWriter struct {
blocked chan struct{} // closed when Write is entered
}
func (bw *blockingWriter) Write(p []byte) (int, error) {
if bw.blocked != nil {
close(bw.blocked)
}
select {} // block forever
}
// TestRun_SemaphoreDoesNotBlockMainLoop (Fix #1 regression) fills all semaphore
// slots with blocked requests, then sends a notification. The main loop must
// stay responsive and not hang on semaphore acquisition.
func TestRun_SemaphoreDoesNotBlockMainLoop(t *testing.T) {
t.Parallel()
const maxConcurrent = 10
// Mock worker that blocks until test context is cancelled
handlerCtx, handlerCancel := context.WithCancel(context.Background())
defer handlerCancel()
ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
<-handlerCtx.Done() // block until test cleanup
w.Header().Set("Content-Type", "application/json")
_, _ = io.WriteString(w, `{}`)
}))
defer func() {
handlerCancel() // unblock all handlers first
ts.CloseClientConnections()
ts.Close()
}()
stdinR, stdinW := io.Pipe()
var stdout bytes.Buffer
server := &Server{
client: ts.Client(),
workerURL: ts.URL,
project: "test",
version: "1.0.0",
stdin: stdinR,
stdout: &stdout,
}
ctx, cancel := context.WithCancel(context.Background())
defer cancel()
runDone := make(chan error, 1)
go func() {
runDone <- server.Run(ctx)
}()
// Send maxConcurrent+2 requests to fill all semaphore slots
for i := 0; i < maxConcurrent+2; i++ {
req := Request{JSONRPC: "2.0", ID: i + 1, Method: "tools/call",
Params: json.RawMessage(`{"name":"get_memory_stats","arguments":{}}`)}
data, _ := json.Marshal(req)
_, err := io.WriteString(stdinW, string(data)+"\n")
require.NoError(t, err)
}
// Give goroutines time to start and fill semaphore
time.Sleep(100 * time.Millisecond)
// Now send a notification — this must not hang the main loop
notifSent := make(chan struct{})
go func() {
notif := `{"jsonrpc":"2.0","method":"notifications/initialized"}` + "\n"
_, _ = io.WriteString(stdinW, notif)
close(notifSent)
}()
select {
case <-notifSent:
// Main loop accepted the notification write (stdin is a pipe, so
// the write completing means the reader consumed it)
case <-time.After(3 * time.Second):
t.Fatal("main loop blocked — semaphore acquisition is blocking the main goroutine")
}
// Clean up: cancel server context, close stdin
cancel()
_ = stdinW.Close()
// Wait for Run to finish
select {
case <-runDone:
case <-time.After(5 * time.Second):
// Acceptable — some goroutines may still be draining
}
}
// TestSendResponse_WriteTimeout (Fix #2 regression) verifies that sendResponse
// returns an error within a bounded time when the writer blocks forever.
func TestSendResponse_WriteTimeout(t *testing.T) {
t.Parallel()
bw := &blockingWriter{blocked: make(chan struct{})}
server := &Server{stdout: bw}
resp := &Response{
JSONRPC: "2.0",
ID: 1,
Result: "ok",
}
done := make(chan error, 1)
go func() {
done <- server.sendResponse(resp)
}()
select {
case err := <-done:
require.Error(t, err, "sendResponse must return error on write timeout")
assert.Contains(t, err.Error(), "write timeout")
case <-time.After(15 * time.Second):
t.Fatal("sendResponse hung forever — write timeout not working")
}
}
// TestSendResponse_MarshalError (Fix #3 regression) verifies that when
// json.Marshal fails, sendResponse sends a fallback error response and
// returns an error (instead of silently returning nil).
func TestSendResponse_MarshalError(t *testing.T) {
t.Parallel()
var buf bytes.Buffer
server := &Server{stdout: &buf}
// Channels are not JSON-serializable — this will cause json.Marshal to fail
resp := &Response{
JSONRPC: "2.0",
ID: 42,
Result: make(chan int), // unserializable
}
err := server.sendResponse(resp)
// (a) Must return an error
require.Error(t, err, "sendResponse must return error when marshal fails")
assert.Contains(t, err.Error(), "marshal error")
// (b) Must have written a fallback JSON-RPC error to stdout
output := buf.String()
require.NotEmpty(t, output, "fallback response must be written to stdout")
var fallback map[string]any
require.NoError(t, json.Unmarshal([]byte(strings.TrimSpace(output)), &fallback),
"fallback must be valid JSON")
assert.Equal(t, "2.0", fallback["jsonrpc"])
assert.Equal(t, float64(42), fallback["id"], "fallback must preserve original request ID")
errObj, ok := fallback["error"].(map[string]any)
require.True(t, ok, "fallback must contain error object")
assert.Equal(t, float64(-32603), errObj["code"])
assert.Equal(t, "internal marshal error", errObj["message"])
}
// TestHandleRequest_UnknownNotification (Fix #4 regression) verifies that
// unknown notification methods (ID == nil) get no response, while unknown
// methods with an ID still get a -32601 error.
func TestHandleRequest_UnknownNotification(t *testing.T) {
t.Parallel()
server := NewServer(nil, "", "", "1.0.0")
ctx := context.Background()
// Case 1: Unknown notification (no ID) — must return nil
notifReq := &Request{
JSONRPC: "2.0",
ID: nil,
Method: "notifications/roots/list_changed",
}
resp := server.handleRequest(ctx, notifReq)
assert.Nil(t, resp, "unknown notification must not produce a response")
// Case 2: Unknown method WITH an ID — must return error response
methodReq := &Request{
JSONRPC: "2.0",
ID: 99,
Method: "some/unknown/method",
}
resp = server.handleRequest(ctx, methodReq)
require.NotNil(t, resp, "unknown method with ID must produce an error response")
require.NotNil(t, resp.Error)
assert.Equal(t, -32601, resp.Error.Code)
assert.Equal(t, "Method not found", resp.Error.Message)
assert.Equal(t, 99, resp.ID)
}
// TestHandleToolsCall_ErrorUsesIsError (Fix #5 regression) verifies that when
// callTool returns an error, the response uses Result with isError:true instead
// of top-level Error field (per MCP spec).
func TestHandleToolsCall_ErrorUsesIsError(t *testing.T) {
t.Parallel()
server := NewServer(nil, "", "", "1.0.0")
ctx := context.Background()
req := &Request{
JSONRPC: "2.0",
ID: 7,
Method: "tools/call",
Params: json.RawMessage(`{"name":"nonexistent_tool","arguments":{}}`),
}
resp := server.handleToolsCall(ctx, req)
// (a) Response must NOT have top-level Error
assert.Nil(t, resp.Error, "tool errors must not use top-level JSON-RPC Error")
// (b) Response must have Result with isError: true
require.NotNil(t, resp.Result, "tool error response must have Result")
result, ok := resp.Result.(map[string]any)
require.True(t, ok, "Result must be a map")
assert.Equal(t, true, result["isError"], "Result must contain isError: true")
// (c) Result.content[0].text must contain the error message
content, ok := result["content"].([]map[string]any)
require.True(t, ok, "Result.content must be []map[string]any")
require.Len(t, content, 1)
assert.Equal(t, "text", content[0]["type"])
errText, ok := content[0]["text"].(string)
require.True(t, ok)
assert.Contains(t, errText, "unknown tool: nonexistent_tool")
}