// Package mcp provides the MCP (Model Context Protocol) server for claude-mnemonic. package mcp import ( "bufio" "bytes" "context" "encoding/json" "io" "net/http" "net/http/httptest" "strings" "sync" "testing" "time" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" "github.com/stretchr/testify/suite" ) // timelineParams is used in tests for timeline request parsing. type timelineParams struct { Query string `json:"query"` Project string `json:"project"` ObsType string `json:"obs_type"` Concepts string `json:"concepts"` Files string `json:"files"` Format string `json:"format"` AnchorID int64 `json:"anchor_id"` Before int `json:"before"` After int `json:"after"` DateStart int64 `json:"dateStart"` DateEnd int64 `json:"dateEnd"` } // ============================================================================= // TEST SUITE // ============================================================================= // ServerSuite is a test suite for MCP Server operations. type ServerSuite struct { suite.Suite } func TestServerSuite(t *testing.T) { suite.Run(t, new(ServerSuite)) } // TestNewServer tests server creation. func (s *ServerSuite) TestNewServer() { server := NewServer(nil, "http://localhost:37777", "test-project", "1.0.0") s.NotNil(server) s.Nil(server.client) s.Equal("1.0.0", server.version) s.Equal("http://localhost:37777", server.workerURL) s.Equal("test-project", server.project) } // ============================================================================= // TESTS FOR Request/Response Structs // ============================================================================= // TestRequest tests Request struct JSON marshaling. func TestRequest(t *testing.T) { t.Parallel() tests := []struct { name string expected string req Request }{ // ===== GOOD CASES ===== { name: "initialize request", req: Request{ JSONRPC: "2.0", ID: 1, Method: "initialize", }, expected: `{"jsonrpc":"2.0","id":1,"method":"initialize"}`, }, { name: "tools/list request", req: Request{ JSONRPC: "2.0", ID: "abc", Method: "tools/list", }, expected: `{"jsonrpc":"2.0","id":"abc","method":"tools/list"}`, }, { name: "tools/call with params", req: Request{ JSONRPC: "2.0", ID: 2, Method: "tools/call", Params: json.RawMessage(`{"name":"search","arguments":{}}`), }, expected: `{"jsonrpc":"2.0","id":2,"method":"tools/call","params":{"name":"search","arguments":{}}}`, }, // ===== EDGE CASES ===== { name: "request with nil ID", req: Request{ JSONRPC: "2.0", ID: nil, Method: "initialize", }, expected: `{"jsonrpc":"2.0","id":null,"method":"initialize"}`, }, { name: "request with float ID", req: Request{ JSONRPC: "2.0", ID: 1.5, Method: "test", }, expected: `{"jsonrpc":"2.0","id":1.5,"method":"test"}`, }, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { t.Parallel() data, err := json.Marshal(tt.req) require.NoError(t, err) assert.JSONEq(t, tt.expected, string(data)) // Test unmarshaling var parsed Request err = json.Unmarshal(data, &parsed) require.NoError(t, err) assert.Equal(t, tt.req.JSONRPC, parsed.JSONRPC) assert.Equal(t, tt.req.Method, parsed.Method) }) } } // TestResponse tests Response struct JSON marshaling. func TestResponse(t *testing.T) { t.Parallel() tests := []struct { name string resp Response expected string }{ // ===== GOOD CASES ===== { name: "success response", resp: Response{ JSONRPC: "2.0", ID: 1, Result: map[string]string{"status": "ok"}, }, expected: `{"jsonrpc":"2.0","id":1,"result":{"status":"ok"}}`, }, { name: "error response", resp: Response{ JSONRPC: "2.0", ID: 2, Error: &Error{ Code: -32600, Message: "Invalid Request", }, }, expected: `{"jsonrpc":"2.0","id":2,"error":{"code":-32600,"message":"Invalid Request"}}`, }, { name: "error with data", resp: Response{ JSONRPC: "2.0", ID: 3, Error: &Error{ Code: -32602, Message: "Invalid params", Data: "missing field", }, }, expected: `{"jsonrpc":"2.0","id":3,"error":{"code":-32602,"message":"Invalid params","data":"missing field"}}`, }, // ===== EDGE CASES ===== { name: "response with nil ID", resp: Response{ JSONRPC: "2.0", ID: nil, Result: "ok", }, expected: `{"jsonrpc":"2.0","id":null,"result":"ok"}`, }, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { t.Parallel() data, err := json.Marshal(tt.resp) require.NoError(t, err) assert.JSONEq(t, tt.expected, string(data)) }) } } // TestError tests Error struct. func TestError(t *testing.T) { t.Parallel() tests := []struct { expected string name string err Error }{ { name: "parse error", err: Error{ Code: -32700, Message: "Parse error", }, expected: `{"code":-32700,"message":"Parse error"}`, }, { name: "method not found", err: Error{ Code: -32601, Message: "Method not found", }, expected: `{"code":-32601,"message":"Method not found"}`, }, { name: "invalid params", err: Error{ Code: -32602, Message: "Invalid params", Data: "details here", }, expected: `{"code":-32602,"message":"Invalid params","data":"details here"}`, }, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { t.Parallel() data, err := json.Marshal(tt.err) require.NoError(t, err) assert.JSONEq(t, tt.expected, string(data)) }) } } // TestToolCallParams tests ToolCallParams struct. func TestToolCallParams(t *testing.T) { t.Parallel() tests := []struct { name string input string expected ToolCallParams }{ { name: "search tool call", input: `{"name":"search","arguments":{"query":"test"}}`, expected: ToolCallParams{ Name: "search", Arguments: json.RawMessage(`{"query":"test"}`), }, }, { name: "decisions tool call", input: `{"name":"decisions","arguments":{"query":"auth"}}`, expected: ToolCallParams{ Name: "decisions", Arguments: json.RawMessage(`{"query":"auth"}`), }, }, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { t.Parallel() var params ToolCallParams err := json.Unmarshal([]byte(tt.input), ¶ms) require.NoError(t, err) assert.Equal(t, tt.expected.Name, params.Name) }) } } // TestTool tests Tool struct. func TestTool(t *testing.T) { t.Parallel() tool := Tool{ Name: "search", Description: "Search observations", InputSchema: map[string]any{ "type": "object", "properties": map[string]any{ "query": map[string]any{"type": "string"}, }, }, } data, err := json.Marshal(tool) require.NoError(t, err) var parsed Tool err = json.Unmarshal(data, &parsed) require.NoError(t, err) assert.Equal(t, "search", parsed.Name) assert.Equal(t, "Search observations", parsed.Description) } // TestTimelineParamsStruct tests timelineParams struct. func TestTimelineParamsStruct(t *testing.T) { t.Parallel() tests := []struct { name string input string expected timelineParams }{ { name: "with anchor_id", input: `{"anchor_id":123,"before":5,"after":5}`, expected: timelineParams{ AnchorID: 123, Before: 5, After: 5, }, }, { name: "with query", input: `{"query":"test query","project":"my-project"}`, expected: timelineParams{ Query: "test query", Project: "my-project", }, }, { name: "full params", input: `{"anchor_id":100,"query":"search","before":10,"after":20,"project":"proj","obs_type":"bugfix","concepts":"security","files":"main.go","dateStart":1234567890,"dateEnd":9876543210,"format":"full"}`, expected: timelineParams{ AnchorID: 100, Query: "search", Before: 10, After: 20, Project: "proj", ObsType: "bugfix", Concepts: "security", Files: "main.go", DateStart: 1234567890, DateEnd: 9876543210, Format: "full", }, }, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { t.Parallel() var params timelineParams err := json.Unmarshal([]byte(tt.input), ¶ms) require.NoError(t, err) assert.Equal(t, tt.expected.AnchorID, params.AnchorID) assert.Equal(t, tt.expected.Query, params.Query) assert.Equal(t, tt.expected.Project, params.Project) }) } } // ============================================================================= // TESTS FOR Server Handlers // ============================================================================= // TestHandleInitialize tests the initialize handler. func TestHandleInitialize(t *testing.T) { t.Parallel() server := NewServer(nil, "", "", "1.2.3") req := &Request{ JSONRPC: "2.0", ID: 1, Method: "initialize", } resp := server.handleInitialize(req) assert.Equal(t, "2.0", resp.JSONRPC) assert.Equal(t, 1, resp.ID) assert.Nil(t, resp.Error) assert.NotNil(t, resp.Result) result, ok := resp.Result.(map[string]any) require.True(t, ok) assert.Equal(t, "2024-11-05", result["protocolVersion"]) serverInfo, ok := result["serverInfo"].(map[string]any) require.True(t, ok) assert.Equal(t, "claude-mnemonic", serverInfo["name"]) assert.Equal(t, "1.2.3", serverInfo["version"]) } // TestHandleToolsList tests the tools/list handler. func TestHandleToolsList(t *testing.T) { t.Parallel() server := NewServer(nil, "", "", "1.0.0") req := &Request{ JSONRPC: "2.0", ID: 1, Method: "tools/list", } resp := server.handleToolsList(req) assert.Equal(t, "2.0", resp.JSONRPC) assert.Equal(t, 1, resp.ID) assert.Nil(t, resp.Error) result, ok := resp.Result.(map[string]any) require.True(t, ok) tools, ok := result["tools"].([]Tool) require.True(t, ok) assert.NotEmpty(t, tools) // Verify expected tools are present toolNames := make(map[string]bool) for _, tool := range tools { toolNames[tool.Name] = true } expectedTools := []string{ "search", "timeline", "decisions", "changes", "how_it_works", "find_by_concept", "find_by_file", "find_by_type", "get_recent_context", "get_context_timeline", "get_timeline_by_query", } for _, name := range expectedTools { assert.True(t, toolNames[name], "expected tool %s to be present", name) } } // TestHandleRequest tests request routing. func TestHandleRequest(t *testing.T) { t.Parallel() server := NewServer(nil, "", "", "1.0.0") ctx := context.Background() tests := []struct { req *Request name string errorMessage string errorCode int expectError bool }{ { name: "initialize method", req: &Request{ JSONRPC: "2.0", ID: 1, Method: "initialize", }, expectError: false, }, { name: "tools/list method", req: &Request{ JSONRPC: "2.0", ID: 2, Method: "tools/list", }, expectError: false, }, { name: "unknown method", req: &Request{ JSONRPC: "2.0", ID: 3, Method: "unknown_method", }, expectError: true, errorCode: -32601, errorMessage: "Method not found", }, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { t.Parallel() resp := server.handleRequest(ctx, tt.req) assert.Equal(t, "2.0", resp.JSONRPC) assert.Equal(t, tt.req.ID, resp.ID) if tt.expectError { require.NotNil(t, resp.Error) assert.Equal(t, tt.errorCode, resp.Error.Code) assert.Equal(t, tt.errorMessage, resp.Error.Message) } else { assert.Nil(t, resp.Error) assert.NotNil(t, resp.Result) } }) } } // TestHandleToolsCall_InvalidParams tests tools/call with invalid params. func TestHandleToolsCall_InvalidParams(t *testing.T) { t.Parallel() server := NewServer(nil, "", "", "1.0.0") ctx := context.Background() req := &Request{ JSONRPC: "2.0", ID: 1, Method: "tools/call", Params: json.RawMessage(`invalid json`), } resp := server.handleToolsCall(ctx, req) require.NotNil(t, resp.Error) assert.Equal(t, -32602, resp.Error.Code) assert.Equal(t, "Invalid params", resp.Error.Message) } // TestCallTool_UnknownTool tests callTool with unknown tool name. func TestCallTool_UnknownTool(t *testing.T) { t.Parallel() server := NewServer(nil, "", "", "1.0.0") ctx := context.Background() _, err := server.callTool(ctx, "nonexistent_tool", json.RawMessage(`{}`)) require.Error(t, err) assert.Contains(t, err.Error(), "unknown tool") } // TestCallTool_InvalidArgs tests callTool with invalid arguments. func TestCallTool_InvalidArgs(t *testing.T) { t.Parallel() server := NewServer(nil, "", "", "1.0.0") ctx := context.Background() // Invalid JSON is best-effort parsed; the call fails because worker is unavailable _, err := server.callTool(ctx, "search", json.RawMessage(`invalid json`)) require.Error(t, err) assert.Contains(t, err.Error(), "worker unavailable") } // ============================================================================= // TESTS FOR Server I/O // ============================================================================= // TestSendResponse tests response sending. func TestSendResponse(t *testing.T) { t.Parallel() var buf bytes.Buffer server := &Server{ stdout: &buf, } resp := &Response{ JSONRPC: "2.0", ID: 1, Result: map[string]string{"status": "ok"}, } server.sendResponse(resp) output := buf.String() assert.Contains(t, output, `"jsonrpc":"2.0"`) assert.Contains(t, output, `"id":1`) assert.Contains(t, output, `"result"`) } // TestSendError tests error response sending. func TestSendError(t *testing.T) { t.Parallel() var buf bytes.Buffer server := &Server{ stdout: &buf, } server.sendError(1, -32700, "Parse error", "details") output := buf.String() assert.Contains(t, output, `"error"`) assert.Contains(t, output, `-32700`) assert.Contains(t, output, `"Parse error"`) } // TestRun_ParseError tests Run with invalid JSON input. func TestRun_ParseError(t *testing.T) { t.Parallel() var stdout bytes.Buffer stdin := strings.NewReader("invalid json\n") server := &Server{ stdin: stdin, stdout: &stdout, } err := server.Run(context.Background()) require.NoError(t, err) output := stdout.String() assert.Contains(t, output, `"error"`) assert.Contains(t, output, `-32700`) assert.Contains(t, output, `"Parse error"`) } // TestRun_EmptyLine tests Run skips empty lines. func TestRun_EmptyLine(t *testing.T) { t.Parallel() var stdout bytes.Buffer stdin := strings.NewReader("\n\n") server := &Server{ stdin: stdin, stdout: &stdout, } err := server.Run(context.Background()) require.NoError(t, err) // Should be empty - no responses for empty lines assert.Empty(t, stdout.String()) } // TestRun_ValidRequest tests Run with a valid request. func TestRun_ValidRequest(t *testing.T) { t.Parallel() var stdout bytes.Buffer req := `{"jsonrpc":"2.0","id":1,"method":"initialize"}` stdin := strings.NewReader(req + "\n") server := &Server{ stdin: stdin, stdout: &stdout, version: "1.0.0", } err := server.Run(context.Background()) require.NoError(t, err) output := stdout.String() assert.Contains(t, output, `"jsonrpc":"2.0"`) assert.Contains(t, output, `"result"`) assert.Contains(t, output, `"protocolVersion"`) } // TestRun_MultipleRequests tests Run with multiple sequential requests. func TestRun_MultipleRequests(t *testing.T) { t.Parallel() var stdout bytes.Buffer req1 := `{"jsonrpc":"2.0","id":1,"method":"initialize"}` req2 := `{"jsonrpc":"2.0","id":2,"method":"tools/list"}` stdin := strings.NewReader(req1 + "\n" + req2 + "\n") server := &Server{ stdin: stdin, stdout: &stdout, version: "1.0.0", } err := server.Run(context.Background()) require.NoError(t, err) output := stdout.String() // Should contain responses for both requests assert.Contains(t, output, `"id":1`) assert.Contains(t, output, `"id":2`) } // TestRunMixedRequests tests Run with mixed valid and invalid requests. func TestRunMixedRequests(t *testing.T) { t.Parallel() var stdout bytes.Buffer req1 := `{"jsonrpc":"2.0","id":1,"method":"initialize"}` req2 := `invalid json` req3 := `{"jsonrpc":"2.0","id":3,"method":"tools/list"}` stdin := strings.NewReader(req1 + "\n" + req2 + "\n" + req3 + "\n") server := &Server{ stdin: stdin, stdout: &stdout, version: "1.0.0", } err := server.Run(context.Background()) require.NoError(t, err) output := stdout.String() // Should have responses for all three requests assert.Contains(t, output, `"id":1`) assert.Contains(t, output, `"error"`) // Parse error for invalid json assert.Contains(t, output, `"id":3`) } // ============================================================================= // TESTS FOR Handler Parameter Validation // ============================================================================= // TestHandleFindRelatedObservations_Validation tests parameter validation. func TestHandleFindRelatedObservations_Validation(t *testing.T) { t.Parallel() server := NewServer(nil, "", "", "1.0.0") ctx := context.Background() tests := []struct { name string args string errContains string wantErr bool }{ { name: "worker unavailable", args: `{"id": 1}`, wantErr: true, errContains: "worker unavailable", }, { name: "missing id", args: `{}`, wantErr: true, errContains: "id is required", }, { name: "invalid json", args: `{invalid`, wantErr: true, errContains: "invalid arguments", }, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { t.Parallel() _, err := server.handleFindRelatedProxy(ctx, json.RawMessage(tt.args)) if tt.wantErr { require.Error(t, err) if tt.errContains != "" { assert.Contains(t, err.Error(), tt.errContains) } } else { require.NoError(t, err) } }) } } // TestHandleFindSimilarObservations_Validation tests parameter validation. func TestHandleFindSimilarObservations_Validation(t *testing.T) { t.Parallel() server := NewServer(nil, "", "", "1.0.0") ctx := context.Background() tests := []struct { name string args string errContains string wantErr bool }{ { name: "missing query", args: `{}`, wantErr: true, errContains: "query is required", }, { name: "invalid json", args: `{invalid`, wantErr: true, errContains: "invalid arguments", }, { name: "nil vector client", args: `{"query": "test"}`, wantErr: true, errContains: "worker unavailable", }, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { t.Parallel() _, err := server.handleFindSimilarProxy(ctx, json.RawMessage(tt.args)) if tt.wantErr { require.Error(t, err) if tt.errContains != "" { assert.Contains(t, err.Error(), tt.errContains) } } else { require.NoError(t, err) } }) } } // TestHandleGetPatterns_Validation tests parameter validation. func TestHandleGetPatterns_Validation(t *testing.T) { t.Parallel() server := NewServer(nil, "", "", "1.0.0") ctx := context.Background() tests := []struct { name string args string errContains string wantErr bool }{ { name: "invalid json", args: `{invalid`, wantErr: true, errContains: "invalid arguments", }, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { t.Parallel() _, err := server.handleGetPatternsProxy(ctx, json.RawMessage(tt.args)) if tt.wantErr { require.Error(t, err) if tt.errContains != "" { assert.Contains(t, err.Error(), tt.errContains) } } else { require.NoError(t, err) } }) } } // TestHandleBulkDeleteObservations_Validation tests parameter validation. func TestHandleBulkDeleteObservations_Validation(t *testing.T) { t.Parallel() server := NewServer(nil, "", "", "1.0.0") ctx := context.Background() tests := []struct { name string args string errContains string wantErr bool }{ { name: "missing ids", args: `{}`, wantErr: true, errContains: "ids is required", }, { name: "empty ids array", args: `{"ids": []}`, wantErr: true, errContains: "ids is required", }, { name: "too many ids", args: `{"ids": [` + strings.Repeat("1,", 1001) + `1]}`, wantErr: true, errContains: "worker unavailable", }, { name: "invalid json", args: `{invalid`, wantErr: true, errContains: "invalid arguments", }, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { t.Parallel() _, err := server.handleBulkStatusProxy(ctx, json.RawMessage(tt.args), "delete") if tt.wantErr { require.Error(t, err) if tt.errContains != "" { assert.Contains(t, err.Error(), tt.errContains) } } else { require.NoError(t, err) } }) } } // TestHandleBulkMarkSuperseded_Validation tests parameter validation. func TestHandleBulkMarkSuperseded_Validation(t *testing.T) { t.Parallel() server := NewServer(nil, "", "", "1.0.0") ctx := context.Background() tests := []struct { name string args string errContains string wantErr bool }{ { name: "missing ids", args: `{}`, wantErr: true, errContains: "ids is required", }, { name: "empty ids array", args: `{"ids": []}`, wantErr: true, errContains: "ids is required", }, { name: "too many ids", args: `{"ids": [` + strings.Repeat("1,", 1001) + `1]}`, wantErr: true, errContains: "worker unavailable", }, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { t.Parallel() _, err := server.handleBulkStatusProxy(ctx, json.RawMessage(tt.args), "supersede") if tt.wantErr { require.Error(t, err) if tt.errContains != "" { assert.Contains(t, err.Error(), tt.errContains) } } else { require.NoError(t, err) } }) } } // TestHandleBulkBoostObservations_Validation tests parameter validation. func TestHandleBulkBoostObservations_Validation(t *testing.T) { t.Parallel() server := NewServer(nil, "", "", "1.0.0") ctx := context.Background() tests := []struct { name string args string errContains string wantErr bool }{ { name: "missing ids", args: `{"boost": 0.1}`, wantErr: true, errContains: "ids is required", }, { name: "boost out of range low", args: `{"ids": [1], "boost": -1.5}`, wantErr: true, errContains: "worker unavailable", }, { name: "boost out of range high", args: `{"ids": [1], "boost": 1.5}`, wantErr: true, errContains: "worker unavailable", }, { name: "too many ids", args: `{"ids": [` + strings.Repeat("1,", 1001) + `1], "boost": 0.1}`, wantErr: true, errContains: "worker unavailable", }, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { t.Parallel() _, err := server.handleBulkStatusProxy(ctx, json.RawMessage(tt.args), "boost") if tt.wantErr { require.Error(t, err) if tt.errContains != "" { assert.Contains(t, err.Error(), tt.errContains) } } else { require.NoError(t, err) } }) } } // TestHandleTriggerMaintenance_Validation tests that nil service returns error. func TestHandleTriggerMaintenance_Validation(t *testing.T) { t.Parallel() server := NewServer(nil, "", "", "1.0.0") ctx := context.Background() _, err := server.proxyPostRaw(ctx, "/api/scoring/recalculate", nil) require.Error(t, err) assert.Contains(t, err.Error(), "worker unavailable") } // TestHandleGetMaintenanceStats_Validation tests that nil service returns error. func TestHandleGetMaintenanceStats_Validation(t *testing.T) { t.Parallel() server := NewServer(nil, "", "", "1.0.0") ctx := context.Background() _, err := server.proxyGetRaw(ctx, "/api/stats", map[string]string{"project": ""}) require.Error(t, err) assert.Contains(t, err.Error(), "worker unavailable") } // TestHandleMergeObservations_Validation tests parameter validation. func TestHandleMergeObservations_Validation(t *testing.T) { t.Parallel() server := NewServer(nil, "", "", "1.0.0") ctx := context.Background() tests := []struct { name string args string errContains string wantErr bool }{ { name: "missing source_id", args: `{"target_id": 2}`, wantErr: true, errContains: "source_id and target_id are required", }, { name: "missing target_id", args: `{"source_id": 1}`, wantErr: true, errContains: "source_id and target_id are required", }, { name: "same source and target", args: `{"source_id": 1, "target_id": 1}`, wantErr: true, errContains: "worker unavailable", }, { name: "worker unavailable with boost", args: `{"source_id": 1, "target_id": 2, "boost": 0.6}`, wantErr: true, errContains: "worker unavailable", }, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { t.Parallel() _, err := server.handleMergeProxy(ctx, json.RawMessage(tt.args)) if tt.wantErr { require.Error(t, err) if tt.errContains != "" { assert.Contains(t, err.Error(), tt.errContains) } } else { require.NoError(t, err) } }) } } // TestHandleGetObservation_Validation tests parameter validation. func TestHandleGetObservation_Validation(t *testing.T) { t.Parallel() server := NewServer(nil, "", "", "1.0.0") ctx := context.Background() tests := []struct { name string args string errContains string wantErr bool }{ { name: "missing id", args: `{}`, wantErr: true, errContains: "id is required", }, { name: "invalid json", args: `{invalid`, wantErr: true, errContains: "invalid arguments", }, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { t.Parallel() _, err := server.handleGetObservationProxy(ctx, json.RawMessage(tt.args)) if tt.wantErr { require.Error(t, err) if tt.errContains != "" { assert.Contains(t, err.Error(), tt.errContains) } } else { require.NoError(t, err) } }) } } // TestHandleEditObservation_Validation tests parameter validation. func TestHandleEditObservation_Validation(t *testing.T) { t.Parallel() server := NewServer(nil, "", "", "1.0.0") ctx := context.Background() tests := []struct { name string args string errContains string wantErr bool }{ { name: "missing id", args: `{"title": "new title"}`, wantErr: true, errContains: "id is required", }, { name: "invalid scope", args: `{"id": 1, "scope": "invalid"}`, wantErr: true, errContains: "worker unavailable", }, { name: "invalid json", args: `{invalid`, wantErr: true, errContains: "invalid arguments", }, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { t.Parallel() _, err := server.handleEditObservationProxy(ctx, json.RawMessage(tt.args)) if tt.wantErr { require.Error(t, err) if tt.errContains != "" { assert.Contains(t, err.Error(), tt.errContains) } } else { require.NoError(t, err) } }) } } // TestHandleGetObservationQuality_Validation tests parameter validation. func TestHandleGetObservationQuality_Validation(t *testing.T) { t.Parallel() server := NewServer(nil, "", "", "1.0.0") ctx := context.Background() tests := []struct { name string args string errContains string wantErr bool }{ { name: "missing id", args: `{}`, wantErr: true, errContains: "id is required", }, { name: "invalid json", args: `{invalid`, wantErr: true, errContains: "invalid arguments", }, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { t.Parallel() _, err := server.handleGetObservationQualityProxy(ctx, json.RawMessage(tt.args)) if tt.wantErr { require.Error(t, err) if tt.errContains != "" { assert.Contains(t, err.Error(), tt.errContains) } } else { require.NoError(t, err) } }) } } // TestHandleSuggestConsolidations_Validation tests parameter validation. func TestHandleSuggestConsolidations_Validation(t *testing.T) { t.Parallel() server := NewServer(nil, "", "", "1.0.0") ctx := context.Background() tests := []struct { name string args string errContains string wantErr bool }{ { name: "min_similarity too low", args: `{"min_similarity": 0.3}`, wantErr: true, errContains: "worker unavailable", }, { name: "min_similarity too high", args: `{"min_similarity": 1.5}`, wantErr: true, errContains: "worker unavailable", }, { name: "invalid json", args: `{invalid`, wantErr: true, errContains: "invalid arguments", }, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { t.Parallel() _, err := server.handleSuggestConsolidationsProxy(ctx, json.RawMessage(tt.args)) if tt.wantErr { require.Error(t, err) if tt.errContains != "" { assert.Contains(t, err.Error(), tt.errContains) } } else { require.NoError(t, err) } }) } } // TestHandleTagObservation_Validation tests parameter validation. func TestHandleTagObservation_Validation(t *testing.T) { t.Parallel() server := NewServer(nil, "", "", "1.0.0") ctx := context.Background() tests := []struct { name string args string errContains string wantErr bool }{ { name: "missing id", args: `{"tags": ["tag1"]}`, wantErr: true, errContains: "id is required", }, { name: "missing tags", args: `{"id": 1}`, wantErr: true, errContains: "tags is required", }, { name: "invalid mode", args: `{"id": 1, "tags": ["tag1"], "mode": "invalid"}`, wantErr: true, errContains: "mode must be 'add', 'remove', or 'set'", }, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { t.Parallel() _, err := server.handleTagObservationProxy(ctx, json.RawMessage(tt.args)) if tt.wantErr { require.Error(t, err) if tt.errContains != "" { assert.Contains(t, err.Error(), tt.errContains) } } else { require.NoError(t, err) } }) } } // TestHandleGetObservationsByTag_Validation tests parameter validation. func TestHandleGetObservationsByTag_Validation(t *testing.T) { t.Parallel() server := NewServer(nil, "", "", "1.0.0") ctx := context.Background() tests := []struct { name string args string errContains string wantErr bool }{ { name: "missing tag", args: `{}`, wantErr: true, errContains: "tag is required", }, { name: "invalid json", args: `{invalid`, wantErr: true, errContains: "invalid arguments", }, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { t.Parallel() _, err := server.handleGetObservationsByTagProxy(ctx, json.RawMessage(tt.args)) if tt.wantErr { require.Error(t, err) if tt.errContains != "" { assert.Contains(t, err.Error(), tt.errContains) } } else { require.NoError(t, err) } }) } } // TestHandleBatchTagByPattern_Validation tests parameter validation. func TestHandleBatchTagByPattern_Validation(t *testing.T) { t.Parallel() server := NewServer(nil, "", "", "1.0.0") ctx := context.Background() tests := []struct { name string args string errContains string wantErr bool }{ { name: "missing pattern", args: `{"tags": ["tag1"]}`, wantErr: true, errContains: "pattern is required", }, { name: "missing tags", args: `{"pattern": "test"}`, wantErr: true, errContains: "tags is required", }, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { t.Parallel() _, err := server.handleBatchTagProxy(ctx, json.RawMessage(tt.args)) if tt.wantErr { require.Error(t, err) if tt.errContains != "" { assert.Contains(t, err.Error(), tt.errContains) } } else { require.NoError(t, err) } }) } } // TestHandleExplainSearchRanking_Validation tests parameter validation. func TestHandleExplainSearchRanking_Validation(t *testing.T) { t.Parallel() server := NewServer(nil, "", "", "1.0.0") ctx := context.Background() tests := []struct { name string args string errContains string wantErr bool }{ { name: "missing query", args: `{"top_n": 5}`, wantErr: true, errContains: "query is required", }, { name: "invalid json", args: `{invalid`, wantErr: true, errContains: "invalid arguments", }, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { t.Parallel() _, err := server.handleExplainSearchProxy(ctx, json.RawMessage(tt.args)) if tt.wantErr { require.Error(t, err) if tt.errContains != "" { assert.Contains(t, err.Error(), tt.errContains) } } else { require.NoError(t, err) } }) } } // TestHandleGetObservationRelationships_Validation tests parameter validation. func TestHandleGetObservationRelationships_Validation(t *testing.T) { t.Parallel() server := NewServer(nil, "", "", "1.0.0") ctx := context.Background() tests := []struct { name string args string errContains string wantErr bool }{ { name: "missing id", args: `{}`, wantErr: true, errContains: "id is required", }, { name: "negative id", args: `{"id": -1}`, wantErr: true, errContains: "id is required and must be positive", }, { name: "nil relation store", args: `{"id": 1}`, wantErr: true, errContains: "worker unavailable", }, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { t.Parallel() _, err := server.handleGetRelationshipsProxy(ctx, json.RawMessage(tt.args)) if tt.wantErr { require.Error(t, err) if tt.errContains != "" { assert.Contains(t, err.Error(), tt.errContains) } } else { require.NoError(t, err) } }) } } // TestHandleGetObservationScoringBreakdown_Validation tests parameter validation. func TestHandleGetObservationScoringBreakdown_Validation(t *testing.T) { t.Parallel() server := NewServer(nil, "", "", "1.0.0") ctx := context.Background() tests := []struct { name string args string errContains string wantErr bool }{ { name: "missing id", args: `{}`, wantErr: true, errContains: "id is required", }, { name: "negative id", args: `{"id": -1}`, wantErr: true, errContains: "id is required and must be positive", }, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { t.Parallel() _, err := server.handleGetScoringBreakdownProxy(ctx, json.RawMessage(tt.args)) if tt.wantErr { require.Error(t, err) if tt.errContains != "" { assert.Contains(t, err.Error(), tt.errContains) } } else { require.NoError(t, err) } }) } } // TestHandleTimeline_InvalidJSON tests timeline with invalid JSON. func TestHandleTimeline_InvalidJSON(t *testing.T) { t.Parallel() server := NewServer(nil, "", "", "1.0.0") ctx := context.Background() _, err := server.handleTimelineProxy(ctx, json.RawMessage(`{invalid`)) require.Error(t, err) assert.Contains(t, err.Error(), "invalid timeline params") } // TestHandleTimelineByQuery_EmptyQuery tests timeline by query with empty query. func TestHandleTimelineByQuery_EmptyQuery(t *testing.T) { t.Parallel() server := NewServer(nil, "", "", "1.0.0") ctx := context.Background() // Empty query should return empty results (no anchor found) result, err := server.handleTimelineProxy(ctx, json.RawMessage(`{}`)) require.NoError(t, err) assert.Contains(t, result, `"observations":[]`) } // TestHandleTimelineByQuery_InvalidJSON tests timeline by query with invalid JSON. func TestHandleTimelineByQuery_InvalidJSON(t *testing.T) { t.Parallel() server := NewServer(nil, "", "", "1.0.0") ctx := context.Background() _, err := server.handleTimelineProxy(ctx, json.RawMessage(`{invalid`)) require.Error(t, err) assert.Contains(t, err.Error(), "invalid timeline params") } // TestHandleTimeline_NoAnchorNoQuery tests timeline with no anchor and no query. func TestHandleTimeline_NoAnchorNoQuery(t *testing.T) { t.Parallel() server := NewServer(nil, "", "", "1.0.0") ctx := context.Background() // No anchor_id and no query should return empty result JSON result, err := server.handleTimelineProxy(ctx, json.RawMessage(`{}`)) require.NoError(t, err) assert.NotEmpty(t, result) assert.Contains(t, result, `"observations":[]`) } // TestHandleTimeline_WithDefaults tests timeline default values are applied. func TestHandleTimeline_WithDefaults(t *testing.T) { t.Parallel() server := NewServer(nil, "", "", "1.0.0") ctx := context.Background() // With anchor_id = 0, should return empty result JSON result, err := server.handleTimelineProxy(ctx, json.RawMessage(`{"anchor_id": 0}`)) require.NoError(t, err) assert.NotEmpty(t, result) assert.Contains(t, result, `"observations":[]`) } // ============================================================================= // TESTS FOR Additional Tool Operations // ============================================================================= // TestJSONRPCErrorCodes tests standard JSON-RPC error codes. func TestJSONRPCErrorCodes(t *testing.T) { t.Parallel() errorCodes := map[string]int{ "Parse error": -32700, "Invalid Request": -32600, "Method not found": -32601, "Invalid params": -32602, "Internal error": -32603, } for msg, code := range errorCodes { t.Run(msg, func(t *testing.T) { t.Parallel() err := Error{Code: code, Message: msg} assert.Equal(t, code, err.Code) assert.Equal(t, msg, err.Message) }) } } // TestToolListContainsExpectedSchemas tests that tool schemas are valid. func TestToolListContainsExpectedSchemas(t *testing.T) { t.Parallel() server := NewServer(nil, "", "", "1.0.0") req := &Request{ JSONRPC: "2.0", ID: 1, Method: "tools/list", } resp := server.handleToolsList(req) result := resp.Result.(map[string]any) tools := result["tools"].([]Tool) for _, tool := range tools { assert.NotEmpty(t, tool.Name) assert.NotEmpty(t, tool.Description) assert.NotNil(t, tool.InputSchema) // Check schema has type schema := tool.InputSchema _, hasType := schema["type"] assert.True(t, hasType, "tool %s schema should have type", tool.Name) } } // TestHandleToolsCall_UnknownTool tests tools/call with unknown tool name. // After Fix #5: tool errors use Result with isError, not top-level Error. func TestHandleToolsCall_UnknownTool(t *testing.T) { t.Parallel() server := NewServer(nil, "", "", "1.0.0") ctx := context.Background() req := &Request{ JSONRPC: "2.0", ID: 1, Method: "tools/call", Params: json.RawMessage(`{"name":"unknown_tool","arguments":{}}`), } resp := server.handleToolsCall(ctx, req) assert.Nil(t, resp.Error, "tool errors must not use top-level Error (MCP spec)") require.NotNil(t, resp.Result) result, ok := resp.Result.(map[string]any) require.True(t, ok) assert.Equal(t, true, result["isError"]) content := result["content"].([]map[string]any) assert.Contains(t, content[0]["text"], "unknown tool") } // TestCallTool_ToolNameRecognition tests that valid tool names are recognized. func TestCallTool_ToolNameRecognition(t *testing.T) { t.Parallel() server := NewServer(nil, "", "", "1.0.0") req := &Request{ JSONRPC: "2.0", ID: 1, Method: "tools/list", } resp := server.handleToolsList(req) result := resp.Result.(map[string]any) tools := result["tools"].([]Tool) // Verify all expected tools are registered expectedTools := map[string]bool{ "search": true, "timeline": true, "decisions": true, "changes": true, "how_it_works": true, "find_by_concept": true, "find_by_file": true, "find_by_type": true, "get_recent_context": true, "get_context_timeline": true, "get_timeline_by_query": true, "find_related_observations": true, "find_similar_observations": true, "get_patterns": true, "get_memory_stats": true, "bulk_delete_observations": true, "bulk_mark_superseded": true, "bulk_boost_observations": true, "trigger_maintenance": true, "get_maintenance_stats": true, "merge_observations": true, "get_observation": true, "edit_observation": true, "get_observation_quality": true, "suggest_consolidations": true, "tag_observation": true, "get_observations_by_tag": true, "get_temporal_trends": true, "get_data_quality_report": true, "batch_tag_by_pattern": true, "explain_search_ranking": true, "export_observations": true, "check_system_health": true, "analyze_search_patterns": true, "get_observation_relationships": true, "get_observation_scoring_breakdown": true, "analyze_observation_importance": true, } foundTools := make(map[string]bool) for _, tool := range tools { foundTools[tool.Name] = true } for name := range expectedTools { assert.True(t, foundTools[name], "tool %s should be registered", name) } } // TestTimelineParamsStruct_Complete tests complete timelineParams parsing. func TestTimelineParamsStruct_Complete(t *testing.T) { t.Parallel() input := `{ "anchor_id": 100, "query": "test query", "before": 5, "after": 15, "project": "my-project", "obs_type": "bugfix", "concepts": "security,auth", "files": "main.go,handler.go", "dateStart": 1700000000000, "dateEnd": 1700100000000, "format": "full" }` var params timelineParams err := json.Unmarshal([]byte(input), ¶ms) require.NoError(t, err) assert.Equal(t, int64(100), params.AnchorID) assert.Equal(t, "test query", params.Query) assert.Equal(t, 5, params.Before) assert.Equal(t, 15, params.After) assert.Equal(t, "my-project", params.Project) assert.Equal(t, "bugfix", params.ObsType) assert.Equal(t, "security,auth", params.Concepts) assert.Equal(t, "main.go,handler.go", params.Files) assert.Equal(t, int64(1700000000000), params.DateStart) assert.Equal(t, int64(1700100000000), params.DateEnd) assert.Equal(t, "full", params.Format) } // TestServerStdinStdoutConfig tests that server stdin/stdout can be configured. func TestServerStdinStdoutConfig(t *testing.T) { t.Parallel() var stdout bytes.Buffer var stdin bytes.Buffer server := &Server{ stdin: &stdin, stdout: &stdout, version: "test-version", } assert.Equal(t, &stdin, server.stdin) assert.Equal(t, &stdout, server.stdout) assert.Equal(t, "test-version", server.version) } // TestResponseIDTypes tests that response IDs can be various types. func TestResponseIDTypes(t *testing.T) { t.Parallel() tests := []struct { id any name string }{ {name: "integer id", id: 1}, {name: "string id", id: "abc-123"}, {name: "float id", id: 1.5}, {name: "null id", id: nil}, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { t.Parallel() var buf bytes.Buffer server := &Server{stdout: &buf} resp := &Response{ JSONRPC: "2.0", ID: tt.id, Result: "ok", } server.sendResponse(resp) output := buf.String() assert.Contains(t, output, `"jsonrpc":"2.0"`) }) } } // TestServerFields tests Server struct fields. func TestServerFields(t *testing.T) { t.Parallel() server := NewServer(nil, "http://localhost:37777", "test", "2.0.0") assert.Equal(t, "2.0.0", server.version) assert.Nil(t, server.client) assert.Equal(t, "http://localhost:37777", server.workerURL) assert.NotNil(t, server.stdin) assert.NotNil(t, server.stdout) } // TestRequestUnmarshalWithNullID tests Request unmarshaling with null ID. func TestRequestUnmarshalWithNullID(t *testing.T) { t.Parallel() input := `{"jsonrpc":"2.0","id":null,"method":"initialize"}` var req Request err := json.Unmarshal([]byte(input), &req) require.NoError(t, err) assert.Equal(t, "2.0", req.JSONRPC) assert.Nil(t, req.ID) assert.Equal(t, "initialize", req.Method) } // TestResponseWithNullError tests Response without error. func TestResponseWithNullError(t *testing.T) { t.Parallel() resp := Response{ JSONRPC: "2.0", ID: 1, Result: "success", Error: nil, } data, err := json.Marshal(resp) require.NoError(t, err) assert.Contains(t, string(data), `"result":"success"`) assert.NotContains(t, string(data), `"error"`) } // TestErrorWithNilData tests Error without data. func TestErrorWithNilData(t *testing.T) { t.Parallel() err := Error{ Code: -32600, Message: "Invalid Request", Data: nil, } data, errMarshal := json.Marshal(err) require.NoError(t, errMarshal) assert.Contains(t, string(data), `"code":-32600`) assert.Contains(t, string(data), `"message":"Invalid Request"`) assert.NotContains(t, string(data), `"data"`) } // TestToolInputSchema tests that tool input schemas have required fields. func TestToolInputSchema(t *testing.T) { t.Parallel() server := NewServer(nil, "", "", "1.0.0") req := &Request{ JSONRPC: "2.0", ID: 1, Method: "tools/list", } resp := server.handleToolsList(req) result := resp.Result.(map[string]any) tools := result["tools"].([]Tool) for _, tool := range tools { schema := tool.InputSchema schemaType, ok := schema["type"] assert.True(t, ok, "tool %s schema should have type", tool.Name) assert.Equal(t, "object", schemaType, "tool %s schema type should be object", tool.Name) // All tools should have properties _, hasProperties := schema["properties"] assert.True(t, hasProperties, "tool %s should have properties", tool.Name) } } // TestCallTool_UnknownToolName tests callTool with various unknown tool names. func TestCallTool_UnknownToolName(t *testing.T) { t.Parallel() server := NewServer(nil, "", "", "1.0.0") ctx := context.Background() unknownTools := []string{ "invalid_tool", "nonexistent", "search_v2", "timeline_special", } for _, name := range unknownTools { t.Run(name, func(t *testing.T) { t.Parallel() result, err := server.callTool(ctx, name, json.RawMessage(`{}`)) assert.Error(t, err) assert.Empty(t, result) assert.Contains(t, err.Error(), "unknown tool") }) } } // TestTimelineParamsStruct_Validation tests timelineParams struct field validation. func TestTimelineParamsStruct_Validation(t *testing.T) { t.Parallel() tests := []struct { name string json string wantOK bool }{ {"valid with anchor_id", `{"anchor_id":123,"before":5,"after":5}`, true}, {"valid with query only", `{"query":"test query"}`, true}, {"empty params", `{}`, true}, {"with all fields", `{"anchor_id":1,"query":"test","before":10,"after":10,"project":"proj","obs_type":"bugfix","format":"full"}`, true}, {"invalid json", `{invalid`, false}, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { t.Parallel() var params timelineParams err := json.Unmarshal([]byte(tt.json), ¶ms) if tt.wantOK { assert.NoError(t, err) } else { assert.Error(t, err) } }) } } // TestHandleToolsCall_EmptyParams tests tools/call with empty params. // After Fix #5: tool errors use Result with isError, not top-level Error. func TestHandleToolsCall_EmptyParams(t *testing.T) { t.Parallel() server := NewServer(nil, "", "", "1.0.0") ctx := context.Background() req := &Request{ JSONRPC: "2.0", ID: 1, Method: "tools/call", Params: json.RawMessage(`{}`), } resp := server.handleToolsCall(ctx, req) // Empty name goes through callTool default branch -> isError assert.Nil(t, resp.Error, "tool errors must use isError in Result") require.NotNil(t, resp.Result) result, ok := resp.Result.(map[string]any) require.True(t, ok) assert.Equal(t, true, result["isError"]) } // TestSendResponse_WithError tests sendResponse with an error response. func TestSendResponse_WithError(t *testing.T) { t.Parallel() var buf bytes.Buffer server := &Server{stdout: &buf} resp := &Response{ JSONRPC: "2.0", ID: 1, Error: &Error{Code: -32600, Message: "Invalid Request"}, } server.sendResponse(resp) output := buf.String() assert.Contains(t, output, `"error"`) assert.Contains(t, output, `-32600`) } // TestSendResponse_NilID tests sendResponse with nil ID. func TestSendResponse_NilID(t *testing.T) { t.Parallel() var buf bytes.Buffer server := &Server{stdout: &buf} resp := &Response{ JSONRPC: "2.0", ID: nil, Result: "notification response", } server.sendResponse(resp) output := buf.String() assert.Contains(t, output, `"id":null`) } // TestToolCallParamsWithComplexArgs tests ToolCallParams with complex arguments. func TestToolCallParamsWithComplexArgs(t *testing.T) { t.Parallel() input := `{ "name": "search", "arguments": { "query": "authentication bug", "project": "my-project", "limit": 10, "type": "observations" } }` var params ToolCallParams err := json.Unmarshal([]byte(input), ¶ms) require.NoError(t, err) assert.Equal(t, "search", params.Name) assert.NotEmpty(t, params.Arguments) } // TestHandleToolsCall_UnknownToolNameError tests tools/call with unknown tool returns error. // After Fix #5: tool errors use Result with isError, not top-level Error. func TestHandleToolsCall_UnknownToolNameError(t *testing.T) { t.Parallel() server := NewServer(nil, "", "", "1.0.0") ctx := context.Background() req := &Request{ JSONRPC: "2.0", ID: 1, Method: "tools/call", Params: json.RawMessage(`{"name":"very_unknown_tool_name","arguments":{}}`), } resp := server.handleToolsCall(ctx, req) assert.Equal(t, "2.0", resp.JSONRPC) assert.Equal(t, 1, resp.ID) assert.Nil(t, resp.Error, "tool errors must use isError in Result, not top-level Error") require.NotNil(t, resp.Result) result, ok := resp.Result.(map[string]any) require.True(t, ok) assert.Equal(t, true, result["isError"]) } // ============================================================================= // TESTS FOR Handler Defaults // ============================================================================= // TestHandleTimeline_Defaults tests timeline default values. func TestHandleTimeline_Defaults(t *testing.T) { t.Parallel() // Test that handleTimeline sets default before/after values params := timelineParams{ AnchorID: 0, Query: "", Before: 0, After: 0, } // Simulate the default value assignment from handleTimeline if params.Before <= 0 { params.Before = 10 } if params.After <= 0 { params.After = 10 } assert.Equal(t, 10, params.Before) assert.Equal(t, 10, params.After) } // TestHandleGetTemporalTrends_Validation tests parameter validation. func TestHandleGetTemporalTrends_Validation(t *testing.T) { t.Parallel() server := NewServer(nil, "", "", "1.0.0") ctx := context.Background() tests := []struct { name string args string errContains string wantErr bool }{ { name: "invalid json", args: `{invalid`, wantErr: true, errContains: "invalid arguments", }, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { t.Parallel() _, err := server.handleGetTemporalTrendsProxy(ctx, json.RawMessage(tt.args)) if tt.wantErr { require.Error(t, err) if tt.errContains != "" { assert.Contains(t, err.Error(), tt.errContains) } } }) } } // TestHandleGetDataQualityReport_Validation tests parameter validation. func TestHandleGetDataQualityReport_Validation(t *testing.T) { t.Parallel() server := NewServer(nil, "", "", "1.0.0") ctx := context.Background() tests := []struct { name string args string errContains string wantErr bool }{ { name: "invalid json", args: `{invalid`, wantErr: true, errContains: "invalid arguments", }, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { t.Parallel() _, err := server.handleGetDataQualityProxy(ctx, json.RawMessage(tt.args)) if tt.wantErr { require.Error(t, err) if tt.errContains != "" { assert.Contains(t, err.Error(), tt.errContains) } } }) } } // TestHandleExportObservations_Validation tests parameter validation. func TestHandleExportObservations_Validation(t *testing.T) { t.Parallel() server := NewServer(nil, "", "", "1.0.0") ctx := context.Background() tests := []struct { name string args string errContains string wantErr bool }{ { name: "invalid json", args: `{invalid`, wantErr: true, errContains: "invalid arguments", }, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { t.Parallel() _, err := server.handleExportProxy(ctx, json.RawMessage(tt.args)) if tt.wantErr { require.Error(t, err) if tt.errContains != "" { assert.Contains(t, err.Error(), tt.errContains) } } }) } } // TestHandleAnalyzeSearchPatterns_Validation tests parameter validation. func TestHandleAnalyzeSearchPatterns_Validation(t *testing.T) { t.Parallel() server := NewServer(nil, "", "", "1.0.0") ctx := context.Background() tests := []struct { name string args string errContains string wantErr bool }{ { name: "worker unavailable", args: `{}`, wantErr: true, errContains: "worker unavailable", }, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { t.Parallel() _, err := server.proxyGetRaw(ctx, "/api/search/analytics", nil) if tt.wantErr { require.Error(t, err) if tt.errContains != "" { assert.Contains(t, err.Error(), tt.errContains) } } }) } } // TestHandleAnalyzeObservationImportance_Validation tests parameter validation. func TestHandleAnalyzeObservationImportance_Validation(t *testing.T) { t.Parallel() server := NewServer(nil, "", "", "1.0.0") ctx := context.Background() tests := []struct { name string args string errContains string wantErr bool }{ { name: "invalid json", args: `{invalid`, wantErr: true, errContains: "invalid arguments", }, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { t.Parallel() _, err := server.handleAnalyzeImportanceProxy(ctx, json.RawMessage(tt.args)) if tt.wantErr { require.Error(t, err) if tt.errContains != "" { assert.Contains(t, err.Error(), tt.errContains) } } }) } } // TestHandleGetMemoryStats_NilStores tests GetMemoryStats with nil stores. func TestHandleGetMemoryStats_NilStores(t *testing.T) { t.Parallel() server := NewServer(nil, "", "", "1.0.0") ctx := context.Background() // Should return error when worker is unavailable (nil client) _, err := server.proxyGetRaw(ctx, "/api/stats", map[string]string{"project": ""}) require.Error(t, err) assert.Contains(t, err.Error(), "worker unavailable") } // TestHandleCheckSystemHealth_NilStores tests CheckSystemHealth with nil stores. func TestHandleCheckSystemHealth_NilStores(t *testing.T) { t.Parallel() server := NewServer(nil, "", "", "1.0.0") ctx := context.Background() // Should return error when worker is unavailable (nil client) _, err := server.proxyGetRaw(ctx, "/api/selfcheck", nil) require.Error(t, err) assert.Contains(t, err.Error(), "worker unavailable") } // ============================================================================= // COMPREHENSIVE callTool TESTS // ============================================================================= // TestCallTool_AllSpecialTools tests all special tool cases in callTool switch. func TestCallTool_AllSpecialTools(t *testing.T) { t.Parallel() server := NewServer(nil, "", "", "1.0.0") ctx := context.Background() // Tests for tools that can work without stores or have nil guards tests := []struct { name string toolName string args string wantErr bool checkPanic bool }{ // Tools that work with nil stores { name: "get_memory_stats", toolName: "get_memory_stats", args: `{}`, wantErr: true, // worker unavailable with nil client }, { name: "check_system_health", toolName: "check_system_health", args: `{}`, wantErr: true, // worker unavailable with nil client }, // Tools that need stores but have parameter validation first { name: "find_related_observations - invalid json", toolName: "find_related_observations", args: `{invalid`, wantErr: true, }, { name: "find_related_observations - missing id", toolName: "find_related_observations", args: `{}`, wantErr: true, }, { name: "find_similar_observations - invalid json", toolName: "find_similar_observations", args: `{invalid`, wantErr: true, }, { name: "find_similar_observations - missing query", toolName: "find_similar_observations", args: `{}`, wantErr: true, }, { name: "get_patterns - invalid json", toolName: "get_patterns", args: `{invalid`, wantErr: true, }, { name: "bulk_delete_observations - invalid json", toolName: "bulk_delete_observations", args: `{invalid`, wantErr: true, }, { name: "bulk_delete_observations - missing ids", toolName: "bulk_delete_observations", args: `{}`, wantErr: true, }, { name: "bulk_mark_superseded - invalid json", toolName: "bulk_mark_superseded", args: `{invalid`, wantErr: true, }, { name: "bulk_mark_superseded - missing ids", toolName: "bulk_mark_superseded", args: `{}`, wantErr: true, }, { name: "bulk_boost_observations - invalid json", toolName: "bulk_boost_observations", args: `{invalid`, wantErr: true, }, { name: "bulk_boost_observations - missing ids", toolName: "bulk_boost_observations", args: `{}`, wantErr: true, }, { name: "merge_observations - invalid json", toolName: "merge_observations", args: `{invalid`, wantErr: true, }, { name: "merge_observations - missing source_ids", toolName: "merge_observations", args: `{}`, wantErr: true, }, { name: "get_observation - invalid json", toolName: "get_observation", args: `{invalid`, wantErr: true, }, { name: "get_observation - missing id", toolName: "get_observation", args: `{}`, wantErr: true, }, { name: "edit_observation - invalid json", toolName: "edit_observation", args: `{invalid`, wantErr: true, }, { name: "edit_observation - missing id", toolName: "edit_observation", args: `{}`, wantErr: true, }, { name: "get_observation_quality - invalid json", toolName: "get_observation_quality", args: `{invalid`, wantErr: true, }, { name: "get_observation_quality - missing id", toolName: "get_observation_quality", args: `{}`, wantErr: true, }, { name: "suggest_consolidations - invalid json", toolName: "suggest_consolidations", args: `{invalid`, wantErr: true, }, { name: "tag_observation - invalid json", toolName: "tag_observation", args: `{invalid`, wantErr: true, }, { name: "tag_observation - missing id", toolName: "tag_observation", args: `{}`, wantErr: true, }, { name: "get_observations_by_tag - invalid json", toolName: "get_observations_by_tag", args: `{invalid`, wantErr: true, }, { name: "get_observations_by_tag - missing tag", toolName: "get_observations_by_tag", args: `{}`, wantErr: true, }, { name: "get_temporal_trends - invalid json", toolName: "get_temporal_trends", args: `{invalid`, wantErr: true, }, { name: "get_data_quality_report - invalid json", toolName: "get_data_quality_report", args: `{invalid`, wantErr: true, }, { name: "batch_tag_by_pattern - invalid json", toolName: "batch_tag_by_pattern", args: `{invalid`, wantErr: true, }, { name: "batch_tag_by_pattern - missing pattern", toolName: "batch_tag_by_pattern", args: `{}`, wantErr: true, }, { name: "explain_search_ranking - invalid json", toolName: "explain_search_ranking", args: `{invalid`, wantErr: true, }, { name: "explain_search_ranking - missing query", toolName: "explain_search_ranking", args: `{}`, wantErr: true, }, { name: "export_observations - invalid json", toolName: "export_observations", args: `{invalid`, wantErr: true, }, { name: "analyze_search_patterns - invalid json", toolName: "analyze_search_patterns", args: `{invalid`, wantErr: true, }, { name: "get_observation_relationships - invalid json", toolName: "get_observation_relationships", args: `{invalid`, wantErr: true, }, { name: "get_observation_relationships - missing id", toolName: "get_observation_relationships", args: `{}`, wantErr: true, }, { name: "get_observation_scoring_breakdown - invalid json", toolName: "get_observation_scoring_breakdown", args: `{invalid`, wantErr: true, }, { name: "get_observation_scoring_breakdown - missing id", toolName: "get_observation_scoring_breakdown", args: `{}`, wantErr: true, }, { name: "analyze_observation_importance - invalid json", toolName: "analyze_observation_importance", args: `{invalid`, wantErr: true, }, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { t.Parallel() result, err := server.callTool(ctx, tt.toolName, json.RawMessage(tt.args)) if tt.wantErr { require.Error(t, err) } else { require.NoError(t, err) assert.NotEmpty(t, result) } }) } } // TestCallTool_SearchTools tests search-based tools in callTool. func TestCallTool_SearchTools(t *testing.T) { t.Parallel() server := NewServer(nil, "", "", "1.0.0") ctx := context.Background() // All search tools should fail when worker is unavailable (nil client) searchTools := []string{ "search", "decisions", "changes", "how_it_works", "find_by_concept", "find_by_file", "find_by_type", "get_recent_context", } for _, toolName := range searchTools { t.Run(toolName+"_worker_unavailable", func(t *testing.T) { t.Parallel() _, err := server.callTool(ctx, toolName, json.RawMessage(`{"query":"test"}`)) require.Error(t, err) assert.Contains(t, err.Error(), "worker unavailable") }) } // Timeline tools should handle invalid JSON with a parse error timelineTools := []string{ "timeline", "get_context_timeline", "get_timeline_by_query", } for _, toolName := range timelineTools { t.Run(toolName+"_invalid_json", func(t *testing.T) { t.Parallel() _, err := server.callTool(ctx, toolName, json.RawMessage(`{invalid`)) require.Error(t, err) assert.Contains(t, err.Error(), "invalid") }) } } // TestHandleTriggerMaintenance_NilService tests trigger_maintenance with nil service. func TestHandleTriggerMaintenance_NilService(t *testing.T) { t.Parallel() server := NewServer(nil, "", "", "1.0.0") ctx := context.Background() // Should return error when maintenanceService is nil _, err := server.proxyPostRaw(ctx, "/api/scoring/recalculate", nil) require.Error(t, err) assert.Contains(t, err.Error(), "worker unavailable") } // TestHandleGetMaintenanceStats_NilService tests get_maintenance_stats with nil service. func TestHandleGetMaintenanceStats_NilService(t *testing.T) { t.Parallel() server := NewServer(nil, "", "", "1.0.0") ctx := context.Background() // Should return error when maintenanceService is nil _, err := server.proxyGetRaw(ctx, "/api/stats", map[string]string{"project": ""}) require.Error(t, err) assert.Contains(t, err.Error(), "worker unavailable") } // TestHandleTimeline_ParameterDefaultsNew tests timeline parameter defaults. func TestHandleTimeline_ParameterDefaultsNew(t *testing.T) { t.Parallel() server := NewServer(nil, "", "", "1.0.0") ctx := context.Background() // Invalid JSON should fail _, err := server.handleTimelineProxy(ctx, json.RawMessage(`{invalid`)) require.Error(t, err) assert.Contains(t, err.Error(), "invalid timeline params") } // TestHandleTimelineByQuery_ValidationExtended tests timeline_by_query validation. func TestHandleTimelineByQuery_ValidationExtended(t *testing.T) { t.Parallel() server := NewServer(nil, "", "", "1.0.0") ctx := context.Background() tests := []struct { name string args string errContains string wantErr bool }{ { name: "invalid json", args: `{invalid`, wantErr: true, errContains: "invalid timeline params", }, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { t.Parallel() _, err := server.handleTimelineProxy(ctx, json.RawMessage(tt.args)) if tt.wantErr { require.Error(t, err) assert.Contains(t, err.Error(), tt.errContains) } }) } } // TestHandleSuggestConsolidations_ValidationExtended tests suggest_consolidations validation. func TestHandleSuggestConsolidations_ValidationExtended(t *testing.T) { t.Parallel() server := NewServer(nil, "", "", "1.0.0") ctx := context.Background() tests := []struct { name string args string errContains string wantErr bool }{ { name: "invalid json", args: `{invalid`, wantErr: true, errContains: "invalid arguments", }, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { t.Parallel() _, err := server.handleSuggestConsolidationsProxy(ctx, json.RawMessage(tt.args)) require.Error(t, err) assert.Contains(t, err.Error(), tt.errContains) }) } } // ============================================================================= // NIL GUARD HANDLER TESTS // ============================================================================= // TestHandleFindSimilarObservations_NilVectorClient tests nil vector client handling. func TestHandleFindSimilarObservations_NilVectorClient(t *testing.T) { t.Parallel() server := NewServer(nil, "", "", "1.0.0") ctx := context.Background() // Should return error when client is nil (worker unavailable) _, err := server.handleFindSimilarProxy(ctx, json.RawMessage(`{"query": "test query"}`)) require.Error(t, err) assert.Contains(t, err.Error(), "worker unavailable") } // TestHandleGetObservationRelationships_NilRelationStore tests nil relation store handling. func TestHandleGetObservationRelationships_NilRelationStore(t *testing.T) { t.Parallel() server := NewServer(nil, "", "", "1.0.0") ctx := context.Background() // Should return error when worker is unavailable _, err := server.handleGetRelationshipsProxy(ctx, json.RawMessage(`{"id": 123}`)) require.Error(t, err) assert.Contains(t, err.Error(), "worker unavailable") } // ============================================================================= // MORE PARAM LIMIT TESTS // ============================================================================= // TestHandleBulkBoostObservations_EmptyIDs tests the empty IDs validation. func TestHandleBulkBoostObservations_EmptyIDs(t *testing.T) { t.Parallel() server := NewServer(nil, "", "", "1.0.0") ctx := context.Background() // Empty IDs should return error _, err := server.handleBulkStatusProxy(ctx, json.RawMessage(`{"ids": []}`), "boost") require.Error(t, err) assert.Contains(t, err.Error(), "ids is required") } // TestHandleMergeObservations_MissingIDs tests merge with missing IDs. func TestHandleMergeObservations_MissingIDs(t *testing.T) { t.Parallel() server := NewServer(nil, "", "", "1.0.0") ctx := context.Background() // source_id and target_id are required _, err := server.handleMergeProxy(ctx, json.RawMessage(`{"source_id": 0, "target_id": 0}`)) require.Error(t, err) assert.Contains(t, err.Error(), "source_id and target_id are required") } // TestHandleMergeObservations_WorkerUnavailable tests merge when worker is down. func TestHandleMergeObservations_WorkerUnavailable(t *testing.T) { t.Parallel() server := NewServer(nil, "", "", "1.0.0") ctx := context.Background() // Should fail when worker is unavailable (nil client) _, err := server.handleMergeProxy(ctx, json.RawMessage(`{"source_id": 1, "target_id": 2}`)) require.Error(t, err) } // ============================================================================= // REGRESSION TESTS — Fix #45: Concurrent request dispatching // ============================================================================= // collectResponses reads newline-delimited JSON responses from r until it has // collected n responses or the context is done. Returns collected responses. func collectResponses(t *testing.T, r io.Reader, n int) []map[string]any { t.Helper() results := make([]map[string]any, 0, n) scanner := bufio.NewScanner(r) for scanner.Scan() { line := scanner.Text() if line == "" { continue } var resp map[string]any if err := json.Unmarshal([]byte(line), &resp); err != nil { t.Logf("collectResponses: bad JSON line: %s", line) continue } results = append(results, resp) if len(results) >= n { break } } return results } // writeRequests writes newline-delimited JSON requests to w and closes it when done. func writeRequests(t *testing.T, w io.WriteCloser, reqs []string) { t.Helper() for _, r := range reqs { _, err := io.WriteString(w, r+"\n") require.NoError(t, err) } _ = w.Close() } // TestRun_ConcurrentRequests verifies multiple requests are processed concurrently // and not serially. If they ran serially, 5 × 100ms = 500ms. Concurrent should be // well under 400ms on any reasonable machine. func TestRun_ConcurrentRequests(t *testing.T) { t.Parallel() const delay = 100 * time.Millisecond const numRequests = 5 // Mock worker: every request sleeps delay then returns "{}" ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { time.Sleep(delay) w.Header().Set("Content-Type", "application/json") _, _ = io.WriteString(w, `{}`) })) defer ts.Close() stdinR, stdinW := io.Pipe() stdoutR, stdoutW := io.Pipe() server := &Server{ client: ts.Client(), workerURL: ts.URL, project: "test", version: "1.0.0", stdin: stdinR, stdout: stdoutW, } // Build requests — use get_memory_stats which goes to GET /api/stats reqs := make([]string, numRequests) for i := range reqs { req := Request{JSONRPC: "2.0", ID: i + 1, Method: "tools/call", Params: json.RawMessage(`{"name":"get_memory_stats","arguments":{}}`)} data, err := json.Marshal(req) require.NoError(t, err) reqs[i] = string(data) } // Collect responses in background var responses []map[string]any var wg sync.WaitGroup wg.Go(func() { responses = collectResponses(t, stdoutR, numRequests) _ = stdoutR.Close() }) start := time.Now() // Write all requests then close stdin (triggers Run to drain and return) go writeRequests(t, stdinW, reqs) err := server.Run(context.Background()) require.NoError(t, err) _ = stdoutW.Close() elapsed := time.Since(start) wg.Wait() // All responses received assert.Len(t, responses, numRequests, "expected %d responses", numRequests) // Concurrent execution: should be much less than numRequests × delay serialTime := time.Duration(numRequests) * delay assert.Less(t, elapsed, serialTime*4/5, "elapsed %v not significantly less than serial %v — requests may be sequential", elapsed, serialTime) // Each response has correct jsonrpc field for _, resp := range responses { assert.Equal(t, "2.0", resp["jsonrpc"]) } } // TestRun_SlowRequestDoesNotBlockOthers is the core regression for #45. // A slow search request must not block a fast stats request from being answered first. func TestRun_SlowRequestDoesNotBlockOthers(t *testing.T) { t.Parallel() const slowDelay = 300 * time.Millisecond // responseOrder records which request IDs responded, in arrival order var mu sync.Mutex var responseOrder []any // Mock worker: /api/context/search is slow, everything else is fast ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { if strings.Contains(r.URL.Path, "/context/search") { time.Sleep(slowDelay) } w.Header().Set("Content-Type", "application/json") _, _ = io.WriteString(w, `{}`) })) defer ts.Close() stdinR, stdinW := io.Pipe() stdoutR, stdoutW := io.Pipe() // Intercept stdout to record response order before passing through pr, pw := io.Pipe() go func() { scanner := bufio.NewScanner(pr) for scanner.Scan() { line := scanner.Text() var resp map[string]any if err := json.Unmarshal([]byte(line), &resp); err == nil { mu.Lock() responseOrder = append(responseOrder, resp["id"]) mu.Unlock() } _, _ = io.WriteString(stdoutW, line+"\n") } _ = stdoutW.Close() }() server := &Server{ client: ts.Client(), workerURL: ts.URL, project: "test", version: "1.0.0", stdin: stdinR, stdout: pw, } // Request 1: slow search (id=1) slowReq := Request{JSONRPC: "2.0", ID: 1, Method: "tools/call", Params: json.RawMessage(`{"name":"search","arguments":{"query":"anything"}}`)} slowData, err := json.Marshal(slowReq) require.NoError(t, err) // Request 2: fast stats (id=2) fastReq := Request{JSONRPC: "2.0", ID: 2, Method: "tools/call", Params: json.RawMessage(`{"name":"get_memory_stats","arguments":{}}`)} fastData, err := json.Marshal(fastReq) require.NoError(t, err) // Collect 2 responses var responses []map[string]any var wg sync.WaitGroup wg.Go(func() { responses = collectResponses(t, stdoutR, 2) _ = stdoutR.Close() }) // Write slow then fast, then close go func() { _, _ = io.WriteString(stdinW, string(slowData)+"\n") // Small pause to ensure slow request goroutine is dispatched first time.Sleep(10 * time.Millisecond) _, _ = io.WriteString(stdinW, string(fastData)+"\n") _ = stdinW.Close() }() runErr := server.Run(context.Background()) require.NoError(t, runErr) wg.Wait() require.Len(t, responses, 2, "expected 2 responses") mu.Lock() order := responseOrder mu.Unlock() require.Len(t, order, 2, "expected 2 recorded response IDs") // The fast request (id=2) must arrive before the slow one (id=1) assert.Equal(t, float64(2), order[0], "fast request (id=2) should respond before slow request (id=1); got order %v", order) assert.Equal(t, float64(1), order[1], "slow request (id=1) should respond second; got order %v", order) } // TestRun_SemaphoreLimitsConcurrency verifies the semaphore cap (10) does not deadlock // when more than 10 requests are sent and all eventually complete. func TestRun_SemaphoreLimitsConcurrency(t *testing.T) { t.Parallel() const blockDelay = 200 * time.Millisecond const numRequests = 15 // exceeds semaphore cap of 10 ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { time.Sleep(blockDelay) w.Header().Set("Content-Type", "application/json") _, _ = io.WriteString(w, `{}`) })) defer ts.Close() stdinR, stdinW := io.Pipe() stdoutR, stdoutW := io.Pipe() server := &Server{ client: ts.Client(), workerURL: ts.URL, project: "test", version: "1.0.0", stdin: stdinR, stdout: stdoutW, } reqs := make([]string, numRequests) for i := range reqs { req := Request{JSONRPC: "2.0", ID: i + 1, Method: "tools/call", Params: json.RawMessage(`{"name":"get_memory_stats","arguments":{}}`)} data, err := json.Marshal(req) require.NoError(t, err) reqs[i] = string(data) } var responses []map[string]any var wg sync.WaitGroup wg.Go(func() { responses = collectResponses(t, stdoutR, numRequests) _ = stdoutR.Close() }) start := time.Now() go writeRequests(t, stdinW, reqs) err := server.Run(context.Background()) require.NoError(t, err) _ = stdoutW.Close() elapsed := time.Since(start) wg.Wait() // All 15 responses received — no deadlock assert.Len(t, responses, numRequests, "all %d requests must complete", numRequests) // With semaphore=10 and 15 requests at 200ms each, we need at least 2 batches. // Should complete in ~2×blockDelay not 15×blockDelay. // Upper bound: 3×blockDelay gives comfortable headroom for scheduling. upperBound := 3 * blockDelay * 2 // generous: 3 batches + 2× overhead factor assert.Less(t, elapsed, upperBound, "elapsed %v suggests sequential processing (15×%v = %v)", elapsed, blockDelay, time.Duration(numRequests)*blockDelay) } // TestRun_GracefulDrainOnCancel verifies that cancelling the context causes Run to // drain in-flight requests (wg.Wait) before returning ctx.Canceled. func TestRun_GracefulDrainOnCancel(t *testing.T) { t.Parallel() const reqDelay = 200 * time.Millisecond const numRequests = 3 ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { // Use a fixed sleep; the request-level context will be cancelled but the // HTTP handler runs to completion independently (server-side). time.Sleep(reqDelay) w.Header().Set("Content-Type", "application/json") _, _ = io.WriteString(w, `{}`) })) defer ts.Close() stdinR, stdinW := io.Pipe() stdoutR, stdoutW := io.Pipe() server := &Server{ client: ts.Client(), workerURL: ts.URL, project: "test", version: "1.0.0", stdin: stdinR, stdout: stdoutW, } ctx, cancel := context.WithCancel(context.Background()) // Build requests reqs := make([]string, numRequests) for i := range reqs { req := Request{JSONRPC: "2.0", ID: i + 1, Method: "tools/call", Params: json.RawMessage(`{"name":"get_memory_stats","arguments":{}}`)} data, err := json.Marshal(req) require.NoError(t, err) reqs[i] = string(data) } // Write all requests before cancelling so they're all dispatched as goroutines go func() { for _, r := range reqs { _, _ = io.WriteString(stdinW, r+"\n") } // Cancel after requests are dispatched but while they're still in-flight time.Sleep(50 * time.Millisecond) cancel() // Leave stdin open — Run should return from the ctx.Done branch }() // Drain responses in background; we don't know exactly how many will complete // because goroutines may get context cancelled on their HTTP calls too. var responseMu sync.Mutex var collectedResponses []map[string]any var collectWg sync.WaitGroup collectWg.Go(func() { scanner := bufio.NewScanner(stdoutR) for scanner.Scan() { line := scanner.Text() if line == "" { continue } var resp map[string]any if err := json.Unmarshal([]byte(line), &resp); err == nil { responseMu.Lock() collectedResponses = append(collectedResponses, resp) responseMu.Unlock() } } }) runErr := server.Run(ctx) _ = stdoutW.Close() _ = stdinW.Close() collectWg.Wait() // Core assertion: Run returned the context cancellation error assert.ErrorIs(t, runErr, context.Canceled, "Run must return context.Canceled when context is cancelled") // Run returned only after wg.Wait() drained goroutines. // The goroutines may have returned errors (ctx cancelled HTTP calls) but // the key invariant is Run itself did not panic and returned cleanly. // Any responses that did complete should be valid JSON-RPC. responseMu.Lock() defer responseMu.Unlock() for _, resp := range collectedResponses { assert.Equal(t, "2.0", resp["jsonrpc"], "any completed response must be valid JSON-RPC 2.0") } } // ============================================================================= // REGRESSION TESTS — Fix #1-#5 // ============================================================================= // blockingWriter is an io.Writer that blocks forever on Write. type blockingWriter struct { blocked chan struct{} // closed when Write is entered } func (bw *blockingWriter) Write(p []byte) (int, error) { if bw.blocked != nil { close(bw.blocked) } select {} // block forever } // TestRun_SemaphoreDoesNotBlockMainLoop (Fix #1 regression) fills all semaphore // slots with blocked requests, then sends a notification. The main loop must // stay responsive and not hang on semaphore acquisition. func TestRun_SemaphoreDoesNotBlockMainLoop(t *testing.T) { t.Parallel() const maxConcurrent = 10 // Mock worker that blocks until test context is cancelled handlerCtx, handlerCancel := context.WithCancel(context.Background()) defer handlerCancel() ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { <-handlerCtx.Done() // block until test cleanup w.Header().Set("Content-Type", "application/json") _, _ = io.WriteString(w, `{}`) })) defer func() { handlerCancel() // unblock all handlers first ts.CloseClientConnections() ts.Close() }() stdinR, stdinW := io.Pipe() var stdout bytes.Buffer server := &Server{ client: ts.Client(), workerURL: ts.URL, project: "test", version: "1.0.0", stdin: stdinR, stdout: &stdout, } ctx, cancel := context.WithCancel(context.Background()) defer cancel() runDone := make(chan error, 1) go func() { runDone <- server.Run(ctx) }() // Send maxConcurrent+2 requests to fill all semaphore slots for i := 0; i < maxConcurrent+2; i++ { req := Request{JSONRPC: "2.0", ID: i + 1, Method: "tools/call", Params: json.RawMessage(`{"name":"get_memory_stats","arguments":{}}`)} data, _ := json.Marshal(req) _, err := io.WriteString(stdinW, string(data)+"\n") require.NoError(t, err) } // Give goroutines time to start and fill semaphore time.Sleep(100 * time.Millisecond) // Now send a notification — this must not hang the main loop notifSent := make(chan struct{}) go func() { notif := `{"jsonrpc":"2.0","method":"notifications/initialized"}` + "\n" _, _ = io.WriteString(stdinW, notif) close(notifSent) }() select { case <-notifSent: // Main loop accepted the notification write (stdin is a pipe, so // the write completing means the reader consumed it) case <-time.After(3 * time.Second): t.Fatal("main loop blocked — semaphore acquisition is blocking the main goroutine") } // Clean up: cancel server context, close stdin cancel() _ = stdinW.Close() // Wait for Run to finish select { case <-runDone: case <-time.After(5 * time.Second): // Acceptable — some goroutines may still be draining } } // TestSendResponse_WriteTimeout (Fix #2 regression) verifies that sendResponse // returns an error within a bounded time when the writer blocks forever. func TestSendResponse_WriteTimeout(t *testing.T) { t.Parallel() bw := &blockingWriter{blocked: make(chan struct{})} server := &Server{stdout: bw} resp := &Response{ JSONRPC: "2.0", ID: 1, Result: "ok", } done := make(chan error, 1) go func() { done <- server.sendResponse(resp) }() select { case err := <-done: require.Error(t, err, "sendResponse must return error on write timeout") assert.Contains(t, err.Error(), "write timeout") case <-time.After(15 * time.Second): t.Fatal("sendResponse hung forever — write timeout not working") } } // TestSendResponse_MarshalError (Fix #3 regression) verifies that when // json.Marshal fails, sendResponse sends a fallback error response and // returns an error (instead of silently returning nil). func TestSendResponse_MarshalError(t *testing.T) { t.Parallel() var buf bytes.Buffer server := &Server{stdout: &buf} // Channels are not JSON-serializable — this will cause json.Marshal to fail resp := &Response{ JSONRPC: "2.0", ID: 42, Result: make(chan int), // unserializable } err := server.sendResponse(resp) // (a) Must return an error require.Error(t, err, "sendResponse must return error when marshal fails") assert.Contains(t, err.Error(), "marshal error") // (b) Must have written a fallback JSON-RPC error to stdout output := buf.String() require.NotEmpty(t, output, "fallback response must be written to stdout") var fallback map[string]any require.NoError(t, json.Unmarshal([]byte(strings.TrimSpace(output)), &fallback), "fallback must be valid JSON") assert.Equal(t, "2.0", fallback["jsonrpc"]) assert.Equal(t, float64(42), fallback["id"], "fallback must preserve original request ID") errObj, ok := fallback["error"].(map[string]any) require.True(t, ok, "fallback must contain error object") assert.Equal(t, float64(-32603), errObj["code"]) assert.Equal(t, "internal marshal error", errObj["message"]) } // TestHandleRequest_UnknownNotification (Fix #4 regression) verifies that // unknown notification methods (ID == nil) get no response, while unknown // methods with an ID still get a -32601 error. func TestHandleRequest_UnknownNotification(t *testing.T) { t.Parallel() server := NewServer(nil, "", "", "1.0.0") ctx := context.Background() // Case 1: Unknown notification (no ID) — must return nil notifReq := &Request{ JSONRPC: "2.0", ID: nil, Method: "notifications/roots/list_changed", } resp := server.handleRequest(ctx, notifReq) assert.Nil(t, resp, "unknown notification must not produce a response") // Case 2: Unknown method WITH an ID — must return error response methodReq := &Request{ JSONRPC: "2.0", ID: 99, Method: "some/unknown/method", } resp = server.handleRequest(ctx, methodReq) require.NotNil(t, resp, "unknown method with ID must produce an error response") require.NotNil(t, resp.Error) assert.Equal(t, -32601, resp.Error.Code) assert.Equal(t, "Method not found", resp.Error.Message) assert.Equal(t, 99, resp.ID) } // TestHandleToolsCall_ErrorUsesIsError (Fix #5 regression) verifies that when // callTool returns an error, the response uses Result with isError:true instead // of top-level Error field (per MCP spec). func TestHandleToolsCall_ErrorUsesIsError(t *testing.T) { t.Parallel() server := NewServer(nil, "", "", "1.0.0") ctx := context.Background() req := &Request{ JSONRPC: "2.0", ID: 7, Method: "tools/call", Params: json.RawMessage(`{"name":"nonexistent_tool","arguments":{}}`), } resp := server.handleToolsCall(ctx, req) // (a) Response must NOT have top-level Error assert.Nil(t, resp.Error, "tool errors must not use top-level JSON-RPC Error") // (b) Response must have Result with isError: true require.NotNil(t, resp.Result, "tool error response must have Result") result, ok := resp.Result.(map[string]any) require.True(t, ok, "Result must be a map") assert.Equal(t, true, result["isError"], "Result must contain isError: true") // (c) Result.content[0].text must contain the error message content, ok := result["content"].([]map[string]any) require.True(t, ok, "Result.content must be []map[string]any") require.Len(t, content, 1) assert.Equal(t, "text", content[0]["type"]) errText, ok := content[0]["text"].(string) require.True(t, ok) assert.Contains(t, errText, "unknown tool: nonexistent_tool") }