mirror of
https://github.com/lukaszraczylo/claude-mnemonic.git
synced 2026-06-05 23:03:55 +00:00
1a4fea5c17
Remove ~170MB of model files from the repository (LFS + committed). Models are now downloaded at runtime from Hugging Face on first use and cached to the OS cache directory with progress reporting and retries. - Add internal/models/download.go: runtime downloader with retry, progress bar, checksums - Remove go:embed for ONNX models (keep tokenizers embedded) - Use file-based ONNX session loading instead of byte-slice - Add scripts/download-models.sh for dev/CI model setup - Update Makefile with setup-models target - Update workflow-prepare.sh to download models in CI - Set lfs: false in all CI workflows - SHA256: bge=828e14..., cross-encoder=5d3e70...
640 lines
18 KiB
Go
640 lines
18 KiB
Go
// Package embedding provides text embedding generation with swappable models.
|
|
package embedding
|
|
|
|
import (
|
|
"bytes"
|
|
"context"
|
|
"crypto/sha256"
|
|
"encoding/hex"
|
|
"fmt"
|
|
"os"
|
|
"path/filepath"
|
|
"strings"
|
|
"sync"
|
|
|
|
"github.com/lukaszraczylo/claude-mnemonic/internal/models"
|
|
"github.com/sugarme/tokenizer"
|
|
"github.com/sugarme/tokenizer/pretrained"
|
|
ort "github.com/yalue/onnxruntime_go"
|
|
)
|
|
|
|
// EmbeddingDim is the dimension of embeddings produced by the current model.
|
|
// Both all-MiniLM-L6-v2 and bge-small-en-v1.5 produce 384-dimensional embeddings.
|
|
const EmbeddingDim = 384
|
|
|
|
// Model version constants
|
|
const (
|
|
// BGEModelVersion is the version string for bge-small-en-v1.5
|
|
BGEModelVersion = "bge-v1.5"
|
|
// BGEModelName is the human-readable name for bge-small-en-v1.5
|
|
BGEModelName = "bge-small-en-v1.5"
|
|
// DefaultModelVersion is the default model to use
|
|
DefaultModelVersion = BGEModelVersion
|
|
|
|
// BGEAssetName is the local filename for the BGE model.
|
|
BGEAssetName = "bge-small-en-v1.5-model.onnx"
|
|
// BGEAssetURL is the Hugging Face URL for the BGE ONNX model.
|
|
BGEAssetURL = "https://huggingface.co/BAAI/bge-small-en-v1.5/resolve/main/onnx/model.onnx"
|
|
// BGEModelSHA256 is the expected SHA-256 of the BGE model file.
|
|
BGEModelSHA256 = "828e1496d7fabb79cfa4dcd84fa38625c0d3d21da474a00f08db0f559940cf35"
|
|
)
|
|
|
|
// MaxSequenceLength is the maximum token sequence length for the model.
|
|
const MaxSequenceLength = 512
|
|
|
|
// bgeONNXConfig defines the ONNX configuration for BGE models.
|
|
// BGE outputs last_hidden_state and requires mean pooling.
|
|
var bgeONNXConfig = ONNXConfig{
|
|
InputNames: []string{"input_ids", "attention_mask", "token_type_ids"},
|
|
OutputNames: []string{"last_hidden_state"},
|
|
Pooling: PoolingMean,
|
|
HiddenSize: EmbeddingDim,
|
|
}
|
|
|
|
// bgeModel is the ONNX-based embedding model implementation.
|
|
// Currently supports bge-small-en-v1.5 (previously all-MiniLM-L6-v2).
|
|
type bgeModel struct {
|
|
tk *tokenizer.Tokenizer
|
|
session *ort.DynamicAdvancedSession
|
|
libDir string
|
|
config ONNXConfig
|
|
mu sync.Mutex
|
|
}
|
|
|
|
// Compile-time check that bgeModel implements EmbeddingModel
|
|
var _ EmbeddingModel = (*bgeModel)(nil)
|
|
|
|
// Compile-time check that bgeModel implements ONNXConfigurer
|
|
var _ ONNXConfigurer = (*bgeModel)(nil)
|
|
|
|
// newBGEModel creates a new BGE embedding model using bundled ONNX runtime and model.
|
|
// The ONNX model is downloaded at runtime from GitHub Releases and cached to disk.
|
|
func newBGEModel() (EmbeddingModel, error) {
|
|
// Extract ONNX runtime library to temp directory
|
|
libDir, err := extractONNXLibrary()
|
|
if err != nil {
|
|
return nil, fmt.Errorf("extract ONNX library: %w", err)
|
|
}
|
|
|
|
// Set the library path
|
|
libPath := filepath.Join(libDir, onnxRuntimeLibName)
|
|
ort.SetSharedLibraryPath(libPath)
|
|
|
|
// Initialize ONNX runtime (idempotent - ignore "already initialized" since
|
|
// the ONNX environment is process-global and shared with the reranking service)
|
|
err = ort.InitializeEnvironment()
|
|
if err != nil && !strings.Contains(err.Error(), "already been initialized") {
|
|
return nil, fmt.Errorf("initialize ONNX runtime: %w", err)
|
|
}
|
|
|
|
// Load tokenizer from embedded data
|
|
tk, err := pretrained.FromReader(bytes.NewReader(tokenizerData))
|
|
if err != nil {
|
|
return nil, fmt.Errorf("load tokenizer: %w", err)
|
|
}
|
|
|
|
// Download model on first use (cached to ~/.claude-mnemonic/models/)
|
|
modelPath, err := models.EnsureModel(BGEAssetName, BGEAssetURL, BGEModelSHA256)
|
|
if err != nil {
|
|
return nil, fmt.Errorf("ensure embedding model: %w", err)
|
|
}
|
|
|
|
// Create ONNX session from file path using model-specific configuration
|
|
config := bgeONNXConfig
|
|
session, err := ort.NewDynamicAdvancedSession(modelPath, config.InputNames, config.OutputNames, nil)
|
|
if err != nil {
|
|
return nil, fmt.Errorf("create ONNX session: %w", err)
|
|
}
|
|
|
|
return &bgeModel{
|
|
tk: tk,
|
|
session: session,
|
|
libDir: libDir,
|
|
config: config,
|
|
}, nil
|
|
}
|
|
|
|
// ONNXConfig returns the model's ONNX configuration.
|
|
func (m *bgeModel) ONNXConfig() ONNXConfig {
|
|
return m.config
|
|
}
|
|
|
|
// Name returns the human-readable model name.
|
|
func (m *bgeModel) Name() string {
|
|
return BGEModelName
|
|
}
|
|
|
|
// Version returns the short version string for storage.
|
|
func (m *bgeModel) Version() string {
|
|
return BGEModelVersion
|
|
}
|
|
|
|
// Dimensions returns the embedding vector size.
|
|
func (m *bgeModel) Dimensions() int {
|
|
return EmbeddingDim
|
|
}
|
|
|
|
// extractONNXLibrary extracts the embedded ONNX runtime library to a temp directory.
|
|
// Uses content hash to avoid re-extracting if already present.
|
|
func extractONNXLibrary() (string, error) {
|
|
// Create a hash of the library content for cache key
|
|
hash := sha256.Sum256(onnxRuntimeLib)
|
|
hashStr := hex.EncodeToString(hash[:8]) // Use first 8 bytes
|
|
|
|
// Create cache directory
|
|
cacheDir := filepath.Join(os.TempDir(), "claude-mnemonic-onnx", hashStr)
|
|
libPath := filepath.Join(cacheDir, onnxRuntimeLibName)
|
|
|
|
// Check if already extracted
|
|
if _, err := os.Stat(libPath); err == nil {
|
|
return cacheDir, nil
|
|
}
|
|
|
|
// Create directory
|
|
// #nosec G301 -- Cache directory needs 0755 for user access
|
|
if err := os.MkdirAll(cacheDir, 0755); err != nil {
|
|
return "", fmt.Errorf("create cache dir: %w", err)
|
|
}
|
|
|
|
// Write main library
|
|
// #nosec G306 -- Shared library needs executable permission (0755) for dynamic linker
|
|
if err := os.WriteFile(libPath, onnxRuntimeLib, 0755); err != nil {
|
|
return "", fmt.Errorf("write library: %w", err)
|
|
}
|
|
|
|
// Write providers library if present (Linux only)
|
|
if len(onnxRuntimeProvidersLib) > 0 && onnxRuntimeProvidersLibName != "" {
|
|
providersPath := filepath.Join(cacheDir, onnxRuntimeProvidersLibName)
|
|
// #nosec G306 -- Shared library needs executable permission (0755) for dynamic linker
|
|
if err := os.WriteFile(providersPath, onnxRuntimeProvidersLib, 0755); err != nil {
|
|
return "", fmt.Errorf("write providers library: %w", err)
|
|
}
|
|
}
|
|
|
|
return cacheDir, nil
|
|
}
|
|
|
|
// Embed generates an embedding for a single text.
|
|
// Returns a 384-dimensional float32 vector.
|
|
func (m *bgeModel) Embed(text string) ([]float32, error) {
|
|
m.mu.Lock()
|
|
defer m.mu.Unlock()
|
|
|
|
if text == "" {
|
|
return make([]float32, EmbeddingDim), nil
|
|
}
|
|
|
|
results, err := m.computeBatch([]string{text})
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
if len(results) == 0 {
|
|
return make([]float32, EmbeddingDim), nil
|
|
}
|
|
return results[0], nil
|
|
}
|
|
|
|
// EmbedBatch generates embeddings for multiple texts.
|
|
// Returns slice of 384-dimensional float32 vectors.
|
|
func (m *bgeModel) EmbedBatch(texts []string) ([][]float32, error) {
|
|
if len(texts) == 0 {
|
|
return nil, nil
|
|
}
|
|
|
|
m.mu.Lock()
|
|
defer m.mu.Unlock()
|
|
|
|
// Filter out empty texts and track indices
|
|
nonEmpty := make([]string, 0, len(texts))
|
|
indices := make([]int, 0, len(texts))
|
|
for i, t := range texts {
|
|
if t != "" {
|
|
nonEmpty = append(nonEmpty, t)
|
|
indices = append(indices, i)
|
|
}
|
|
}
|
|
|
|
// If all texts are empty, return zero vectors
|
|
if len(nonEmpty) == 0 {
|
|
results := make([][]float32, len(texts))
|
|
for i := range results {
|
|
results[i] = make([]float32, EmbeddingDim)
|
|
}
|
|
return results, nil
|
|
}
|
|
|
|
// Compute embeddings for non-empty texts
|
|
embeddings, err := m.computeBatch(nonEmpty)
|
|
if err != nil {
|
|
return nil, fmt.Errorf("compute batch embeddings: %w", err)
|
|
}
|
|
|
|
// Build result with zero vectors for empty texts
|
|
results := make([][]float32, len(texts))
|
|
for i := range results {
|
|
results[i] = make([]float32, EmbeddingDim)
|
|
}
|
|
for i, idx := range indices {
|
|
results[idx] = embeddings[i]
|
|
}
|
|
|
|
return results, nil
|
|
}
|
|
|
|
// acquireMutex attempts to acquire the model mutex, respecting context cancellation.
|
|
// On success the caller MUST call the returned unlock function.
|
|
// If ctx is cancelled while waiting, returns ctx.Err() and no unlock is needed.
|
|
func (m *bgeModel) acquireMutex(ctx context.Context) (unlock func(), err error) {
|
|
acquired := make(chan struct{})
|
|
go func() {
|
|
m.mu.Lock()
|
|
close(acquired)
|
|
}()
|
|
|
|
select {
|
|
case <-acquired:
|
|
// Got the lock normally.
|
|
return m.mu.Unlock, nil
|
|
case <-ctx.Done():
|
|
// Context cancelled while waiting. The goroutine above will eventually
|
|
// acquire the mutex — we must ensure it gets unlocked.
|
|
go func() {
|
|
<-acquired
|
|
m.mu.Unlock()
|
|
}()
|
|
return nil, ctx.Err()
|
|
}
|
|
}
|
|
|
|
// EmbedWithContext generates an embedding for a single text with context-aware mutex acquisition.
|
|
// If ctx is cancelled while waiting for the model lock, returns immediately with ctx.Err().
|
|
func (m *bgeModel) EmbedWithContext(ctx context.Context, text string) ([]float32, error) {
|
|
if text == "" {
|
|
return make([]float32, EmbeddingDim), nil
|
|
}
|
|
|
|
unlock, err := m.acquireMutex(ctx)
|
|
if err != nil {
|
|
return nil, fmt.Errorf("acquire embedding lock: %w", err)
|
|
}
|
|
defer unlock()
|
|
|
|
results, err := m.computeBatch([]string{text})
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
if len(results) == 0 {
|
|
return make([]float32, EmbeddingDim), nil
|
|
}
|
|
return results[0], nil
|
|
}
|
|
|
|
// EmbedBatchWithContext generates embeddings for multiple texts with context-aware mutex acquisition.
|
|
func (m *bgeModel) EmbedBatchWithContext(ctx context.Context, texts []string) ([][]float32, error) {
|
|
if len(texts) == 0 {
|
|
return nil, nil
|
|
}
|
|
|
|
// Filter out empty texts and track indices
|
|
nonEmpty := make([]string, 0, len(texts))
|
|
indices := make([]int, 0, len(texts))
|
|
for i, t := range texts {
|
|
if t != "" {
|
|
nonEmpty = append(nonEmpty, t)
|
|
indices = append(indices, i)
|
|
}
|
|
}
|
|
|
|
// If all texts are empty, return zero vectors
|
|
if len(nonEmpty) == 0 {
|
|
results := make([][]float32, len(texts))
|
|
for i := range results {
|
|
results[i] = make([]float32, EmbeddingDim)
|
|
}
|
|
return results, nil
|
|
}
|
|
|
|
unlock, err := m.acquireMutex(ctx)
|
|
if err != nil {
|
|
return nil, fmt.Errorf("acquire embedding lock: %w", err)
|
|
}
|
|
defer unlock()
|
|
|
|
// Compute embeddings for non-empty texts
|
|
embeddings, err := m.computeBatch(nonEmpty)
|
|
if err != nil {
|
|
return nil, fmt.Errorf("compute batch embeddings: %w", err)
|
|
}
|
|
|
|
// Build result with zero vectors for empty texts
|
|
results := make([][]float32, len(texts))
|
|
for i := range results {
|
|
results[i] = make([]float32, EmbeddingDim)
|
|
}
|
|
for i, idx := range indices {
|
|
results[idx] = embeddings[i]
|
|
}
|
|
|
|
return results, nil
|
|
}
|
|
|
|
// computeBatch runs inference on a batch of texts. Must be called with lock held.
|
|
func (m *bgeModel) computeBatch(sentences []string) ([][]float32, error) {
|
|
if len(sentences) == 0 {
|
|
return nil, nil
|
|
}
|
|
|
|
// Tokenize all sentences
|
|
inputBatch := make([]tokenizer.EncodeInput, len(sentences))
|
|
for i, sent := range sentences {
|
|
inputBatch[i] = tokenizer.NewSingleEncodeInput(tokenizer.NewRawInputSequence(sent))
|
|
}
|
|
|
|
encodings, err := m.tk.EncodeBatch(inputBatch, true)
|
|
if err != nil {
|
|
return nil, fmt.Errorf("tokenize: %w", err)
|
|
}
|
|
|
|
batchSize := len(encodings)
|
|
hiddenSize := m.config.HiddenSize
|
|
|
|
// Find max sequence length across all encodings (tokenizer may not pad uniformly)
|
|
// Also enforce MaxSequenceLength to prevent model errors
|
|
seqLength := 0
|
|
for _, enc := range encodings {
|
|
if len(enc.Ids) > seqLength {
|
|
seqLength = len(enc.Ids)
|
|
}
|
|
}
|
|
// Truncate to max model sequence length
|
|
if seqLength > MaxSequenceLength {
|
|
seqLength = MaxSequenceLength
|
|
}
|
|
|
|
inputShape := ort.NewShape(int64(batchSize), int64(seqLength))
|
|
|
|
// Create input tensors (pre-filled with zeros for padding)
|
|
inputIdsData := make([]int64, batchSize*seqLength)
|
|
attentionMaskData := make([]int64, batchSize*seqLength)
|
|
tokenTypeIdsData := make([]int64, batchSize*seqLength)
|
|
|
|
for b := 0; b < batchSize; b++ {
|
|
// Copy actual token data (rest remains 0 as padding)
|
|
// Truncate to seqLength to handle long inputs
|
|
copyLen := len(encodings[b].Ids)
|
|
if copyLen > seqLength {
|
|
copyLen = seqLength
|
|
}
|
|
for i := 0; i < copyLen; i++ {
|
|
inputIdsData[b*seqLength+i] = int64(encodings[b].Ids[i])
|
|
}
|
|
copyLen = len(encodings[b].AttentionMask)
|
|
if copyLen > seqLength {
|
|
copyLen = seqLength
|
|
}
|
|
for i := 0; i < copyLen; i++ {
|
|
attentionMaskData[b*seqLength+i] = int64(encodings[b].AttentionMask[i])
|
|
}
|
|
copyLen = len(encodings[b].TypeIds)
|
|
if copyLen > seqLength {
|
|
copyLen = seqLength
|
|
}
|
|
for i := 0; i < copyLen; i++ {
|
|
tokenTypeIdsData[b*seqLength+i] = int64(encodings[b].TypeIds[i])
|
|
}
|
|
}
|
|
|
|
inputIdsTensor, err := ort.NewTensor(inputShape, inputIdsData)
|
|
if err != nil {
|
|
return nil, fmt.Errorf("create input_ids tensor: %w", err)
|
|
}
|
|
defer func() { _ = inputIdsTensor.Destroy() }()
|
|
|
|
attentionMaskTensor, err := ort.NewTensor(inputShape, attentionMaskData)
|
|
if err != nil {
|
|
return nil, fmt.Errorf("create attention_mask tensor: %w", err)
|
|
}
|
|
defer func() { _ = attentionMaskTensor.Destroy() }()
|
|
|
|
tokenTypeIdsTensor, err := ort.NewTensor(inputShape, tokenTypeIdsData)
|
|
if err != nil {
|
|
return nil, fmt.Errorf("create token_type_ids tensor: %w", err)
|
|
}
|
|
defer func() { _ = tokenTypeIdsTensor.Destroy() }()
|
|
|
|
// Create output tensor based on pooling strategy
|
|
var outputShape ort.Shape
|
|
|
|
switch m.config.Pooling {
|
|
case PoolingNone:
|
|
// Direct sentence embedding output: [batch, hidden]
|
|
outputShape = ort.NewShape(int64(batchSize), int64(hiddenSize))
|
|
case PoolingMean, PoolingCLS:
|
|
// Token-level output: [batch, seq_len, hidden]
|
|
outputShape = ort.NewShape(int64(batchSize), int64(seqLength), int64(hiddenSize))
|
|
default:
|
|
outputShape = ort.NewShape(int64(batchSize), int64(hiddenSize))
|
|
}
|
|
|
|
outputTensor, err := ort.NewEmptyTensor[float32](outputShape)
|
|
if err != nil {
|
|
return nil, fmt.Errorf("create output tensor: %w", err)
|
|
}
|
|
defer func() { _ = outputTensor.Destroy() }()
|
|
|
|
// Run inference
|
|
inputTensors := []ort.Value{inputIdsTensor, attentionMaskTensor, tokenTypeIdsTensor}
|
|
outputTensors := []ort.Value{outputTensor}
|
|
|
|
if err := m.session.Run(inputTensors, outputTensors); err != nil {
|
|
return nil, fmt.Errorf("run inference: %w", err)
|
|
}
|
|
|
|
// Extract and pool results based on strategy
|
|
flatOutput := outputTensor.GetData()
|
|
|
|
switch m.config.Pooling {
|
|
case PoolingNone:
|
|
// Direct output, no pooling needed
|
|
expectedSize := batchSize * hiddenSize
|
|
if len(flatOutput) != expectedSize {
|
|
return nil, fmt.Errorf("unexpected output size: got %d, expected %d", len(flatOutput), expectedSize)
|
|
}
|
|
results := make([][]float32, batchSize)
|
|
for i := 0; i < batchSize; i++ {
|
|
start := i * hiddenSize
|
|
end := start + hiddenSize
|
|
results[i] = make([]float32, hiddenSize)
|
|
copy(results[i], flatOutput[start:end])
|
|
}
|
|
return results, nil
|
|
|
|
case PoolingMean:
|
|
// Mean pooling over tokens (weighted by attention mask)
|
|
return meanPooling(flatOutput, attentionMaskData, batchSize, seqLength, hiddenSize), nil
|
|
|
|
case PoolingCLS:
|
|
// CLS token pooling (first token of each sequence)
|
|
return clsPooling(flatOutput, batchSize, seqLength, hiddenSize), nil
|
|
|
|
default:
|
|
return nil, fmt.Errorf("unknown pooling strategy: %s", m.config.Pooling)
|
|
}
|
|
}
|
|
|
|
// meanPooling applies mean pooling over token embeddings, weighted by attention mask.
|
|
// Input shape: [batch, seq_len, hidden], attention mask: [batch, seq_len]
|
|
// Output shape: [batch, hidden]
|
|
func meanPooling(embeddings []float32, attentionMask []int64, batchSize, seqLen, hiddenSize int) [][]float32 {
|
|
results := make([][]float32, batchSize)
|
|
|
|
for b := 0; b < batchSize; b++ {
|
|
result := make([]float32, hiddenSize)
|
|
var maskSum float32
|
|
|
|
// Sum embeddings weighted by attention mask
|
|
for s := 0; s < seqLen; s++ {
|
|
maskVal := float32(attentionMask[b*seqLen+s])
|
|
maskSum += maskVal
|
|
|
|
if maskVal > 0 {
|
|
embOffset := (b*seqLen + s) * hiddenSize
|
|
for h := 0; h < hiddenSize; h++ {
|
|
result[h] += embeddings[embOffset+h] * maskVal
|
|
}
|
|
}
|
|
}
|
|
|
|
// Normalize by mask sum (avoid division by zero)
|
|
if maskSum > 0 {
|
|
for h := 0; h < hiddenSize; h++ {
|
|
result[h] /= maskSum
|
|
}
|
|
}
|
|
|
|
results[b] = result
|
|
}
|
|
|
|
return results
|
|
}
|
|
|
|
// clsPooling extracts the [CLS] token embedding (first token).
|
|
// Input shape: [batch, seq_len, hidden]
|
|
// Output shape: [batch, hidden]
|
|
func clsPooling(embeddings []float32, batchSize, seqLen, hiddenSize int) [][]float32 {
|
|
results := make([][]float32, batchSize)
|
|
|
|
for b := 0; b < batchSize; b++ {
|
|
result := make([]float32, hiddenSize)
|
|
// CLS token is at position 0
|
|
embOffset := b * seqLen * hiddenSize
|
|
copy(result, embeddings[embOffset:embOffset+hiddenSize])
|
|
results[b] = result
|
|
}
|
|
|
|
return results
|
|
}
|
|
|
|
// Close releases model resources.
|
|
func (m *bgeModel) Close() error {
|
|
m.mu.Lock()
|
|
defer m.mu.Unlock()
|
|
|
|
var errs []error
|
|
|
|
if m.session != nil {
|
|
if err := m.session.Destroy(); err != nil {
|
|
errs = append(errs, fmt.Errorf("destroy session: %w", err))
|
|
}
|
|
m.session = nil
|
|
}
|
|
|
|
if err := ort.DestroyEnvironment(); err != nil {
|
|
errs = append(errs, fmt.Errorf("destroy environment: %w", err))
|
|
}
|
|
|
|
// Optionally clean up extracted library (leave for caching)
|
|
// os.RemoveAll(m.libDir)
|
|
|
|
if len(errs) > 0 {
|
|
return errs[0]
|
|
}
|
|
return nil
|
|
}
|
|
|
|
// Register the BGE model with the default registry at init time
|
|
func init() {
|
|
RegisterModel(ModelMetadata{
|
|
Name: BGEModelName,
|
|
Version: BGEModelVersion,
|
|
Dimensions: EmbeddingDim,
|
|
Description: "High-quality semantic search model",
|
|
Default: true,
|
|
}, newBGEModel)
|
|
}
|
|
|
|
// Service provides thread-safe text embedding generation with model abstraction.
|
|
type Service struct {
|
|
model EmbeddingModel
|
|
}
|
|
|
|
// NewService creates a new embedding service using the default model.
|
|
func NewService() (*Service, error) {
|
|
return NewServiceWithModel(DefaultModelVersion)
|
|
}
|
|
|
|
// NewServiceWithModel creates a new embedding service using the specified model.
|
|
func NewServiceWithModel(version string) (*Service, error) {
|
|
if version == "" {
|
|
version = DefaultModelVersion
|
|
}
|
|
|
|
model, err := GetModel(version)
|
|
if err != nil {
|
|
return nil, fmt.Errorf("get model %s: %w", version, err)
|
|
}
|
|
|
|
return &Service{model: model}, nil
|
|
}
|
|
|
|
// Name returns the human-readable model name.
|
|
func (s *Service) Name() string {
|
|
return s.model.Name()
|
|
}
|
|
|
|
// Version returns the short version string for storage.
|
|
func (s *Service) Version() string {
|
|
return s.model.Version()
|
|
}
|
|
|
|
// Dimensions returns the embedding vector size.
|
|
func (s *Service) Dimensions() int {
|
|
return s.model.Dimensions()
|
|
}
|
|
|
|
// Embed generates an embedding for a single text.
|
|
func (s *Service) Embed(text string) ([]float32, error) {
|
|
return s.model.Embed(text)
|
|
}
|
|
|
|
// EmbedBatch generates embeddings for multiple texts.
|
|
func (s *Service) EmbedBatch(texts []string) ([][]float32, error) {
|
|
return s.model.EmbedBatch(texts)
|
|
}
|
|
|
|
// EmbedWithContext generates an embedding with context-aware cancellation.
|
|
// If ctx is cancelled while waiting for the model lock, returns immediately.
|
|
func (s *Service) EmbedWithContext(ctx context.Context, text string) ([]float32, error) {
|
|
return s.model.EmbedWithContext(ctx, text)
|
|
}
|
|
|
|
// EmbedBatchWithContext generates embeddings with context-aware cancellation.
|
|
func (s *Service) EmbedBatchWithContext(ctx context.Context, texts []string) ([][]float32, error) {
|
|
return s.model.EmbedBatchWithContext(ctx, texts)
|
|
}
|
|
|
|
// Close releases model resources.
|
|
func (s *Service) Close() error {
|
|
return s.model.Close()
|
|
}
|