Files
claude-mnemonic/internal/mcp/server_test.go
T

600 lines
14 KiB
Go

// Package mcp provides the MCP (Model Context Protocol) server for claude-mnemonic.
package mcp
import (
"bytes"
"context"
"encoding/json"
"strings"
"testing"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"github.com/stretchr/testify/suite"
)
// ServerSuite is a test suite for MCP Server operations.
type ServerSuite struct {
suite.Suite
}
func TestServerSuite(t *testing.T) {
suite.Run(t, new(ServerSuite))
}
// TestNewServer tests server creation.
func (s *ServerSuite) TestNewServer() {
server := NewServer(nil, "1.0.0")
s.NotNil(server)
s.Nil(server.searchMgr)
s.Equal("1.0.0", server.version)
}
// TestRequest tests Request struct JSON marshaling.
func TestRequest(t *testing.T) {
tests := []struct {
name string
req Request
expected string
}{
{
name: "initialize request",
req: Request{
JSONRPC: "2.0",
ID: 1,
Method: "initialize",
},
expected: `{"jsonrpc":"2.0","id":1,"method":"initialize"}`,
},
{
name: "tools/list request",
req: Request{
JSONRPC: "2.0",
ID: "abc",
Method: "tools/list",
},
expected: `{"jsonrpc":"2.0","id":"abc","method":"tools/list"}`,
},
{
name: "tools/call with params",
req: Request{
JSONRPC: "2.0",
ID: 2,
Method: "tools/call",
Params: json.RawMessage(`{"name":"search","arguments":{}}`),
},
expected: `{"jsonrpc":"2.0","id":2,"method":"tools/call","params":{"name":"search","arguments":{}}}`,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
data, err := json.Marshal(tt.req)
require.NoError(t, err)
assert.JSONEq(t, tt.expected, string(data))
// Test unmarshaling
var parsed Request
err = json.Unmarshal(data, &parsed)
require.NoError(t, err)
assert.Equal(t, tt.req.JSONRPC, parsed.JSONRPC)
assert.Equal(t, tt.req.Method, parsed.Method)
})
}
}
// TestResponse tests Response struct JSON marshaling.
func TestResponse(t *testing.T) {
tests := []struct {
name string
resp Response
expected string
}{
{
name: "success response",
resp: Response{
JSONRPC: "2.0",
ID: 1,
Result: map[string]string{"status": "ok"},
},
expected: `{"jsonrpc":"2.0","id":1,"result":{"status":"ok"}}`,
},
{
name: "error response",
resp: Response{
JSONRPC: "2.0",
ID: 2,
Error: &Error{
Code: -32600,
Message: "Invalid Request",
},
},
expected: `{"jsonrpc":"2.0","id":2,"error":{"code":-32600,"message":"Invalid Request"}}`,
},
{
name: "error with data",
resp: Response{
JSONRPC: "2.0",
ID: 3,
Error: &Error{
Code: -32602,
Message: "Invalid params",
Data: "missing field",
},
},
expected: `{"jsonrpc":"2.0","id":3,"error":{"code":-32602,"message":"Invalid params","data":"missing field"}}`,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
data, err := json.Marshal(tt.resp)
require.NoError(t, err)
assert.JSONEq(t, tt.expected, string(data))
})
}
}
// TestError tests Error struct.
func TestError(t *testing.T) {
tests := []struct {
name string
err Error
expected string
}{
{
name: "parse error",
err: Error{
Code: -32700,
Message: "Parse error",
},
expected: `{"code":-32700,"message":"Parse error"}`,
},
{
name: "method not found",
err: Error{
Code: -32601,
Message: "Method not found",
},
expected: `{"code":-32601,"message":"Method not found"}`,
},
{
name: "invalid params",
err: Error{
Code: -32602,
Message: "Invalid params",
Data: "details here",
},
expected: `{"code":-32602,"message":"Invalid params","data":"details here"}`,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
data, err := json.Marshal(tt.err)
require.NoError(t, err)
assert.JSONEq(t, tt.expected, string(data))
})
}
}
// TestToolCallParams tests ToolCallParams struct.
func TestToolCallParams(t *testing.T) {
tests := []struct {
name string
input string
expected ToolCallParams
}{
{
name: "search tool call",
input: `{"name":"search","arguments":{"query":"test"}}`,
expected: ToolCallParams{
Name: "search",
Arguments: json.RawMessage(`{"query":"test"}`),
},
},
{
name: "decisions tool call",
input: `{"name":"decisions","arguments":{"query":"auth"}}`,
expected: ToolCallParams{
Name: "decisions",
Arguments: json.RawMessage(`{"query":"auth"}`),
},
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
var params ToolCallParams
err := json.Unmarshal([]byte(tt.input), &params)
require.NoError(t, err)
assert.Equal(t, tt.expected.Name, params.Name)
})
}
}
// TestTool tests Tool struct.
func TestTool(t *testing.T) {
tool := Tool{
Name: "search",
Description: "Search observations",
InputSchema: map[string]any{
"type": "object",
"properties": map[string]any{
"query": map[string]any{"type": "string"},
},
},
}
data, err := json.Marshal(tool)
require.NoError(t, err)
var parsed Tool
err = json.Unmarshal(data, &parsed)
require.NoError(t, err)
assert.Equal(t, "search", parsed.Name)
assert.Equal(t, "Search observations", parsed.Description)
}
// TestTimelineParams tests TimelineParams struct.
func TestTimelineParams(t *testing.T) {
tests := []struct {
name string
input string
expected TimelineParams
}{
{
name: "with anchor_id",
input: `{"anchor_id":123,"before":5,"after":5}`,
expected: TimelineParams{
AnchorID: 123,
Before: 5,
After: 5,
},
},
{
name: "with query",
input: `{"query":"test query","project":"my-project"}`,
expected: TimelineParams{
Query: "test query",
Project: "my-project",
},
},
{
name: "full params",
input: `{"anchor_id":100,"query":"search","before":10,"after":20,"project":"proj","obs_type":"bugfix","concepts":"security","files":"main.go","dateStart":1234567890,"dateEnd":9876543210,"format":"full"}`,
expected: TimelineParams{
AnchorID: 100,
Query: "search",
Before: 10,
After: 20,
Project: "proj",
ObsType: "bugfix",
Concepts: "security",
Files: "main.go",
DateStart: 1234567890,
DateEnd: 9876543210,
Format: "full",
},
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
var params TimelineParams
err := json.Unmarshal([]byte(tt.input), &params)
require.NoError(t, err)
assert.Equal(t, tt.expected.AnchorID, params.AnchorID)
assert.Equal(t, tt.expected.Query, params.Query)
assert.Equal(t, tt.expected.Project, params.Project)
})
}
}
// TestHandleInitialize tests the initialize handler.
func TestHandleInitialize(t *testing.T) {
server := NewServer(nil, "1.2.3")
req := &Request{
JSONRPC: "2.0",
ID: 1,
Method: "initialize",
}
resp := server.handleInitialize(req)
assert.Equal(t, "2.0", resp.JSONRPC)
assert.Equal(t, 1, resp.ID)
assert.Nil(t, resp.Error)
assert.NotNil(t, resp.Result)
result, ok := resp.Result.(map[string]any)
require.True(t, ok)
assert.Equal(t, "2024-11-05", result["protocolVersion"])
serverInfo, ok := result["serverInfo"].(map[string]any)
require.True(t, ok)
assert.Equal(t, "claude-mnemonic", serverInfo["name"])
assert.Equal(t, "1.2.3", serverInfo["version"])
}
// TestHandleToolsList tests the tools/list handler.
func TestHandleToolsList(t *testing.T) {
server := NewServer(nil, "1.0.0")
req := &Request{
JSONRPC: "2.0",
ID: 1,
Method: "tools/list",
}
resp := server.handleToolsList(req)
assert.Equal(t, "2.0", resp.JSONRPC)
assert.Equal(t, 1, resp.ID)
assert.Nil(t, resp.Error)
result, ok := resp.Result.(map[string]any)
require.True(t, ok)
tools, ok := result["tools"].([]Tool)
require.True(t, ok)
assert.NotEmpty(t, tools)
// Verify expected tools are present
toolNames := make(map[string]bool)
for _, tool := range tools {
toolNames[tool.Name] = true
}
expectedTools := []string{
"search", "timeline", "decisions", "changes",
"how_it_works", "find_by_concept", "find_by_file",
"find_by_type", "get_recent_context", "get_context_timeline",
"get_timeline_by_query",
}
for _, name := range expectedTools {
assert.True(t, toolNames[name], "expected tool %s to be present", name)
}
}
// TestHandleRequest tests request routing.
func TestHandleRequest(t *testing.T) {
server := NewServer(nil, "1.0.0")
ctx := context.Background()
tests := []struct {
name string
req *Request
expectError bool
errorCode int
errorMessage string
}{
{
name: "initialize method",
req: &Request{
JSONRPC: "2.0",
ID: 1,
Method: "initialize",
},
expectError: false,
},
{
name: "tools/list method",
req: &Request{
JSONRPC: "2.0",
ID: 2,
Method: "tools/list",
},
expectError: false,
},
{
name: "unknown method",
req: &Request{
JSONRPC: "2.0",
ID: 3,
Method: "unknown_method",
},
expectError: true,
errorCode: -32601,
errorMessage: "Method not found",
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
resp := server.handleRequest(ctx, tt.req)
assert.Equal(t, "2.0", resp.JSONRPC)
assert.Equal(t, tt.req.ID, resp.ID)
if tt.expectError {
require.NotNil(t, resp.Error)
assert.Equal(t, tt.errorCode, resp.Error.Code)
assert.Equal(t, tt.errorMessage, resp.Error.Message)
} else {
assert.Nil(t, resp.Error)
assert.NotNil(t, resp.Result)
}
})
}
}
// TestHandleToolsCall_InvalidParams tests tools/call with invalid params.
func TestHandleToolsCall_InvalidParams(t *testing.T) {
server := NewServer(nil, "1.0.0")
ctx := context.Background()
req := &Request{
JSONRPC: "2.0",
ID: 1,
Method: "tools/call",
Params: json.RawMessage(`invalid json`),
}
resp := server.handleToolsCall(ctx, req)
require.NotNil(t, resp.Error)
assert.Equal(t, -32602, resp.Error.Code)
assert.Equal(t, "Invalid params", resp.Error.Message)
}
// TestCallTool_UnknownTool tests callTool with unknown tool name.
func TestCallTool_UnknownTool(t *testing.T) {
server := NewServer(nil, "1.0.0")
ctx := context.Background()
_, err := server.callTool(ctx, "nonexistent_tool", json.RawMessage(`{}`))
require.Error(t, err)
assert.Contains(t, err.Error(), "unknown tool")
}
// TestCallTool_InvalidArgs tests callTool with invalid arguments.
func TestCallTool_InvalidArgs(t *testing.T) {
server := NewServer(nil, "1.0.0")
ctx := context.Background()
_, err := server.callTool(ctx, "search", json.RawMessage(`invalid json`))
require.Error(t, err)
assert.Contains(t, err.Error(), "invalid arguments")
}
// TestSendResponse tests response sending.
func TestSendResponse(t *testing.T) {
var buf bytes.Buffer
server := &Server{
stdout: &buf,
}
resp := &Response{
JSONRPC: "2.0",
ID: 1,
Result: map[string]string{"status": "ok"},
}
server.sendResponse(resp)
output := buf.String()
assert.Contains(t, output, `"jsonrpc":"2.0"`)
assert.Contains(t, output, `"id":1`)
assert.Contains(t, output, `"result"`)
}
// TestSendError tests error response sending.
func TestSendError(t *testing.T) {
var buf bytes.Buffer
server := &Server{
stdout: &buf,
}
server.sendError(1, -32700, "Parse error", "details")
output := buf.String()
assert.Contains(t, output, `"error"`)
assert.Contains(t, output, `-32700`)
assert.Contains(t, output, `"Parse error"`)
}
// TestRun_ParseError tests Run with invalid JSON input.
func TestRun_ParseError(t *testing.T) {
var stdout bytes.Buffer
stdin := strings.NewReader("invalid json\n")
server := &Server{
stdin: stdin,
stdout: &stdout,
}
err := server.Run(context.Background())
require.NoError(t, err)
output := stdout.String()
assert.Contains(t, output, `"error"`)
assert.Contains(t, output, `-32700`)
assert.Contains(t, output, `"Parse error"`)
}
// TestRun_EmptyLine tests Run skips empty lines.
func TestRun_EmptyLine(t *testing.T) {
var stdout bytes.Buffer
stdin := strings.NewReader("\n\n")
server := &Server{
stdin: stdin,
stdout: &stdout,
}
err := server.Run(context.Background())
require.NoError(t, err)
// Should be empty - no responses for empty lines
assert.Empty(t, stdout.String())
}
// TestRun_ValidRequest tests Run with a valid request.
func TestRun_ValidRequest(t *testing.T) {
var stdout bytes.Buffer
req := `{"jsonrpc":"2.0","id":1,"method":"initialize"}`
stdin := strings.NewReader(req + "\n")
server := &Server{
stdin: stdin,
stdout: &stdout,
version: "1.0.0",
}
err := server.Run(context.Background())
require.NoError(t, err)
output := stdout.String()
assert.Contains(t, output, `"jsonrpc":"2.0"`)
assert.Contains(t, output, `"result"`)
assert.Contains(t, output, `"protocolVersion"`)
}
// TestJSONRPCErrorCodes tests standard JSON-RPC error codes.
func TestJSONRPCErrorCodes(t *testing.T) {
errorCodes := map[string]int{
"Parse error": -32700,
"Invalid Request": -32600,
"Method not found": -32601,
"Invalid params": -32602,
"Internal error": -32603,
}
for msg, code := range errorCodes {
t.Run(msg, func(t *testing.T) {
err := Error{Code: code, Message: msg}
assert.Equal(t, code, err.Code)
assert.Equal(t, msg, err.Message)
})
}
}
// TestToolListContainsExpectedSchemas tests that tool schemas are valid.
func TestToolListContainsExpectedSchemas(t *testing.T) {
server := NewServer(nil, "1.0.0")
req := &Request{
JSONRPC: "2.0",
ID: 1,
Method: "tools/list",
}
resp := server.handleToolsList(req)
result := resp.Result.(map[string]any)
tools := result["tools"].([]Tool)
for _, tool := range tools {
assert.NotEmpty(t, tool.Name)
assert.NotEmpty(t, tool.Description)
assert.NotNil(t, tool.InputSchema)
// Check schema has type
schema := tool.InputSchema
_, hasType := schema["type"]
assert.True(t, hasType, "tool %s schema should have type", tool.Name)
}
}