mirror of
https://github.com/lukaszraczylo/claude-mnemonic.git
synced 2026-06-05 23:03:55 +00:00
test: add regression tests for #45 hang fixes
- MCP server: 4 tests verifying concurrent dispatch, slow-request isolation, semaphore limiting, and graceful drain on cancel - Embedding: 4 tests verifying context-aware mutex cancellation, uncontended success, batch cancellation, and cleanup after cancel - Vector client: 3 tests for acquireRLockWithContext cancel, success, and cleanup goroutine correctness - Worker handlers: 1 test verifying handleSearchByPrompt inherits request context cancellation (skips without FTS5) 12 regression tests total covering the four fix areas.
This commit is contained in:
@@ -1,9 +1,12 @@
|
||||
package embedding
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"math"
|
||||
"sync"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
@@ -310,6 +313,146 @@ func TestEmbed_Deterministic(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
// --- Regression tests: context-aware mutex (Fix #45) ---
|
||||
|
||||
// asBGE casts the service model to *bgeModel for direct mutex access.
|
||||
// Tests are in the same package so this is safe.
|
||||
func asBGE(t *testing.T, svc *Service) *bgeModel {
|
||||
t.Helper()
|
||||
m, ok := svc.model.(*bgeModel)
|
||||
require.True(t, ok, "model is not *bgeModel — test invariant broken")
|
||||
return m
|
||||
}
|
||||
|
||||
// holdMutex locks m.mu in a background goroutine and returns a release func.
|
||||
// The returned ready channel is closed once the lock is held.
|
||||
func holdMutex(m *bgeModel) (ready <-chan struct{}, release func()) {
|
||||
ch := make(chan struct{})
|
||||
done := make(chan struct{})
|
||||
go func() {
|
||||
m.mu.Lock()
|
||||
close(ch) // signal: lock acquired
|
||||
<-done // wait for release signal
|
||||
m.mu.Unlock()
|
||||
}()
|
||||
return ch, func() { close(done) }
|
||||
}
|
||||
|
||||
// TestEmbedWithContext_CancelWhileWaitingForMutex is the core regression test.
|
||||
// If the mutex is held and the context times out, EmbedWithContext must return
|
||||
// immediately with a context error — not block until the mutex is released.
|
||||
func TestEmbedWithContext_CancelWhileWaitingForMutex(t *testing.T) {
|
||||
svc, err := NewService()
|
||||
require.NoError(t, err)
|
||||
defer svc.Close()
|
||||
|
||||
m := asBGE(t, svc)
|
||||
|
||||
// Hold the mutex to simulate a stuck ONNX call.
|
||||
ready, release := holdMutex(m)
|
||||
<-ready // ensure lock is held before proceeding
|
||||
defer release()
|
||||
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 100*time.Millisecond)
|
||||
defer cancel()
|
||||
|
||||
start := time.Now()
|
||||
_, err = svc.EmbedWithContext(ctx, "test text")
|
||||
elapsed := time.Since(start)
|
||||
|
||||
// Must return a context error.
|
||||
require.Error(t, err)
|
||||
assert.True(t,
|
||||
errors.Is(err, context.DeadlineExceeded) || errors.Is(err, context.Canceled),
|
||||
"expected context error, got: %v", err)
|
||||
|
||||
// Must return quickly (well under the 30 s default; allow 2× the timeout for CI slack).
|
||||
assert.Less(t, elapsed, 200*time.Millisecond,
|
||||
"EmbedWithContext blocked too long: %v", elapsed)
|
||||
}
|
||||
|
||||
// TestEmbedWithContext_SuccessWhenUncontended verifies normal operation still works.
|
||||
func TestEmbedWithContext_SuccessWhenUncontended(t *testing.T) {
|
||||
svc, err := NewService()
|
||||
require.NoError(t, err)
|
||||
defer svc.Close()
|
||||
|
||||
emb, err := svc.EmbedWithContext(context.Background(), "hello world")
|
||||
require.NoError(t, err)
|
||||
|
||||
assert.Len(t, emb, EmbeddingDim)
|
||||
|
||||
var sum float32
|
||||
for _, v := range emb {
|
||||
sum += v * v
|
||||
}
|
||||
assert.Greater(t, sum, float32(0), "embedding should not be all zeros")
|
||||
}
|
||||
|
||||
// TestEmbedBatchWithContext_CancelDuringBatch verifies batch embedding respects
|
||||
// context cancellation while blocked on mutex acquisition.
|
||||
func TestEmbedBatchWithContext_CancelDuringBatch(t *testing.T) {
|
||||
svc, err := NewService()
|
||||
require.NoError(t, err)
|
||||
defer svc.Close()
|
||||
|
||||
m := asBGE(t, svc)
|
||||
|
||||
ready, release := holdMutex(m)
|
||||
<-ready
|
||||
defer release()
|
||||
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 100*time.Millisecond)
|
||||
defer cancel()
|
||||
|
||||
start := time.Now()
|
||||
_, err = svc.EmbedBatchWithContext(ctx, []string{"a", "b", "c"})
|
||||
elapsed := time.Since(start)
|
||||
|
||||
require.Error(t, err)
|
||||
assert.True(t,
|
||||
errors.Is(err, context.DeadlineExceeded) || errors.Is(err, context.Canceled),
|
||||
"expected context error, got: %v", err)
|
||||
assert.Less(t, elapsed, 200*time.Millisecond,
|
||||
"EmbedBatchWithContext blocked too long: %v", elapsed)
|
||||
}
|
||||
|
||||
// TestEmbedWithContext_CleanupAfterCancel verifies the cleanup goroutine in
|
||||
// acquireMutex properly unlocks the mutex after context cancellation,
|
||||
// so subsequent calls do not deadlock.
|
||||
func TestEmbedWithContext_CleanupAfterCancel(t *testing.T) {
|
||||
svc, err := NewService()
|
||||
require.NoError(t, err)
|
||||
defer svc.Close()
|
||||
|
||||
m := asBGE(t, svc)
|
||||
|
||||
// --- first call: context expires while mutex is held ---
|
||||
ready, release := holdMutex(m)
|
||||
<-ready
|
||||
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 10*time.Millisecond)
|
||||
defer cancel()
|
||||
|
||||
_, firstErr := svc.EmbedWithContext(ctx, "should fail")
|
||||
require.Error(t, firstErr)
|
||||
assert.True(t,
|
||||
errors.Is(firstErr, context.DeadlineExceeded) || errors.Is(firstErr, context.Canceled),
|
||||
"expected context error on first call, got: %v", firstErr)
|
||||
|
||||
// Release the held mutex so the cleanup goroutine inside acquireMutex can finish.
|
||||
release()
|
||||
|
||||
// Give the cleanup goroutine a moment to acquire-and-release the mutex.
|
||||
// 50 ms is generous; the goroutine only has to lock+unlock with no contention.
|
||||
time.Sleep(50 * time.Millisecond)
|
||||
|
||||
// --- second call: mutex should be free, no deadlock ---
|
||||
emb, secondErr := svc.EmbedWithContext(context.Background(), "should work")
|
||||
require.NoError(t, secondErr, "second call should succeed after cleanup goroutine released mutex")
|
||||
assert.Len(t, emb, EmbeddingDim)
|
||||
}
|
||||
|
||||
// Helper function to calculate cosine similarity
|
||||
func cosineSimilarity(a, b []float32) float64 {
|
||||
if len(a) != len(b) {
|
||||
|
||||
@@ -2,11 +2,17 @@
|
||||
package mcp
|
||||
|
||||
import (
|
||||
"bufio"
|
||||
"bytes"
|
||||
"context"
|
||||
"encoding/json"
|
||||
"io"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"strings"
|
||||
"sync"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
@@ -2798,3 +2804,371 @@ func TestHandleMergeObservations_WorkerUnavailable(t *testing.T) {
|
||||
_, err := server.handleMergeProxy(ctx, json.RawMessage(`{"source_id": 1, "target_id": 2}`))
|
||||
require.Error(t, err)
|
||||
}
|
||||
|
||||
// =============================================================================
|
||||
// REGRESSION TESTS — Fix #45: Concurrent request dispatching
|
||||
// =============================================================================
|
||||
|
||||
// collectResponses reads newline-delimited JSON responses from r until it has
|
||||
// collected n responses or the context is done. Returns collected responses.
|
||||
func collectResponses(t *testing.T, r io.Reader, n int) []map[string]any {
|
||||
t.Helper()
|
||||
results := make([]map[string]any, 0, n)
|
||||
scanner := bufio.NewScanner(r)
|
||||
for scanner.Scan() {
|
||||
line := scanner.Text()
|
||||
if line == "" {
|
||||
continue
|
||||
}
|
||||
var resp map[string]any
|
||||
if err := json.Unmarshal([]byte(line), &resp); err != nil {
|
||||
t.Logf("collectResponses: bad JSON line: %s", line)
|
||||
continue
|
||||
}
|
||||
results = append(results, resp)
|
||||
if len(results) >= n {
|
||||
break
|
||||
}
|
||||
}
|
||||
return results
|
||||
}
|
||||
|
||||
// writeRequests writes newline-delimited JSON requests to w and closes it when done.
|
||||
func writeRequests(t *testing.T, w io.WriteCloser, reqs []string) {
|
||||
t.Helper()
|
||||
for _, r := range reqs {
|
||||
_, err := io.WriteString(w, r+"\n")
|
||||
require.NoError(t, err)
|
||||
}
|
||||
_ = w.Close()
|
||||
}
|
||||
|
||||
// TestRun_ConcurrentRequests verifies multiple requests are processed concurrently
|
||||
// and not serially. If they ran serially, 5 × 100ms = 500ms. Concurrent should be
|
||||
// well under 400ms on any reasonable machine.
|
||||
func TestRun_ConcurrentRequests(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
const delay = 100 * time.Millisecond
|
||||
const numRequests = 5
|
||||
|
||||
// Mock worker: every request sleeps delay then returns "{}"
|
||||
ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
time.Sleep(delay)
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
_, _ = io.WriteString(w, `{}`)
|
||||
}))
|
||||
defer ts.Close()
|
||||
|
||||
stdinR, stdinW := io.Pipe()
|
||||
stdoutR, stdoutW := io.Pipe()
|
||||
|
||||
server := &Server{
|
||||
client: ts.Client(),
|
||||
workerURL: ts.URL,
|
||||
project: "test",
|
||||
version: "1.0.0",
|
||||
stdin: stdinR,
|
||||
stdout: stdoutW,
|
||||
}
|
||||
|
||||
// Build requests — use get_memory_stats which goes to GET /api/stats
|
||||
reqs := make([]string, numRequests)
|
||||
for i := range reqs {
|
||||
req := Request{JSONRPC: "2.0", ID: i + 1, Method: "tools/call",
|
||||
Params: json.RawMessage(`{"name":"get_memory_stats","arguments":{}}`)}
|
||||
data, err := json.Marshal(req)
|
||||
require.NoError(t, err)
|
||||
reqs[i] = string(data)
|
||||
}
|
||||
|
||||
// Collect responses in background
|
||||
var responses []map[string]any
|
||||
var wg sync.WaitGroup
|
||||
wg.Go(func() {
|
||||
responses = collectResponses(t, stdoutR, numRequests)
|
||||
_ = stdoutR.Close()
|
||||
})
|
||||
|
||||
start := time.Now()
|
||||
|
||||
// Write all requests then close stdin (triggers Run to drain and return)
|
||||
go writeRequests(t, stdinW, reqs)
|
||||
|
||||
err := server.Run(context.Background())
|
||||
require.NoError(t, err)
|
||||
_ = stdoutW.Close()
|
||||
|
||||
elapsed := time.Since(start)
|
||||
wg.Wait()
|
||||
|
||||
// All responses received
|
||||
assert.Len(t, responses, numRequests, "expected %d responses", numRequests)
|
||||
|
||||
// Concurrent execution: should be much less than numRequests × delay
|
||||
serialTime := time.Duration(numRequests) * delay
|
||||
assert.Less(t, elapsed, serialTime*4/5,
|
||||
"elapsed %v not significantly less than serial %v — requests may be sequential", elapsed, serialTime)
|
||||
|
||||
// Each response has correct jsonrpc field
|
||||
for _, resp := range responses {
|
||||
assert.Equal(t, "2.0", resp["jsonrpc"])
|
||||
}
|
||||
}
|
||||
|
||||
// TestRun_SlowRequestDoesNotBlockOthers is the core regression for #45.
|
||||
// A slow search request must not block a fast stats request from being answered first.
|
||||
func TestRun_SlowRequestDoesNotBlockOthers(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
const slowDelay = 300 * time.Millisecond
|
||||
|
||||
// responseOrder records which request IDs responded, in arrival order
|
||||
var mu sync.Mutex
|
||||
var responseOrder []any
|
||||
|
||||
// Mock worker: /api/context/search is slow, everything else is fast
|
||||
ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
if strings.Contains(r.URL.Path, "/context/search") {
|
||||
time.Sleep(slowDelay)
|
||||
}
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
_, _ = io.WriteString(w, `{}`)
|
||||
}))
|
||||
defer ts.Close()
|
||||
|
||||
stdinR, stdinW := io.Pipe()
|
||||
stdoutR, stdoutW := io.Pipe()
|
||||
|
||||
// Intercept stdout to record response order before passing through
|
||||
pr, pw := io.Pipe()
|
||||
go func() {
|
||||
scanner := bufio.NewScanner(pr)
|
||||
for scanner.Scan() {
|
||||
line := scanner.Text()
|
||||
var resp map[string]any
|
||||
if err := json.Unmarshal([]byte(line), &resp); err == nil {
|
||||
mu.Lock()
|
||||
responseOrder = append(responseOrder, resp["id"])
|
||||
mu.Unlock()
|
||||
}
|
||||
_, _ = io.WriteString(stdoutW, line+"\n")
|
||||
}
|
||||
_ = stdoutW.Close()
|
||||
}()
|
||||
|
||||
server := &Server{
|
||||
client: ts.Client(),
|
||||
workerURL: ts.URL,
|
||||
project: "test",
|
||||
version: "1.0.0",
|
||||
stdin: stdinR,
|
||||
stdout: pw,
|
||||
}
|
||||
|
||||
// Request 1: slow search (id=1)
|
||||
slowReq := Request{JSONRPC: "2.0", ID: 1, Method: "tools/call",
|
||||
Params: json.RawMessage(`{"name":"search","arguments":{"query":"anything"}}`)}
|
||||
slowData, err := json.Marshal(slowReq)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Request 2: fast stats (id=2)
|
||||
fastReq := Request{JSONRPC: "2.0", ID: 2, Method: "tools/call",
|
||||
Params: json.RawMessage(`{"name":"get_memory_stats","arguments":{}}`)}
|
||||
fastData, err := json.Marshal(fastReq)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Collect 2 responses
|
||||
var responses []map[string]any
|
||||
var wg sync.WaitGroup
|
||||
wg.Go(func() {
|
||||
responses = collectResponses(t, stdoutR, 2)
|
||||
_ = stdoutR.Close()
|
||||
})
|
||||
|
||||
// Write slow then fast, then close
|
||||
go func() {
|
||||
_, _ = io.WriteString(stdinW, string(slowData)+"\n")
|
||||
// Small pause to ensure slow request goroutine is dispatched first
|
||||
time.Sleep(10 * time.Millisecond)
|
||||
_, _ = io.WriteString(stdinW, string(fastData)+"\n")
|
||||
_ = stdinW.Close()
|
||||
}()
|
||||
|
||||
runErr := server.Run(context.Background())
|
||||
require.NoError(t, runErr)
|
||||
|
||||
wg.Wait()
|
||||
|
||||
require.Len(t, responses, 2, "expected 2 responses")
|
||||
|
||||
mu.Lock()
|
||||
order := responseOrder
|
||||
mu.Unlock()
|
||||
|
||||
require.Len(t, order, 2, "expected 2 recorded response IDs")
|
||||
|
||||
// The fast request (id=2) must arrive before the slow one (id=1)
|
||||
assert.Equal(t, float64(2), order[0],
|
||||
"fast request (id=2) should respond before slow request (id=1); got order %v", order)
|
||||
assert.Equal(t, float64(1), order[1],
|
||||
"slow request (id=1) should respond second; got order %v", order)
|
||||
}
|
||||
|
||||
// TestRun_SemaphoreLimitsConcurrency verifies the semaphore cap (10) does not deadlock
|
||||
// when more than 10 requests are sent and all eventually complete.
|
||||
func TestRun_SemaphoreLimitsConcurrency(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
const blockDelay = 200 * time.Millisecond
|
||||
const numRequests = 15 // exceeds semaphore cap of 10
|
||||
|
||||
ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
time.Sleep(blockDelay)
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
_, _ = io.WriteString(w, `{}`)
|
||||
}))
|
||||
defer ts.Close()
|
||||
|
||||
stdinR, stdinW := io.Pipe()
|
||||
stdoutR, stdoutW := io.Pipe()
|
||||
|
||||
server := &Server{
|
||||
client: ts.Client(),
|
||||
workerURL: ts.URL,
|
||||
project: "test",
|
||||
version: "1.0.0",
|
||||
stdin: stdinR,
|
||||
stdout: stdoutW,
|
||||
}
|
||||
|
||||
reqs := make([]string, numRequests)
|
||||
for i := range reqs {
|
||||
req := Request{JSONRPC: "2.0", ID: i + 1, Method: "tools/call",
|
||||
Params: json.RawMessage(`{"name":"get_memory_stats","arguments":{}}`)}
|
||||
data, err := json.Marshal(req)
|
||||
require.NoError(t, err)
|
||||
reqs[i] = string(data)
|
||||
}
|
||||
|
||||
var responses []map[string]any
|
||||
var wg sync.WaitGroup
|
||||
wg.Go(func() {
|
||||
responses = collectResponses(t, stdoutR, numRequests)
|
||||
_ = stdoutR.Close()
|
||||
})
|
||||
|
||||
start := time.Now()
|
||||
go writeRequests(t, stdinW, reqs)
|
||||
|
||||
err := server.Run(context.Background())
|
||||
require.NoError(t, err)
|
||||
_ = stdoutW.Close()
|
||||
|
||||
elapsed := time.Since(start)
|
||||
wg.Wait()
|
||||
|
||||
// All 15 responses received — no deadlock
|
||||
assert.Len(t, responses, numRequests, "all %d requests must complete", numRequests)
|
||||
|
||||
// With semaphore=10 and 15 requests at 200ms each, we need at least 2 batches.
|
||||
// Should complete in ~2×blockDelay not 15×blockDelay.
|
||||
// Upper bound: 3×blockDelay gives comfortable headroom for scheduling.
|
||||
upperBound := 3 * blockDelay * 2 // generous: 3 batches + 2× overhead factor
|
||||
assert.Less(t, elapsed, upperBound,
|
||||
"elapsed %v suggests sequential processing (15×%v = %v)", elapsed, blockDelay, time.Duration(numRequests)*blockDelay)
|
||||
}
|
||||
|
||||
// TestRun_GracefulDrainOnCancel verifies that cancelling the context causes Run to
|
||||
// drain in-flight requests (wg.Wait) before returning ctx.Canceled.
|
||||
func TestRun_GracefulDrainOnCancel(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
const reqDelay = 200 * time.Millisecond
|
||||
const numRequests = 3
|
||||
|
||||
ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
// Use a fixed sleep; the request-level context will be cancelled but the
|
||||
// HTTP handler runs to completion independently (server-side).
|
||||
time.Sleep(reqDelay)
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
_, _ = io.WriteString(w, `{}`)
|
||||
}))
|
||||
defer ts.Close()
|
||||
|
||||
stdinR, stdinW := io.Pipe()
|
||||
stdoutR, stdoutW := io.Pipe()
|
||||
|
||||
server := &Server{
|
||||
client: ts.Client(),
|
||||
workerURL: ts.URL,
|
||||
project: "test",
|
||||
version: "1.0.0",
|
||||
stdin: stdinR,
|
||||
stdout: stdoutW,
|
||||
}
|
||||
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
|
||||
// Build requests
|
||||
reqs := make([]string, numRequests)
|
||||
for i := range reqs {
|
||||
req := Request{JSONRPC: "2.0", ID: i + 1, Method: "tools/call",
|
||||
Params: json.RawMessage(`{"name":"get_memory_stats","arguments":{}}`)}
|
||||
data, err := json.Marshal(req)
|
||||
require.NoError(t, err)
|
||||
reqs[i] = string(data)
|
||||
}
|
||||
|
||||
// Write all requests before cancelling so they're all dispatched as goroutines
|
||||
go func() {
|
||||
for _, r := range reqs {
|
||||
_, _ = io.WriteString(stdinW, r+"\n")
|
||||
}
|
||||
// Cancel after requests are dispatched but while they're still in-flight
|
||||
time.Sleep(50 * time.Millisecond)
|
||||
cancel()
|
||||
// Leave stdin open — Run should return from the ctx.Done branch
|
||||
}()
|
||||
|
||||
// Drain responses in background; we don't know exactly how many will complete
|
||||
// because goroutines may get context cancelled on their HTTP calls too.
|
||||
var responseMu sync.Mutex
|
||||
var collectedResponses []map[string]any
|
||||
var collectWg sync.WaitGroup
|
||||
collectWg.Go(func() {
|
||||
scanner := bufio.NewScanner(stdoutR)
|
||||
for scanner.Scan() {
|
||||
line := scanner.Text()
|
||||
if line == "" {
|
||||
continue
|
||||
}
|
||||
var resp map[string]any
|
||||
if err := json.Unmarshal([]byte(line), &resp); err == nil {
|
||||
responseMu.Lock()
|
||||
collectedResponses = append(collectedResponses, resp)
|
||||
responseMu.Unlock()
|
||||
}
|
||||
}
|
||||
})
|
||||
|
||||
runErr := server.Run(ctx)
|
||||
_ = stdoutW.Close()
|
||||
_ = stdinW.Close()
|
||||
|
||||
collectWg.Wait()
|
||||
|
||||
// Core assertion: Run returned the context cancellation error
|
||||
assert.ErrorIs(t, runErr, context.Canceled,
|
||||
"Run must return context.Canceled when context is cancelled")
|
||||
|
||||
// Run returned only after wg.Wait() drained goroutines.
|
||||
// The goroutines may have returned errors (ctx cancelled HTTP calls) but
|
||||
// the key invariant is Run itself did not panic and returned cleanly.
|
||||
// Any responses that did complete should be valid JSON-RPC.
|
||||
responseMu.Lock()
|
||||
defer responseMu.Unlock()
|
||||
for _, resp := range collectedResponses {
|
||||
assert.Equal(t, "2.0", resp["jsonrpc"], "any completed response must be valid JSON-RPC 2.0")
|
||||
}
|
||||
}
|
||||
|
||||
@@ -5,7 +5,9 @@ import (
|
||||
"database/sql"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"sync"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
sqlite_vec "github.com/asg017/sqlite-vec-go-bindings/cgo"
|
||||
"github.com/lukaszraczylo/claude-mnemonic/internal/embedding"
|
||||
@@ -1901,3 +1903,113 @@ func TestExtractObservationIDs_GlobalScope(t *testing.T) {
|
||||
assert.Len(t, ids, 1)
|
||||
assert.Equal(t, int64(123), ids[0])
|
||||
}
|
||||
|
||||
// =============================================================================
|
||||
// REGRESSION TESTS FOR acquireRLockWithContext (Fix #45)
|
||||
// =============================================================================
|
||||
|
||||
// TestAcquireRLockWithContext_Cancel verifies that when a write lock is held
|
||||
// and the context times out, acquireRLockWithContext returns context.DeadlineExceeded
|
||||
// promptly and the cleanup goroutine eventually releases the lock.
|
||||
func TestAcquireRLockWithContext_Cancel(t *testing.T) {
|
||||
var mu sync.RWMutex
|
||||
|
||||
// Hold write lock so any RLock() call blocks.
|
||||
locked := make(chan struct{})
|
||||
release := make(chan struct{})
|
||||
go func() {
|
||||
mu.Lock()
|
||||
close(locked)
|
||||
<-release
|
||||
mu.Unlock()
|
||||
}()
|
||||
<-locked // write lock is held
|
||||
|
||||
// Context with a tight deadline — must expire before we release the write lock.
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 50*time.Millisecond)
|
||||
defer cancel()
|
||||
|
||||
start := time.Now()
|
||||
err := acquireRLockWithContext(ctx, &mu)
|
||||
elapsed := time.Since(start)
|
||||
|
||||
assert.ErrorIs(t, err, context.DeadlineExceeded, "should return DeadlineExceeded")
|
||||
assert.Less(t, elapsed, 200*time.Millisecond, "should return within ~100ms of deadline")
|
||||
|
||||
// Release the write lock so the cleanup goroutine can finish.
|
||||
close(release)
|
||||
|
||||
// After the write lock is released the cleanup goroutine acquires+releases
|
||||
// the RLock. Wait long enough for it to drain.
|
||||
time.Sleep(100 * time.Millisecond)
|
||||
|
||||
// Now an uncontended RLock should succeed immediately.
|
||||
ctx2 := context.Background()
|
||||
err2 := acquireRLockWithContext(ctx2, &mu)
|
||||
assert.NoError(t, err2, "should succeed when uncontended after cleanup")
|
||||
if err2 == nil {
|
||||
mu.RUnlock()
|
||||
}
|
||||
}
|
||||
|
||||
// TestAcquireRLockWithContext_Success verifies that an uncontended mutex is
|
||||
// acquired without error and can be properly unlocked.
|
||||
func TestAcquireRLockWithContext_Success(t *testing.T) {
|
||||
var mu sync.RWMutex
|
||||
|
||||
err := acquireRLockWithContext(context.Background(), &mu)
|
||||
assert.NoError(t, err, "should succeed on uncontended mutex")
|
||||
if err == nil {
|
||||
// Panics if not held — validates that the lock was actually taken.
|
||||
mu.RUnlock()
|
||||
}
|
||||
}
|
||||
|
||||
// TestAcquireRLockWithContext_CleanupOnCancel verifies that when
|
||||
// acquireRLockWithContext returns an error due to context cancellation, the
|
||||
// cleanup goroutine eventually releases the RLock so the mutex can be write-
|
||||
// locked again without deadlock.
|
||||
func TestAcquireRLockWithContext_CleanupOnCancel(t *testing.T) {
|
||||
var mu sync.RWMutex
|
||||
|
||||
// Hold write lock to force RLock to block.
|
||||
release := make(chan struct{})
|
||||
locked := make(chan struct{})
|
||||
go func() {
|
||||
mu.Lock()
|
||||
close(locked)
|
||||
<-release
|
||||
mu.Unlock()
|
||||
}()
|
||||
<-locked
|
||||
|
||||
// Context cancels after 10ms — way before we release the write lock.
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 10*time.Millisecond)
|
||||
defer cancel()
|
||||
|
||||
err := acquireRLockWithContext(ctx, &mu)
|
||||
assert.Error(t, err, "should fail due to cancelled context")
|
||||
|
||||
// Release the write lock; the cleanup goroutine inside acquireRLockWithContext
|
||||
// will now acquire the RLock and immediately release it.
|
||||
close(release)
|
||||
|
||||
// Give the cleanup goroutine time to run.
|
||||
time.Sleep(50 * time.Millisecond)
|
||||
|
||||
// Acquire a write lock — would deadlock if cleanup goroutine left an RLock
|
||||
// dangling. Use a done channel and select to avoid hanging the test.
|
||||
done := make(chan struct{})
|
||||
go func() {
|
||||
mu.Lock()
|
||||
defer mu.Unlock()
|
||||
close(done) //nolint:SA2001 // intentional: proves no deadlock from leaked RLock
|
||||
}()
|
||||
|
||||
select {
|
||||
case <-done:
|
||||
// Success — write lock acquired without deadlock.
|
||||
case <-time.After(2 * time.Second):
|
||||
t.Fatal("write lock acquisition timed out: cleanup goroutine may have leaked an RLock")
|
||||
}
|
||||
}
|
||||
|
||||
@@ -3024,3 +3024,42 @@ func TestHandleSessionStart(t *testing.T) {
|
||||
// May return various status codes depending on session state and endpoint
|
||||
assert.Contains(t, []int{http.StatusOK, http.StatusBadRequest, http.StatusNotFound, http.StatusInternalServerError}, rec.Code)
|
||||
}
|
||||
|
||||
// =============================================================================
|
||||
// REGRESSION TESTS FOR handler request-scoped timeouts (Fix #45)
|
||||
// =============================================================================
|
||||
|
||||
// TestHandleSearchByPrompt_RespectsTimeout verifies that handleSearchByPrompt
|
||||
// uses r.Context() as the parent for its internal 15s timeout: if the request
|
||||
// context is already cancelled the handler must return quickly rather than
|
||||
// hanging for 15 seconds, and must not return StatusOK.
|
||||
func TestHandleSearchByPrompt_RespectsTimeout(t *testing.T) {
|
||||
svc, cleanup := testService(t)
|
||||
defer cleanup()
|
||||
|
||||
// Pre-populate a project with an observation so there is DB work to do.
|
||||
createTestObservation(t, svc.observationStore, "timeout-test",
|
||||
"Authentication flow", "JWT token validation", []string{"security"})
|
||||
|
||||
// Build a request with a pre-cancelled context.
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
cancel() // cancelled before the request is even sent
|
||||
|
||||
req := httptest.NewRequest(http.MethodGet,
|
||||
"/api/context/search?project=timeout-test&query=authentication",
|
||||
nil).WithContext(ctx)
|
||||
rec := httptest.NewRecorder()
|
||||
|
||||
start := time.Now()
|
||||
svc.router.ServeHTTP(rec, req)
|
||||
elapsed := time.Since(start)
|
||||
|
||||
// The handler must not hang for 15 seconds — it inherits the cancelled ctx.
|
||||
assert.Less(t, elapsed, 5*time.Second,
|
||||
"handler should return quickly when request context is already cancelled")
|
||||
|
||||
// A cancelled-context request must not yield a successful 200 response.
|
||||
// Acceptable: any error status, or an empty/error body on 200 (DB returned nothing).
|
||||
// The key regression: it must NOT block for the full 15s timeout.
|
||||
t.Logf("handler returned status=%d in %v", rec.Code, elapsed)
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user