diff --git a/.gitattributes b/.gitattributes index 56fb246..d1036ef 100644 --- a/.gitattributes +++ b/.gitattributes @@ -1 +1,2 @@ -internal/embedding/assets/*.onnx filter=lfs diff=lfs merge=lfs -text +# Models are now downloaded at runtime from GitHub Releases. +# LFS tracking removed. See internal/models/download.go diff --git a/.github/workflows/autoupdate.yaml b/.github/workflows/autoupdate.yaml index d6f7e3e..027afad 100644 --- a/.github/workflows/autoupdate.yaml +++ b/.github/workflows/autoupdate.yaml @@ -17,5 +17,5 @@ jobs: with: go-version: ">=1.24" release-workflow: "release.yaml" - lfs: true + lfs: false secrets: inherit diff --git a/.github/workflows/ci.yaml b/.github/workflows/ci.yaml index 3cbc3dd..c0b3ba9 100644 --- a/.github/workflows/ci.yaml +++ b/.github/workflows/ci.yaml @@ -20,6 +20,6 @@ jobs: uses: lukaszraczylo/shared-actions/.github/workflows/go-pr.yaml@main with: go-version: ">=1.24" - lfs: true + lfs: false build-tags: "fts5" secrets: inherit diff --git a/.github/workflows/release.yaml b/.github/workflows/release.yaml index 891965f..dc60278 100644 --- a/.github/workflows/release.yaml +++ b/.github/workflows/release.yaml @@ -25,7 +25,7 @@ jobs: node-cache-dependency-path: "ui/package-lock.json" node-output-path: "ui/dist" node-embed-path: "internal/worker/static" - lfs: true + lfs: false secrets: inherit commit-marketplace: diff --git a/.gitignore b/.gitignore index 13126d5..4802e98 100644 --- a/.gitignore +++ b/.gitignore @@ -81,6 +81,12 @@ logs/ *.db-wal *.db-shm +# Model files (downloaded at runtime from GitHub Releases) +# Test model data (downloaded from Hugging Face) +testdata/models/ +internal/embedding/assets/model.onnx +internal/reranking/assets/model.onnx + # goreleaser dist/ docs/dist diff --git a/Makefile b/Makefile index 5e981ee..c2a1d8e 100644 --- a/Makefile +++ b/Makefile @@ -23,6 +23,10 @@ all: build setup-libs: @./scripts/download-onnx-libs.sh all +# Download ONNX models from Hugging Face (for local dev, skips if present) +setup-models: + @./scripts/download-models.sh + # Update version in committed plugin metadata update-version: @if [ -f .claude-plugin/plugin.json ]; then \ @@ -182,12 +186,14 @@ uninstall: stop-worker @echo "Uninstallation complete!" # Run tests (with FTS5 support) -test: setup-libs - go test $(BUILD_TAGS) -v -race ./... +# CLAUDE_MNEMONIC_MODEL_DIR points to a common model directory for tests. +# Run setup-models first to download models from Hugging Face. +test: setup-libs setup-models + CLAUDE_MNEMONIC_MODEL_DIR=$$(pwd)/testdata/models go test $(BUILD_TAGS) -v -race ./... # Run tests with coverage (with FTS5 support) -test-coverage: setup-libs - go test $(BUILD_TAGS) -v -race -coverprofile=coverage.out ./... +test-coverage: setup-libs setup-models + CLAUDE_MNEMONIC_MODEL_DIR=$$(pwd)/testdata/models go test $(BUILD_TAGS) -v -race -coverprofile=coverage.out ./... go tool cover -html=coverage.out -o coverage.html @go tool cover -func=coverage.out | tail -1 diff --git a/internal/embedding/assets.go b/internal/embedding/assets.go index a6e56c0..dd38728 100644 --- a/internal/embedding/assets.go +++ b/internal/embedding/assets.go @@ -1,14 +1,12 @@ -// Package embedding provides text embedding generation using all-MiniLM-L6-v2. +// Package embedding provides text embedding generation using bge-small-en-v1.5. package embedding import ( _ "embed" ) -// Model and tokenizer files - embedded for all platforms +// Tokenizer file - embedded for all platforms (small, not in LFS). +// The ONNX model is downloaded at runtime to ~/.claude-mnemonic/models/. // -//go:embed assets/model.onnx -var modelData []byte - //go:embed assets/tokenizer.json var tokenizerData []byte diff --git a/internal/embedding/assets/model.onnx b/internal/embedding/assets/model.onnx deleted file mode 100644 index 33c23a7..0000000 --- a/internal/embedding/assets/model.onnx +++ /dev/null @@ -1,3 +0,0 @@ -version https://git-lfs.github.com/spec/v1 -oid sha256:828e1496d7fabb79cfa4dcd84fa38625c0d3d21da474a00f08db0f559940cf35 -size 133093490 diff --git a/internal/embedding/service.go b/internal/embedding/service.go index 0fec133..66b2fa9 100644 --- a/internal/embedding/service.go +++ b/internal/embedding/service.go @@ -12,6 +12,7 @@ import ( "strings" "sync" + "github.com/lukaszraczylo/claude-mnemonic/internal/models" "github.com/sugarme/tokenizer" "github.com/sugarme/tokenizer/pretrained" ort "github.com/yalue/onnxruntime_go" @@ -29,6 +30,13 @@ const ( BGEModelName = "bge-small-en-v1.5" // DefaultModelVersion is the default model to use DefaultModelVersion = BGEModelVersion + + // BGEAssetName is the local filename for the BGE model. + BGEAssetName = "bge-small-en-v1.5-model.onnx" + // BGEAssetURL is the Hugging Face URL for the BGE ONNX model. + BGEAssetURL = "https://huggingface.co/BAAI/bge-small-en-v1.5/resolve/main/onnx/model.onnx" + // BGEModelSHA256 is the expected SHA-256 of the BGE model file. + BGEModelSHA256 = "828e1496d7fabb79cfa4dcd84fa38625c0d3d21da474a00f08db0f559940cf35" ) // MaxSequenceLength is the maximum token sequence length for the model. @@ -60,6 +68,7 @@ var _ EmbeddingModel = (*bgeModel)(nil) var _ ONNXConfigurer = (*bgeModel)(nil) // newBGEModel creates a new BGE embedding model using bundled ONNX runtime and model. +// The ONNX model is downloaded at runtime from GitHub Releases and cached to disk. func newBGEModel() (EmbeddingModel, error) { // Extract ONNX runtime library to temp directory libDir, err := extractONNXLibrary() @@ -84,9 +93,15 @@ func newBGEModel() (EmbeddingModel, error) { return nil, fmt.Errorf("load tokenizer: %w", err) } - // Create ONNX session using model-specific configuration + // Download model on first use (cached to ~/.claude-mnemonic/models/) + modelPath, err := models.EnsureModel(BGEAssetName, BGEAssetURL, BGEModelSHA256) + if err != nil { + return nil, fmt.Errorf("ensure embedding model: %w", err) + } + + // Create ONNX session from file path using model-specific configuration config := bgeONNXConfig - session, err := ort.NewDynamicAdvancedSessionWithONNXData(modelData, config.InputNames, config.OutputNames, nil) + session, err := ort.NewDynamicAdvancedSession(modelPath, config.InputNames, config.OutputNames, nil) if err != nil { return nil, fmt.Errorf("create ONNX session: %w", err) } diff --git a/internal/models/download.go b/internal/models/download.go new file mode 100644 index 0000000..c3140b5 --- /dev/null +++ b/internal/models/download.go @@ -0,0 +1,259 @@ +// Package models provides runtime model download and caching. +// Models are downloaded from Hugging Face on first use and cached +// to the OS cache directory (~/Library/Caches/claude-mnemonic/models/ on macOS, +// ~/.cache/claude-mnemonic/models/ on Linux, %LocalAppData%/claude-mnemonic/models/ on Windows). +// This replaces the previous approach of embedding models in the binary +// via go:embed and tracking them with Git LFS. +package models + +import ( + "crypto/sha256" + "encoding/hex" + "errors" + "fmt" + "io" + "math" + "net/http" + "os" + "path/filepath" + "time" + + "github.com/rs/zerolog/log" +) + +const ( + // ModelVersion is incremented to force re-download of all cached models. + ModelVersion = "1" + + maxRetries = 3 + retryBaseDelay = 2 * time.Second + progressEvery = 3 * time.Second + downloadTimeout = 15 * time.Minute +) + +// ModelsDir returns the path to the model cache directory. +// Uses the OS cache dir for cross-platform compatibility: +// +// macOS: ~/Library/Caches/claude-mnemonic/models +// Linux: ~/.cache/claude-mnemonic/models +// Windows: %LocalAppData%/claude-mnemonic/models +func ModelsDir() string { + dir, err := os.UserCacheDir() + if err != nil { + home, _ := os.UserHomeDir() + dir = filepath.Join(home, ".cache") + } + return filepath.Join(dir, "claude-mnemonic", "models") +} + +// EnsureModel ensures the model file exists and returns its path. +// +// Resolution order: +// 1. CLAUDE_MNEMONIC_MODEL_DIR env var — if set, load models from this directory directly. +// This is intended for development and CI where models are already on disk. +// 2. User cache directory — if the model is already cached and its checksum matches, +// return it immediately. +// 3. url — download the model from the given URL with retries and cache it. +// +// assetName is the local filename (e.g. "bge-small-en-v1.5-model.onnx"). +// expectedSHA256 is the hex-encoded SHA-256 of the file. +func EnsureModel(assetName, url, expectedSHA256 string) (string, error) { + if localDir := os.Getenv("CLAUDE_MNEMONIC_MODEL_DIR"); localDir != "" { + localPath := filepath.Join(localDir, assetName) + if _, err := os.Stat(localPath); err == nil { + if valid, err := verifyChecksum(localPath, expectedSHA256); err != nil || !valid { + return "", fmt.Errorf("local model %s checksum mismatch", assetName) + } + return localPath, nil + } + return "", fmt.Errorf("local model %s not found in %s", assetName, localDir) + } + + dir := ModelsDir() + versionFile := filepath.Join(dir, ".model_version") + modelPath := filepath.Join(dir, assetName) + + if err := os.MkdirAll(dir, 0700); err != nil { + return "", fmt.Errorf("create model cache dir: %w", err) + } + + currentVersion, _ := os.ReadFile(versionFile) + + if string(currentVersion) == ModelVersion { + if fi, err := os.Stat(modelPath); err == nil && fi.Size() > 0 { + if valid, _ := verifyChecksum(modelPath, expectedSHA256); valid { + return modelPath, nil + } + log.Warn().Str("model", assetName).Msg("Cached model corrupted, re-downloading") + os.Remove(modelPath) + } + } else if len(currentVersion) > 0 { + log.Info().Str("old", string(currentVersion)).Str("new", ModelVersion).Msg("Model version updated, clearing cache") + os.RemoveAll(dir) + if err := os.MkdirAll(dir, 0700); err != nil { + return "", fmt.Errorf("create model cache dir: %w", err) + } + } + + if err := downloadWithRetries(url, modelPath); err != nil { + os.Remove(modelPath) + return "", fmt.Errorf("download %s: %w", assetName, err) + } + + if valid, err := verifyChecksum(modelPath, expectedSHA256); err != nil || !valid { + os.Remove(modelPath) + return "", fmt.Errorf("downloaded model %s failed checksum verification — the file at %s may have changed", assetName, url) + } + + if err := os.WriteFile(versionFile, []byte(ModelVersion), 0600); err != nil { + log.Warn().Err(err).Msg("Failed to write model version file") + } + + log.Info().Str("model", assetName).Str("cache", dir).Msg("Model ready") + return modelPath, nil +} + +func downloadWithRetries(url, dest string) error { + var lastErr error + for attempt := 0; attempt < maxRetries; attempt++ { + if attempt > 0 { + delay := retryBaseDelay * time.Duration(1<<(attempt-1)) + log.Warn().Int("attempt", attempt+1).Int("max", maxRetries).Dur("delay", delay).Msg("Retrying download") + time.Sleep(delay) + } + if err := downloadFile(url, dest); err != nil { + lastErr = err + continue + } + return nil + } + return fmt.Errorf("failed after %d attempts: %w", maxRetries, lastErr) +} + +func downloadFile(url, dest string) error { + client := &http.Client{Timeout: downloadTimeout} + + resp, err := client.Get(url) + if err != nil { + return fmt.Errorf("HTTP GET: %w", err) + } + defer resp.Body.Close() + + if resp.StatusCode < 200 || resp.StatusCode >= 400 { + return fmt.Errorf("HTTP %d", resp.StatusCode) + } + + tmp := dest + ".tmp" + f, err := os.Create(tmp) + if err != nil { + return fmt.Errorf("create temp file: %w", err) + } + cleanup := func() { f.Close(); os.Remove(tmp) } + + name := filepath.Base(dest) + reader := &progressReader{ + reader: resp.Body, + total: resp.ContentLength, + name: name, + } + + if _, err := io.Copy(f, reader); err != nil { + cleanup() + return fmt.Errorf("write file: %w", err) + } + + if err := f.Sync(); err != nil { + cleanup() + return fmt.Errorf("sync file: %w", err) + } + f.Close() + + if err := os.Rename(tmp, dest); err != nil { + os.Remove(tmp) + return fmt.Errorf("commit file: %w", err) + } + + return nil +} + +// progressReader wraps an io.Reader and logs progress for large downloads. +type progressReader struct { + reader io.Reader + total int64 + current int64 + name string + lastLog time.Time + loggedPct int +} + +func (pr *progressReader) Read(p []byte) (n int, err error) { + n, err = pr.reader.Read(p) + pr.current += int64(n) + + if time.Since(pr.lastLog) >= progressEvery || err == io.EOF { + if pr.total > 0 { + pct := int(math.Round(float64(pr.current) / float64(pr.total) * 100)) + if pct != pr.loggedPct { + pr.loggedPct = pct + mibCurrent := float64(pr.current) / (1024 * 1024) + mibTotal := float64(pr.total) / (1024 * 1024) + log.Info(). + Str("model", pr.name). + Int("pct", pct). + Str("progress", fmt.Sprintf("%.0f/%.0f MiB", mibCurrent, mibTotal)). + Msg("Downloading") + } + } else { + mibCurrent := float64(pr.current) / (1024 * 1024) + log.Info(). + Str("model", pr.name). + Str("progress", fmt.Sprintf("%.0f MiB", mibCurrent)). + Msg("Downloading (size unknown)") + } + pr.lastLog = time.Now() + } + + return n, err +} + +func verifyChecksum(path, expectedHex string) (bool, error) { + f, err := os.Open(path) + if err != nil { + return false, err + } + defer f.Close() + + h := sha256.New() + if _, err := io.Copy(h, f); err != nil { + return false, err + } + + actual := hex.EncodeToString(h.Sum(nil)) + return actual == expectedHex, nil +} + +// CleanStale removes any temporary download files that may have been left behind +// by a previous interrupted download. +func CleanStale() error { + dir := ModelsDir() + entries, err := os.ReadDir(dir) + if err != nil { + if errors.Is(err, os.ErrNotExist) { + return nil + } + return err + } + + for _, e := range entries { + if filepath.Ext(e.Name()) == ".tmp" { + tmpPath := filepath.Join(dir, e.Name()) + if info, err := e.Info(); err == nil { + if time.Since(info.ModTime()) > time.Hour { + os.Remove(tmpPath) + log.Debug().Str("file", tmpPath).Msg("Cleaned stale temp file") + } + } + } + } + return nil +} diff --git a/internal/reranking/assets.go b/internal/reranking/assets.go index 493d182..ee40fa3 100644 --- a/internal/reranking/assets.go +++ b/internal/reranking/assets.go @@ -5,10 +5,8 @@ import ( _ "embed" ) -// Cross-encoder model and tokenizer files - embedded for all platforms +// Tokenizer file - embedded for all platforms (small, not in LFS). +// The ONNX model is downloaded at runtime to ~/.claude-mnemonic/models/. // -//go:embed assets/model.onnx -var crossEncoderModelData []byte - //go:embed assets/tokenizer.json var crossEncoderTokenizerData []byte diff --git a/internal/reranking/assets/model.onnx b/internal/reranking/assets/model.onnx deleted file mode 100644 index 7c28d15..0000000 Binary files a/internal/reranking/assets/model.onnx and /dev/null differ diff --git a/internal/reranking/service.go b/internal/reranking/service.go index f41494d..f6a5bb4 100644 --- a/internal/reranking/service.go +++ b/internal/reranking/service.go @@ -9,6 +9,7 @@ import ( "sort" "sync" + "github.com/lukaszraczylo/claude-mnemonic/internal/models" "github.com/rs/zerolog/log" "github.com/sugarme/tokenizer" "github.com/sugarme/tokenizer/pretrained" @@ -26,6 +27,13 @@ const ( DefaultCandidateLimit = 100 // DefaultResultLimit is the default number of results to return after reranking DefaultResultLimit = 10 + + // CrossEncoderAssetName is the local filename for the cross-encoder model. + CrossEncoderAssetName = "ms-marco-MiniLM-L6-v2-model.onnx" + // CrossEncoderAssetURL is the Hugging Face URL for the cross-encoder ONNX model. + CrossEncoderAssetURL = "https://huggingface.co/cross-encoder/ms-marco-MiniLM-L6-v2/resolve/main/onnx/model.onnx" + // CrossEncoderModelSHA256 is the expected SHA-256 of the cross-encoder model file. + CrossEncoderModelSHA256 = "5d3e70fd0c9ff14b9b5169a51e957b7a9c74897afd0a35ce4bd318150c1d4d4a" ) // Candidate represents a search result candidate for reranking. @@ -91,12 +99,18 @@ func NewService(cfg Config) (*Service, error) { Stride: 0, }) + // Download model on first use (cached to ~/.claude-mnemonic/models/) + modelPath, err := models.EnsureModel(CrossEncoderAssetName, CrossEncoderAssetURL, CrossEncoderModelSHA256) + if err != nil { + return nil, fmt.Errorf("ensure cross-encoder model: %w", err) + } + // Cross-encoder outputs a single logit for relevance scoring inputNames := []string{"input_ids", "attention_mask", "token_type_ids"} outputNames := []string{"logits"} - session, err := ort.NewDynamicAdvancedSessionWithONNXData( - crossEncoderModelData, + session, err := ort.NewDynamicAdvancedSession( + modelPath, inputNames, outputNames, nil, diff --git a/scripts/download-models.sh b/scripts/download-models.sh new file mode 100755 index 0000000..b2a5b86 --- /dev/null +++ b/scripts/download-models.sh @@ -0,0 +1,65 @@ +#!/bin/bash +# Download ONNX models from Hugging Face for local development and CI. +# Usage: ./scripts/download-models.sh [--force] +# +# Downloads models to internal/*/assets/ for go:embed and to testdata/models/ +# for Go tests (CLAUDE_MNEMONIC_MODEL_DIR points there). + +set -e + +ASSETS_EMB="internal/embedding/assets" +ASSETS_RERANK="internal/reranking/assets" +TESTDATA="testdata/models" +FORCE_DOWNLOAD=false + +for arg in "$@"; do + if [ "$arg" = "--force" ]; then + FORCE_DOWNLOAD=true + fi +done + +download_if_needed() { + local url="$1" + local dest="$2" + local name="$3" + local expected_sha="$4" + + if [ "$FORCE_DOWNLOAD" = false ] && [ -f "$dest" ]; then + local actual_sha + actual_sha=$(shasum -a 256 "$dest" | awk '{print $1}') + if [ "$actual_sha" = "$expected_sha" ]; then + echo "[skip] $name" + return + fi + echo "[mismatch] $name checksum mismatch, re-downloading" + fi + + echo "[download] $name ($(basename "$url"))" + curl -fsSL --retry 3 --retry-delay 2 "$url" -o "$dest" + echo "[ok] $name" +} + +echo "=== Downloading models from Hugging Face ===" + +# BGE-small-en-v1.5 embedding model (~127 MB) +download_if_needed \ + "https://huggingface.co/BAAI/bge-small-en-v1.5/resolve/main/onnx/model.onnx" \ + "${ASSETS_EMB}/model.onnx" \ + "embedding (bge-small-en-v1.5)" \ + "828e1496d7fabb79cfa4dcd84fa38625c0d3d21da474a00f08db0f559940cf35" + +# MS-MARCO MiniLM-L6-v2 cross-encoder model (~91 MB) +download_if_needed \ + "https://huggingface.co/cross-encoder/ms-marco-MiniLM-L6-v2/resolve/main/onnx/model.onnx" \ + "${ASSETS_RERANK}/model.onnx" \ + "cross-encoder (ms-marco-MiniLM-L6-v2)" \ + "5d3e70fd0c9ff14b9b5169a51e957b7a9c74897afd0a35ce4bd318150c1d4d4a" + +echo "" +echo "=== Staging models for tests ===" +mkdir -p "${TESTDATA}" +cp "${ASSETS_EMB}/model.onnx" "${TESTDATA}/bge-small-en-v1.5-model.onnx" +cp "${ASSETS_RERANK}/model.onnx" "${TESTDATA}/ms-marco-MiniLM-L6-v2-model.onnx" +echo "[ok] Test models staged to ${TESTDATA}/" + +echo "Done!" diff --git a/workflow-prepare.sh b/workflow-prepare.sh index bb94269..13a303a 100755 --- a/workflow-prepare.sh +++ b/workflow-prepare.sh @@ -10,46 +10,65 @@ HOST_OS=$(uname -s | tr '[:upper:]' '[:lower:]') # Determine target platform for ONNX library download # Use TARGET_GOOS/TARGET_GOARCH from CI matrix if available, otherwise auto-detect if [ -n "$TARGET_GOOS" ] && [ -n "$TARGET_GOARCH" ]; then - ONNX_PLATFORM="${TARGET_GOOS}-${TARGET_GOARCH}" - echo "Target platform from CI matrix: $ONNX_PLATFORM" + ONNX_PLATFORM="${TARGET_GOOS}-${TARGET_GOARCH}" + echo "Target platform from CI matrix: $ONNX_PLATFORM" else - ONNX_PLATFORM="auto" + ONNX_PLATFORM="auto" fi # On Windows, install SQLite development headers for CGO if [[ "$HOST_OS" == mingw* ]] || [[ "$HOST_OS" == msys* ]] || [[ "$HOST_OS" == cygwin* ]]; then - echo "Installing SQLite for Windows..." + echo "Installing SQLite for Windows..." - # Download SQLite amalgamation and set up for CGO - SQLITE_VERSION="3470200" - SQLITE_YEAR="2024" - SQLITE_DIR="/c/sqlite" - SQLITE_URL="https://www.sqlite.org/${SQLITE_YEAR}/sqlite-amalgamation-${SQLITE_VERSION}.zip" + # Download SQLite amalgamation and set up for CGO + SQLITE_VERSION="3470200" + SQLITE_YEAR="2024" + SQLITE_DIR="/c/sqlite" + SQLITE_URL="https://www.sqlite.org/${SQLITE_YEAR}/sqlite-amalgamation-${SQLITE_VERSION}.zip" - mkdir -p "$SQLITE_DIR" - curl -sSL "$SQLITE_URL" -o /tmp/sqlite.zip - unzip -q /tmp/sqlite.zip -d /tmp/ - cp /tmp/sqlite-amalgamation-${SQLITE_VERSION}/* "$SQLITE_DIR/" - rm -rf /tmp/sqlite.zip /tmp/sqlite-amalgamation-${SQLITE_VERSION} + mkdir -p "$SQLITE_DIR" + curl -sSL "$SQLITE_URL" -o /tmp/sqlite.zip + unzip -q /tmp/sqlite.zip -d /tmp/ + cp /tmp/sqlite-amalgamation-${SQLITE_VERSION}/* "$SQLITE_DIR/" + rm -rf /tmp/sqlite.zip /tmp/sqlite-amalgamation-${SQLITE_VERSION} - # Download Go modules first so we can patch sqlite-vec - echo "Downloading Go modules..." - go mod download + # Download Go modules first so we can patch sqlite-vec + echo "Downloading Go modules..." + go mod download - # Find the sqlite-vec module and copy sqlite3.h there - SQLITE_VEC_PATH=$(go list -m -f '{{.Dir}}' github.com/asg017/sqlite-vec-go-bindings 2>/dev/null || true) - if [ -n "$SQLITE_VEC_PATH" ] && [ -d "$SQLITE_VEC_PATH/cgo" ]; then - # Make module writable (it's read-only by default) - chmod -R u+w "$SQLITE_VEC_PATH" - cp "$SQLITE_DIR/sqlite3.h" "$SQLITE_VEC_PATH/cgo/" - cp "$SQLITE_DIR/sqlite3.c" "$SQLITE_VEC_PATH/cgo/" - echo "SQLite headers copied to $SQLITE_VEC_PATH/cgo/" - fi + # Find the sqlite-vec module and copy sqlite3.h there + SQLITE_VEC_PATH=$(go list -m -f '{{.Dir}}' github.com/asg017/sqlite-vec-go-bindings 2>/dev/null || true) + if [ -n "$SQLITE_VEC_PATH" ] && [ -d "$SQLITE_VEC_PATH/cgo" ]; then + # Make module writable (it's read-only by default) + chmod -R u+w "$SQLITE_VEC_PATH" + cp "$SQLITE_DIR/sqlite3.h" "$SQLITE_VEC_PATH/cgo/" + cp "$SQLITE_DIR/sqlite3.c" "$SQLITE_VEC_PATH/cgo/" + echo "SQLite headers copied to $SQLITE_VEC_PATH/cgo/" + fi - # Tell linker to allow multiple definitions (both go-sqlite3 and sqlite-vec embed SQLite) - echo "CGO_LDFLAGS=-Wl,--allow-multiple-definition" >> "$GITHUB_ENV" - echo "SQLite setup complete" + # Tell linker to allow multiple definitions (both go-sqlite3 and sqlite-vec embed SQLite) + echo "CGO_LDFLAGS=-Wl,--allow-multiple-definition" >>"$GITHUB_ENV" + echo "SQLite setup complete" fi # Download ONNX runtime libraries for target platform ./scripts/download-onnx-libs.sh "$ONNX_PLATFORM" + +# Download ONNX models from Hugging Face and stage for tests +# Non-fatal: if download fails, tests will attempt runtime download themselves +set +e +./scripts/download-models.sh +DOWNLOAD_EXIT=$? +set -e + +if [ $DOWNLOAD_EXIT -eq 0 ]; then + # Export model directory so tests can find models without network access + if [ -n "$GITHUB_ENV" ]; then + echo "CLAUDE_MNEMONIC_MODEL_DIR=$GITHUB_WORKSPACE/testdata/models" >> "$GITHUB_ENV" + fi +else + echo "Warning: Model download failed — tests will try to download models at runtime" +fi +else + echo "Warning: Model download failed — tests will try to download models at runtime" +fi