refactor: replace Git LFS with runtime model download from Hugging Face

Remove ~170MB of model files from the repository (LFS + committed).
Models are now downloaded at runtime from Hugging Face on first use
and cached to the OS cache directory with progress reporting and retries.

- Add internal/models/download.go: runtime downloader with retry, progress bar, checksums
- Remove go:embed for ONNX models (keep tokenizers embedded)
- Use file-based ONNX session loading instead of byte-slice
- Add scripts/download-models.sh for dev/CI model setup
- Update Makefile with setup-models target
- Update workflow-prepare.sh to download models in CI
- Set lfs: false in all CI workflows
- SHA256: bge=828e14..., cross-encoder=5d3e70...
This commit is contained in:
2026-05-26 17:52:55 +01:00
parent c8b462aaec
commit 1a4fea5c17
15 changed files with 431 additions and 53 deletions
+3 -5
View File
@@ -1,14 +1,12 @@
// Package embedding provides text embedding generation using all-MiniLM-L6-v2.
// Package embedding provides text embedding generation using bge-small-en-v1.5.
package embedding
import (
_ "embed"
)
// Model and tokenizer files - embedded for all platforms
// Tokenizer file - embedded for all platforms (small, not in LFS).
// The ONNX model is downloaded at runtime to ~/.claude-mnemonic/models/.
//
//go:embed assets/model.onnx
var modelData []byte
//go:embed assets/tokenizer.json
var tokenizerData []byte
Binary file not shown.
+17 -2
View File
@@ -12,6 +12,7 @@ import (
"strings"
"sync"
"github.com/lukaszraczylo/claude-mnemonic/internal/models"
"github.com/sugarme/tokenizer"
"github.com/sugarme/tokenizer/pretrained"
ort "github.com/yalue/onnxruntime_go"
@@ -29,6 +30,13 @@ const (
BGEModelName = "bge-small-en-v1.5"
// DefaultModelVersion is the default model to use
DefaultModelVersion = BGEModelVersion
// BGEAssetName is the local filename for the BGE model.
BGEAssetName = "bge-small-en-v1.5-model.onnx"
// BGEAssetURL is the Hugging Face URL for the BGE ONNX model.
BGEAssetURL = "https://huggingface.co/BAAI/bge-small-en-v1.5/resolve/main/onnx/model.onnx"
// BGEModelSHA256 is the expected SHA-256 of the BGE model file.
BGEModelSHA256 = "828e1496d7fabb79cfa4dcd84fa38625c0d3d21da474a00f08db0f559940cf35"
)
// MaxSequenceLength is the maximum token sequence length for the model.
@@ -60,6 +68,7 @@ var _ EmbeddingModel = (*bgeModel)(nil)
var _ ONNXConfigurer = (*bgeModel)(nil)
// newBGEModel creates a new BGE embedding model using bundled ONNX runtime and model.
// The ONNX model is downloaded at runtime from GitHub Releases and cached to disk.
func newBGEModel() (EmbeddingModel, error) {
// Extract ONNX runtime library to temp directory
libDir, err := extractONNXLibrary()
@@ -84,9 +93,15 @@ func newBGEModel() (EmbeddingModel, error) {
return nil, fmt.Errorf("load tokenizer: %w", err)
}
// Create ONNX session using model-specific configuration
// Download model on first use (cached to ~/.claude-mnemonic/models/)
modelPath, err := models.EnsureModel(BGEAssetName, BGEAssetURL, BGEModelSHA256)
if err != nil {
return nil, fmt.Errorf("ensure embedding model: %w", err)
}
// Create ONNX session from file path using model-specific configuration
config := bgeONNXConfig
session, err := ort.NewDynamicAdvancedSessionWithONNXData(modelData, config.InputNames, config.OutputNames, nil)
session, err := ort.NewDynamicAdvancedSession(modelPath, config.InputNames, config.OutputNames, nil)
if err != nil {
return nil, fmt.Errorf("create ONNX session: %w", err)
}
+259
View File
@@ -0,0 +1,259 @@
// Package models provides runtime model download and caching.
// Models are downloaded from Hugging Face on first use and cached
// to the OS cache directory (~/Library/Caches/claude-mnemonic/models/ on macOS,
// ~/.cache/claude-mnemonic/models/ on Linux, %LocalAppData%/claude-mnemonic/models/ on Windows).
// This replaces the previous approach of embedding models in the binary
// via go:embed and tracking them with Git LFS.
package models
import (
"crypto/sha256"
"encoding/hex"
"errors"
"fmt"
"io"
"math"
"net/http"
"os"
"path/filepath"
"time"
"github.com/rs/zerolog/log"
)
const (
// ModelVersion is incremented to force re-download of all cached models.
ModelVersion = "1"
maxRetries = 3
retryBaseDelay = 2 * time.Second
progressEvery = 3 * time.Second
downloadTimeout = 15 * time.Minute
)
// ModelsDir returns the path to the model cache directory.
// Uses the OS cache dir for cross-platform compatibility:
//
// macOS: ~/Library/Caches/claude-mnemonic/models
// Linux: ~/.cache/claude-mnemonic/models
// Windows: %LocalAppData%/claude-mnemonic/models
func ModelsDir() string {
dir, err := os.UserCacheDir()
if err != nil {
home, _ := os.UserHomeDir()
dir = filepath.Join(home, ".cache")
}
return filepath.Join(dir, "claude-mnemonic", "models")
}
// EnsureModel ensures the model file exists and returns its path.
//
// Resolution order:
// 1. CLAUDE_MNEMONIC_MODEL_DIR env var — if set, load models from this directory directly.
// This is intended for development and CI where models are already on disk.
// 2. User cache directory — if the model is already cached and its checksum matches,
// return it immediately.
// 3. url — download the model from the given URL with retries and cache it.
//
// assetName is the local filename (e.g. "bge-small-en-v1.5-model.onnx").
// expectedSHA256 is the hex-encoded SHA-256 of the file.
func EnsureModel(assetName, url, expectedSHA256 string) (string, error) {
if localDir := os.Getenv("CLAUDE_MNEMONIC_MODEL_DIR"); localDir != "" {
localPath := filepath.Join(localDir, assetName)
if _, err := os.Stat(localPath); err == nil {
if valid, err := verifyChecksum(localPath, expectedSHA256); err != nil || !valid {
return "", fmt.Errorf("local model %s checksum mismatch", assetName)
}
return localPath, nil
}
return "", fmt.Errorf("local model %s not found in %s", assetName, localDir)
}
dir := ModelsDir()
versionFile := filepath.Join(dir, ".model_version")
modelPath := filepath.Join(dir, assetName)
if err := os.MkdirAll(dir, 0700); err != nil {
return "", fmt.Errorf("create model cache dir: %w", err)
}
currentVersion, _ := os.ReadFile(versionFile)
if string(currentVersion) == ModelVersion {
if fi, err := os.Stat(modelPath); err == nil && fi.Size() > 0 {
if valid, _ := verifyChecksum(modelPath, expectedSHA256); valid {
return modelPath, nil
}
log.Warn().Str("model", assetName).Msg("Cached model corrupted, re-downloading")
os.Remove(modelPath)
}
} else if len(currentVersion) > 0 {
log.Info().Str("old", string(currentVersion)).Str("new", ModelVersion).Msg("Model version updated, clearing cache")
os.RemoveAll(dir)
if err := os.MkdirAll(dir, 0700); err != nil {
return "", fmt.Errorf("create model cache dir: %w", err)
}
}
if err := downloadWithRetries(url, modelPath); err != nil {
os.Remove(modelPath)
return "", fmt.Errorf("download %s: %w", assetName, err)
}
if valid, err := verifyChecksum(modelPath, expectedSHA256); err != nil || !valid {
os.Remove(modelPath)
return "", fmt.Errorf("downloaded model %s failed checksum verification — the file at %s may have changed", assetName, url)
}
if err := os.WriteFile(versionFile, []byte(ModelVersion), 0600); err != nil {
log.Warn().Err(err).Msg("Failed to write model version file")
}
log.Info().Str("model", assetName).Str("cache", dir).Msg("Model ready")
return modelPath, nil
}
func downloadWithRetries(url, dest string) error {
var lastErr error
for attempt := 0; attempt < maxRetries; attempt++ {
if attempt > 0 {
delay := retryBaseDelay * time.Duration(1<<(attempt-1))
log.Warn().Int("attempt", attempt+1).Int("max", maxRetries).Dur("delay", delay).Msg("Retrying download")
time.Sleep(delay)
}
if err := downloadFile(url, dest); err != nil {
lastErr = err
continue
}
return nil
}
return fmt.Errorf("failed after %d attempts: %w", maxRetries, lastErr)
}
func downloadFile(url, dest string) error {
client := &http.Client{Timeout: downloadTimeout}
resp, err := client.Get(url)
if err != nil {
return fmt.Errorf("HTTP GET: %w", err)
}
defer resp.Body.Close()
if resp.StatusCode < 200 || resp.StatusCode >= 400 {
return fmt.Errorf("HTTP %d", resp.StatusCode)
}
tmp := dest + ".tmp"
f, err := os.Create(tmp)
if err != nil {
return fmt.Errorf("create temp file: %w", err)
}
cleanup := func() { f.Close(); os.Remove(tmp) }
name := filepath.Base(dest)
reader := &progressReader{
reader: resp.Body,
total: resp.ContentLength,
name: name,
}
if _, err := io.Copy(f, reader); err != nil {
cleanup()
return fmt.Errorf("write file: %w", err)
}
if err := f.Sync(); err != nil {
cleanup()
return fmt.Errorf("sync file: %w", err)
}
f.Close()
if err := os.Rename(tmp, dest); err != nil {
os.Remove(tmp)
return fmt.Errorf("commit file: %w", err)
}
return nil
}
// progressReader wraps an io.Reader and logs progress for large downloads.
type progressReader struct {
reader io.Reader
total int64
current int64
name string
lastLog time.Time
loggedPct int
}
func (pr *progressReader) Read(p []byte) (n int, err error) {
n, err = pr.reader.Read(p)
pr.current += int64(n)
if time.Since(pr.lastLog) >= progressEvery || err == io.EOF {
if pr.total > 0 {
pct := int(math.Round(float64(pr.current) / float64(pr.total) * 100))
if pct != pr.loggedPct {
pr.loggedPct = pct
mibCurrent := float64(pr.current) / (1024 * 1024)
mibTotal := float64(pr.total) / (1024 * 1024)
log.Info().
Str("model", pr.name).
Int("pct", pct).
Str("progress", fmt.Sprintf("%.0f/%.0f MiB", mibCurrent, mibTotal)).
Msg("Downloading")
}
} else {
mibCurrent := float64(pr.current) / (1024 * 1024)
log.Info().
Str("model", pr.name).
Str("progress", fmt.Sprintf("%.0f MiB", mibCurrent)).
Msg("Downloading (size unknown)")
}
pr.lastLog = time.Now()
}
return n, err
}
func verifyChecksum(path, expectedHex string) (bool, error) {
f, err := os.Open(path)
if err != nil {
return false, err
}
defer f.Close()
h := sha256.New()
if _, err := io.Copy(h, f); err != nil {
return false, err
}
actual := hex.EncodeToString(h.Sum(nil))
return actual == expectedHex, nil
}
// CleanStale removes any temporary download files that may have been left behind
// by a previous interrupted download.
func CleanStale() error {
dir := ModelsDir()
entries, err := os.ReadDir(dir)
if err != nil {
if errors.Is(err, os.ErrNotExist) {
return nil
}
return err
}
for _, e := range entries {
if filepath.Ext(e.Name()) == ".tmp" {
tmpPath := filepath.Join(dir, e.Name())
if info, err := e.Info(); err == nil {
if time.Since(info.ModTime()) > time.Hour {
os.Remove(tmpPath)
log.Debug().Str("file", tmpPath).Msg("Cleaned stale temp file")
}
}
}
}
return nil
}
+2 -4
View File
@@ -5,10 +5,8 @@ import (
_ "embed"
)
// Cross-encoder model and tokenizer files - embedded for all platforms
// Tokenizer file - embedded for all platforms (small, not in LFS).
// The ONNX model is downloaded at runtime to ~/.claude-mnemonic/models/.
//
//go:embed assets/model.onnx
var crossEncoderModelData []byte
//go:embed assets/tokenizer.json
var crossEncoderTokenizerData []byte
Binary file not shown.
+16 -2
View File
@@ -9,6 +9,7 @@ import (
"sort"
"sync"
"github.com/lukaszraczylo/claude-mnemonic/internal/models"
"github.com/rs/zerolog/log"
"github.com/sugarme/tokenizer"
"github.com/sugarme/tokenizer/pretrained"
@@ -26,6 +27,13 @@ const (
DefaultCandidateLimit = 100
// DefaultResultLimit is the default number of results to return after reranking
DefaultResultLimit = 10
// CrossEncoderAssetName is the local filename for the cross-encoder model.
CrossEncoderAssetName = "ms-marco-MiniLM-L6-v2-model.onnx"
// CrossEncoderAssetURL is the Hugging Face URL for the cross-encoder ONNX model.
CrossEncoderAssetURL = "https://huggingface.co/cross-encoder/ms-marco-MiniLM-L6-v2/resolve/main/onnx/model.onnx"
// CrossEncoderModelSHA256 is the expected SHA-256 of the cross-encoder model file.
CrossEncoderModelSHA256 = "5d3e70fd0c9ff14b9b5169a51e957b7a9c74897afd0a35ce4bd318150c1d4d4a"
)
// Candidate represents a search result candidate for reranking.
@@ -91,12 +99,18 @@ func NewService(cfg Config) (*Service, error) {
Stride: 0,
})
// Download model on first use (cached to ~/.claude-mnemonic/models/)
modelPath, err := models.EnsureModel(CrossEncoderAssetName, CrossEncoderAssetURL, CrossEncoderModelSHA256)
if err != nil {
return nil, fmt.Errorf("ensure cross-encoder model: %w", err)
}
// Cross-encoder outputs a single logit for relevance scoring
inputNames := []string{"input_ids", "attention_mask", "token_type_ids"}
outputNames := []string{"logits"}
session, err := ort.NewDynamicAdvancedSessionWithONNXData(
crossEncoderModelData,
session, err := ort.NewDynamicAdvancedSession(
modelPath,
inputNames,
outputNames,
nil,