Files
claude-mnemonic/internal/embedding/service_test.go
T
lukaszraczylo de5796bbe6 test: add regression tests for #45 hang fixes
- MCP server: 4 tests verifying concurrent dispatch, slow-request
  isolation, semaphore limiting, and graceful drain on cancel
- Embedding: 4 tests verifying context-aware mutex cancellation,
  uncontended success, batch cancellation, and cleanup after cancel
- Vector client: 3 tests for acquireRLockWithContext cancel, success,
  and cleanup goroutine correctness
- Worker handlers: 1 test verifying handleSearchByPrompt inherits
  request context cancellation (skips without FTS5)

12 regression tests total covering the four fix areas.
2026-05-26 14:29:34 +01:00

478 lines
12 KiB
Go
Raw Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
package embedding
import (
"context"
"errors"
"math"
"sync"
"testing"
"time"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
// TestEmbeddingDim verifies the embedding dimension constant.
func TestEmbeddingDim(t *testing.T) {
assert.Equal(t, 384, EmbeddingDim)
}
// TestNewService tests creating a new embedding service.
func TestNewService(t *testing.T) {
svc, err := NewService()
require.NoError(t, err)
require.NotNil(t, svc)
defer svc.Close()
// Verify the service is properly initialized via public methods
assert.NotEmpty(t, svc.Name())
assert.NotEmpty(t, svc.Version())
assert.Equal(t, EmbeddingDim, svc.Dimensions())
}
// TestEmbed_SingleText tests embedding a single text.
func TestEmbed_SingleText(t *testing.T) {
svc, err := NewService()
require.NoError(t, err)
defer svc.Close()
embedding, err := svc.Embed("Hello, world!")
require.NoError(t, err)
assert.Len(t, embedding, EmbeddingDim)
// Verify non-zero embedding
var sum float32
for _, v := range embedding {
sum += v * v
}
assert.Greater(t, sum, float32(0), "Embedding should not be all zeros")
}
// TestEmbed_EmptyText tests embedding an empty string.
func TestEmbed_EmptyText(t *testing.T) {
svc, err := NewService()
require.NoError(t, err)
defer svc.Close()
embedding, err := svc.Embed("")
require.NoError(t, err)
assert.Len(t, embedding, EmbeddingDim)
// Empty text should return zero vector
for _, v := range embedding {
assert.Equal(t, float32(0), v)
}
}
// TestEmbed_SimilarTexts tests that similar texts produce similar embeddings.
func TestEmbed_SimilarTexts(t *testing.T) {
svc, err := NewService()
require.NoError(t, err)
defer svc.Close()
emb1, err := svc.Embed("The quick brown fox jumps over the lazy dog.")
require.NoError(t, err)
emb2, err := svc.Embed("A fast brown fox leaps over a sleepy dog.")
require.NoError(t, err)
emb3, err := svc.Embed("Go programming language concurrency patterns.")
require.NoError(t, err)
// Calculate cosine similarity
sim12 := cosineSimilarity(emb1, emb2)
sim13 := cosineSimilarity(emb1, emb3)
// Similar texts should have higher similarity
assert.Greater(t, sim12, sim13, "Similar sentences should have higher similarity than dissimilar ones")
assert.Greater(t, sim12, float64(0.7), "Similar sentences should have high similarity")
}
// TestEmbedBatch_MultipleTexts tests batch embedding.
func TestEmbedBatch_MultipleTexts(t *testing.T) {
svc, err := NewService()
require.NoError(t, err)
defer svc.Close()
texts := []string{
"First text about programming.",
"Second text about databases.",
"Third text about machine learning.",
}
embeddings, err := svc.EmbedBatch(texts)
require.NoError(t, err)
assert.Len(t, embeddings, len(texts))
for i, emb := range embeddings {
assert.Len(t, emb, EmbeddingDim, "Embedding %d should have correct dimension", i)
}
}
// TestEmbedBatch_EmptySlice tests batch embedding with empty slice.
func TestEmbedBatch_EmptySlice(t *testing.T) {
svc, err := NewService()
require.NoError(t, err)
defer svc.Close()
embeddings, err := svc.EmbedBatch([]string{})
require.NoError(t, err)
assert.Nil(t, embeddings)
}
// TestEmbedBatch_WithEmptyTexts tests batch embedding with some empty texts.
func TestEmbedBatch_WithEmptyTexts(t *testing.T) {
svc, err := NewService()
require.NoError(t, err)
defer svc.Close()
texts := []string{
"Valid text one.",
"",
"Valid text two.",
"",
}
embeddings, err := svc.EmbedBatch(texts)
require.NoError(t, err)
assert.Len(t, embeddings, 4)
// Non-empty texts should have non-zero embeddings
var sum0 float32
for _, v := range embeddings[0] {
sum0 += v * v
}
assert.Greater(t, sum0, float32(0))
// Empty texts should have zero embeddings
for _, v := range embeddings[1] {
assert.Equal(t, float32(0), v)
}
var sum2 float32
for _, v := range embeddings[2] {
sum2 += v * v
}
assert.Greater(t, sum2, float32(0))
for _, v := range embeddings[3] {
assert.Equal(t, float32(0), v)
}
}
// TestEmbedBatch_AllEmpty tests batch embedding with all empty texts.
func TestEmbedBatch_AllEmpty(t *testing.T) {
svc, err := NewService()
require.NoError(t, err)
defer svc.Close()
texts := []string{"", "", ""}
embeddings, err := svc.EmbedBatch(texts)
require.NoError(t, err)
assert.Len(t, embeddings, 3)
for i, emb := range embeddings {
assert.Len(t, emb, EmbeddingDim, "Embedding %d should have correct dimension", i)
for j, v := range emb {
assert.Equal(t, float32(0), v, "Embedding %d[%d] should be zero", i, j)
}
}
}
// TestEmbed_Concurrent tests concurrent embedding calls.
func TestEmbed_Concurrent(t *testing.T) {
svc, err := NewService()
require.NoError(t, err)
defer svc.Close()
var wg sync.WaitGroup
numGoroutines := 10
errors := make(chan error, numGoroutines)
embeddings := make([][]float32, numGoroutines)
for i := 0; i < numGoroutines; i++ {
wg.Add(1)
go func(idx int) {
defer wg.Done()
text := "Test text for concurrent embedding test"
emb, err := svc.Embed(text)
if err != nil {
errors <- err
return
}
embeddings[idx] = emb
}(i)
}
wg.Wait()
close(errors)
for err := range errors {
t.Errorf("Concurrent embedding error: %v", err)
}
// All embeddings should be valid
for i, emb := range embeddings {
if emb != nil {
assert.Len(t, emb, EmbeddingDim, "Embedding %d should have correct dimension", i)
}
}
}
// TestEmbed_SpecialCharacters tests embedding text with special characters.
func TestEmbed_SpecialCharacters(t *testing.T) {
svc, err := NewService()
require.NoError(t, err)
defer svc.Close()
texts := []string{
"Text with unicode: 你好世界 🎉",
"Text with newlines:\nLine 1\nLine 2",
"Text with tabs:\tColumn1\tColumn2",
"Text with quotes: \"quoted\" and 'single'",
"Text with code: func main() { fmt.Println(\"hello\") }",
}
for _, text := range texts {
t.Run(text[:20], func(t *testing.T) {
emb, err := svc.Embed(text)
require.NoError(t, err)
assert.Len(t, emb, EmbeddingDim)
})
}
}
// TestEmbed_LongText tests embedding long text.
func TestEmbed_LongText(t *testing.T) {
svc, err := NewService()
require.NoError(t, err)
defer svc.Close()
// Create a long text (tokenizer should truncate appropriately)
longText := ""
for i := 0; i < 100; i++ {
longText += "This is a sentence to make the text very long. "
}
emb, err := svc.Embed(longText)
require.NoError(t, err)
assert.Len(t, emb, EmbeddingDim)
}
// TestClose tests closing the service.
func TestClose(t *testing.T) {
svc, err := NewService()
require.NoError(t, err)
err = svc.Close()
require.NoError(t, err)
// After close, embedding should fail (model resources released)
// Note: This behavior is model-specific; some models may still work after close
}
// TestEmbedBatch_SingleItem tests batch embedding with single item.
func TestEmbedBatch_SingleItem(t *testing.T) {
svc, err := NewService()
require.NoError(t, err)
defer svc.Close()
texts := []string{"Single text for batch embedding."}
embeddings, err := svc.EmbedBatch(texts)
require.NoError(t, err)
assert.Len(t, embeddings, 1)
assert.Len(t, embeddings[0], EmbeddingDim)
}
// TestEmbed_Deterministic tests that embedding is deterministic.
func TestEmbed_Deterministic(t *testing.T) {
svc, err := NewService()
require.NoError(t, err)
defer svc.Close()
text := "Test text for deterministic embedding."
emb1, err := svc.Embed(text)
require.NoError(t, err)
emb2, err := svc.Embed(text)
require.NoError(t, err)
// Same text should produce same embedding
for i := 0; i < EmbeddingDim; i++ {
assert.Equal(t, emb1[i], emb2[i], "Embedding should be deterministic at index %d", i)
}
}
// --- Regression tests: context-aware mutex (Fix #45) ---
// asBGE casts the service model to *bgeModel for direct mutex access.
// Tests are in the same package so this is safe.
func asBGE(t *testing.T, svc *Service) *bgeModel {
t.Helper()
m, ok := svc.model.(*bgeModel)
require.True(t, ok, "model is not *bgeModel — test invariant broken")
return m
}
// holdMutex locks m.mu in a background goroutine and returns a release func.
// The returned ready channel is closed once the lock is held.
func holdMutex(m *bgeModel) (ready <-chan struct{}, release func()) {
ch := make(chan struct{})
done := make(chan struct{})
go func() {
m.mu.Lock()
close(ch) // signal: lock acquired
<-done // wait for release signal
m.mu.Unlock()
}()
return ch, func() { close(done) }
}
// TestEmbedWithContext_CancelWhileWaitingForMutex is the core regression test.
// If the mutex is held and the context times out, EmbedWithContext must return
// immediately with a context error — not block until the mutex is released.
func TestEmbedWithContext_CancelWhileWaitingForMutex(t *testing.T) {
svc, err := NewService()
require.NoError(t, err)
defer svc.Close()
m := asBGE(t, svc)
// Hold the mutex to simulate a stuck ONNX call.
ready, release := holdMutex(m)
<-ready // ensure lock is held before proceeding
defer release()
ctx, cancel := context.WithTimeout(context.Background(), 100*time.Millisecond)
defer cancel()
start := time.Now()
_, err = svc.EmbedWithContext(ctx, "test text")
elapsed := time.Since(start)
// Must return a context error.
require.Error(t, err)
assert.True(t,
errors.Is(err, context.DeadlineExceeded) || errors.Is(err, context.Canceled),
"expected context error, got: %v", err)
// Must return quickly (well under the 30 s default; allow 2× the timeout for CI slack).
assert.Less(t, elapsed, 200*time.Millisecond,
"EmbedWithContext blocked too long: %v", elapsed)
}
// TestEmbedWithContext_SuccessWhenUncontended verifies normal operation still works.
func TestEmbedWithContext_SuccessWhenUncontended(t *testing.T) {
svc, err := NewService()
require.NoError(t, err)
defer svc.Close()
emb, err := svc.EmbedWithContext(context.Background(), "hello world")
require.NoError(t, err)
assert.Len(t, emb, EmbeddingDim)
var sum float32
for _, v := range emb {
sum += v * v
}
assert.Greater(t, sum, float32(0), "embedding should not be all zeros")
}
// TestEmbedBatchWithContext_CancelDuringBatch verifies batch embedding respects
// context cancellation while blocked on mutex acquisition.
func TestEmbedBatchWithContext_CancelDuringBatch(t *testing.T) {
svc, err := NewService()
require.NoError(t, err)
defer svc.Close()
m := asBGE(t, svc)
ready, release := holdMutex(m)
<-ready
defer release()
ctx, cancel := context.WithTimeout(context.Background(), 100*time.Millisecond)
defer cancel()
start := time.Now()
_, err = svc.EmbedBatchWithContext(ctx, []string{"a", "b", "c"})
elapsed := time.Since(start)
require.Error(t, err)
assert.True(t,
errors.Is(err, context.DeadlineExceeded) || errors.Is(err, context.Canceled),
"expected context error, got: %v", err)
assert.Less(t, elapsed, 200*time.Millisecond,
"EmbedBatchWithContext blocked too long: %v", elapsed)
}
// TestEmbedWithContext_CleanupAfterCancel verifies the cleanup goroutine in
// acquireMutex properly unlocks the mutex after context cancellation,
// so subsequent calls do not deadlock.
func TestEmbedWithContext_CleanupAfterCancel(t *testing.T) {
svc, err := NewService()
require.NoError(t, err)
defer svc.Close()
m := asBGE(t, svc)
// --- first call: context expires while mutex is held ---
ready, release := holdMutex(m)
<-ready
ctx, cancel := context.WithTimeout(context.Background(), 10*time.Millisecond)
defer cancel()
_, firstErr := svc.EmbedWithContext(ctx, "should fail")
require.Error(t, firstErr)
assert.True(t,
errors.Is(firstErr, context.DeadlineExceeded) || errors.Is(firstErr, context.Canceled),
"expected context error on first call, got: %v", firstErr)
// Release the held mutex so the cleanup goroutine inside acquireMutex can finish.
release()
// Give the cleanup goroutine a moment to acquire-and-release the mutex.
// 50 ms is generous; the goroutine only has to lock+unlock with no contention.
time.Sleep(50 * time.Millisecond)
// --- second call: mutex should be free, no deadlock ---
emb, secondErr := svc.EmbedWithContext(context.Background(), "should work")
require.NoError(t, secondErr, "second call should succeed after cleanup goroutine released mutex")
assert.Len(t, emb, EmbeddingDim)
}
// Helper function to calculate cosine similarity
func cosineSimilarity(a, b []float32) float64 {
if len(a) != len(b) {
return 0
}
var dotProduct float64
var normA float64
var normB float64
for i := range a {
dotProduct += float64(a[i]) * float64(b[i])
normA += float64(a[i]) * float64(a[i])
normB += float64(b[i]) * float64(b[i])
}
if normA == 0 || normB == 0 {
return 0
}
return dotProduct / (math.Sqrt(normA) * math.Sqrt(normB))
}