mirror of
https://github.com/lukaszraczylo/claude-mnemonic.git
synced 2026-06-06 23:13:50 +00:00
a81482d06a
MCP server (5 fixes):
- Move semaphore acquisition inside goroutine so main loop stays
responsive when all slots are taken
- Add 10s write timeout to sendResponse to prevent pipe deadlock
when Claude Code pauses reading stdout
- Send fallback JSON-RPC error when json.Marshal fails instead of
silently swallowing the error and leaving caller waiting forever
- Silence unknown notification methods (req.ID == nil) instead of
sending unsolicited error responses that may desync the host
- Return MCP isError content for tool failures instead of top-level
JSON-RPC error, matching the MCP specification
Vector/embedding (3 fixes):
- Move EmbedBatchWithContext call before writeMu.Lock in AddDocuments
so ONNX inference runs outside the write lock
- Replace singleflight.Do with DoChan + ctx select in both
getOrComputeEmbedding and UnifiedSearch so callers can bail out
independently when their context expires
- Add activeQueries atomic counter; skip cache warming when user
queries are in-flight; reduce warming timeout from 5s to 2s
Hooks (4 fixes):
- Cap EnsureWorkerRunning to 15s hard deadline with context; reduce
StartupTimeout from 30s to 10s; reduce port-in-use retries
- Fix nil dereference panic in user-prompt hook when initResult is
nil (non-JSON worker response); use comma-ok assertions
- Use package-level hookClient/healthClient with DisableKeepAlives
to prevent FD leaks in short-lived hook processes
- Set SysProcAttr{Setpgid: true} to detach worker from hook process
group, preventing kill-cascade from Claude Code
Worker/DB (3 fixes):
- Replace os.Exit(0) in MCP config watcher with context cancellation
for clean protocol shutdown
- Add 60s context.WithTimeout around ProcessObservation calls in
processAllSessions to prevent hung CLI subprocesses from blocking
the queue processor forever
- Set explicit PRAGMA wal_autocheckpoint=1000 and add PASSIVE WAL
checkpoint to Optimize() to prevent checkpoint stalls
Adds 20+ regression tests across all fix areas.
3421 lines
86 KiB
Go
3421 lines
86 KiB
Go
// Package mcp provides the MCP (Model Context Protocol) server for claude-mnemonic.
|
||
package mcp
|
||
|
||
import (
|
||
"bufio"
|
||
"bytes"
|
||
"context"
|
||
"encoding/json"
|
||
"io"
|
||
"net/http"
|
||
"net/http/httptest"
|
||
"strings"
|
||
"sync"
|
||
"testing"
|
||
"time"
|
||
|
||
"github.com/stretchr/testify/assert"
|
||
"github.com/stretchr/testify/require"
|
||
"github.com/stretchr/testify/suite"
|
||
)
|
||
|
||
// timelineParams is used in tests for timeline request parsing.
|
||
type timelineParams struct {
|
||
Query string `json:"query"`
|
||
Project string `json:"project"`
|
||
ObsType string `json:"obs_type"`
|
||
Concepts string `json:"concepts"`
|
||
Files string `json:"files"`
|
||
Format string `json:"format"`
|
||
AnchorID int64 `json:"anchor_id"`
|
||
Before int `json:"before"`
|
||
After int `json:"after"`
|
||
DateStart int64 `json:"dateStart"`
|
||
DateEnd int64 `json:"dateEnd"`
|
||
}
|
||
|
||
// =============================================================================
|
||
// TEST SUITE
|
||
// =============================================================================
|
||
|
||
// ServerSuite is a test suite for MCP Server operations.
|
||
type ServerSuite struct {
|
||
suite.Suite
|
||
}
|
||
|
||
func TestServerSuite(t *testing.T) {
|
||
suite.Run(t, new(ServerSuite))
|
||
}
|
||
|
||
// TestNewServer tests server creation.
|
||
func (s *ServerSuite) TestNewServer() {
|
||
server := NewServer(nil, "http://localhost:37777", "test-project", "1.0.0")
|
||
s.NotNil(server)
|
||
s.Nil(server.client)
|
||
s.Equal("1.0.0", server.version)
|
||
s.Equal("http://localhost:37777", server.workerURL)
|
||
s.Equal("test-project", server.project)
|
||
}
|
||
|
||
// =============================================================================
|
||
// TESTS FOR Request/Response Structs
|
||
// =============================================================================
|
||
|
||
// TestRequest tests Request struct JSON marshaling.
|
||
func TestRequest(t *testing.T) {
|
||
t.Parallel()
|
||
|
||
tests := []struct {
|
||
name string
|
||
expected string
|
||
req Request
|
||
}{
|
||
// ===== GOOD CASES =====
|
||
{
|
||
name: "initialize request",
|
||
req: Request{
|
||
JSONRPC: "2.0",
|
||
ID: 1,
|
||
Method: "initialize",
|
||
},
|
||
expected: `{"jsonrpc":"2.0","id":1,"method":"initialize"}`,
|
||
},
|
||
{
|
||
name: "tools/list request",
|
||
req: Request{
|
||
JSONRPC: "2.0",
|
||
ID: "abc",
|
||
Method: "tools/list",
|
||
},
|
||
expected: `{"jsonrpc":"2.0","id":"abc","method":"tools/list"}`,
|
||
},
|
||
{
|
||
name: "tools/call with params",
|
||
req: Request{
|
||
JSONRPC: "2.0",
|
||
ID: 2,
|
||
Method: "tools/call",
|
||
Params: json.RawMessage(`{"name":"search","arguments":{}}`),
|
||
},
|
||
expected: `{"jsonrpc":"2.0","id":2,"method":"tools/call","params":{"name":"search","arguments":{}}}`,
|
||
},
|
||
// ===== EDGE CASES =====
|
||
{
|
||
name: "request with nil ID",
|
||
req: Request{
|
||
JSONRPC: "2.0",
|
||
ID: nil,
|
||
Method: "initialize",
|
||
},
|
||
expected: `{"jsonrpc":"2.0","id":null,"method":"initialize"}`,
|
||
},
|
||
{
|
||
name: "request with float ID",
|
||
req: Request{
|
||
JSONRPC: "2.0",
|
||
ID: 1.5,
|
||
Method: "test",
|
||
},
|
||
expected: `{"jsonrpc":"2.0","id":1.5,"method":"test"}`,
|
||
},
|
||
}
|
||
|
||
for _, tt := range tests {
|
||
t.Run(tt.name, func(t *testing.T) {
|
||
t.Parallel()
|
||
data, err := json.Marshal(tt.req)
|
||
require.NoError(t, err)
|
||
assert.JSONEq(t, tt.expected, string(data))
|
||
|
||
// Test unmarshaling
|
||
var parsed Request
|
||
err = json.Unmarshal(data, &parsed)
|
||
require.NoError(t, err)
|
||
assert.Equal(t, tt.req.JSONRPC, parsed.JSONRPC)
|
||
assert.Equal(t, tt.req.Method, parsed.Method)
|
||
})
|
||
}
|
||
}
|
||
|
||
// TestResponse tests Response struct JSON marshaling.
|
||
func TestResponse(t *testing.T) {
|
||
t.Parallel()
|
||
|
||
tests := []struct {
|
||
name string
|
||
resp Response
|
||
expected string
|
||
}{
|
||
// ===== GOOD CASES =====
|
||
{
|
||
name: "success response",
|
||
resp: Response{
|
||
JSONRPC: "2.0",
|
||
ID: 1,
|
||
Result: map[string]string{"status": "ok"},
|
||
},
|
||
expected: `{"jsonrpc":"2.0","id":1,"result":{"status":"ok"}}`,
|
||
},
|
||
{
|
||
name: "error response",
|
||
resp: Response{
|
||
JSONRPC: "2.0",
|
||
ID: 2,
|
||
Error: &Error{
|
||
Code: -32600,
|
||
Message: "Invalid Request",
|
||
},
|
||
},
|
||
expected: `{"jsonrpc":"2.0","id":2,"error":{"code":-32600,"message":"Invalid Request"}}`,
|
||
},
|
||
{
|
||
name: "error with data",
|
||
resp: Response{
|
||
JSONRPC: "2.0",
|
||
ID: 3,
|
||
Error: &Error{
|
||
Code: -32602,
|
||
Message: "Invalid params",
|
||
Data: "missing field",
|
||
},
|
||
},
|
||
expected: `{"jsonrpc":"2.0","id":3,"error":{"code":-32602,"message":"Invalid params","data":"missing field"}}`,
|
||
},
|
||
// ===== EDGE CASES =====
|
||
{
|
||
name: "response with nil ID",
|
||
resp: Response{
|
||
JSONRPC: "2.0",
|
||
ID: nil,
|
||
Result: "ok",
|
||
},
|
||
expected: `{"jsonrpc":"2.0","id":null,"result":"ok"}`,
|
||
},
|
||
}
|
||
|
||
for _, tt := range tests {
|
||
t.Run(tt.name, func(t *testing.T) {
|
||
t.Parallel()
|
||
data, err := json.Marshal(tt.resp)
|
||
require.NoError(t, err)
|
||
assert.JSONEq(t, tt.expected, string(data))
|
||
})
|
||
}
|
||
}
|
||
|
||
// TestError tests Error struct.
|
||
func TestError(t *testing.T) {
|
||
t.Parallel()
|
||
|
||
tests := []struct {
|
||
expected string
|
||
name string
|
||
err Error
|
||
}{
|
||
{
|
||
name: "parse error",
|
||
err: Error{
|
||
Code: -32700,
|
||
Message: "Parse error",
|
||
},
|
||
expected: `{"code":-32700,"message":"Parse error"}`,
|
||
},
|
||
{
|
||
name: "method not found",
|
||
err: Error{
|
||
Code: -32601,
|
||
Message: "Method not found",
|
||
},
|
||
expected: `{"code":-32601,"message":"Method not found"}`,
|
||
},
|
||
{
|
||
name: "invalid params",
|
||
err: Error{
|
||
Code: -32602,
|
||
Message: "Invalid params",
|
||
Data: "details here",
|
||
},
|
||
expected: `{"code":-32602,"message":"Invalid params","data":"details here"}`,
|
||
},
|
||
}
|
||
|
||
for _, tt := range tests {
|
||
t.Run(tt.name, func(t *testing.T) {
|
||
t.Parallel()
|
||
data, err := json.Marshal(tt.err)
|
||
require.NoError(t, err)
|
||
assert.JSONEq(t, tt.expected, string(data))
|
||
})
|
||
}
|
||
}
|
||
|
||
// TestToolCallParams tests ToolCallParams struct.
|
||
func TestToolCallParams(t *testing.T) {
|
||
t.Parallel()
|
||
|
||
tests := []struct {
|
||
name string
|
||
input string
|
||
expected ToolCallParams
|
||
}{
|
||
{
|
||
name: "search tool call",
|
||
input: `{"name":"search","arguments":{"query":"test"}}`,
|
||
expected: ToolCallParams{
|
||
Name: "search",
|
||
Arguments: json.RawMessage(`{"query":"test"}`),
|
||
},
|
||
},
|
||
{
|
||
name: "decisions tool call",
|
||
input: `{"name":"decisions","arguments":{"query":"auth"}}`,
|
||
expected: ToolCallParams{
|
||
Name: "decisions",
|
||
Arguments: json.RawMessage(`{"query":"auth"}`),
|
||
},
|
||
},
|
||
}
|
||
|
||
for _, tt := range tests {
|
||
t.Run(tt.name, func(t *testing.T) {
|
||
t.Parallel()
|
||
var params ToolCallParams
|
||
err := json.Unmarshal([]byte(tt.input), ¶ms)
|
||
require.NoError(t, err)
|
||
assert.Equal(t, tt.expected.Name, params.Name)
|
||
})
|
||
}
|
||
}
|
||
|
||
// TestTool tests Tool struct.
|
||
func TestTool(t *testing.T) {
|
||
t.Parallel()
|
||
|
||
tool := Tool{
|
||
Name: "search",
|
||
Description: "Search observations",
|
||
InputSchema: map[string]any{
|
||
"type": "object",
|
||
"properties": map[string]any{
|
||
"query": map[string]any{"type": "string"},
|
||
},
|
||
},
|
||
}
|
||
|
||
data, err := json.Marshal(tool)
|
||
require.NoError(t, err)
|
||
|
||
var parsed Tool
|
||
err = json.Unmarshal(data, &parsed)
|
||
require.NoError(t, err)
|
||
assert.Equal(t, "search", parsed.Name)
|
||
assert.Equal(t, "Search observations", parsed.Description)
|
||
}
|
||
|
||
// TestTimelineParamsStruct tests timelineParams struct.
|
||
func TestTimelineParamsStruct(t *testing.T) {
|
||
t.Parallel()
|
||
|
||
tests := []struct {
|
||
name string
|
||
input string
|
||
expected timelineParams
|
||
}{
|
||
{
|
||
name: "with anchor_id",
|
||
input: `{"anchor_id":123,"before":5,"after":5}`,
|
||
expected: timelineParams{
|
||
AnchorID: 123,
|
||
Before: 5,
|
||
After: 5,
|
||
},
|
||
},
|
||
{
|
||
name: "with query",
|
||
input: `{"query":"test query","project":"my-project"}`,
|
||
expected: timelineParams{
|
||
Query: "test query",
|
||
Project: "my-project",
|
||
},
|
||
},
|
||
{
|
||
name: "full params",
|
||
input: `{"anchor_id":100,"query":"search","before":10,"after":20,"project":"proj","obs_type":"bugfix","concepts":"security","files":"main.go","dateStart":1234567890,"dateEnd":9876543210,"format":"full"}`,
|
||
expected: timelineParams{
|
||
AnchorID: 100,
|
||
Query: "search",
|
||
Before: 10,
|
||
After: 20,
|
||
Project: "proj",
|
||
ObsType: "bugfix",
|
||
Concepts: "security",
|
||
Files: "main.go",
|
||
DateStart: 1234567890,
|
||
DateEnd: 9876543210,
|
||
Format: "full",
|
||
},
|
||
},
|
||
}
|
||
|
||
for _, tt := range tests {
|
||
t.Run(tt.name, func(t *testing.T) {
|
||
t.Parallel()
|
||
var params timelineParams
|
||
err := json.Unmarshal([]byte(tt.input), ¶ms)
|
||
require.NoError(t, err)
|
||
assert.Equal(t, tt.expected.AnchorID, params.AnchorID)
|
||
assert.Equal(t, tt.expected.Query, params.Query)
|
||
assert.Equal(t, tt.expected.Project, params.Project)
|
||
})
|
||
}
|
||
}
|
||
|
||
// =============================================================================
|
||
// TESTS FOR Server Handlers
|
||
// =============================================================================
|
||
|
||
// TestHandleInitialize tests the initialize handler.
|
||
func TestHandleInitialize(t *testing.T) {
|
||
t.Parallel()
|
||
|
||
server := NewServer(nil, "", "", "1.2.3")
|
||
|
||
req := &Request{
|
||
JSONRPC: "2.0",
|
||
ID: 1,
|
||
Method: "initialize",
|
||
}
|
||
|
||
resp := server.handleInitialize(req)
|
||
|
||
assert.Equal(t, "2.0", resp.JSONRPC)
|
||
assert.Equal(t, 1, resp.ID)
|
||
assert.Nil(t, resp.Error)
|
||
assert.NotNil(t, resp.Result)
|
||
|
||
result, ok := resp.Result.(map[string]any)
|
||
require.True(t, ok)
|
||
assert.Equal(t, "2024-11-05", result["protocolVersion"])
|
||
|
||
serverInfo, ok := result["serverInfo"].(map[string]any)
|
||
require.True(t, ok)
|
||
assert.Equal(t, "claude-mnemonic", serverInfo["name"])
|
||
assert.Equal(t, "1.2.3", serverInfo["version"])
|
||
}
|
||
|
||
// TestHandleToolsList tests the tools/list handler.
|
||
func TestHandleToolsList(t *testing.T) {
|
||
t.Parallel()
|
||
|
||
server := NewServer(nil, "", "", "1.0.0")
|
||
|
||
req := &Request{
|
||
JSONRPC: "2.0",
|
||
ID: 1,
|
||
Method: "tools/list",
|
||
}
|
||
|
||
resp := server.handleToolsList(req)
|
||
|
||
assert.Equal(t, "2.0", resp.JSONRPC)
|
||
assert.Equal(t, 1, resp.ID)
|
||
assert.Nil(t, resp.Error)
|
||
|
||
result, ok := resp.Result.(map[string]any)
|
||
require.True(t, ok)
|
||
|
||
tools, ok := result["tools"].([]Tool)
|
||
require.True(t, ok)
|
||
assert.NotEmpty(t, tools)
|
||
|
||
// Verify expected tools are present
|
||
toolNames := make(map[string]bool)
|
||
for _, tool := range tools {
|
||
toolNames[tool.Name] = true
|
||
}
|
||
|
||
expectedTools := []string{
|
||
"search", "timeline", "decisions", "changes",
|
||
"how_it_works", "find_by_concept", "find_by_file",
|
||
"find_by_type", "get_recent_context", "get_context_timeline",
|
||
"get_timeline_by_query",
|
||
}
|
||
|
||
for _, name := range expectedTools {
|
||
assert.True(t, toolNames[name], "expected tool %s to be present", name)
|
||
}
|
||
}
|
||
|
||
// TestHandleRequest tests request routing.
|
||
func TestHandleRequest(t *testing.T) {
|
||
t.Parallel()
|
||
|
||
server := NewServer(nil, "", "", "1.0.0")
|
||
ctx := context.Background()
|
||
|
||
tests := []struct {
|
||
req *Request
|
||
name string
|
||
errorMessage string
|
||
errorCode int
|
||
expectError bool
|
||
}{
|
||
{
|
||
name: "initialize method",
|
||
req: &Request{
|
||
JSONRPC: "2.0",
|
||
ID: 1,
|
||
Method: "initialize",
|
||
},
|
||
expectError: false,
|
||
},
|
||
{
|
||
name: "tools/list method",
|
||
req: &Request{
|
||
JSONRPC: "2.0",
|
||
ID: 2,
|
||
Method: "tools/list",
|
||
},
|
||
expectError: false,
|
||
},
|
||
{
|
||
name: "unknown method",
|
||
req: &Request{
|
||
JSONRPC: "2.0",
|
||
ID: 3,
|
||
Method: "unknown_method",
|
||
},
|
||
expectError: true,
|
||
errorCode: -32601,
|
||
errorMessage: "Method not found",
|
||
},
|
||
}
|
||
|
||
for _, tt := range tests {
|
||
t.Run(tt.name, func(t *testing.T) {
|
||
t.Parallel()
|
||
resp := server.handleRequest(ctx, tt.req)
|
||
|
||
assert.Equal(t, "2.0", resp.JSONRPC)
|
||
assert.Equal(t, tt.req.ID, resp.ID)
|
||
|
||
if tt.expectError {
|
||
require.NotNil(t, resp.Error)
|
||
assert.Equal(t, tt.errorCode, resp.Error.Code)
|
||
assert.Equal(t, tt.errorMessage, resp.Error.Message)
|
||
} else {
|
||
assert.Nil(t, resp.Error)
|
||
assert.NotNil(t, resp.Result)
|
||
}
|
||
})
|
||
}
|
||
}
|
||
|
||
// TestHandleToolsCall_InvalidParams tests tools/call with invalid params.
|
||
func TestHandleToolsCall_InvalidParams(t *testing.T) {
|
||
t.Parallel()
|
||
|
||
server := NewServer(nil, "", "", "1.0.0")
|
||
ctx := context.Background()
|
||
|
||
req := &Request{
|
||
JSONRPC: "2.0",
|
||
ID: 1,
|
||
Method: "tools/call",
|
||
Params: json.RawMessage(`invalid json`),
|
||
}
|
||
|
||
resp := server.handleToolsCall(ctx, req)
|
||
|
||
require.NotNil(t, resp.Error)
|
||
assert.Equal(t, -32602, resp.Error.Code)
|
||
assert.Equal(t, "Invalid params", resp.Error.Message)
|
||
}
|
||
|
||
// TestCallTool_UnknownTool tests callTool with unknown tool name.
|
||
func TestCallTool_UnknownTool(t *testing.T) {
|
||
t.Parallel()
|
||
|
||
server := NewServer(nil, "", "", "1.0.0")
|
||
ctx := context.Background()
|
||
|
||
_, err := server.callTool(ctx, "nonexistent_tool", json.RawMessage(`{}`))
|
||
require.Error(t, err)
|
||
assert.Contains(t, err.Error(), "unknown tool")
|
||
}
|
||
|
||
// TestCallTool_InvalidArgs tests callTool with invalid arguments.
|
||
func TestCallTool_InvalidArgs(t *testing.T) {
|
||
t.Parallel()
|
||
|
||
server := NewServer(nil, "", "", "1.0.0")
|
||
ctx := context.Background()
|
||
|
||
// Invalid JSON is best-effort parsed; the call fails because worker is unavailable
|
||
_, err := server.callTool(ctx, "search", json.RawMessage(`invalid json`))
|
||
require.Error(t, err)
|
||
assert.Contains(t, err.Error(), "worker unavailable")
|
||
}
|
||
|
||
// =============================================================================
|
||
// TESTS FOR Server I/O
|
||
// =============================================================================
|
||
|
||
// TestSendResponse tests response sending.
|
||
func TestSendResponse(t *testing.T) {
|
||
t.Parallel()
|
||
|
||
var buf bytes.Buffer
|
||
server := &Server{
|
||
stdout: &buf,
|
||
}
|
||
|
||
resp := &Response{
|
||
JSONRPC: "2.0",
|
||
ID: 1,
|
||
Result: map[string]string{"status": "ok"},
|
||
}
|
||
|
||
server.sendResponse(resp)
|
||
|
||
output := buf.String()
|
||
assert.Contains(t, output, `"jsonrpc":"2.0"`)
|
||
assert.Contains(t, output, `"id":1`)
|
||
assert.Contains(t, output, `"result"`)
|
||
}
|
||
|
||
// TestSendError tests error response sending.
|
||
func TestSendError(t *testing.T) {
|
||
t.Parallel()
|
||
|
||
var buf bytes.Buffer
|
||
server := &Server{
|
||
stdout: &buf,
|
||
}
|
||
|
||
server.sendError(1, -32700, "Parse error", "details")
|
||
|
||
output := buf.String()
|
||
assert.Contains(t, output, `"error"`)
|
||
assert.Contains(t, output, `-32700`)
|
||
assert.Contains(t, output, `"Parse error"`)
|
||
}
|
||
|
||
// TestRun_ParseError tests Run with invalid JSON input.
|
||
func TestRun_ParseError(t *testing.T) {
|
||
t.Parallel()
|
||
|
||
var stdout bytes.Buffer
|
||
stdin := strings.NewReader("invalid json\n")
|
||
|
||
server := &Server{
|
||
stdin: stdin,
|
||
stdout: &stdout,
|
||
}
|
||
|
||
err := server.Run(context.Background())
|
||
require.NoError(t, err)
|
||
|
||
output := stdout.String()
|
||
assert.Contains(t, output, `"error"`)
|
||
assert.Contains(t, output, `-32700`)
|
||
assert.Contains(t, output, `"Parse error"`)
|
||
}
|
||
|
||
// TestRun_EmptyLine tests Run skips empty lines.
|
||
func TestRun_EmptyLine(t *testing.T) {
|
||
t.Parallel()
|
||
|
||
var stdout bytes.Buffer
|
||
stdin := strings.NewReader("\n\n")
|
||
|
||
server := &Server{
|
||
stdin: stdin,
|
||
stdout: &stdout,
|
||
}
|
||
|
||
err := server.Run(context.Background())
|
||
require.NoError(t, err)
|
||
|
||
// Should be empty - no responses for empty lines
|
||
assert.Empty(t, stdout.String())
|
||
}
|
||
|
||
// TestRun_ValidRequest tests Run with a valid request.
|
||
func TestRun_ValidRequest(t *testing.T) {
|
||
t.Parallel()
|
||
|
||
var stdout bytes.Buffer
|
||
req := `{"jsonrpc":"2.0","id":1,"method":"initialize"}`
|
||
stdin := strings.NewReader(req + "\n")
|
||
|
||
server := &Server{
|
||
stdin: stdin,
|
||
stdout: &stdout,
|
||
version: "1.0.0",
|
||
}
|
||
|
||
err := server.Run(context.Background())
|
||
require.NoError(t, err)
|
||
|
||
output := stdout.String()
|
||
assert.Contains(t, output, `"jsonrpc":"2.0"`)
|
||
assert.Contains(t, output, `"result"`)
|
||
assert.Contains(t, output, `"protocolVersion"`)
|
||
}
|
||
|
||
// TestRun_MultipleRequests tests Run with multiple sequential requests.
|
||
func TestRun_MultipleRequests(t *testing.T) {
|
||
t.Parallel()
|
||
|
||
var stdout bytes.Buffer
|
||
req1 := `{"jsonrpc":"2.0","id":1,"method":"initialize"}`
|
||
req2 := `{"jsonrpc":"2.0","id":2,"method":"tools/list"}`
|
||
stdin := strings.NewReader(req1 + "\n" + req2 + "\n")
|
||
|
||
server := &Server{
|
||
stdin: stdin,
|
||
stdout: &stdout,
|
||
version: "1.0.0",
|
||
}
|
||
|
||
err := server.Run(context.Background())
|
||
require.NoError(t, err)
|
||
|
||
output := stdout.String()
|
||
// Should contain responses for both requests
|
||
assert.Contains(t, output, `"id":1`)
|
||
assert.Contains(t, output, `"id":2`)
|
||
}
|
||
|
||
// TestRunMixedRequests tests Run with mixed valid and invalid requests.
|
||
func TestRunMixedRequests(t *testing.T) {
|
||
t.Parallel()
|
||
|
||
var stdout bytes.Buffer
|
||
req1 := `{"jsonrpc":"2.0","id":1,"method":"initialize"}`
|
||
req2 := `invalid json`
|
||
req3 := `{"jsonrpc":"2.0","id":3,"method":"tools/list"}`
|
||
stdin := strings.NewReader(req1 + "\n" + req2 + "\n" + req3 + "\n")
|
||
|
||
server := &Server{
|
||
stdin: stdin,
|
||
stdout: &stdout,
|
||
version: "1.0.0",
|
||
}
|
||
|
||
err := server.Run(context.Background())
|
||
require.NoError(t, err)
|
||
|
||
output := stdout.String()
|
||
// Should have responses for all three requests
|
||
assert.Contains(t, output, `"id":1`)
|
||
assert.Contains(t, output, `"error"`) // Parse error for invalid json
|
||
assert.Contains(t, output, `"id":3`)
|
||
}
|
||
|
||
// =============================================================================
|
||
// TESTS FOR Handler Parameter Validation
|
||
// =============================================================================
|
||
|
||
// TestHandleFindRelatedObservations_Validation tests parameter validation.
|
||
func TestHandleFindRelatedObservations_Validation(t *testing.T) {
|
||
t.Parallel()
|
||
|
||
server := NewServer(nil, "", "", "1.0.0")
|
||
ctx := context.Background()
|
||
|
||
tests := []struct {
|
||
name string
|
||
args string
|
||
errContains string
|
||
wantErr bool
|
||
}{
|
||
{
|
||
name: "worker unavailable",
|
||
args: `{"id": 1}`,
|
||
wantErr: true,
|
||
errContains: "worker unavailable",
|
||
},
|
||
{
|
||
name: "missing id",
|
||
args: `{}`,
|
||
wantErr: true,
|
||
errContains: "id is required",
|
||
},
|
||
{
|
||
name: "invalid json",
|
||
args: `{invalid`,
|
||
wantErr: true,
|
||
errContains: "invalid arguments",
|
||
},
|
||
}
|
||
|
||
for _, tt := range tests {
|
||
t.Run(tt.name, func(t *testing.T) {
|
||
t.Parallel()
|
||
_, err := server.handleFindRelatedProxy(ctx, json.RawMessage(tt.args))
|
||
if tt.wantErr {
|
||
require.Error(t, err)
|
||
if tt.errContains != "" {
|
||
assert.Contains(t, err.Error(), tt.errContains)
|
||
}
|
||
} else {
|
||
require.NoError(t, err)
|
||
}
|
||
})
|
||
}
|
||
}
|
||
|
||
// TestHandleFindSimilarObservations_Validation tests parameter validation.
|
||
func TestHandleFindSimilarObservations_Validation(t *testing.T) {
|
||
t.Parallel()
|
||
|
||
server := NewServer(nil, "", "", "1.0.0")
|
||
ctx := context.Background()
|
||
|
||
tests := []struct {
|
||
name string
|
||
args string
|
||
errContains string
|
||
wantErr bool
|
||
}{
|
||
{
|
||
name: "missing query",
|
||
args: `{}`,
|
||
wantErr: true,
|
||
errContains: "query is required",
|
||
},
|
||
{
|
||
name: "invalid json",
|
||
args: `{invalid`,
|
||
wantErr: true,
|
||
errContains: "invalid arguments",
|
||
},
|
||
{
|
||
name: "nil vector client",
|
||
args: `{"query": "test"}`,
|
||
wantErr: true,
|
||
errContains: "worker unavailable",
|
||
},
|
||
}
|
||
|
||
for _, tt := range tests {
|
||
t.Run(tt.name, func(t *testing.T) {
|
||
t.Parallel()
|
||
_, err := server.handleFindSimilarProxy(ctx, json.RawMessage(tt.args))
|
||
if tt.wantErr {
|
||
require.Error(t, err)
|
||
if tt.errContains != "" {
|
||
assert.Contains(t, err.Error(), tt.errContains)
|
||
}
|
||
} else {
|
||
require.NoError(t, err)
|
||
}
|
||
})
|
||
}
|
||
}
|
||
|
||
// TestHandleGetPatterns_Validation tests parameter validation.
|
||
func TestHandleGetPatterns_Validation(t *testing.T) {
|
||
t.Parallel()
|
||
|
||
server := NewServer(nil, "", "", "1.0.0")
|
||
ctx := context.Background()
|
||
|
||
tests := []struct {
|
||
name string
|
||
args string
|
||
errContains string
|
||
wantErr bool
|
||
}{
|
||
{
|
||
name: "invalid json",
|
||
args: `{invalid`,
|
||
wantErr: true,
|
||
errContains: "invalid arguments",
|
||
},
|
||
}
|
||
|
||
for _, tt := range tests {
|
||
t.Run(tt.name, func(t *testing.T) {
|
||
t.Parallel()
|
||
_, err := server.handleGetPatternsProxy(ctx, json.RawMessage(tt.args))
|
||
if tt.wantErr {
|
||
require.Error(t, err)
|
||
if tt.errContains != "" {
|
||
assert.Contains(t, err.Error(), tt.errContains)
|
||
}
|
||
} else {
|
||
require.NoError(t, err)
|
||
}
|
||
})
|
||
}
|
||
}
|
||
|
||
// TestHandleBulkDeleteObservations_Validation tests parameter validation.
|
||
func TestHandleBulkDeleteObservations_Validation(t *testing.T) {
|
||
t.Parallel()
|
||
|
||
server := NewServer(nil, "", "", "1.0.0")
|
||
ctx := context.Background()
|
||
|
||
tests := []struct {
|
||
name string
|
||
args string
|
||
errContains string
|
||
wantErr bool
|
||
}{
|
||
{
|
||
name: "missing ids",
|
||
args: `{}`,
|
||
wantErr: true,
|
||
errContains: "ids is required",
|
||
},
|
||
{
|
||
name: "empty ids array",
|
||
args: `{"ids": []}`,
|
||
wantErr: true,
|
||
errContains: "ids is required",
|
||
},
|
||
{
|
||
name: "too many ids",
|
||
args: `{"ids": [` + strings.Repeat("1,", 1001) + `1]}`,
|
||
wantErr: true,
|
||
errContains: "worker unavailable",
|
||
},
|
||
{
|
||
name: "invalid json",
|
||
args: `{invalid`,
|
||
wantErr: true,
|
||
errContains: "invalid arguments",
|
||
},
|
||
}
|
||
|
||
for _, tt := range tests {
|
||
t.Run(tt.name, func(t *testing.T) {
|
||
t.Parallel()
|
||
_, err := server.handleBulkStatusProxy(ctx, json.RawMessage(tt.args), "delete")
|
||
if tt.wantErr {
|
||
require.Error(t, err)
|
||
if tt.errContains != "" {
|
||
assert.Contains(t, err.Error(), tt.errContains)
|
||
}
|
||
} else {
|
||
require.NoError(t, err)
|
||
}
|
||
})
|
||
}
|
||
}
|
||
|
||
// TestHandleBulkMarkSuperseded_Validation tests parameter validation.
|
||
func TestHandleBulkMarkSuperseded_Validation(t *testing.T) {
|
||
t.Parallel()
|
||
|
||
server := NewServer(nil, "", "", "1.0.0")
|
||
ctx := context.Background()
|
||
|
||
tests := []struct {
|
||
name string
|
||
args string
|
||
errContains string
|
||
wantErr bool
|
||
}{
|
||
{
|
||
name: "missing ids",
|
||
args: `{}`,
|
||
wantErr: true,
|
||
errContains: "ids is required",
|
||
},
|
||
{
|
||
name: "empty ids array",
|
||
args: `{"ids": []}`,
|
||
wantErr: true,
|
||
errContains: "ids is required",
|
||
},
|
||
{
|
||
name: "too many ids",
|
||
args: `{"ids": [` + strings.Repeat("1,", 1001) + `1]}`,
|
||
wantErr: true,
|
||
errContains: "worker unavailable",
|
||
},
|
||
}
|
||
|
||
for _, tt := range tests {
|
||
t.Run(tt.name, func(t *testing.T) {
|
||
t.Parallel()
|
||
_, err := server.handleBulkStatusProxy(ctx, json.RawMessage(tt.args), "supersede")
|
||
if tt.wantErr {
|
||
require.Error(t, err)
|
||
if tt.errContains != "" {
|
||
assert.Contains(t, err.Error(), tt.errContains)
|
||
}
|
||
} else {
|
||
require.NoError(t, err)
|
||
}
|
||
})
|
||
}
|
||
}
|
||
|
||
// TestHandleBulkBoostObservations_Validation tests parameter validation.
|
||
func TestHandleBulkBoostObservations_Validation(t *testing.T) {
|
||
t.Parallel()
|
||
|
||
server := NewServer(nil, "", "", "1.0.0")
|
||
ctx := context.Background()
|
||
|
||
tests := []struct {
|
||
name string
|
||
args string
|
||
errContains string
|
||
wantErr bool
|
||
}{
|
||
{
|
||
name: "missing ids",
|
||
args: `{"boost": 0.1}`,
|
||
wantErr: true,
|
||
errContains: "ids is required",
|
||
},
|
||
{
|
||
name: "boost out of range low",
|
||
args: `{"ids": [1], "boost": -1.5}`,
|
||
wantErr: true,
|
||
errContains: "worker unavailable",
|
||
},
|
||
{
|
||
name: "boost out of range high",
|
||
args: `{"ids": [1], "boost": 1.5}`,
|
||
wantErr: true,
|
||
errContains: "worker unavailable",
|
||
},
|
||
{
|
||
name: "too many ids",
|
||
args: `{"ids": [` + strings.Repeat("1,", 1001) + `1], "boost": 0.1}`,
|
||
wantErr: true,
|
||
errContains: "worker unavailable",
|
||
},
|
||
}
|
||
|
||
for _, tt := range tests {
|
||
t.Run(tt.name, func(t *testing.T) {
|
||
t.Parallel()
|
||
_, err := server.handleBulkStatusProxy(ctx, json.RawMessage(tt.args), "boost")
|
||
if tt.wantErr {
|
||
require.Error(t, err)
|
||
if tt.errContains != "" {
|
||
assert.Contains(t, err.Error(), tt.errContains)
|
||
}
|
||
} else {
|
||
require.NoError(t, err)
|
||
}
|
||
})
|
||
}
|
||
}
|
||
|
||
// TestHandleTriggerMaintenance_Validation tests that nil service returns error.
|
||
func TestHandleTriggerMaintenance_Validation(t *testing.T) {
|
||
t.Parallel()
|
||
|
||
server := NewServer(nil, "", "", "1.0.0")
|
||
ctx := context.Background()
|
||
|
||
_, err := server.proxyPostRaw(ctx, "/api/scoring/recalculate", nil)
|
||
require.Error(t, err)
|
||
assert.Contains(t, err.Error(), "worker unavailable")
|
||
}
|
||
|
||
// TestHandleGetMaintenanceStats_Validation tests that nil service returns error.
|
||
func TestHandleGetMaintenanceStats_Validation(t *testing.T) {
|
||
t.Parallel()
|
||
|
||
server := NewServer(nil, "", "", "1.0.0")
|
||
ctx := context.Background()
|
||
|
||
_, err := server.proxyGetRaw(ctx, "/api/stats", map[string]string{"project": ""})
|
||
require.Error(t, err)
|
||
assert.Contains(t, err.Error(), "worker unavailable")
|
||
}
|
||
|
||
// TestHandleMergeObservations_Validation tests parameter validation.
|
||
func TestHandleMergeObservations_Validation(t *testing.T) {
|
||
t.Parallel()
|
||
|
||
server := NewServer(nil, "", "", "1.0.0")
|
||
ctx := context.Background()
|
||
|
||
tests := []struct {
|
||
name string
|
||
args string
|
||
errContains string
|
||
wantErr bool
|
||
}{
|
||
{
|
||
name: "missing source_id",
|
||
args: `{"target_id": 2}`,
|
||
wantErr: true,
|
||
errContains: "source_id and target_id are required",
|
||
},
|
||
{
|
||
name: "missing target_id",
|
||
args: `{"source_id": 1}`,
|
||
wantErr: true,
|
||
errContains: "source_id and target_id are required",
|
||
},
|
||
{
|
||
name: "same source and target",
|
||
args: `{"source_id": 1, "target_id": 1}`,
|
||
wantErr: true,
|
||
errContains: "worker unavailable",
|
||
},
|
||
{
|
||
name: "worker unavailable with boost",
|
||
args: `{"source_id": 1, "target_id": 2, "boost": 0.6}`,
|
||
wantErr: true,
|
||
errContains: "worker unavailable",
|
||
},
|
||
}
|
||
|
||
for _, tt := range tests {
|
||
t.Run(tt.name, func(t *testing.T) {
|
||
t.Parallel()
|
||
_, err := server.handleMergeProxy(ctx, json.RawMessage(tt.args))
|
||
if tt.wantErr {
|
||
require.Error(t, err)
|
||
if tt.errContains != "" {
|
||
assert.Contains(t, err.Error(), tt.errContains)
|
||
}
|
||
} else {
|
||
require.NoError(t, err)
|
||
}
|
||
})
|
||
}
|
||
}
|
||
|
||
// TestHandleGetObservation_Validation tests parameter validation.
|
||
func TestHandleGetObservation_Validation(t *testing.T) {
|
||
t.Parallel()
|
||
|
||
server := NewServer(nil, "", "", "1.0.0")
|
||
ctx := context.Background()
|
||
|
||
tests := []struct {
|
||
name string
|
||
args string
|
||
errContains string
|
||
wantErr bool
|
||
}{
|
||
{
|
||
name: "missing id",
|
||
args: `{}`,
|
||
wantErr: true,
|
||
errContains: "id is required",
|
||
},
|
||
{
|
||
name: "invalid json",
|
||
args: `{invalid`,
|
||
wantErr: true,
|
||
errContains: "invalid arguments",
|
||
},
|
||
}
|
||
|
||
for _, tt := range tests {
|
||
t.Run(tt.name, func(t *testing.T) {
|
||
t.Parallel()
|
||
_, err := server.handleGetObservationProxy(ctx, json.RawMessage(tt.args))
|
||
if tt.wantErr {
|
||
require.Error(t, err)
|
||
if tt.errContains != "" {
|
||
assert.Contains(t, err.Error(), tt.errContains)
|
||
}
|
||
} else {
|
||
require.NoError(t, err)
|
||
}
|
||
})
|
||
}
|
||
}
|
||
|
||
// TestHandleEditObservation_Validation tests parameter validation.
|
||
func TestHandleEditObservation_Validation(t *testing.T) {
|
||
t.Parallel()
|
||
|
||
server := NewServer(nil, "", "", "1.0.0")
|
||
ctx := context.Background()
|
||
|
||
tests := []struct {
|
||
name string
|
||
args string
|
||
errContains string
|
||
wantErr bool
|
||
}{
|
||
{
|
||
name: "missing id",
|
||
args: `{"title": "new title"}`,
|
||
wantErr: true,
|
||
errContains: "id is required",
|
||
},
|
||
{
|
||
name: "invalid scope",
|
||
args: `{"id": 1, "scope": "invalid"}`,
|
||
wantErr: true,
|
||
errContains: "worker unavailable",
|
||
},
|
||
{
|
||
name: "invalid json",
|
||
args: `{invalid`,
|
||
wantErr: true,
|
||
errContains: "invalid arguments",
|
||
},
|
||
}
|
||
|
||
for _, tt := range tests {
|
||
t.Run(tt.name, func(t *testing.T) {
|
||
t.Parallel()
|
||
_, err := server.handleEditObservationProxy(ctx, json.RawMessage(tt.args))
|
||
if tt.wantErr {
|
||
require.Error(t, err)
|
||
if tt.errContains != "" {
|
||
assert.Contains(t, err.Error(), tt.errContains)
|
||
}
|
||
} else {
|
||
require.NoError(t, err)
|
||
}
|
||
})
|
||
}
|
||
}
|
||
|
||
// TestHandleGetObservationQuality_Validation tests parameter validation.
|
||
func TestHandleGetObservationQuality_Validation(t *testing.T) {
|
||
t.Parallel()
|
||
|
||
server := NewServer(nil, "", "", "1.0.0")
|
||
ctx := context.Background()
|
||
|
||
tests := []struct {
|
||
name string
|
||
args string
|
||
errContains string
|
||
wantErr bool
|
||
}{
|
||
{
|
||
name: "missing id",
|
||
args: `{}`,
|
||
wantErr: true,
|
||
errContains: "id is required",
|
||
},
|
||
{
|
||
name: "invalid json",
|
||
args: `{invalid`,
|
||
wantErr: true,
|
||
errContains: "invalid arguments",
|
||
},
|
||
}
|
||
|
||
for _, tt := range tests {
|
||
t.Run(tt.name, func(t *testing.T) {
|
||
t.Parallel()
|
||
_, err := server.handleGetObservationQualityProxy(ctx, json.RawMessage(tt.args))
|
||
if tt.wantErr {
|
||
require.Error(t, err)
|
||
if tt.errContains != "" {
|
||
assert.Contains(t, err.Error(), tt.errContains)
|
||
}
|
||
} else {
|
||
require.NoError(t, err)
|
||
}
|
||
})
|
||
}
|
||
}
|
||
|
||
// TestHandleSuggestConsolidations_Validation tests parameter validation.
|
||
func TestHandleSuggestConsolidations_Validation(t *testing.T) {
|
||
t.Parallel()
|
||
|
||
server := NewServer(nil, "", "", "1.0.0")
|
||
ctx := context.Background()
|
||
|
||
tests := []struct {
|
||
name string
|
||
args string
|
||
errContains string
|
||
wantErr bool
|
||
}{
|
||
{
|
||
name: "min_similarity too low",
|
||
args: `{"min_similarity": 0.3}`,
|
||
wantErr: true,
|
||
errContains: "worker unavailable",
|
||
},
|
||
{
|
||
name: "min_similarity too high",
|
||
args: `{"min_similarity": 1.5}`,
|
||
wantErr: true,
|
||
errContains: "worker unavailable",
|
||
},
|
||
{
|
||
name: "invalid json",
|
||
args: `{invalid`,
|
||
wantErr: true,
|
||
errContains: "invalid arguments",
|
||
},
|
||
}
|
||
|
||
for _, tt := range tests {
|
||
t.Run(tt.name, func(t *testing.T) {
|
||
t.Parallel()
|
||
_, err := server.handleSuggestConsolidationsProxy(ctx, json.RawMessage(tt.args))
|
||
if tt.wantErr {
|
||
require.Error(t, err)
|
||
if tt.errContains != "" {
|
||
assert.Contains(t, err.Error(), tt.errContains)
|
||
}
|
||
} else {
|
||
require.NoError(t, err)
|
||
}
|
||
})
|
||
}
|
||
}
|
||
|
||
// TestHandleTagObservation_Validation tests parameter validation.
|
||
func TestHandleTagObservation_Validation(t *testing.T) {
|
||
t.Parallel()
|
||
|
||
server := NewServer(nil, "", "", "1.0.0")
|
||
ctx := context.Background()
|
||
|
||
tests := []struct {
|
||
name string
|
||
args string
|
||
errContains string
|
||
wantErr bool
|
||
}{
|
||
{
|
||
name: "missing id",
|
||
args: `{"tags": ["tag1"]}`,
|
||
wantErr: true,
|
||
errContains: "id is required",
|
||
},
|
||
{
|
||
name: "missing tags",
|
||
args: `{"id": 1}`,
|
||
wantErr: true,
|
||
errContains: "tags is required",
|
||
},
|
||
{
|
||
name: "invalid mode",
|
||
args: `{"id": 1, "tags": ["tag1"], "mode": "invalid"}`,
|
||
wantErr: true,
|
||
errContains: "mode must be 'add', 'remove', or 'set'",
|
||
},
|
||
}
|
||
|
||
for _, tt := range tests {
|
||
t.Run(tt.name, func(t *testing.T) {
|
||
t.Parallel()
|
||
_, err := server.handleTagObservationProxy(ctx, json.RawMessage(tt.args))
|
||
if tt.wantErr {
|
||
require.Error(t, err)
|
||
if tt.errContains != "" {
|
||
assert.Contains(t, err.Error(), tt.errContains)
|
||
}
|
||
} else {
|
||
require.NoError(t, err)
|
||
}
|
||
})
|
||
}
|
||
}
|
||
|
||
// TestHandleGetObservationsByTag_Validation tests parameter validation.
|
||
func TestHandleGetObservationsByTag_Validation(t *testing.T) {
|
||
t.Parallel()
|
||
|
||
server := NewServer(nil, "", "", "1.0.0")
|
||
ctx := context.Background()
|
||
|
||
tests := []struct {
|
||
name string
|
||
args string
|
||
errContains string
|
||
wantErr bool
|
||
}{
|
||
{
|
||
name: "missing tag",
|
||
args: `{}`,
|
||
wantErr: true,
|
||
errContains: "tag is required",
|
||
},
|
||
{
|
||
name: "invalid json",
|
||
args: `{invalid`,
|
||
wantErr: true,
|
||
errContains: "invalid arguments",
|
||
},
|
||
}
|
||
|
||
for _, tt := range tests {
|
||
t.Run(tt.name, func(t *testing.T) {
|
||
t.Parallel()
|
||
_, err := server.handleGetObservationsByTagProxy(ctx, json.RawMessage(tt.args))
|
||
if tt.wantErr {
|
||
require.Error(t, err)
|
||
if tt.errContains != "" {
|
||
assert.Contains(t, err.Error(), tt.errContains)
|
||
}
|
||
} else {
|
||
require.NoError(t, err)
|
||
}
|
||
})
|
||
}
|
||
}
|
||
|
||
// TestHandleBatchTagByPattern_Validation tests parameter validation.
|
||
func TestHandleBatchTagByPattern_Validation(t *testing.T) {
|
||
t.Parallel()
|
||
|
||
server := NewServer(nil, "", "", "1.0.0")
|
||
ctx := context.Background()
|
||
|
||
tests := []struct {
|
||
name string
|
||
args string
|
||
errContains string
|
||
wantErr bool
|
||
}{
|
||
{
|
||
name: "missing pattern",
|
||
args: `{"tags": ["tag1"]}`,
|
||
wantErr: true,
|
||
errContains: "pattern is required",
|
||
},
|
||
{
|
||
name: "missing tags",
|
||
args: `{"pattern": "test"}`,
|
||
wantErr: true,
|
||
errContains: "tags is required",
|
||
},
|
||
}
|
||
|
||
for _, tt := range tests {
|
||
t.Run(tt.name, func(t *testing.T) {
|
||
t.Parallel()
|
||
_, err := server.handleBatchTagProxy(ctx, json.RawMessage(tt.args))
|
||
if tt.wantErr {
|
||
require.Error(t, err)
|
||
if tt.errContains != "" {
|
||
assert.Contains(t, err.Error(), tt.errContains)
|
||
}
|
||
} else {
|
||
require.NoError(t, err)
|
||
}
|
||
})
|
||
}
|
||
}
|
||
|
||
// TestHandleExplainSearchRanking_Validation tests parameter validation.
|
||
func TestHandleExplainSearchRanking_Validation(t *testing.T) {
|
||
t.Parallel()
|
||
|
||
server := NewServer(nil, "", "", "1.0.0")
|
||
ctx := context.Background()
|
||
|
||
tests := []struct {
|
||
name string
|
||
args string
|
||
errContains string
|
||
wantErr bool
|
||
}{
|
||
{
|
||
name: "missing query",
|
||
args: `{"top_n": 5}`,
|
||
wantErr: true,
|
||
errContains: "query is required",
|
||
},
|
||
{
|
||
name: "invalid json",
|
||
args: `{invalid`,
|
||
wantErr: true,
|
||
errContains: "invalid arguments",
|
||
},
|
||
}
|
||
|
||
for _, tt := range tests {
|
||
t.Run(tt.name, func(t *testing.T) {
|
||
t.Parallel()
|
||
_, err := server.handleExplainSearchProxy(ctx, json.RawMessage(tt.args))
|
||
if tt.wantErr {
|
||
require.Error(t, err)
|
||
if tt.errContains != "" {
|
||
assert.Contains(t, err.Error(), tt.errContains)
|
||
}
|
||
} else {
|
||
require.NoError(t, err)
|
||
}
|
||
})
|
||
}
|
||
}
|
||
|
||
// TestHandleGetObservationRelationships_Validation tests parameter validation.
|
||
func TestHandleGetObservationRelationships_Validation(t *testing.T) {
|
||
t.Parallel()
|
||
|
||
server := NewServer(nil, "", "", "1.0.0")
|
||
ctx := context.Background()
|
||
|
||
tests := []struct {
|
||
name string
|
||
args string
|
||
errContains string
|
||
wantErr bool
|
||
}{
|
||
{
|
||
name: "missing id",
|
||
args: `{}`,
|
||
wantErr: true,
|
||
errContains: "id is required",
|
||
},
|
||
{
|
||
name: "negative id",
|
||
args: `{"id": -1}`,
|
||
wantErr: true,
|
||
errContains: "id is required and must be positive",
|
||
},
|
||
{
|
||
name: "nil relation store",
|
||
args: `{"id": 1}`,
|
||
wantErr: true,
|
||
errContains: "worker unavailable",
|
||
},
|
||
}
|
||
|
||
for _, tt := range tests {
|
||
t.Run(tt.name, func(t *testing.T) {
|
||
t.Parallel()
|
||
_, err := server.handleGetRelationshipsProxy(ctx, json.RawMessage(tt.args))
|
||
if tt.wantErr {
|
||
require.Error(t, err)
|
||
if tt.errContains != "" {
|
||
assert.Contains(t, err.Error(), tt.errContains)
|
||
}
|
||
} else {
|
||
require.NoError(t, err)
|
||
}
|
||
})
|
||
}
|
||
}
|
||
|
||
// TestHandleGetObservationScoringBreakdown_Validation tests parameter validation.
|
||
func TestHandleGetObservationScoringBreakdown_Validation(t *testing.T) {
|
||
t.Parallel()
|
||
|
||
server := NewServer(nil, "", "", "1.0.0")
|
||
ctx := context.Background()
|
||
|
||
tests := []struct {
|
||
name string
|
||
args string
|
||
errContains string
|
||
wantErr bool
|
||
}{
|
||
{
|
||
name: "missing id",
|
||
args: `{}`,
|
||
wantErr: true,
|
||
errContains: "id is required",
|
||
},
|
||
{
|
||
name: "negative id",
|
||
args: `{"id": -1}`,
|
||
wantErr: true,
|
||
errContains: "id is required and must be positive",
|
||
},
|
||
}
|
||
|
||
for _, tt := range tests {
|
||
t.Run(tt.name, func(t *testing.T) {
|
||
t.Parallel()
|
||
_, err := server.handleGetScoringBreakdownProxy(ctx, json.RawMessage(tt.args))
|
||
if tt.wantErr {
|
||
require.Error(t, err)
|
||
if tt.errContains != "" {
|
||
assert.Contains(t, err.Error(), tt.errContains)
|
||
}
|
||
} else {
|
||
require.NoError(t, err)
|
||
}
|
||
})
|
||
}
|
||
}
|
||
|
||
// TestHandleTimeline_InvalidJSON tests timeline with invalid JSON.
|
||
func TestHandleTimeline_InvalidJSON(t *testing.T) {
|
||
t.Parallel()
|
||
|
||
server := NewServer(nil, "", "", "1.0.0")
|
||
ctx := context.Background()
|
||
|
||
_, err := server.handleTimelineProxy(ctx, json.RawMessage(`{invalid`))
|
||
require.Error(t, err)
|
||
assert.Contains(t, err.Error(), "invalid timeline params")
|
||
}
|
||
|
||
// TestHandleTimelineByQuery_EmptyQuery tests timeline by query with empty query.
|
||
func TestHandleTimelineByQuery_EmptyQuery(t *testing.T) {
|
||
t.Parallel()
|
||
|
||
server := NewServer(nil, "", "", "1.0.0")
|
||
ctx := context.Background()
|
||
|
||
// Empty query should return empty results (no anchor found)
|
||
result, err := server.handleTimelineProxy(ctx, json.RawMessage(`{}`))
|
||
require.NoError(t, err)
|
||
assert.Contains(t, result, `"observations":[]`)
|
||
}
|
||
|
||
// TestHandleTimelineByQuery_InvalidJSON tests timeline by query with invalid JSON.
|
||
func TestHandleTimelineByQuery_InvalidJSON(t *testing.T) {
|
||
t.Parallel()
|
||
|
||
server := NewServer(nil, "", "", "1.0.0")
|
||
ctx := context.Background()
|
||
|
||
_, err := server.handleTimelineProxy(ctx, json.RawMessage(`{invalid`))
|
||
require.Error(t, err)
|
||
assert.Contains(t, err.Error(), "invalid timeline params")
|
||
}
|
||
|
||
// TestHandleTimeline_NoAnchorNoQuery tests timeline with no anchor and no query.
|
||
func TestHandleTimeline_NoAnchorNoQuery(t *testing.T) {
|
||
t.Parallel()
|
||
|
||
server := NewServer(nil, "", "", "1.0.0")
|
||
ctx := context.Background()
|
||
|
||
// No anchor_id and no query should return empty result JSON
|
||
result, err := server.handleTimelineProxy(ctx, json.RawMessage(`{}`))
|
||
require.NoError(t, err)
|
||
assert.NotEmpty(t, result)
|
||
assert.Contains(t, result, `"observations":[]`)
|
||
}
|
||
|
||
// TestHandleTimeline_WithDefaults tests timeline default values are applied.
|
||
func TestHandleTimeline_WithDefaults(t *testing.T) {
|
||
t.Parallel()
|
||
|
||
server := NewServer(nil, "", "", "1.0.0")
|
||
ctx := context.Background()
|
||
|
||
// With anchor_id = 0, should return empty result JSON
|
||
result, err := server.handleTimelineProxy(ctx, json.RawMessage(`{"anchor_id": 0}`))
|
||
require.NoError(t, err)
|
||
assert.NotEmpty(t, result)
|
||
assert.Contains(t, result, `"observations":[]`)
|
||
}
|
||
|
||
// =============================================================================
|
||
// TESTS FOR Additional Tool Operations
|
||
// =============================================================================
|
||
|
||
// TestJSONRPCErrorCodes tests standard JSON-RPC error codes.
|
||
func TestJSONRPCErrorCodes(t *testing.T) {
|
||
t.Parallel()
|
||
|
||
errorCodes := map[string]int{
|
||
"Parse error": -32700,
|
||
"Invalid Request": -32600,
|
||
"Method not found": -32601,
|
||
"Invalid params": -32602,
|
||
"Internal error": -32603,
|
||
}
|
||
|
||
for msg, code := range errorCodes {
|
||
t.Run(msg, func(t *testing.T) {
|
||
t.Parallel()
|
||
err := Error{Code: code, Message: msg}
|
||
assert.Equal(t, code, err.Code)
|
||
assert.Equal(t, msg, err.Message)
|
||
})
|
||
}
|
||
}
|
||
|
||
// TestToolListContainsExpectedSchemas tests that tool schemas are valid.
|
||
func TestToolListContainsExpectedSchemas(t *testing.T) {
|
||
t.Parallel()
|
||
|
||
server := NewServer(nil, "", "", "1.0.0")
|
||
|
||
req := &Request{
|
||
JSONRPC: "2.0",
|
||
ID: 1,
|
||
Method: "tools/list",
|
||
}
|
||
|
||
resp := server.handleToolsList(req)
|
||
result := resp.Result.(map[string]any)
|
||
tools := result["tools"].([]Tool)
|
||
|
||
for _, tool := range tools {
|
||
assert.NotEmpty(t, tool.Name)
|
||
assert.NotEmpty(t, tool.Description)
|
||
assert.NotNil(t, tool.InputSchema)
|
||
|
||
// Check schema has type
|
||
schema := tool.InputSchema
|
||
_, hasType := schema["type"]
|
||
assert.True(t, hasType, "tool %s schema should have type", tool.Name)
|
||
}
|
||
}
|
||
|
||
// TestHandleToolsCall_UnknownTool tests tools/call with unknown tool name.
|
||
// After Fix #5: tool errors use Result with isError, not top-level Error.
|
||
func TestHandleToolsCall_UnknownTool(t *testing.T) {
|
||
t.Parallel()
|
||
|
||
server := NewServer(nil, "", "", "1.0.0")
|
||
ctx := context.Background()
|
||
|
||
req := &Request{
|
||
JSONRPC: "2.0",
|
||
ID: 1,
|
||
Method: "tools/call",
|
||
Params: json.RawMessage(`{"name":"unknown_tool","arguments":{}}`),
|
||
}
|
||
|
||
resp := server.handleToolsCall(ctx, req)
|
||
assert.Nil(t, resp.Error, "tool errors must not use top-level Error (MCP spec)")
|
||
require.NotNil(t, resp.Result)
|
||
result, ok := resp.Result.(map[string]any)
|
||
require.True(t, ok)
|
||
assert.Equal(t, true, result["isError"])
|
||
content := result["content"].([]map[string]any)
|
||
assert.Contains(t, content[0]["text"], "unknown tool")
|
||
}
|
||
|
||
// TestCallTool_ToolNameRecognition tests that valid tool names are recognized.
|
||
func TestCallTool_ToolNameRecognition(t *testing.T) {
|
||
t.Parallel()
|
||
|
||
server := NewServer(nil, "", "", "1.0.0")
|
||
|
||
req := &Request{
|
||
JSONRPC: "2.0",
|
||
ID: 1,
|
||
Method: "tools/list",
|
||
}
|
||
|
||
resp := server.handleToolsList(req)
|
||
result := resp.Result.(map[string]any)
|
||
tools := result["tools"].([]Tool)
|
||
|
||
// Verify all expected tools are registered
|
||
expectedTools := map[string]bool{
|
||
"search": true,
|
||
"timeline": true,
|
||
"decisions": true,
|
||
"changes": true,
|
||
"how_it_works": true,
|
||
"find_by_concept": true,
|
||
"find_by_file": true,
|
||
"find_by_type": true,
|
||
"get_recent_context": true,
|
||
"get_context_timeline": true,
|
||
"get_timeline_by_query": true,
|
||
"find_related_observations": true,
|
||
"find_similar_observations": true,
|
||
"get_patterns": true,
|
||
"get_memory_stats": true,
|
||
"bulk_delete_observations": true,
|
||
"bulk_mark_superseded": true,
|
||
"bulk_boost_observations": true,
|
||
"trigger_maintenance": true,
|
||
"get_maintenance_stats": true,
|
||
"merge_observations": true,
|
||
"get_observation": true,
|
||
"edit_observation": true,
|
||
"get_observation_quality": true,
|
||
"suggest_consolidations": true,
|
||
"tag_observation": true,
|
||
"get_observations_by_tag": true,
|
||
"get_temporal_trends": true,
|
||
"get_data_quality_report": true,
|
||
"batch_tag_by_pattern": true,
|
||
"explain_search_ranking": true,
|
||
"export_observations": true,
|
||
"check_system_health": true,
|
||
"analyze_search_patterns": true,
|
||
"get_observation_relationships": true,
|
||
"get_observation_scoring_breakdown": true,
|
||
"analyze_observation_importance": true,
|
||
}
|
||
|
||
foundTools := make(map[string]bool)
|
||
for _, tool := range tools {
|
||
foundTools[tool.Name] = true
|
||
}
|
||
|
||
for name := range expectedTools {
|
||
assert.True(t, foundTools[name], "tool %s should be registered", name)
|
||
}
|
||
}
|
||
|
||
// TestTimelineParamsStruct_Complete tests complete timelineParams parsing.
|
||
func TestTimelineParamsStruct_Complete(t *testing.T) {
|
||
t.Parallel()
|
||
|
||
input := `{
|
||
"anchor_id": 100,
|
||
"query": "test query",
|
||
"before": 5,
|
||
"after": 15,
|
||
"project": "my-project",
|
||
"obs_type": "bugfix",
|
||
"concepts": "security,auth",
|
||
"files": "main.go,handler.go",
|
||
"dateStart": 1700000000000,
|
||
"dateEnd": 1700100000000,
|
||
"format": "full"
|
||
}`
|
||
|
||
var params timelineParams
|
||
err := json.Unmarshal([]byte(input), ¶ms)
|
||
require.NoError(t, err)
|
||
|
||
assert.Equal(t, int64(100), params.AnchorID)
|
||
assert.Equal(t, "test query", params.Query)
|
||
assert.Equal(t, 5, params.Before)
|
||
assert.Equal(t, 15, params.After)
|
||
assert.Equal(t, "my-project", params.Project)
|
||
assert.Equal(t, "bugfix", params.ObsType)
|
||
assert.Equal(t, "security,auth", params.Concepts)
|
||
assert.Equal(t, "main.go,handler.go", params.Files)
|
||
assert.Equal(t, int64(1700000000000), params.DateStart)
|
||
assert.Equal(t, int64(1700100000000), params.DateEnd)
|
||
assert.Equal(t, "full", params.Format)
|
||
}
|
||
|
||
// TestServerStdinStdoutConfig tests that server stdin/stdout can be configured.
|
||
func TestServerStdinStdoutConfig(t *testing.T) {
|
||
t.Parallel()
|
||
|
||
var stdout bytes.Buffer
|
||
var stdin bytes.Buffer
|
||
|
||
server := &Server{
|
||
stdin: &stdin,
|
||
stdout: &stdout,
|
||
version: "test-version",
|
||
}
|
||
|
||
assert.Equal(t, &stdin, server.stdin)
|
||
assert.Equal(t, &stdout, server.stdout)
|
||
assert.Equal(t, "test-version", server.version)
|
||
}
|
||
|
||
// TestResponseIDTypes tests that response IDs can be various types.
|
||
func TestResponseIDTypes(t *testing.T) {
|
||
t.Parallel()
|
||
|
||
tests := []struct {
|
||
id any
|
||
name string
|
||
}{
|
||
{name: "integer id", id: 1},
|
||
{name: "string id", id: "abc-123"},
|
||
{name: "float id", id: 1.5},
|
||
{name: "null id", id: nil},
|
||
}
|
||
|
||
for _, tt := range tests {
|
||
t.Run(tt.name, func(t *testing.T) {
|
||
t.Parallel()
|
||
var buf bytes.Buffer
|
||
server := &Server{stdout: &buf}
|
||
|
||
resp := &Response{
|
||
JSONRPC: "2.0",
|
||
ID: tt.id,
|
||
Result: "ok",
|
||
}
|
||
|
||
server.sendResponse(resp)
|
||
output := buf.String()
|
||
assert.Contains(t, output, `"jsonrpc":"2.0"`)
|
||
})
|
||
}
|
||
}
|
||
|
||
// TestServerFields tests Server struct fields.
|
||
func TestServerFields(t *testing.T) {
|
||
t.Parallel()
|
||
|
||
server := NewServer(nil, "http://localhost:37777", "test", "2.0.0")
|
||
|
||
assert.Equal(t, "2.0.0", server.version)
|
||
assert.Nil(t, server.client)
|
||
assert.Equal(t, "http://localhost:37777", server.workerURL)
|
||
assert.NotNil(t, server.stdin)
|
||
assert.NotNil(t, server.stdout)
|
||
}
|
||
|
||
// TestRequestUnmarshalWithNullID tests Request unmarshaling with null ID.
|
||
func TestRequestUnmarshalWithNullID(t *testing.T) {
|
||
t.Parallel()
|
||
|
||
input := `{"jsonrpc":"2.0","id":null,"method":"initialize"}`
|
||
|
||
var req Request
|
||
err := json.Unmarshal([]byte(input), &req)
|
||
require.NoError(t, err)
|
||
assert.Equal(t, "2.0", req.JSONRPC)
|
||
assert.Nil(t, req.ID)
|
||
assert.Equal(t, "initialize", req.Method)
|
||
}
|
||
|
||
// TestResponseWithNullError tests Response without error.
|
||
func TestResponseWithNullError(t *testing.T) {
|
||
t.Parallel()
|
||
|
||
resp := Response{
|
||
JSONRPC: "2.0",
|
||
ID: 1,
|
||
Result: "success",
|
||
Error: nil,
|
||
}
|
||
|
||
data, err := json.Marshal(resp)
|
||
require.NoError(t, err)
|
||
assert.Contains(t, string(data), `"result":"success"`)
|
||
assert.NotContains(t, string(data), `"error"`)
|
||
}
|
||
|
||
// TestErrorWithNilData tests Error without data.
|
||
func TestErrorWithNilData(t *testing.T) {
|
||
t.Parallel()
|
||
|
||
err := Error{
|
||
Code: -32600,
|
||
Message: "Invalid Request",
|
||
Data: nil,
|
||
}
|
||
|
||
data, errMarshal := json.Marshal(err)
|
||
require.NoError(t, errMarshal)
|
||
assert.Contains(t, string(data), `"code":-32600`)
|
||
assert.Contains(t, string(data), `"message":"Invalid Request"`)
|
||
assert.NotContains(t, string(data), `"data"`)
|
||
}
|
||
|
||
// TestToolInputSchema tests that tool input schemas have required fields.
|
||
func TestToolInputSchema(t *testing.T) {
|
||
t.Parallel()
|
||
|
||
server := NewServer(nil, "", "", "1.0.0")
|
||
|
||
req := &Request{
|
||
JSONRPC: "2.0",
|
||
ID: 1,
|
||
Method: "tools/list",
|
||
}
|
||
|
||
resp := server.handleToolsList(req)
|
||
result := resp.Result.(map[string]any)
|
||
tools := result["tools"].([]Tool)
|
||
|
||
for _, tool := range tools {
|
||
schema := tool.InputSchema
|
||
schemaType, ok := schema["type"]
|
||
assert.True(t, ok, "tool %s schema should have type", tool.Name)
|
||
assert.Equal(t, "object", schemaType, "tool %s schema type should be object", tool.Name)
|
||
|
||
// All tools should have properties
|
||
_, hasProperties := schema["properties"]
|
||
assert.True(t, hasProperties, "tool %s should have properties", tool.Name)
|
||
}
|
||
}
|
||
|
||
// TestCallTool_UnknownToolName tests callTool with various unknown tool names.
|
||
func TestCallTool_UnknownToolName(t *testing.T) {
|
||
t.Parallel()
|
||
|
||
server := NewServer(nil, "", "", "1.0.0")
|
||
ctx := context.Background()
|
||
|
||
unknownTools := []string{
|
||
"invalid_tool",
|
||
"nonexistent",
|
||
"search_v2",
|
||
"timeline_special",
|
||
}
|
||
|
||
for _, name := range unknownTools {
|
||
t.Run(name, func(t *testing.T) {
|
||
t.Parallel()
|
||
result, err := server.callTool(ctx, name, json.RawMessage(`{}`))
|
||
assert.Error(t, err)
|
||
assert.Empty(t, result)
|
||
assert.Contains(t, err.Error(), "unknown tool")
|
||
})
|
||
}
|
||
}
|
||
|
||
// TestTimelineParamsStruct_Validation tests timelineParams struct field validation.
|
||
func TestTimelineParamsStruct_Validation(t *testing.T) {
|
||
t.Parallel()
|
||
|
||
tests := []struct {
|
||
name string
|
||
json string
|
||
wantOK bool
|
||
}{
|
||
{"valid with anchor_id", `{"anchor_id":123,"before":5,"after":5}`, true},
|
||
{"valid with query only", `{"query":"test query"}`, true},
|
||
{"empty params", `{}`, true},
|
||
{"with all fields", `{"anchor_id":1,"query":"test","before":10,"after":10,"project":"proj","obs_type":"bugfix","format":"full"}`, true},
|
||
{"invalid json", `{invalid`, false},
|
||
}
|
||
|
||
for _, tt := range tests {
|
||
t.Run(tt.name, func(t *testing.T) {
|
||
t.Parallel()
|
||
var params timelineParams
|
||
err := json.Unmarshal([]byte(tt.json), ¶ms)
|
||
if tt.wantOK {
|
||
assert.NoError(t, err)
|
||
} else {
|
||
assert.Error(t, err)
|
||
}
|
||
})
|
||
}
|
||
}
|
||
|
||
// TestHandleToolsCall_EmptyParams tests tools/call with empty params.
|
||
// After Fix #5: tool errors use Result with isError, not top-level Error.
|
||
func TestHandleToolsCall_EmptyParams(t *testing.T) {
|
||
t.Parallel()
|
||
|
||
server := NewServer(nil, "", "", "1.0.0")
|
||
ctx := context.Background()
|
||
|
||
req := &Request{
|
||
JSONRPC: "2.0",
|
||
ID: 1,
|
||
Method: "tools/call",
|
||
Params: json.RawMessage(`{}`),
|
||
}
|
||
|
||
resp := server.handleToolsCall(ctx, req)
|
||
|
||
// Empty name goes through callTool default branch -> isError
|
||
assert.Nil(t, resp.Error, "tool errors must use isError in Result")
|
||
require.NotNil(t, resp.Result)
|
||
result, ok := resp.Result.(map[string]any)
|
||
require.True(t, ok)
|
||
assert.Equal(t, true, result["isError"])
|
||
}
|
||
|
||
// TestSendResponse_WithError tests sendResponse with an error response.
|
||
func TestSendResponse_WithError(t *testing.T) {
|
||
t.Parallel()
|
||
|
||
var buf bytes.Buffer
|
||
server := &Server{stdout: &buf}
|
||
|
||
resp := &Response{
|
||
JSONRPC: "2.0",
|
||
ID: 1,
|
||
Error: &Error{Code: -32600, Message: "Invalid Request"},
|
||
}
|
||
|
||
server.sendResponse(resp)
|
||
|
||
output := buf.String()
|
||
assert.Contains(t, output, `"error"`)
|
||
assert.Contains(t, output, `-32600`)
|
||
}
|
||
|
||
// TestSendResponse_NilID tests sendResponse with nil ID.
|
||
func TestSendResponse_NilID(t *testing.T) {
|
||
t.Parallel()
|
||
|
||
var buf bytes.Buffer
|
||
server := &Server{stdout: &buf}
|
||
|
||
resp := &Response{
|
||
JSONRPC: "2.0",
|
||
ID: nil,
|
||
Result: "notification response",
|
||
}
|
||
|
||
server.sendResponse(resp)
|
||
|
||
output := buf.String()
|
||
assert.Contains(t, output, `"id":null`)
|
||
}
|
||
|
||
// TestToolCallParamsWithComplexArgs tests ToolCallParams with complex arguments.
|
||
func TestToolCallParamsWithComplexArgs(t *testing.T) {
|
||
t.Parallel()
|
||
|
||
input := `{
|
||
"name": "search",
|
||
"arguments": {
|
||
"query": "authentication bug",
|
||
"project": "my-project",
|
||
"limit": 10,
|
||
"type": "observations"
|
||
}
|
||
}`
|
||
|
||
var params ToolCallParams
|
||
err := json.Unmarshal([]byte(input), ¶ms)
|
||
require.NoError(t, err)
|
||
assert.Equal(t, "search", params.Name)
|
||
assert.NotEmpty(t, params.Arguments)
|
||
}
|
||
|
||
// TestHandleToolsCall_UnknownToolNameError tests tools/call with unknown tool returns error.
|
||
// After Fix #5: tool errors use Result with isError, not top-level Error.
|
||
func TestHandleToolsCall_UnknownToolNameError(t *testing.T) {
|
||
t.Parallel()
|
||
|
||
server := NewServer(nil, "", "", "1.0.0")
|
||
ctx := context.Background()
|
||
|
||
req := &Request{
|
||
JSONRPC: "2.0",
|
||
ID: 1,
|
||
Method: "tools/call",
|
||
Params: json.RawMessage(`{"name":"very_unknown_tool_name","arguments":{}}`),
|
||
}
|
||
|
||
resp := server.handleToolsCall(ctx, req)
|
||
|
||
assert.Equal(t, "2.0", resp.JSONRPC)
|
||
assert.Equal(t, 1, resp.ID)
|
||
assert.Nil(t, resp.Error, "tool errors must use isError in Result, not top-level Error")
|
||
require.NotNil(t, resp.Result)
|
||
result, ok := resp.Result.(map[string]any)
|
||
require.True(t, ok)
|
||
assert.Equal(t, true, result["isError"])
|
||
}
|
||
|
||
// =============================================================================
|
||
// TESTS FOR Handler Defaults
|
||
// =============================================================================
|
||
|
||
// TestHandleTimeline_Defaults tests timeline default values.
|
||
func TestHandleTimeline_Defaults(t *testing.T) {
|
||
t.Parallel()
|
||
|
||
// Test that handleTimeline sets default before/after values
|
||
params := timelineParams{
|
||
AnchorID: 0,
|
||
Query: "",
|
||
Before: 0,
|
||
After: 0,
|
||
}
|
||
|
||
// Simulate the default value assignment from handleTimeline
|
||
if params.Before <= 0 {
|
||
params.Before = 10
|
||
}
|
||
if params.After <= 0 {
|
||
params.After = 10
|
||
}
|
||
|
||
assert.Equal(t, 10, params.Before)
|
||
assert.Equal(t, 10, params.After)
|
||
}
|
||
|
||
// TestHandleGetTemporalTrends_Validation tests parameter validation.
|
||
func TestHandleGetTemporalTrends_Validation(t *testing.T) {
|
||
t.Parallel()
|
||
|
||
server := NewServer(nil, "", "", "1.0.0")
|
||
ctx := context.Background()
|
||
|
||
tests := []struct {
|
||
name string
|
||
args string
|
||
errContains string
|
||
wantErr bool
|
||
}{
|
||
{
|
||
name: "invalid json",
|
||
args: `{invalid`,
|
||
wantErr: true,
|
||
errContains: "invalid arguments",
|
||
},
|
||
}
|
||
|
||
for _, tt := range tests {
|
||
t.Run(tt.name, func(t *testing.T) {
|
||
t.Parallel()
|
||
_, err := server.handleGetTemporalTrendsProxy(ctx, json.RawMessage(tt.args))
|
||
if tt.wantErr {
|
||
require.Error(t, err)
|
||
if tt.errContains != "" {
|
||
assert.Contains(t, err.Error(), tt.errContains)
|
||
}
|
||
}
|
||
})
|
||
}
|
||
}
|
||
|
||
// TestHandleGetDataQualityReport_Validation tests parameter validation.
|
||
func TestHandleGetDataQualityReport_Validation(t *testing.T) {
|
||
t.Parallel()
|
||
|
||
server := NewServer(nil, "", "", "1.0.0")
|
||
ctx := context.Background()
|
||
|
||
tests := []struct {
|
||
name string
|
||
args string
|
||
errContains string
|
||
wantErr bool
|
||
}{
|
||
{
|
||
name: "invalid json",
|
||
args: `{invalid`,
|
||
wantErr: true,
|
||
errContains: "invalid arguments",
|
||
},
|
||
}
|
||
|
||
for _, tt := range tests {
|
||
t.Run(tt.name, func(t *testing.T) {
|
||
t.Parallel()
|
||
_, err := server.handleGetDataQualityProxy(ctx, json.RawMessage(tt.args))
|
||
if tt.wantErr {
|
||
require.Error(t, err)
|
||
if tt.errContains != "" {
|
||
assert.Contains(t, err.Error(), tt.errContains)
|
||
}
|
||
}
|
||
})
|
||
}
|
||
}
|
||
|
||
// TestHandleExportObservations_Validation tests parameter validation.
|
||
func TestHandleExportObservations_Validation(t *testing.T) {
|
||
t.Parallel()
|
||
|
||
server := NewServer(nil, "", "", "1.0.0")
|
||
ctx := context.Background()
|
||
|
||
tests := []struct {
|
||
name string
|
||
args string
|
||
errContains string
|
||
wantErr bool
|
||
}{
|
||
{
|
||
name: "invalid json",
|
||
args: `{invalid`,
|
||
wantErr: true,
|
||
errContains: "invalid arguments",
|
||
},
|
||
}
|
||
|
||
for _, tt := range tests {
|
||
t.Run(tt.name, func(t *testing.T) {
|
||
t.Parallel()
|
||
_, err := server.handleExportProxy(ctx, json.RawMessage(tt.args))
|
||
if tt.wantErr {
|
||
require.Error(t, err)
|
||
if tt.errContains != "" {
|
||
assert.Contains(t, err.Error(), tt.errContains)
|
||
}
|
||
}
|
||
})
|
||
}
|
||
}
|
||
|
||
// TestHandleAnalyzeSearchPatterns_Validation tests parameter validation.
|
||
func TestHandleAnalyzeSearchPatterns_Validation(t *testing.T) {
|
||
t.Parallel()
|
||
|
||
server := NewServer(nil, "", "", "1.0.0")
|
||
ctx := context.Background()
|
||
|
||
tests := []struct {
|
||
name string
|
||
args string
|
||
errContains string
|
||
wantErr bool
|
||
}{
|
||
{
|
||
name: "worker unavailable",
|
||
args: `{}`,
|
||
wantErr: true,
|
||
errContains: "worker unavailable",
|
||
},
|
||
}
|
||
|
||
for _, tt := range tests {
|
||
t.Run(tt.name, func(t *testing.T) {
|
||
t.Parallel()
|
||
_, err := server.proxyGetRaw(ctx, "/api/search/analytics", nil)
|
||
if tt.wantErr {
|
||
require.Error(t, err)
|
||
if tt.errContains != "" {
|
||
assert.Contains(t, err.Error(), tt.errContains)
|
||
}
|
||
}
|
||
})
|
||
}
|
||
}
|
||
|
||
// TestHandleAnalyzeObservationImportance_Validation tests parameter validation.
|
||
func TestHandleAnalyzeObservationImportance_Validation(t *testing.T) {
|
||
t.Parallel()
|
||
|
||
server := NewServer(nil, "", "", "1.0.0")
|
||
ctx := context.Background()
|
||
|
||
tests := []struct {
|
||
name string
|
||
args string
|
||
errContains string
|
||
wantErr bool
|
||
}{
|
||
{
|
||
name: "invalid json",
|
||
args: `{invalid`,
|
||
wantErr: true,
|
||
errContains: "invalid arguments",
|
||
},
|
||
}
|
||
|
||
for _, tt := range tests {
|
||
t.Run(tt.name, func(t *testing.T) {
|
||
t.Parallel()
|
||
_, err := server.handleAnalyzeImportanceProxy(ctx, json.RawMessage(tt.args))
|
||
if tt.wantErr {
|
||
require.Error(t, err)
|
||
if tt.errContains != "" {
|
||
assert.Contains(t, err.Error(), tt.errContains)
|
||
}
|
||
}
|
||
})
|
||
}
|
||
}
|
||
|
||
// TestHandleGetMemoryStats_NilStores tests GetMemoryStats with nil stores.
|
||
func TestHandleGetMemoryStats_NilStores(t *testing.T) {
|
||
t.Parallel()
|
||
|
||
server := NewServer(nil, "", "", "1.0.0")
|
||
ctx := context.Background()
|
||
|
||
// Should return error when worker is unavailable (nil client)
|
||
_, err := server.proxyGetRaw(ctx, "/api/stats", map[string]string{"project": ""})
|
||
require.Error(t, err)
|
||
assert.Contains(t, err.Error(), "worker unavailable")
|
||
}
|
||
|
||
// TestHandleCheckSystemHealth_NilStores tests CheckSystemHealth with nil stores.
|
||
func TestHandleCheckSystemHealth_NilStores(t *testing.T) {
|
||
t.Parallel()
|
||
|
||
server := NewServer(nil, "", "", "1.0.0")
|
||
ctx := context.Background()
|
||
|
||
// Should return error when worker is unavailable (nil client)
|
||
_, err := server.proxyGetRaw(ctx, "/api/selfcheck", nil)
|
||
require.Error(t, err)
|
||
assert.Contains(t, err.Error(), "worker unavailable")
|
||
}
|
||
|
||
// =============================================================================
|
||
// COMPREHENSIVE callTool TESTS
|
||
// =============================================================================
|
||
|
||
// TestCallTool_AllSpecialTools tests all special tool cases in callTool switch.
|
||
func TestCallTool_AllSpecialTools(t *testing.T) {
|
||
t.Parallel()
|
||
|
||
server := NewServer(nil, "", "", "1.0.0")
|
||
ctx := context.Background()
|
||
|
||
// Tests for tools that can work without stores or have nil guards
|
||
tests := []struct {
|
||
name string
|
||
toolName string
|
||
args string
|
||
wantErr bool
|
||
checkPanic bool
|
||
}{
|
||
// Tools that work with nil stores
|
||
{
|
||
name: "get_memory_stats",
|
||
toolName: "get_memory_stats",
|
||
args: `{}`,
|
||
wantErr: true, // worker unavailable with nil client
|
||
},
|
||
{
|
||
name: "check_system_health",
|
||
toolName: "check_system_health",
|
||
args: `{}`,
|
||
wantErr: true, // worker unavailable with nil client
|
||
},
|
||
// Tools that need stores but have parameter validation first
|
||
{
|
||
name: "find_related_observations - invalid json",
|
||
toolName: "find_related_observations",
|
||
args: `{invalid`,
|
||
wantErr: true,
|
||
},
|
||
{
|
||
name: "find_related_observations - missing id",
|
||
toolName: "find_related_observations",
|
||
args: `{}`,
|
||
wantErr: true,
|
||
},
|
||
{
|
||
name: "find_similar_observations - invalid json",
|
||
toolName: "find_similar_observations",
|
||
args: `{invalid`,
|
||
wantErr: true,
|
||
},
|
||
{
|
||
name: "find_similar_observations - missing query",
|
||
toolName: "find_similar_observations",
|
||
args: `{}`,
|
||
wantErr: true,
|
||
},
|
||
{
|
||
name: "get_patterns - invalid json",
|
||
toolName: "get_patterns",
|
||
args: `{invalid`,
|
||
wantErr: true,
|
||
},
|
||
{
|
||
name: "bulk_delete_observations - invalid json",
|
||
toolName: "bulk_delete_observations",
|
||
args: `{invalid`,
|
||
wantErr: true,
|
||
},
|
||
{
|
||
name: "bulk_delete_observations - missing ids",
|
||
toolName: "bulk_delete_observations",
|
||
args: `{}`,
|
||
wantErr: true,
|
||
},
|
||
{
|
||
name: "bulk_mark_superseded - invalid json",
|
||
toolName: "bulk_mark_superseded",
|
||
args: `{invalid`,
|
||
wantErr: true,
|
||
},
|
||
{
|
||
name: "bulk_mark_superseded - missing ids",
|
||
toolName: "bulk_mark_superseded",
|
||
args: `{}`,
|
||
wantErr: true,
|
||
},
|
||
{
|
||
name: "bulk_boost_observations - invalid json",
|
||
toolName: "bulk_boost_observations",
|
||
args: `{invalid`,
|
||
wantErr: true,
|
||
},
|
||
{
|
||
name: "bulk_boost_observations - missing ids",
|
||
toolName: "bulk_boost_observations",
|
||
args: `{}`,
|
||
wantErr: true,
|
||
},
|
||
{
|
||
name: "merge_observations - invalid json",
|
||
toolName: "merge_observations",
|
||
args: `{invalid`,
|
||
wantErr: true,
|
||
},
|
||
{
|
||
name: "merge_observations - missing source_ids",
|
||
toolName: "merge_observations",
|
||
args: `{}`,
|
||
wantErr: true,
|
||
},
|
||
{
|
||
name: "get_observation - invalid json",
|
||
toolName: "get_observation",
|
||
args: `{invalid`,
|
||
wantErr: true,
|
||
},
|
||
{
|
||
name: "get_observation - missing id",
|
||
toolName: "get_observation",
|
||
args: `{}`,
|
||
wantErr: true,
|
||
},
|
||
{
|
||
name: "edit_observation - invalid json",
|
||
toolName: "edit_observation",
|
||
args: `{invalid`,
|
||
wantErr: true,
|
||
},
|
||
{
|
||
name: "edit_observation - missing id",
|
||
toolName: "edit_observation",
|
||
args: `{}`,
|
||
wantErr: true,
|
||
},
|
||
{
|
||
name: "get_observation_quality - invalid json",
|
||
toolName: "get_observation_quality",
|
||
args: `{invalid`,
|
||
wantErr: true,
|
||
},
|
||
{
|
||
name: "get_observation_quality - missing id",
|
||
toolName: "get_observation_quality",
|
||
args: `{}`,
|
||
wantErr: true,
|
||
},
|
||
{
|
||
name: "suggest_consolidations - invalid json",
|
||
toolName: "suggest_consolidations",
|
||
args: `{invalid`,
|
||
wantErr: true,
|
||
},
|
||
{
|
||
name: "tag_observation - invalid json",
|
||
toolName: "tag_observation",
|
||
args: `{invalid`,
|
||
wantErr: true,
|
||
},
|
||
{
|
||
name: "tag_observation - missing id",
|
||
toolName: "tag_observation",
|
||
args: `{}`,
|
||
wantErr: true,
|
||
},
|
||
{
|
||
name: "get_observations_by_tag - invalid json",
|
||
toolName: "get_observations_by_tag",
|
||
args: `{invalid`,
|
||
wantErr: true,
|
||
},
|
||
{
|
||
name: "get_observations_by_tag - missing tag",
|
||
toolName: "get_observations_by_tag",
|
||
args: `{}`,
|
||
wantErr: true,
|
||
},
|
||
{
|
||
name: "get_temporal_trends - invalid json",
|
||
toolName: "get_temporal_trends",
|
||
args: `{invalid`,
|
||
wantErr: true,
|
||
},
|
||
{
|
||
name: "get_data_quality_report - invalid json",
|
||
toolName: "get_data_quality_report",
|
||
args: `{invalid`,
|
||
wantErr: true,
|
||
},
|
||
{
|
||
name: "batch_tag_by_pattern - invalid json",
|
||
toolName: "batch_tag_by_pattern",
|
||
args: `{invalid`,
|
||
wantErr: true,
|
||
},
|
||
{
|
||
name: "batch_tag_by_pattern - missing pattern",
|
||
toolName: "batch_tag_by_pattern",
|
||
args: `{}`,
|
||
wantErr: true,
|
||
},
|
||
{
|
||
name: "explain_search_ranking - invalid json",
|
||
toolName: "explain_search_ranking",
|
||
args: `{invalid`,
|
||
wantErr: true,
|
||
},
|
||
{
|
||
name: "explain_search_ranking - missing query",
|
||
toolName: "explain_search_ranking",
|
||
args: `{}`,
|
||
wantErr: true,
|
||
},
|
||
{
|
||
name: "export_observations - invalid json",
|
||
toolName: "export_observations",
|
||
args: `{invalid`,
|
||
wantErr: true,
|
||
},
|
||
{
|
||
name: "analyze_search_patterns - invalid json",
|
||
toolName: "analyze_search_patterns",
|
||
args: `{invalid`,
|
||
wantErr: true,
|
||
},
|
||
{
|
||
name: "get_observation_relationships - invalid json",
|
||
toolName: "get_observation_relationships",
|
||
args: `{invalid`,
|
||
wantErr: true,
|
||
},
|
||
{
|
||
name: "get_observation_relationships - missing id",
|
||
toolName: "get_observation_relationships",
|
||
args: `{}`,
|
||
wantErr: true,
|
||
},
|
||
{
|
||
name: "get_observation_scoring_breakdown - invalid json",
|
||
toolName: "get_observation_scoring_breakdown",
|
||
args: `{invalid`,
|
||
wantErr: true,
|
||
},
|
||
{
|
||
name: "get_observation_scoring_breakdown - missing id",
|
||
toolName: "get_observation_scoring_breakdown",
|
||
args: `{}`,
|
||
wantErr: true,
|
||
},
|
||
{
|
||
name: "analyze_observation_importance - invalid json",
|
||
toolName: "analyze_observation_importance",
|
||
args: `{invalid`,
|
||
wantErr: true,
|
||
},
|
||
}
|
||
|
||
for _, tt := range tests {
|
||
t.Run(tt.name, func(t *testing.T) {
|
||
t.Parallel()
|
||
result, err := server.callTool(ctx, tt.toolName, json.RawMessage(tt.args))
|
||
if tt.wantErr {
|
||
require.Error(t, err)
|
||
} else {
|
||
require.NoError(t, err)
|
||
assert.NotEmpty(t, result)
|
||
}
|
||
})
|
||
}
|
||
}
|
||
|
||
// TestCallTool_SearchTools tests search-based tools in callTool.
|
||
func TestCallTool_SearchTools(t *testing.T) {
|
||
t.Parallel()
|
||
|
||
server := NewServer(nil, "", "", "1.0.0")
|
||
ctx := context.Background()
|
||
|
||
// All search tools should fail when worker is unavailable (nil client)
|
||
searchTools := []string{
|
||
"search",
|
||
"decisions",
|
||
"changes",
|
||
"how_it_works",
|
||
"find_by_concept",
|
||
"find_by_file",
|
||
"find_by_type",
|
||
"get_recent_context",
|
||
}
|
||
|
||
for _, toolName := range searchTools {
|
||
t.Run(toolName+"_worker_unavailable", func(t *testing.T) {
|
||
t.Parallel()
|
||
_, err := server.callTool(ctx, toolName, json.RawMessage(`{"query":"test"}`))
|
||
require.Error(t, err)
|
||
assert.Contains(t, err.Error(), "worker unavailable")
|
||
})
|
||
}
|
||
|
||
// Timeline tools should handle invalid JSON with a parse error
|
||
timelineTools := []string{
|
||
"timeline",
|
||
"get_context_timeline",
|
||
"get_timeline_by_query",
|
||
}
|
||
|
||
for _, toolName := range timelineTools {
|
||
t.Run(toolName+"_invalid_json", func(t *testing.T) {
|
||
t.Parallel()
|
||
_, err := server.callTool(ctx, toolName, json.RawMessage(`{invalid`))
|
||
require.Error(t, err)
|
||
assert.Contains(t, err.Error(), "invalid")
|
||
})
|
||
}
|
||
}
|
||
|
||
// TestHandleTriggerMaintenance_NilService tests trigger_maintenance with nil service.
|
||
func TestHandleTriggerMaintenance_NilService(t *testing.T) {
|
||
t.Parallel()
|
||
|
||
server := NewServer(nil, "", "", "1.0.0")
|
||
ctx := context.Background()
|
||
|
||
// Should return error when maintenanceService is nil
|
||
_, err := server.proxyPostRaw(ctx, "/api/scoring/recalculate", nil)
|
||
require.Error(t, err)
|
||
assert.Contains(t, err.Error(), "worker unavailable")
|
||
}
|
||
|
||
// TestHandleGetMaintenanceStats_NilService tests get_maintenance_stats with nil service.
|
||
func TestHandleGetMaintenanceStats_NilService(t *testing.T) {
|
||
t.Parallel()
|
||
|
||
server := NewServer(nil, "", "", "1.0.0")
|
||
ctx := context.Background()
|
||
|
||
// Should return error when maintenanceService is nil
|
||
_, err := server.proxyGetRaw(ctx, "/api/stats", map[string]string{"project": ""})
|
||
require.Error(t, err)
|
||
assert.Contains(t, err.Error(), "worker unavailable")
|
||
}
|
||
|
||
// TestHandleTimeline_ParameterDefaultsNew tests timeline parameter defaults.
|
||
func TestHandleTimeline_ParameterDefaultsNew(t *testing.T) {
|
||
t.Parallel()
|
||
|
||
server := NewServer(nil, "", "", "1.0.0")
|
||
ctx := context.Background()
|
||
|
||
// Invalid JSON should fail
|
||
_, err := server.handleTimelineProxy(ctx, json.RawMessage(`{invalid`))
|
||
require.Error(t, err)
|
||
assert.Contains(t, err.Error(), "invalid timeline params")
|
||
}
|
||
|
||
// TestHandleTimelineByQuery_ValidationExtended tests timeline_by_query validation.
|
||
func TestHandleTimelineByQuery_ValidationExtended(t *testing.T) {
|
||
t.Parallel()
|
||
|
||
server := NewServer(nil, "", "", "1.0.0")
|
||
ctx := context.Background()
|
||
|
||
tests := []struct {
|
||
name string
|
||
args string
|
||
errContains string
|
||
wantErr bool
|
||
}{
|
||
{
|
||
name: "invalid json",
|
||
args: `{invalid`,
|
||
wantErr: true,
|
||
errContains: "invalid timeline params",
|
||
},
|
||
}
|
||
|
||
for _, tt := range tests {
|
||
t.Run(tt.name, func(t *testing.T) {
|
||
t.Parallel()
|
||
_, err := server.handleTimelineProxy(ctx, json.RawMessage(tt.args))
|
||
if tt.wantErr {
|
||
require.Error(t, err)
|
||
assert.Contains(t, err.Error(), tt.errContains)
|
||
}
|
||
})
|
||
}
|
||
}
|
||
|
||
// TestHandleSuggestConsolidations_ValidationExtended tests suggest_consolidations validation.
|
||
func TestHandleSuggestConsolidations_ValidationExtended(t *testing.T) {
|
||
t.Parallel()
|
||
|
||
server := NewServer(nil, "", "", "1.0.0")
|
||
ctx := context.Background()
|
||
|
||
tests := []struct {
|
||
name string
|
||
args string
|
||
errContains string
|
||
wantErr bool
|
||
}{
|
||
{
|
||
name: "invalid json",
|
||
args: `{invalid`,
|
||
wantErr: true,
|
||
errContains: "invalid arguments",
|
||
},
|
||
}
|
||
|
||
for _, tt := range tests {
|
||
t.Run(tt.name, func(t *testing.T) {
|
||
t.Parallel()
|
||
_, err := server.handleSuggestConsolidationsProxy(ctx, json.RawMessage(tt.args))
|
||
require.Error(t, err)
|
||
assert.Contains(t, err.Error(), tt.errContains)
|
||
})
|
||
}
|
||
}
|
||
|
||
// =============================================================================
|
||
// NIL GUARD HANDLER TESTS
|
||
// =============================================================================
|
||
|
||
// TestHandleFindSimilarObservations_NilVectorClient tests nil vector client handling.
|
||
func TestHandleFindSimilarObservations_NilVectorClient(t *testing.T) {
|
||
t.Parallel()
|
||
|
||
server := NewServer(nil, "", "", "1.0.0")
|
||
ctx := context.Background()
|
||
|
||
// Should return error when client is nil (worker unavailable)
|
||
_, err := server.handleFindSimilarProxy(ctx, json.RawMessage(`{"query": "test query"}`))
|
||
require.Error(t, err)
|
||
assert.Contains(t, err.Error(), "worker unavailable")
|
||
}
|
||
|
||
// TestHandleGetObservationRelationships_NilRelationStore tests nil relation store handling.
|
||
func TestHandleGetObservationRelationships_NilRelationStore(t *testing.T) {
|
||
t.Parallel()
|
||
|
||
server := NewServer(nil, "", "", "1.0.0")
|
||
ctx := context.Background()
|
||
|
||
// Should return error when worker is unavailable
|
||
_, err := server.handleGetRelationshipsProxy(ctx, json.RawMessage(`{"id": 123}`))
|
||
require.Error(t, err)
|
||
assert.Contains(t, err.Error(), "worker unavailable")
|
||
}
|
||
|
||
// =============================================================================
|
||
// MORE PARAM LIMIT TESTS
|
||
// =============================================================================
|
||
|
||
// TestHandleBulkBoostObservations_EmptyIDs tests the empty IDs validation.
|
||
func TestHandleBulkBoostObservations_EmptyIDs(t *testing.T) {
|
||
t.Parallel()
|
||
|
||
server := NewServer(nil, "", "", "1.0.0")
|
||
ctx := context.Background()
|
||
|
||
// Empty IDs should return error
|
||
_, err := server.handleBulkStatusProxy(ctx, json.RawMessage(`{"ids": []}`), "boost")
|
||
require.Error(t, err)
|
||
assert.Contains(t, err.Error(), "ids is required")
|
||
}
|
||
|
||
// TestHandleMergeObservations_MissingIDs tests merge with missing IDs.
|
||
func TestHandleMergeObservations_MissingIDs(t *testing.T) {
|
||
t.Parallel()
|
||
|
||
server := NewServer(nil, "", "", "1.0.0")
|
||
ctx := context.Background()
|
||
|
||
// source_id and target_id are required
|
||
_, err := server.handleMergeProxy(ctx, json.RawMessage(`{"source_id": 0, "target_id": 0}`))
|
||
require.Error(t, err)
|
||
assert.Contains(t, err.Error(), "source_id and target_id are required")
|
||
}
|
||
|
||
// TestHandleMergeObservations_WorkerUnavailable tests merge when worker is down.
|
||
func TestHandleMergeObservations_WorkerUnavailable(t *testing.T) {
|
||
t.Parallel()
|
||
|
||
server := NewServer(nil, "", "", "1.0.0")
|
||
ctx := context.Background()
|
||
|
||
// Should fail when worker is unavailable (nil client)
|
||
_, err := server.handleMergeProxy(ctx, json.RawMessage(`{"source_id": 1, "target_id": 2}`))
|
||
require.Error(t, err)
|
||
}
|
||
|
||
// =============================================================================
|
||
// REGRESSION TESTS — Fix #45: Concurrent request dispatching
|
||
// =============================================================================
|
||
|
||
// collectResponses reads newline-delimited JSON responses from r until it has
|
||
// collected n responses or the context is done. Returns collected responses.
|
||
func collectResponses(t *testing.T, r io.Reader, n int) []map[string]any {
|
||
t.Helper()
|
||
results := make([]map[string]any, 0, n)
|
||
scanner := bufio.NewScanner(r)
|
||
for scanner.Scan() {
|
||
line := scanner.Text()
|
||
if line == "" {
|
||
continue
|
||
}
|
||
var resp map[string]any
|
||
if err := json.Unmarshal([]byte(line), &resp); err != nil {
|
||
t.Logf("collectResponses: bad JSON line: %s", line)
|
||
continue
|
||
}
|
||
results = append(results, resp)
|
||
if len(results) >= n {
|
||
break
|
||
}
|
||
}
|
||
return results
|
||
}
|
||
|
||
// writeRequests writes newline-delimited JSON requests to w and closes it when done.
|
||
func writeRequests(t *testing.T, w io.WriteCloser, reqs []string) {
|
||
t.Helper()
|
||
for _, r := range reqs {
|
||
_, err := io.WriteString(w, r+"\n")
|
||
require.NoError(t, err)
|
||
}
|
||
_ = w.Close()
|
||
}
|
||
|
||
// TestRun_ConcurrentRequests verifies multiple requests are processed concurrently
|
||
// and not serially. If they ran serially, 5 × 100ms = 500ms. Concurrent should be
|
||
// well under 400ms on any reasonable machine.
|
||
func TestRun_ConcurrentRequests(t *testing.T) {
|
||
t.Parallel()
|
||
|
||
const delay = 100 * time.Millisecond
|
||
const numRequests = 5
|
||
|
||
// Mock worker: every request sleeps delay then returns "{}"
|
||
ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||
time.Sleep(delay)
|
||
w.Header().Set("Content-Type", "application/json")
|
||
_, _ = io.WriteString(w, `{}`)
|
||
}))
|
||
defer ts.Close()
|
||
|
||
stdinR, stdinW := io.Pipe()
|
||
stdoutR, stdoutW := io.Pipe()
|
||
|
||
server := &Server{
|
||
client: ts.Client(),
|
||
workerURL: ts.URL,
|
||
project: "test",
|
||
version: "1.0.0",
|
||
stdin: stdinR,
|
||
stdout: stdoutW,
|
||
}
|
||
|
||
// Build requests — use get_memory_stats which goes to GET /api/stats
|
||
reqs := make([]string, numRequests)
|
||
for i := range reqs {
|
||
req := Request{JSONRPC: "2.0", ID: i + 1, Method: "tools/call",
|
||
Params: json.RawMessage(`{"name":"get_memory_stats","arguments":{}}`)}
|
||
data, err := json.Marshal(req)
|
||
require.NoError(t, err)
|
||
reqs[i] = string(data)
|
||
}
|
||
|
||
// Collect responses in background
|
||
var responses []map[string]any
|
||
var wg sync.WaitGroup
|
||
wg.Go(func() {
|
||
responses = collectResponses(t, stdoutR, numRequests)
|
||
_ = stdoutR.Close()
|
||
})
|
||
|
||
start := time.Now()
|
||
|
||
// Write all requests then close stdin (triggers Run to drain and return)
|
||
go writeRequests(t, stdinW, reqs)
|
||
|
||
err := server.Run(context.Background())
|
||
require.NoError(t, err)
|
||
_ = stdoutW.Close()
|
||
|
||
elapsed := time.Since(start)
|
||
wg.Wait()
|
||
|
||
// All responses received
|
||
assert.Len(t, responses, numRequests, "expected %d responses", numRequests)
|
||
|
||
// Concurrent execution: should be much less than numRequests × delay
|
||
serialTime := time.Duration(numRequests) * delay
|
||
assert.Less(t, elapsed, serialTime*4/5,
|
||
"elapsed %v not significantly less than serial %v — requests may be sequential", elapsed, serialTime)
|
||
|
||
// Each response has correct jsonrpc field
|
||
for _, resp := range responses {
|
||
assert.Equal(t, "2.0", resp["jsonrpc"])
|
||
}
|
||
}
|
||
|
||
// TestRun_SlowRequestDoesNotBlockOthers is the core regression for #45.
|
||
// A slow search request must not block a fast stats request from being answered first.
|
||
func TestRun_SlowRequestDoesNotBlockOthers(t *testing.T) {
|
||
t.Parallel()
|
||
|
||
const slowDelay = 300 * time.Millisecond
|
||
|
||
// responseOrder records which request IDs responded, in arrival order
|
||
var mu sync.Mutex
|
||
var responseOrder []any
|
||
|
||
// Mock worker: /api/context/search is slow, everything else is fast
|
||
ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||
if strings.Contains(r.URL.Path, "/context/search") {
|
||
time.Sleep(slowDelay)
|
||
}
|
||
w.Header().Set("Content-Type", "application/json")
|
||
_, _ = io.WriteString(w, `{}`)
|
||
}))
|
||
defer ts.Close()
|
||
|
||
stdinR, stdinW := io.Pipe()
|
||
stdoutR, stdoutW := io.Pipe()
|
||
|
||
// Intercept stdout to record response order before passing through
|
||
pr, pw := io.Pipe()
|
||
go func() {
|
||
scanner := bufio.NewScanner(pr)
|
||
for scanner.Scan() {
|
||
line := scanner.Text()
|
||
var resp map[string]any
|
||
if err := json.Unmarshal([]byte(line), &resp); err == nil {
|
||
mu.Lock()
|
||
responseOrder = append(responseOrder, resp["id"])
|
||
mu.Unlock()
|
||
}
|
||
_, _ = io.WriteString(stdoutW, line+"\n")
|
||
}
|
||
_ = stdoutW.Close()
|
||
}()
|
||
|
||
server := &Server{
|
||
client: ts.Client(),
|
||
workerURL: ts.URL,
|
||
project: "test",
|
||
version: "1.0.0",
|
||
stdin: stdinR,
|
||
stdout: pw,
|
||
}
|
||
|
||
// Request 1: slow search (id=1)
|
||
slowReq := Request{JSONRPC: "2.0", ID: 1, Method: "tools/call",
|
||
Params: json.RawMessage(`{"name":"search","arguments":{"query":"anything"}}`)}
|
||
slowData, err := json.Marshal(slowReq)
|
||
require.NoError(t, err)
|
||
|
||
// Request 2: fast stats (id=2)
|
||
fastReq := Request{JSONRPC: "2.0", ID: 2, Method: "tools/call",
|
||
Params: json.RawMessage(`{"name":"get_memory_stats","arguments":{}}`)}
|
||
fastData, err := json.Marshal(fastReq)
|
||
require.NoError(t, err)
|
||
|
||
// Collect 2 responses
|
||
var responses []map[string]any
|
||
var wg sync.WaitGroup
|
||
wg.Go(func() {
|
||
responses = collectResponses(t, stdoutR, 2)
|
||
_ = stdoutR.Close()
|
||
})
|
||
|
||
// Write slow then fast, then close
|
||
go func() {
|
||
_, _ = io.WriteString(stdinW, string(slowData)+"\n")
|
||
// Small pause to ensure slow request goroutine is dispatched first
|
||
time.Sleep(10 * time.Millisecond)
|
||
_, _ = io.WriteString(stdinW, string(fastData)+"\n")
|
||
_ = stdinW.Close()
|
||
}()
|
||
|
||
runErr := server.Run(context.Background())
|
||
require.NoError(t, runErr)
|
||
|
||
wg.Wait()
|
||
|
||
require.Len(t, responses, 2, "expected 2 responses")
|
||
|
||
mu.Lock()
|
||
order := responseOrder
|
||
mu.Unlock()
|
||
|
||
require.Len(t, order, 2, "expected 2 recorded response IDs")
|
||
|
||
// The fast request (id=2) must arrive before the slow one (id=1)
|
||
assert.Equal(t, float64(2), order[0],
|
||
"fast request (id=2) should respond before slow request (id=1); got order %v", order)
|
||
assert.Equal(t, float64(1), order[1],
|
||
"slow request (id=1) should respond second; got order %v", order)
|
||
}
|
||
|
||
// TestRun_SemaphoreLimitsConcurrency verifies the semaphore cap (10) does not deadlock
|
||
// when more than 10 requests are sent and all eventually complete.
|
||
func TestRun_SemaphoreLimitsConcurrency(t *testing.T) {
|
||
t.Parallel()
|
||
|
||
const blockDelay = 200 * time.Millisecond
|
||
const numRequests = 15 // exceeds semaphore cap of 10
|
||
|
||
ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||
time.Sleep(blockDelay)
|
||
w.Header().Set("Content-Type", "application/json")
|
||
_, _ = io.WriteString(w, `{}`)
|
||
}))
|
||
defer ts.Close()
|
||
|
||
stdinR, stdinW := io.Pipe()
|
||
stdoutR, stdoutW := io.Pipe()
|
||
|
||
server := &Server{
|
||
client: ts.Client(),
|
||
workerURL: ts.URL,
|
||
project: "test",
|
||
version: "1.0.0",
|
||
stdin: stdinR,
|
||
stdout: stdoutW,
|
||
}
|
||
|
||
reqs := make([]string, numRequests)
|
||
for i := range reqs {
|
||
req := Request{JSONRPC: "2.0", ID: i + 1, Method: "tools/call",
|
||
Params: json.RawMessage(`{"name":"get_memory_stats","arguments":{}}`)}
|
||
data, err := json.Marshal(req)
|
||
require.NoError(t, err)
|
||
reqs[i] = string(data)
|
||
}
|
||
|
||
var responses []map[string]any
|
||
var wg sync.WaitGroup
|
||
wg.Go(func() {
|
||
responses = collectResponses(t, stdoutR, numRequests)
|
||
_ = stdoutR.Close()
|
||
})
|
||
|
||
start := time.Now()
|
||
go writeRequests(t, stdinW, reqs)
|
||
|
||
err := server.Run(context.Background())
|
||
require.NoError(t, err)
|
||
_ = stdoutW.Close()
|
||
|
||
elapsed := time.Since(start)
|
||
wg.Wait()
|
||
|
||
// All 15 responses received — no deadlock
|
||
assert.Len(t, responses, numRequests, "all %d requests must complete", numRequests)
|
||
|
||
// With semaphore=10 and 15 requests at 200ms each, we need at least 2 batches.
|
||
// Should complete in ~2×blockDelay not 15×blockDelay.
|
||
// Upper bound: 3×blockDelay gives comfortable headroom for scheduling.
|
||
upperBound := 3 * blockDelay * 2 // generous: 3 batches + 2× overhead factor
|
||
assert.Less(t, elapsed, upperBound,
|
||
"elapsed %v suggests sequential processing (15×%v = %v)", elapsed, blockDelay, time.Duration(numRequests)*blockDelay)
|
||
}
|
||
|
||
// TestRun_GracefulDrainOnCancel verifies that cancelling the context causes Run to
|
||
// drain in-flight requests (wg.Wait) before returning ctx.Canceled.
|
||
func TestRun_GracefulDrainOnCancel(t *testing.T) {
|
||
t.Parallel()
|
||
|
||
const reqDelay = 200 * time.Millisecond
|
||
const numRequests = 3
|
||
|
||
ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||
// Use a fixed sleep; the request-level context will be cancelled but the
|
||
// HTTP handler runs to completion independently (server-side).
|
||
time.Sleep(reqDelay)
|
||
w.Header().Set("Content-Type", "application/json")
|
||
_, _ = io.WriteString(w, `{}`)
|
||
}))
|
||
defer ts.Close()
|
||
|
||
stdinR, stdinW := io.Pipe()
|
||
stdoutR, stdoutW := io.Pipe()
|
||
|
||
server := &Server{
|
||
client: ts.Client(),
|
||
workerURL: ts.URL,
|
||
project: "test",
|
||
version: "1.0.0",
|
||
stdin: stdinR,
|
||
stdout: stdoutW,
|
||
}
|
||
|
||
ctx, cancel := context.WithCancel(context.Background())
|
||
|
||
// Build requests
|
||
reqs := make([]string, numRequests)
|
||
for i := range reqs {
|
||
req := Request{JSONRPC: "2.0", ID: i + 1, Method: "tools/call",
|
||
Params: json.RawMessage(`{"name":"get_memory_stats","arguments":{}}`)}
|
||
data, err := json.Marshal(req)
|
||
require.NoError(t, err)
|
||
reqs[i] = string(data)
|
||
}
|
||
|
||
// Write all requests before cancelling so they're all dispatched as goroutines
|
||
go func() {
|
||
for _, r := range reqs {
|
||
_, _ = io.WriteString(stdinW, r+"\n")
|
||
}
|
||
// Cancel after requests are dispatched but while they're still in-flight
|
||
time.Sleep(50 * time.Millisecond)
|
||
cancel()
|
||
// Leave stdin open — Run should return from the ctx.Done branch
|
||
}()
|
||
|
||
// Drain responses in background; we don't know exactly how many will complete
|
||
// because goroutines may get context cancelled on their HTTP calls too.
|
||
var responseMu sync.Mutex
|
||
var collectedResponses []map[string]any
|
||
var collectWg sync.WaitGroup
|
||
collectWg.Go(func() {
|
||
scanner := bufio.NewScanner(stdoutR)
|
||
for scanner.Scan() {
|
||
line := scanner.Text()
|
||
if line == "" {
|
||
continue
|
||
}
|
||
var resp map[string]any
|
||
if err := json.Unmarshal([]byte(line), &resp); err == nil {
|
||
responseMu.Lock()
|
||
collectedResponses = append(collectedResponses, resp)
|
||
responseMu.Unlock()
|
||
}
|
||
}
|
||
})
|
||
|
||
runErr := server.Run(ctx)
|
||
_ = stdoutW.Close()
|
||
_ = stdinW.Close()
|
||
|
||
collectWg.Wait()
|
||
|
||
// Core assertion: Run returned the context cancellation error
|
||
assert.ErrorIs(t, runErr, context.Canceled,
|
||
"Run must return context.Canceled when context is cancelled")
|
||
|
||
// Run returned only after wg.Wait() drained goroutines.
|
||
// The goroutines may have returned errors (ctx cancelled HTTP calls) but
|
||
// the key invariant is Run itself did not panic and returned cleanly.
|
||
// Any responses that did complete should be valid JSON-RPC.
|
||
responseMu.Lock()
|
||
defer responseMu.Unlock()
|
||
for _, resp := range collectedResponses {
|
||
assert.Equal(t, "2.0", resp["jsonrpc"], "any completed response must be valid JSON-RPC 2.0")
|
||
}
|
||
}
|
||
|
||
// =============================================================================
|
||
// REGRESSION TESTS — Fix #1-#5
|
||
// =============================================================================
|
||
|
||
// blockingWriter is an io.Writer that blocks forever on Write.
|
||
type blockingWriter struct {
|
||
blocked chan struct{} // closed when Write is entered
|
||
}
|
||
|
||
func (bw *blockingWriter) Write(p []byte) (int, error) {
|
||
if bw.blocked != nil {
|
||
close(bw.blocked)
|
||
}
|
||
select {} // block forever
|
||
}
|
||
|
||
// TestRun_SemaphoreDoesNotBlockMainLoop (Fix #1 regression) fills all semaphore
|
||
// slots with blocked requests, then sends a notification. The main loop must
|
||
// stay responsive and not hang on semaphore acquisition.
|
||
func TestRun_SemaphoreDoesNotBlockMainLoop(t *testing.T) {
|
||
t.Parallel()
|
||
|
||
const maxConcurrent = 10
|
||
|
||
// Mock worker that blocks until test context is cancelled
|
||
handlerCtx, handlerCancel := context.WithCancel(context.Background())
|
||
defer handlerCancel()
|
||
|
||
ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||
<-handlerCtx.Done() // block until test cleanup
|
||
w.Header().Set("Content-Type", "application/json")
|
||
_, _ = io.WriteString(w, `{}`)
|
||
}))
|
||
defer func() {
|
||
handlerCancel() // unblock all handlers first
|
||
ts.CloseClientConnections()
|
||
ts.Close()
|
||
}()
|
||
|
||
stdinR, stdinW := io.Pipe()
|
||
var stdout bytes.Buffer
|
||
|
||
server := &Server{
|
||
client: ts.Client(),
|
||
workerURL: ts.URL,
|
||
project: "test",
|
||
version: "1.0.0",
|
||
stdin: stdinR,
|
||
stdout: &stdout,
|
||
}
|
||
|
||
ctx, cancel := context.WithCancel(context.Background())
|
||
defer cancel()
|
||
|
||
runDone := make(chan error, 1)
|
||
go func() {
|
||
runDone <- server.Run(ctx)
|
||
}()
|
||
|
||
// Send maxConcurrent+2 requests to fill all semaphore slots
|
||
for i := 0; i < maxConcurrent+2; i++ {
|
||
req := Request{JSONRPC: "2.0", ID: i + 1, Method: "tools/call",
|
||
Params: json.RawMessage(`{"name":"get_memory_stats","arguments":{}}`)}
|
||
data, _ := json.Marshal(req)
|
||
_, err := io.WriteString(stdinW, string(data)+"\n")
|
||
require.NoError(t, err)
|
||
}
|
||
|
||
// Give goroutines time to start and fill semaphore
|
||
time.Sleep(100 * time.Millisecond)
|
||
|
||
// Now send a notification — this must not hang the main loop
|
||
notifSent := make(chan struct{})
|
||
go func() {
|
||
notif := `{"jsonrpc":"2.0","method":"notifications/initialized"}` + "\n"
|
||
_, _ = io.WriteString(stdinW, notif)
|
||
close(notifSent)
|
||
}()
|
||
|
||
select {
|
||
case <-notifSent:
|
||
// Main loop accepted the notification write (stdin is a pipe, so
|
||
// the write completing means the reader consumed it)
|
||
case <-time.After(3 * time.Second):
|
||
t.Fatal("main loop blocked — semaphore acquisition is blocking the main goroutine")
|
||
}
|
||
|
||
// Clean up: cancel server context, close stdin
|
||
cancel()
|
||
_ = stdinW.Close()
|
||
|
||
// Wait for Run to finish
|
||
select {
|
||
case <-runDone:
|
||
case <-time.After(5 * time.Second):
|
||
// Acceptable — some goroutines may still be draining
|
||
}
|
||
}
|
||
|
||
// TestSendResponse_WriteTimeout (Fix #2 regression) verifies that sendResponse
|
||
// returns an error within a bounded time when the writer blocks forever.
|
||
func TestSendResponse_WriteTimeout(t *testing.T) {
|
||
t.Parallel()
|
||
|
||
bw := &blockingWriter{blocked: make(chan struct{})}
|
||
server := &Server{stdout: bw}
|
||
|
||
resp := &Response{
|
||
JSONRPC: "2.0",
|
||
ID: 1,
|
||
Result: "ok",
|
||
}
|
||
|
||
done := make(chan error, 1)
|
||
go func() {
|
||
done <- server.sendResponse(resp)
|
||
}()
|
||
|
||
select {
|
||
case err := <-done:
|
||
require.Error(t, err, "sendResponse must return error on write timeout")
|
||
assert.Contains(t, err.Error(), "write timeout")
|
||
case <-time.After(15 * time.Second):
|
||
t.Fatal("sendResponse hung forever — write timeout not working")
|
||
}
|
||
}
|
||
|
||
// TestSendResponse_MarshalError (Fix #3 regression) verifies that when
|
||
// json.Marshal fails, sendResponse sends a fallback error response and
|
||
// returns an error (instead of silently returning nil).
|
||
func TestSendResponse_MarshalError(t *testing.T) {
|
||
t.Parallel()
|
||
|
||
var buf bytes.Buffer
|
||
server := &Server{stdout: &buf}
|
||
|
||
// Channels are not JSON-serializable — this will cause json.Marshal to fail
|
||
resp := &Response{
|
||
JSONRPC: "2.0",
|
||
ID: 42,
|
||
Result: make(chan int), // unserializable
|
||
}
|
||
|
||
err := server.sendResponse(resp)
|
||
|
||
// (a) Must return an error
|
||
require.Error(t, err, "sendResponse must return error when marshal fails")
|
||
assert.Contains(t, err.Error(), "marshal error")
|
||
|
||
// (b) Must have written a fallback JSON-RPC error to stdout
|
||
output := buf.String()
|
||
require.NotEmpty(t, output, "fallback response must be written to stdout")
|
||
|
||
var fallback map[string]any
|
||
require.NoError(t, json.Unmarshal([]byte(strings.TrimSpace(output)), &fallback),
|
||
"fallback must be valid JSON")
|
||
assert.Equal(t, "2.0", fallback["jsonrpc"])
|
||
assert.Equal(t, float64(42), fallback["id"], "fallback must preserve original request ID")
|
||
|
||
errObj, ok := fallback["error"].(map[string]any)
|
||
require.True(t, ok, "fallback must contain error object")
|
||
assert.Equal(t, float64(-32603), errObj["code"])
|
||
assert.Equal(t, "internal marshal error", errObj["message"])
|
||
}
|
||
|
||
// TestHandleRequest_UnknownNotification (Fix #4 regression) verifies that
|
||
// unknown notification methods (ID == nil) get no response, while unknown
|
||
// methods with an ID still get a -32601 error.
|
||
func TestHandleRequest_UnknownNotification(t *testing.T) {
|
||
t.Parallel()
|
||
|
||
server := NewServer(nil, "", "", "1.0.0")
|
||
ctx := context.Background()
|
||
|
||
// Case 1: Unknown notification (no ID) — must return nil
|
||
notifReq := &Request{
|
||
JSONRPC: "2.0",
|
||
ID: nil,
|
||
Method: "notifications/roots/list_changed",
|
||
}
|
||
resp := server.handleRequest(ctx, notifReq)
|
||
assert.Nil(t, resp, "unknown notification must not produce a response")
|
||
|
||
// Case 2: Unknown method WITH an ID — must return error response
|
||
methodReq := &Request{
|
||
JSONRPC: "2.0",
|
||
ID: 99,
|
||
Method: "some/unknown/method",
|
||
}
|
||
resp = server.handleRequest(ctx, methodReq)
|
||
require.NotNil(t, resp, "unknown method with ID must produce an error response")
|
||
require.NotNil(t, resp.Error)
|
||
assert.Equal(t, -32601, resp.Error.Code)
|
||
assert.Equal(t, "Method not found", resp.Error.Message)
|
||
assert.Equal(t, 99, resp.ID)
|
||
}
|
||
|
||
// TestHandleToolsCall_ErrorUsesIsError (Fix #5 regression) verifies that when
|
||
// callTool returns an error, the response uses Result with isError:true instead
|
||
// of top-level Error field (per MCP spec).
|
||
func TestHandleToolsCall_ErrorUsesIsError(t *testing.T) {
|
||
t.Parallel()
|
||
|
||
server := NewServer(nil, "", "", "1.0.0")
|
||
ctx := context.Background()
|
||
|
||
req := &Request{
|
||
JSONRPC: "2.0",
|
||
ID: 7,
|
||
Method: "tools/call",
|
||
Params: json.RawMessage(`{"name":"nonexistent_tool","arguments":{}}`),
|
||
}
|
||
|
||
resp := server.handleToolsCall(ctx, req)
|
||
|
||
// (a) Response must NOT have top-level Error
|
||
assert.Nil(t, resp.Error, "tool errors must not use top-level JSON-RPC Error")
|
||
|
||
// (b) Response must have Result with isError: true
|
||
require.NotNil(t, resp.Result, "tool error response must have Result")
|
||
result, ok := resp.Result.(map[string]any)
|
||
require.True(t, ok, "Result must be a map")
|
||
assert.Equal(t, true, result["isError"], "Result must contain isError: true")
|
||
|
||
// (c) Result.content[0].text must contain the error message
|
||
content, ok := result["content"].([]map[string]any)
|
||
require.True(t, ok, "Result.content must be []map[string]any")
|
||
require.Len(t, content, 1)
|
||
assert.Equal(t, "text", content[0]["type"])
|
||
errText, ok := content[0]["text"].(string)
|
||
require.True(t, ok)
|
||
assert.Contains(t, errText, "unknown tool: nonexistent_tool")
|
||
}
|