mirror of
https://github.com/lukaszraczylo/claude-mnemonic.git
synced 2026-06-11 00:09:28 +00:00
refactor: replace Git LFS with runtime model download from Hugging Face
Remove ~170MB of model files from the repository (LFS + committed). Models are now downloaded at runtime from Hugging Face on first use and cached to the OS cache directory with progress reporting and retries. - Add internal/models/download.go: runtime downloader with retry, progress bar, checksums - Remove go:embed for ONNX models (keep tokenizers embedded) - Use file-based ONNX session loading instead of byte-slice - Add scripts/download-models.sh for dev/CI model setup - Update Makefile with setup-models target - Update workflow-prepare.sh to download models in CI - Set lfs: false in all CI workflows - SHA256: bge=828e14..., cross-encoder=5d3e70...
This commit is contained in:
+2
-1
@@ -1 +1,2 @@
|
||||
internal/embedding/assets/*.onnx filter=lfs diff=lfs merge=lfs -text
|
||||
# Models are now downloaded at runtime from GitHub Releases.
|
||||
# LFS tracking removed. See internal/models/download.go
|
||||
|
||||
@@ -17,5 +17,5 @@ jobs:
|
||||
with:
|
||||
go-version: ">=1.24"
|
||||
release-workflow: "release.yaml"
|
||||
lfs: true
|
||||
lfs: false
|
||||
secrets: inherit
|
||||
|
||||
@@ -20,6 +20,6 @@ jobs:
|
||||
uses: lukaszraczylo/shared-actions/.github/workflows/go-pr.yaml@main
|
||||
with:
|
||||
go-version: ">=1.24"
|
||||
lfs: true
|
||||
lfs: false
|
||||
build-tags: "fts5"
|
||||
secrets: inherit
|
||||
|
||||
@@ -25,7 +25,7 @@ jobs:
|
||||
node-cache-dependency-path: "ui/package-lock.json"
|
||||
node-output-path: "ui/dist"
|
||||
node-embed-path: "internal/worker/static"
|
||||
lfs: true
|
||||
lfs: false
|
||||
secrets: inherit
|
||||
|
||||
commit-marketplace:
|
||||
|
||||
@@ -81,6 +81,12 @@ logs/
|
||||
*.db-wal
|
||||
*.db-shm
|
||||
|
||||
# Model files (downloaded at runtime from GitHub Releases)
|
||||
# Test model data (downloaded from Hugging Face)
|
||||
testdata/models/
|
||||
internal/embedding/assets/model.onnx
|
||||
internal/reranking/assets/model.onnx
|
||||
|
||||
# goreleaser
|
||||
dist/
|
||||
docs/dist
|
||||
|
||||
@@ -23,6 +23,10 @@ all: build
|
||||
setup-libs:
|
||||
@./scripts/download-onnx-libs.sh all
|
||||
|
||||
# Download ONNX models from Hugging Face (for local dev, skips if present)
|
||||
setup-models:
|
||||
@./scripts/download-models.sh
|
||||
|
||||
# Update version in committed plugin metadata
|
||||
update-version:
|
||||
@if [ -f .claude-plugin/plugin.json ]; then \
|
||||
@@ -182,12 +186,14 @@ uninstall: stop-worker
|
||||
@echo "Uninstallation complete!"
|
||||
|
||||
# Run tests (with FTS5 support)
|
||||
test: setup-libs
|
||||
go test $(BUILD_TAGS) -v -race ./...
|
||||
# CLAUDE_MNEMONIC_MODEL_DIR points to a common model directory for tests.
|
||||
# Run setup-models first to download models from Hugging Face.
|
||||
test: setup-libs setup-models
|
||||
CLAUDE_MNEMONIC_MODEL_DIR=$$(pwd)/testdata/models go test $(BUILD_TAGS) -v -race ./...
|
||||
|
||||
# Run tests with coverage (with FTS5 support)
|
||||
test-coverage: setup-libs
|
||||
go test $(BUILD_TAGS) -v -race -coverprofile=coverage.out ./...
|
||||
test-coverage: setup-libs setup-models
|
||||
CLAUDE_MNEMONIC_MODEL_DIR=$$(pwd)/testdata/models go test $(BUILD_TAGS) -v -race -coverprofile=coverage.out ./...
|
||||
go tool cover -html=coverage.out -o coverage.html
|
||||
@go tool cover -func=coverage.out | tail -1
|
||||
|
||||
|
||||
@@ -1,14 +1,12 @@
|
||||
// Package embedding provides text embedding generation using all-MiniLM-L6-v2.
|
||||
// Package embedding provides text embedding generation using bge-small-en-v1.5.
|
||||
package embedding
|
||||
|
||||
import (
|
||||
_ "embed"
|
||||
)
|
||||
|
||||
// Model and tokenizer files - embedded for all platforms
|
||||
// Tokenizer file - embedded for all platforms (small, not in LFS).
|
||||
// The ONNX model is downloaded at runtime to ~/.claude-mnemonic/models/.
|
||||
//
|
||||
//go:embed assets/model.onnx
|
||||
var modelData []byte
|
||||
|
||||
//go:embed assets/tokenizer.json
|
||||
var tokenizerData []byte
|
||||
|
||||
Binary file not shown.
@@ -12,6 +12,7 @@ import (
|
||||
"strings"
|
||||
"sync"
|
||||
|
||||
"github.com/lukaszraczylo/claude-mnemonic/internal/models"
|
||||
"github.com/sugarme/tokenizer"
|
||||
"github.com/sugarme/tokenizer/pretrained"
|
||||
ort "github.com/yalue/onnxruntime_go"
|
||||
@@ -29,6 +30,13 @@ const (
|
||||
BGEModelName = "bge-small-en-v1.5"
|
||||
// DefaultModelVersion is the default model to use
|
||||
DefaultModelVersion = BGEModelVersion
|
||||
|
||||
// BGEAssetName is the local filename for the BGE model.
|
||||
BGEAssetName = "bge-small-en-v1.5-model.onnx"
|
||||
// BGEAssetURL is the Hugging Face URL for the BGE ONNX model.
|
||||
BGEAssetURL = "https://huggingface.co/BAAI/bge-small-en-v1.5/resolve/main/onnx/model.onnx"
|
||||
// BGEModelSHA256 is the expected SHA-256 of the BGE model file.
|
||||
BGEModelSHA256 = "828e1496d7fabb79cfa4dcd84fa38625c0d3d21da474a00f08db0f559940cf35"
|
||||
)
|
||||
|
||||
// MaxSequenceLength is the maximum token sequence length for the model.
|
||||
@@ -60,6 +68,7 @@ var _ EmbeddingModel = (*bgeModel)(nil)
|
||||
var _ ONNXConfigurer = (*bgeModel)(nil)
|
||||
|
||||
// newBGEModel creates a new BGE embedding model using bundled ONNX runtime and model.
|
||||
// The ONNX model is downloaded at runtime from GitHub Releases and cached to disk.
|
||||
func newBGEModel() (EmbeddingModel, error) {
|
||||
// Extract ONNX runtime library to temp directory
|
||||
libDir, err := extractONNXLibrary()
|
||||
@@ -84,9 +93,15 @@ func newBGEModel() (EmbeddingModel, error) {
|
||||
return nil, fmt.Errorf("load tokenizer: %w", err)
|
||||
}
|
||||
|
||||
// Create ONNX session using model-specific configuration
|
||||
// Download model on first use (cached to ~/.claude-mnemonic/models/)
|
||||
modelPath, err := models.EnsureModel(BGEAssetName, BGEAssetURL, BGEModelSHA256)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("ensure embedding model: %w", err)
|
||||
}
|
||||
|
||||
// Create ONNX session from file path using model-specific configuration
|
||||
config := bgeONNXConfig
|
||||
session, err := ort.NewDynamicAdvancedSessionWithONNXData(modelData, config.InputNames, config.OutputNames, nil)
|
||||
session, err := ort.NewDynamicAdvancedSession(modelPath, config.InputNames, config.OutputNames, nil)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("create ONNX session: %w", err)
|
||||
}
|
||||
|
||||
@@ -0,0 +1,259 @@
|
||||
// Package models provides runtime model download and caching.
|
||||
// Models are downloaded from Hugging Face on first use and cached
|
||||
// to the OS cache directory (~/Library/Caches/claude-mnemonic/models/ on macOS,
|
||||
// ~/.cache/claude-mnemonic/models/ on Linux, %LocalAppData%/claude-mnemonic/models/ on Windows).
|
||||
// This replaces the previous approach of embedding models in the binary
|
||||
// via go:embed and tracking them with Git LFS.
|
||||
package models
|
||||
|
||||
import (
|
||||
"crypto/sha256"
|
||||
"encoding/hex"
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
"math"
|
||||
"net/http"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"time"
|
||||
|
||||
"github.com/rs/zerolog/log"
|
||||
)
|
||||
|
||||
const (
|
||||
// ModelVersion is incremented to force re-download of all cached models.
|
||||
ModelVersion = "1"
|
||||
|
||||
maxRetries = 3
|
||||
retryBaseDelay = 2 * time.Second
|
||||
progressEvery = 3 * time.Second
|
||||
downloadTimeout = 15 * time.Minute
|
||||
)
|
||||
|
||||
// ModelsDir returns the path to the model cache directory.
|
||||
// Uses the OS cache dir for cross-platform compatibility:
|
||||
//
|
||||
// macOS: ~/Library/Caches/claude-mnemonic/models
|
||||
// Linux: ~/.cache/claude-mnemonic/models
|
||||
// Windows: %LocalAppData%/claude-mnemonic/models
|
||||
func ModelsDir() string {
|
||||
dir, err := os.UserCacheDir()
|
||||
if err != nil {
|
||||
home, _ := os.UserHomeDir()
|
||||
dir = filepath.Join(home, ".cache")
|
||||
}
|
||||
return filepath.Join(dir, "claude-mnemonic", "models")
|
||||
}
|
||||
|
||||
// EnsureModel ensures the model file exists and returns its path.
|
||||
//
|
||||
// Resolution order:
|
||||
// 1. CLAUDE_MNEMONIC_MODEL_DIR env var — if set, load models from this directory directly.
|
||||
// This is intended for development and CI where models are already on disk.
|
||||
// 2. User cache directory — if the model is already cached and its checksum matches,
|
||||
// return it immediately.
|
||||
// 3. url — download the model from the given URL with retries and cache it.
|
||||
//
|
||||
// assetName is the local filename (e.g. "bge-small-en-v1.5-model.onnx").
|
||||
// expectedSHA256 is the hex-encoded SHA-256 of the file.
|
||||
func EnsureModel(assetName, url, expectedSHA256 string) (string, error) {
|
||||
if localDir := os.Getenv("CLAUDE_MNEMONIC_MODEL_DIR"); localDir != "" {
|
||||
localPath := filepath.Join(localDir, assetName)
|
||||
if _, err := os.Stat(localPath); err == nil {
|
||||
if valid, err := verifyChecksum(localPath, expectedSHA256); err != nil || !valid {
|
||||
return "", fmt.Errorf("local model %s checksum mismatch", assetName)
|
||||
}
|
||||
return localPath, nil
|
||||
}
|
||||
return "", fmt.Errorf("local model %s not found in %s", assetName, localDir)
|
||||
}
|
||||
|
||||
dir := ModelsDir()
|
||||
versionFile := filepath.Join(dir, ".model_version")
|
||||
modelPath := filepath.Join(dir, assetName)
|
||||
|
||||
if err := os.MkdirAll(dir, 0700); err != nil {
|
||||
return "", fmt.Errorf("create model cache dir: %w", err)
|
||||
}
|
||||
|
||||
currentVersion, _ := os.ReadFile(versionFile)
|
||||
|
||||
if string(currentVersion) == ModelVersion {
|
||||
if fi, err := os.Stat(modelPath); err == nil && fi.Size() > 0 {
|
||||
if valid, _ := verifyChecksum(modelPath, expectedSHA256); valid {
|
||||
return modelPath, nil
|
||||
}
|
||||
log.Warn().Str("model", assetName).Msg("Cached model corrupted, re-downloading")
|
||||
os.Remove(modelPath)
|
||||
}
|
||||
} else if len(currentVersion) > 0 {
|
||||
log.Info().Str("old", string(currentVersion)).Str("new", ModelVersion).Msg("Model version updated, clearing cache")
|
||||
os.RemoveAll(dir)
|
||||
if err := os.MkdirAll(dir, 0700); err != nil {
|
||||
return "", fmt.Errorf("create model cache dir: %w", err)
|
||||
}
|
||||
}
|
||||
|
||||
if err := downloadWithRetries(url, modelPath); err != nil {
|
||||
os.Remove(modelPath)
|
||||
return "", fmt.Errorf("download %s: %w", assetName, err)
|
||||
}
|
||||
|
||||
if valid, err := verifyChecksum(modelPath, expectedSHA256); err != nil || !valid {
|
||||
os.Remove(modelPath)
|
||||
return "", fmt.Errorf("downloaded model %s failed checksum verification — the file at %s may have changed", assetName, url)
|
||||
}
|
||||
|
||||
if err := os.WriteFile(versionFile, []byte(ModelVersion), 0600); err != nil {
|
||||
log.Warn().Err(err).Msg("Failed to write model version file")
|
||||
}
|
||||
|
||||
log.Info().Str("model", assetName).Str("cache", dir).Msg("Model ready")
|
||||
return modelPath, nil
|
||||
}
|
||||
|
||||
func downloadWithRetries(url, dest string) error {
|
||||
var lastErr error
|
||||
for attempt := 0; attempt < maxRetries; attempt++ {
|
||||
if attempt > 0 {
|
||||
delay := retryBaseDelay * time.Duration(1<<(attempt-1))
|
||||
log.Warn().Int("attempt", attempt+1).Int("max", maxRetries).Dur("delay", delay).Msg("Retrying download")
|
||||
time.Sleep(delay)
|
||||
}
|
||||
if err := downloadFile(url, dest); err != nil {
|
||||
lastErr = err
|
||||
continue
|
||||
}
|
||||
return nil
|
||||
}
|
||||
return fmt.Errorf("failed after %d attempts: %w", maxRetries, lastErr)
|
||||
}
|
||||
|
||||
func downloadFile(url, dest string) error {
|
||||
client := &http.Client{Timeout: downloadTimeout}
|
||||
|
||||
resp, err := client.Get(url)
|
||||
if err != nil {
|
||||
return fmt.Errorf("HTTP GET: %w", err)
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
if resp.StatusCode < 200 || resp.StatusCode >= 400 {
|
||||
return fmt.Errorf("HTTP %d", resp.StatusCode)
|
||||
}
|
||||
|
||||
tmp := dest + ".tmp"
|
||||
f, err := os.Create(tmp)
|
||||
if err != nil {
|
||||
return fmt.Errorf("create temp file: %w", err)
|
||||
}
|
||||
cleanup := func() { f.Close(); os.Remove(tmp) }
|
||||
|
||||
name := filepath.Base(dest)
|
||||
reader := &progressReader{
|
||||
reader: resp.Body,
|
||||
total: resp.ContentLength,
|
||||
name: name,
|
||||
}
|
||||
|
||||
if _, err := io.Copy(f, reader); err != nil {
|
||||
cleanup()
|
||||
return fmt.Errorf("write file: %w", err)
|
||||
}
|
||||
|
||||
if err := f.Sync(); err != nil {
|
||||
cleanup()
|
||||
return fmt.Errorf("sync file: %w", err)
|
||||
}
|
||||
f.Close()
|
||||
|
||||
if err := os.Rename(tmp, dest); err != nil {
|
||||
os.Remove(tmp)
|
||||
return fmt.Errorf("commit file: %w", err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// progressReader wraps an io.Reader and logs progress for large downloads.
|
||||
type progressReader struct {
|
||||
reader io.Reader
|
||||
total int64
|
||||
current int64
|
||||
name string
|
||||
lastLog time.Time
|
||||
loggedPct int
|
||||
}
|
||||
|
||||
func (pr *progressReader) Read(p []byte) (n int, err error) {
|
||||
n, err = pr.reader.Read(p)
|
||||
pr.current += int64(n)
|
||||
|
||||
if time.Since(pr.lastLog) >= progressEvery || err == io.EOF {
|
||||
if pr.total > 0 {
|
||||
pct := int(math.Round(float64(pr.current) / float64(pr.total) * 100))
|
||||
if pct != pr.loggedPct {
|
||||
pr.loggedPct = pct
|
||||
mibCurrent := float64(pr.current) / (1024 * 1024)
|
||||
mibTotal := float64(pr.total) / (1024 * 1024)
|
||||
log.Info().
|
||||
Str("model", pr.name).
|
||||
Int("pct", pct).
|
||||
Str("progress", fmt.Sprintf("%.0f/%.0f MiB", mibCurrent, mibTotal)).
|
||||
Msg("Downloading")
|
||||
}
|
||||
} else {
|
||||
mibCurrent := float64(pr.current) / (1024 * 1024)
|
||||
log.Info().
|
||||
Str("model", pr.name).
|
||||
Str("progress", fmt.Sprintf("%.0f MiB", mibCurrent)).
|
||||
Msg("Downloading (size unknown)")
|
||||
}
|
||||
pr.lastLog = time.Now()
|
||||
}
|
||||
|
||||
return n, err
|
||||
}
|
||||
|
||||
func verifyChecksum(path, expectedHex string) (bool, error) {
|
||||
f, err := os.Open(path)
|
||||
if err != nil {
|
||||
return false, err
|
||||
}
|
||||
defer f.Close()
|
||||
|
||||
h := sha256.New()
|
||||
if _, err := io.Copy(h, f); err != nil {
|
||||
return false, err
|
||||
}
|
||||
|
||||
actual := hex.EncodeToString(h.Sum(nil))
|
||||
return actual == expectedHex, nil
|
||||
}
|
||||
|
||||
// CleanStale removes any temporary download files that may have been left behind
|
||||
// by a previous interrupted download.
|
||||
func CleanStale() error {
|
||||
dir := ModelsDir()
|
||||
entries, err := os.ReadDir(dir)
|
||||
if err != nil {
|
||||
if errors.Is(err, os.ErrNotExist) {
|
||||
return nil
|
||||
}
|
||||
return err
|
||||
}
|
||||
|
||||
for _, e := range entries {
|
||||
if filepath.Ext(e.Name()) == ".tmp" {
|
||||
tmpPath := filepath.Join(dir, e.Name())
|
||||
if info, err := e.Info(); err == nil {
|
||||
if time.Since(info.ModTime()) > time.Hour {
|
||||
os.Remove(tmpPath)
|
||||
log.Debug().Str("file", tmpPath).Msg("Cleaned stale temp file")
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
@@ -5,10 +5,8 @@ import (
|
||||
_ "embed"
|
||||
)
|
||||
|
||||
// Cross-encoder model and tokenizer files - embedded for all platforms
|
||||
// Tokenizer file - embedded for all platforms (small, not in LFS).
|
||||
// The ONNX model is downloaded at runtime to ~/.claude-mnemonic/models/.
|
||||
//
|
||||
//go:embed assets/model.onnx
|
||||
var crossEncoderModelData []byte
|
||||
|
||||
//go:embed assets/tokenizer.json
|
||||
var crossEncoderTokenizerData []byte
|
||||
|
||||
Binary file not shown.
@@ -9,6 +9,7 @@ import (
|
||||
"sort"
|
||||
"sync"
|
||||
|
||||
"github.com/lukaszraczylo/claude-mnemonic/internal/models"
|
||||
"github.com/rs/zerolog/log"
|
||||
"github.com/sugarme/tokenizer"
|
||||
"github.com/sugarme/tokenizer/pretrained"
|
||||
@@ -26,6 +27,13 @@ const (
|
||||
DefaultCandidateLimit = 100
|
||||
// DefaultResultLimit is the default number of results to return after reranking
|
||||
DefaultResultLimit = 10
|
||||
|
||||
// CrossEncoderAssetName is the local filename for the cross-encoder model.
|
||||
CrossEncoderAssetName = "ms-marco-MiniLM-L6-v2-model.onnx"
|
||||
// CrossEncoderAssetURL is the Hugging Face URL for the cross-encoder ONNX model.
|
||||
CrossEncoderAssetURL = "https://huggingface.co/cross-encoder/ms-marco-MiniLM-L6-v2/resolve/main/onnx/model.onnx"
|
||||
// CrossEncoderModelSHA256 is the expected SHA-256 of the cross-encoder model file.
|
||||
CrossEncoderModelSHA256 = "5d3e70fd0c9ff14b9b5169a51e957b7a9c74897afd0a35ce4bd318150c1d4d4a"
|
||||
)
|
||||
|
||||
// Candidate represents a search result candidate for reranking.
|
||||
@@ -91,12 +99,18 @@ func NewService(cfg Config) (*Service, error) {
|
||||
Stride: 0,
|
||||
})
|
||||
|
||||
// Download model on first use (cached to ~/.claude-mnemonic/models/)
|
||||
modelPath, err := models.EnsureModel(CrossEncoderAssetName, CrossEncoderAssetURL, CrossEncoderModelSHA256)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("ensure cross-encoder model: %w", err)
|
||||
}
|
||||
|
||||
// Cross-encoder outputs a single logit for relevance scoring
|
||||
inputNames := []string{"input_ids", "attention_mask", "token_type_ids"}
|
||||
outputNames := []string{"logits"}
|
||||
|
||||
session, err := ort.NewDynamicAdvancedSessionWithONNXData(
|
||||
crossEncoderModelData,
|
||||
session, err := ort.NewDynamicAdvancedSession(
|
||||
modelPath,
|
||||
inputNames,
|
||||
outputNames,
|
||||
nil,
|
||||
|
||||
Executable
+65
@@ -0,0 +1,65 @@
|
||||
#!/bin/bash
|
||||
# Download ONNX models from Hugging Face for local development and CI.
|
||||
# Usage: ./scripts/download-models.sh [--force]
|
||||
#
|
||||
# Downloads models to internal/*/assets/ for go:embed and to testdata/models/
|
||||
# for Go tests (CLAUDE_MNEMONIC_MODEL_DIR points there).
|
||||
|
||||
set -e
|
||||
|
||||
ASSETS_EMB="internal/embedding/assets"
|
||||
ASSETS_RERANK="internal/reranking/assets"
|
||||
TESTDATA="testdata/models"
|
||||
FORCE_DOWNLOAD=false
|
||||
|
||||
for arg in "$@"; do
|
||||
if [ "$arg" = "--force" ]; then
|
||||
FORCE_DOWNLOAD=true
|
||||
fi
|
||||
done
|
||||
|
||||
download_if_needed() {
|
||||
local url="$1"
|
||||
local dest="$2"
|
||||
local name="$3"
|
||||
local expected_sha="$4"
|
||||
|
||||
if [ "$FORCE_DOWNLOAD" = false ] && [ -f "$dest" ]; then
|
||||
local actual_sha
|
||||
actual_sha=$(shasum -a 256 "$dest" | awk '{print $1}')
|
||||
if [ "$actual_sha" = "$expected_sha" ]; then
|
||||
echo "[skip] $name"
|
||||
return
|
||||
fi
|
||||
echo "[mismatch] $name checksum mismatch, re-downloading"
|
||||
fi
|
||||
|
||||
echo "[download] $name ($(basename "$url"))"
|
||||
curl -fsSL --retry 3 --retry-delay 2 "$url" -o "$dest"
|
||||
echo "[ok] $name"
|
||||
}
|
||||
|
||||
echo "=== Downloading models from Hugging Face ==="
|
||||
|
||||
# BGE-small-en-v1.5 embedding model (~127 MB)
|
||||
download_if_needed \
|
||||
"https://huggingface.co/BAAI/bge-small-en-v1.5/resolve/main/onnx/model.onnx" \
|
||||
"${ASSETS_EMB}/model.onnx" \
|
||||
"embedding (bge-small-en-v1.5)" \
|
||||
"828e1496d7fabb79cfa4dcd84fa38625c0d3d21da474a00f08db0f559940cf35"
|
||||
|
||||
# MS-MARCO MiniLM-L6-v2 cross-encoder model (~91 MB)
|
||||
download_if_needed \
|
||||
"https://huggingface.co/cross-encoder/ms-marco-MiniLM-L6-v2/resolve/main/onnx/model.onnx" \
|
||||
"${ASSETS_RERANK}/model.onnx" \
|
||||
"cross-encoder (ms-marco-MiniLM-L6-v2)" \
|
||||
"5d3e70fd0c9ff14b9b5169a51e957b7a9c74897afd0a35ce4bd318150c1d4d4a"
|
||||
|
||||
echo ""
|
||||
echo "=== Staging models for tests ==="
|
||||
mkdir -p "${TESTDATA}"
|
||||
cp "${ASSETS_EMB}/model.onnx" "${TESTDATA}/bge-small-en-v1.5-model.onnx"
|
||||
cp "${ASSETS_RERANK}/model.onnx" "${TESTDATA}/ms-marco-MiniLM-L6-v2-model.onnx"
|
||||
echo "[ok] Test models staged to ${TESTDATA}/"
|
||||
|
||||
echo "Done!"
|
||||
+48
-29
@@ -10,46 +10,65 @@ HOST_OS=$(uname -s | tr '[:upper:]' '[:lower:]')
|
||||
# Determine target platform for ONNX library download
|
||||
# Use TARGET_GOOS/TARGET_GOARCH from CI matrix if available, otherwise auto-detect
|
||||
if [ -n "$TARGET_GOOS" ] && [ -n "$TARGET_GOARCH" ]; then
|
||||
ONNX_PLATFORM="${TARGET_GOOS}-${TARGET_GOARCH}"
|
||||
echo "Target platform from CI matrix: $ONNX_PLATFORM"
|
||||
ONNX_PLATFORM="${TARGET_GOOS}-${TARGET_GOARCH}"
|
||||
echo "Target platform from CI matrix: $ONNX_PLATFORM"
|
||||
else
|
||||
ONNX_PLATFORM="auto"
|
||||
ONNX_PLATFORM="auto"
|
||||
fi
|
||||
|
||||
# On Windows, install SQLite development headers for CGO
|
||||
if [[ "$HOST_OS" == mingw* ]] || [[ "$HOST_OS" == msys* ]] || [[ "$HOST_OS" == cygwin* ]]; then
|
||||
echo "Installing SQLite for Windows..."
|
||||
echo "Installing SQLite for Windows..."
|
||||
|
||||
# Download SQLite amalgamation and set up for CGO
|
||||
SQLITE_VERSION="3470200"
|
||||
SQLITE_YEAR="2024"
|
||||
SQLITE_DIR="/c/sqlite"
|
||||
SQLITE_URL="https://www.sqlite.org/${SQLITE_YEAR}/sqlite-amalgamation-${SQLITE_VERSION}.zip"
|
||||
# Download SQLite amalgamation and set up for CGO
|
||||
SQLITE_VERSION="3470200"
|
||||
SQLITE_YEAR="2024"
|
||||
SQLITE_DIR="/c/sqlite"
|
||||
SQLITE_URL="https://www.sqlite.org/${SQLITE_YEAR}/sqlite-amalgamation-${SQLITE_VERSION}.zip"
|
||||
|
||||
mkdir -p "$SQLITE_DIR"
|
||||
curl -sSL "$SQLITE_URL" -o /tmp/sqlite.zip
|
||||
unzip -q /tmp/sqlite.zip -d /tmp/
|
||||
cp /tmp/sqlite-amalgamation-${SQLITE_VERSION}/* "$SQLITE_DIR/"
|
||||
rm -rf /tmp/sqlite.zip /tmp/sqlite-amalgamation-${SQLITE_VERSION}
|
||||
mkdir -p "$SQLITE_DIR"
|
||||
curl -sSL "$SQLITE_URL" -o /tmp/sqlite.zip
|
||||
unzip -q /tmp/sqlite.zip -d /tmp/
|
||||
cp /tmp/sqlite-amalgamation-${SQLITE_VERSION}/* "$SQLITE_DIR/"
|
||||
rm -rf /tmp/sqlite.zip /tmp/sqlite-amalgamation-${SQLITE_VERSION}
|
||||
|
||||
# Download Go modules first so we can patch sqlite-vec
|
||||
echo "Downloading Go modules..."
|
||||
go mod download
|
||||
# Download Go modules first so we can patch sqlite-vec
|
||||
echo "Downloading Go modules..."
|
||||
go mod download
|
||||
|
||||
# Find the sqlite-vec module and copy sqlite3.h there
|
||||
SQLITE_VEC_PATH=$(go list -m -f '{{.Dir}}' github.com/asg017/sqlite-vec-go-bindings 2>/dev/null || true)
|
||||
if [ -n "$SQLITE_VEC_PATH" ] && [ -d "$SQLITE_VEC_PATH/cgo" ]; then
|
||||
# Make module writable (it's read-only by default)
|
||||
chmod -R u+w "$SQLITE_VEC_PATH"
|
||||
cp "$SQLITE_DIR/sqlite3.h" "$SQLITE_VEC_PATH/cgo/"
|
||||
cp "$SQLITE_DIR/sqlite3.c" "$SQLITE_VEC_PATH/cgo/"
|
||||
echo "SQLite headers copied to $SQLITE_VEC_PATH/cgo/"
|
||||
fi
|
||||
# Find the sqlite-vec module and copy sqlite3.h there
|
||||
SQLITE_VEC_PATH=$(go list -m -f '{{.Dir}}' github.com/asg017/sqlite-vec-go-bindings 2>/dev/null || true)
|
||||
if [ -n "$SQLITE_VEC_PATH" ] && [ -d "$SQLITE_VEC_PATH/cgo" ]; then
|
||||
# Make module writable (it's read-only by default)
|
||||
chmod -R u+w "$SQLITE_VEC_PATH"
|
||||
cp "$SQLITE_DIR/sqlite3.h" "$SQLITE_VEC_PATH/cgo/"
|
||||
cp "$SQLITE_DIR/sqlite3.c" "$SQLITE_VEC_PATH/cgo/"
|
||||
echo "SQLite headers copied to $SQLITE_VEC_PATH/cgo/"
|
||||
fi
|
||||
|
||||
# Tell linker to allow multiple definitions (both go-sqlite3 and sqlite-vec embed SQLite)
|
||||
echo "CGO_LDFLAGS=-Wl,--allow-multiple-definition" >> "$GITHUB_ENV"
|
||||
echo "SQLite setup complete"
|
||||
# Tell linker to allow multiple definitions (both go-sqlite3 and sqlite-vec embed SQLite)
|
||||
echo "CGO_LDFLAGS=-Wl,--allow-multiple-definition" >>"$GITHUB_ENV"
|
||||
echo "SQLite setup complete"
|
||||
fi
|
||||
|
||||
# Download ONNX runtime libraries for target platform
|
||||
./scripts/download-onnx-libs.sh "$ONNX_PLATFORM"
|
||||
|
||||
# Download ONNX models from Hugging Face and stage for tests
|
||||
# Non-fatal: if download fails, tests will attempt runtime download themselves
|
||||
set +e
|
||||
./scripts/download-models.sh
|
||||
DOWNLOAD_EXIT=$?
|
||||
set -e
|
||||
|
||||
if [ $DOWNLOAD_EXIT -eq 0 ]; then
|
||||
# Export model directory so tests can find models without network access
|
||||
if [ -n "$GITHUB_ENV" ]; then
|
||||
echo "CLAUDE_MNEMONIC_MODEL_DIR=$GITHUB_WORKSPACE/testdata/models" >> "$GITHUB_ENV"
|
||||
fi
|
||||
else
|
||||
echo "Warning: Model download failed — tests will try to download models at runtime"
|
||||
fi
|
||||
else
|
||||
echo "Warning: Model download failed — tests will try to download models at runtime"
|
||||
fi
|
||||
|
||||
Reference in New Issue
Block a user