mirror of
https://github.com/lukaszraczylo/claude-mnemonic.git
synced 2026-06-05 23:03:55 +00:00
refactor: replace Git LFS with runtime model download from Hugging Face
Remove ~170MB of model files from the repository (LFS + committed). Models are now downloaded at runtime from Hugging Face on first use and cached to the OS cache directory with progress reporting and retries. - Add internal/models/download.go: runtime downloader with retry, progress bar, checksums - Remove go:embed for ONNX models (keep tokenizers embedded) - Use file-based ONNX session loading instead of byte-slice - Add scripts/download-models.sh for dev/CI model setup - Update Makefile with setup-models target - Update workflow-prepare.sh to download models in CI - Set lfs: false in all CI workflows - SHA256: bge=828e14..., cross-encoder=5d3e70...
This commit is contained in:
@@ -1,14 +1,12 @@
|
||||
// Package embedding provides text embedding generation using all-MiniLM-L6-v2.
|
||||
// Package embedding provides text embedding generation using bge-small-en-v1.5.
|
||||
package embedding
|
||||
|
||||
import (
|
||||
_ "embed"
|
||||
)
|
||||
|
||||
// Model and tokenizer files - embedded for all platforms
|
||||
// Tokenizer file - embedded for all platforms (small, not in LFS).
|
||||
// The ONNX model is downloaded at runtime to ~/.claude-mnemonic/models/.
|
||||
//
|
||||
//go:embed assets/model.onnx
|
||||
var modelData []byte
|
||||
|
||||
//go:embed assets/tokenizer.json
|
||||
var tokenizerData []byte
|
||||
|
||||
Binary file not shown.
@@ -12,6 +12,7 @@ import (
|
||||
"strings"
|
||||
"sync"
|
||||
|
||||
"github.com/lukaszraczylo/claude-mnemonic/internal/models"
|
||||
"github.com/sugarme/tokenizer"
|
||||
"github.com/sugarme/tokenizer/pretrained"
|
||||
ort "github.com/yalue/onnxruntime_go"
|
||||
@@ -29,6 +30,13 @@ const (
|
||||
BGEModelName = "bge-small-en-v1.5"
|
||||
// DefaultModelVersion is the default model to use
|
||||
DefaultModelVersion = BGEModelVersion
|
||||
|
||||
// BGEAssetName is the local filename for the BGE model.
|
||||
BGEAssetName = "bge-small-en-v1.5-model.onnx"
|
||||
// BGEAssetURL is the Hugging Face URL for the BGE ONNX model.
|
||||
BGEAssetURL = "https://huggingface.co/BAAI/bge-small-en-v1.5/resolve/main/onnx/model.onnx"
|
||||
// BGEModelSHA256 is the expected SHA-256 of the BGE model file.
|
||||
BGEModelSHA256 = "828e1496d7fabb79cfa4dcd84fa38625c0d3d21da474a00f08db0f559940cf35"
|
||||
)
|
||||
|
||||
// MaxSequenceLength is the maximum token sequence length for the model.
|
||||
@@ -60,6 +68,7 @@ var _ EmbeddingModel = (*bgeModel)(nil)
|
||||
var _ ONNXConfigurer = (*bgeModel)(nil)
|
||||
|
||||
// newBGEModel creates a new BGE embedding model using bundled ONNX runtime and model.
|
||||
// The ONNX model is downloaded at runtime from GitHub Releases and cached to disk.
|
||||
func newBGEModel() (EmbeddingModel, error) {
|
||||
// Extract ONNX runtime library to temp directory
|
||||
libDir, err := extractONNXLibrary()
|
||||
@@ -84,9 +93,15 @@ func newBGEModel() (EmbeddingModel, error) {
|
||||
return nil, fmt.Errorf("load tokenizer: %w", err)
|
||||
}
|
||||
|
||||
// Create ONNX session using model-specific configuration
|
||||
// Download model on first use (cached to ~/.claude-mnemonic/models/)
|
||||
modelPath, err := models.EnsureModel(BGEAssetName, BGEAssetURL, BGEModelSHA256)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("ensure embedding model: %w", err)
|
||||
}
|
||||
|
||||
// Create ONNX session from file path using model-specific configuration
|
||||
config := bgeONNXConfig
|
||||
session, err := ort.NewDynamicAdvancedSessionWithONNXData(modelData, config.InputNames, config.OutputNames, nil)
|
||||
session, err := ort.NewDynamicAdvancedSession(modelPath, config.InputNames, config.OutputNames, nil)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("create ONNX session: %w", err)
|
||||
}
|
||||
|
||||
@@ -0,0 +1,259 @@
|
||||
// Package models provides runtime model download and caching.
|
||||
// Models are downloaded from Hugging Face on first use and cached
|
||||
// to the OS cache directory (~/Library/Caches/claude-mnemonic/models/ on macOS,
|
||||
// ~/.cache/claude-mnemonic/models/ on Linux, %LocalAppData%/claude-mnemonic/models/ on Windows).
|
||||
// This replaces the previous approach of embedding models in the binary
|
||||
// via go:embed and tracking them with Git LFS.
|
||||
package models
|
||||
|
||||
import (
|
||||
"crypto/sha256"
|
||||
"encoding/hex"
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
"math"
|
||||
"net/http"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"time"
|
||||
|
||||
"github.com/rs/zerolog/log"
|
||||
)
|
||||
|
||||
const (
|
||||
// ModelVersion is incremented to force re-download of all cached models.
|
||||
ModelVersion = "1"
|
||||
|
||||
maxRetries = 3
|
||||
retryBaseDelay = 2 * time.Second
|
||||
progressEvery = 3 * time.Second
|
||||
downloadTimeout = 15 * time.Minute
|
||||
)
|
||||
|
||||
// ModelsDir returns the path to the model cache directory.
|
||||
// Uses the OS cache dir for cross-platform compatibility:
|
||||
//
|
||||
// macOS: ~/Library/Caches/claude-mnemonic/models
|
||||
// Linux: ~/.cache/claude-mnemonic/models
|
||||
// Windows: %LocalAppData%/claude-mnemonic/models
|
||||
func ModelsDir() string {
|
||||
dir, err := os.UserCacheDir()
|
||||
if err != nil {
|
||||
home, _ := os.UserHomeDir()
|
||||
dir = filepath.Join(home, ".cache")
|
||||
}
|
||||
return filepath.Join(dir, "claude-mnemonic", "models")
|
||||
}
|
||||
|
||||
// EnsureModel ensures the model file exists and returns its path.
|
||||
//
|
||||
// Resolution order:
|
||||
// 1. CLAUDE_MNEMONIC_MODEL_DIR env var — if set, load models from this directory directly.
|
||||
// This is intended for development and CI where models are already on disk.
|
||||
// 2. User cache directory — if the model is already cached and its checksum matches,
|
||||
// return it immediately.
|
||||
// 3. url — download the model from the given URL with retries and cache it.
|
||||
//
|
||||
// assetName is the local filename (e.g. "bge-small-en-v1.5-model.onnx").
|
||||
// expectedSHA256 is the hex-encoded SHA-256 of the file.
|
||||
func EnsureModel(assetName, url, expectedSHA256 string) (string, error) {
|
||||
if localDir := os.Getenv("CLAUDE_MNEMONIC_MODEL_DIR"); localDir != "" {
|
||||
localPath := filepath.Join(localDir, assetName)
|
||||
if _, err := os.Stat(localPath); err == nil {
|
||||
if valid, err := verifyChecksum(localPath, expectedSHA256); err != nil || !valid {
|
||||
return "", fmt.Errorf("local model %s checksum mismatch", assetName)
|
||||
}
|
||||
return localPath, nil
|
||||
}
|
||||
return "", fmt.Errorf("local model %s not found in %s", assetName, localDir)
|
||||
}
|
||||
|
||||
dir := ModelsDir()
|
||||
versionFile := filepath.Join(dir, ".model_version")
|
||||
modelPath := filepath.Join(dir, assetName)
|
||||
|
||||
if err := os.MkdirAll(dir, 0700); err != nil {
|
||||
return "", fmt.Errorf("create model cache dir: %w", err)
|
||||
}
|
||||
|
||||
currentVersion, _ := os.ReadFile(versionFile)
|
||||
|
||||
if string(currentVersion) == ModelVersion {
|
||||
if fi, err := os.Stat(modelPath); err == nil && fi.Size() > 0 {
|
||||
if valid, _ := verifyChecksum(modelPath, expectedSHA256); valid {
|
||||
return modelPath, nil
|
||||
}
|
||||
log.Warn().Str("model", assetName).Msg("Cached model corrupted, re-downloading")
|
||||
os.Remove(modelPath)
|
||||
}
|
||||
} else if len(currentVersion) > 0 {
|
||||
log.Info().Str("old", string(currentVersion)).Str("new", ModelVersion).Msg("Model version updated, clearing cache")
|
||||
os.RemoveAll(dir)
|
||||
if err := os.MkdirAll(dir, 0700); err != nil {
|
||||
return "", fmt.Errorf("create model cache dir: %w", err)
|
||||
}
|
||||
}
|
||||
|
||||
if err := downloadWithRetries(url, modelPath); err != nil {
|
||||
os.Remove(modelPath)
|
||||
return "", fmt.Errorf("download %s: %w", assetName, err)
|
||||
}
|
||||
|
||||
if valid, err := verifyChecksum(modelPath, expectedSHA256); err != nil || !valid {
|
||||
os.Remove(modelPath)
|
||||
return "", fmt.Errorf("downloaded model %s failed checksum verification — the file at %s may have changed", assetName, url)
|
||||
}
|
||||
|
||||
if err := os.WriteFile(versionFile, []byte(ModelVersion), 0600); err != nil {
|
||||
log.Warn().Err(err).Msg("Failed to write model version file")
|
||||
}
|
||||
|
||||
log.Info().Str("model", assetName).Str("cache", dir).Msg("Model ready")
|
||||
return modelPath, nil
|
||||
}
|
||||
|
||||
func downloadWithRetries(url, dest string) error {
|
||||
var lastErr error
|
||||
for attempt := 0; attempt < maxRetries; attempt++ {
|
||||
if attempt > 0 {
|
||||
delay := retryBaseDelay * time.Duration(1<<(attempt-1))
|
||||
log.Warn().Int("attempt", attempt+1).Int("max", maxRetries).Dur("delay", delay).Msg("Retrying download")
|
||||
time.Sleep(delay)
|
||||
}
|
||||
if err := downloadFile(url, dest); err != nil {
|
||||
lastErr = err
|
||||
continue
|
||||
}
|
||||
return nil
|
||||
}
|
||||
return fmt.Errorf("failed after %d attempts: %w", maxRetries, lastErr)
|
||||
}
|
||||
|
||||
func downloadFile(url, dest string) error {
|
||||
client := &http.Client{Timeout: downloadTimeout}
|
||||
|
||||
resp, err := client.Get(url)
|
||||
if err != nil {
|
||||
return fmt.Errorf("HTTP GET: %w", err)
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
if resp.StatusCode < 200 || resp.StatusCode >= 400 {
|
||||
return fmt.Errorf("HTTP %d", resp.StatusCode)
|
||||
}
|
||||
|
||||
tmp := dest + ".tmp"
|
||||
f, err := os.Create(tmp)
|
||||
if err != nil {
|
||||
return fmt.Errorf("create temp file: %w", err)
|
||||
}
|
||||
cleanup := func() { f.Close(); os.Remove(tmp) }
|
||||
|
||||
name := filepath.Base(dest)
|
||||
reader := &progressReader{
|
||||
reader: resp.Body,
|
||||
total: resp.ContentLength,
|
||||
name: name,
|
||||
}
|
||||
|
||||
if _, err := io.Copy(f, reader); err != nil {
|
||||
cleanup()
|
||||
return fmt.Errorf("write file: %w", err)
|
||||
}
|
||||
|
||||
if err := f.Sync(); err != nil {
|
||||
cleanup()
|
||||
return fmt.Errorf("sync file: %w", err)
|
||||
}
|
||||
f.Close()
|
||||
|
||||
if err := os.Rename(tmp, dest); err != nil {
|
||||
os.Remove(tmp)
|
||||
return fmt.Errorf("commit file: %w", err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// progressReader wraps an io.Reader and logs progress for large downloads.
|
||||
type progressReader struct {
|
||||
reader io.Reader
|
||||
total int64
|
||||
current int64
|
||||
name string
|
||||
lastLog time.Time
|
||||
loggedPct int
|
||||
}
|
||||
|
||||
func (pr *progressReader) Read(p []byte) (n int, err error) {
|
||||
n, err = pr.reader.Read(p)
|
||||
pr.current += int64(n)
|
||||
|
||||
if time.Since(pr.lastLog) >= progressEvery || err == io.EOF {
|
||||
if pr.total > 0 {
|
||||
pct := int(math.Round(float64(pr.current) / float64(pr.total) * 100))
|
||||
if pct != pr.loggedPct {
|
||||
pr.loggedPct = pct
|
||||
mibCurrent := float64(pr.current) / (1024 * 1024)
|
||||
mibTotal := float64(pr.total) / (1024 * 1024)
|
||||
log.Info().
|
||||
Str("model", pr.name).
|
||||
Int("pct", pct).
|
||||
Str("progress", fmt.Sprintf("%.0f/%.0f MiB", mibCurrent, mibTotal)).
|
||||
Msg("Downloading")
|
||||
}
|
||||
} else {
|
||||
mibCurrent := float64(pr.current) / (1024 * 1024)
|
||||
log.Info().
|
||||
Str("model", pr.name).
|
||||
Str("progress", fmt.Sprintf("%.0f MiB", mibCurrent)).
|
||||
Msg("Downloading (size unknown)")
|
||||
}
|
||||
pr.lastLog = time.Now()
|
||||
}
|
||||
|
||||
return n, err
|
||||
}
|
||||
|
||||
func verifyChecksum(path, expectedHex string) (bool, error) {
|
||||
f, err := os.Open(path)
|
||||
if err != nil {
|
||||
return false, err
|
||||
}
|
||||
defer f.Close()
|
||||
|
||||
h := sha256.New()
|
||||
if _, err := io.Copy(h, f); err != nil {
|
||||
return false, err
|
||||
}
|
||||
|
||||
actual := hex.EncodeToString(h.Sum(nil))
|
||||
return actual == expectedHex, nil
|
||||
}
|
||||
|
||||
// CleanStale removes any temporary download files that may have been left behind
|
||||
// by a previous interrupted download.
|
||||
func CleanStale() error {
|
||||
dir := ModelsDir()
|
||||
entries, err := os.ReadDir(dir)
|
||||
if err != nil {
|
||||
if errors.Is(err, os.ErrNotExist) {
|
||||
return nil
|
||||
}
|
||||
return err
|
||||
}
|
||||
|
||||
for _, e := range entries {
|
||||
if filepath.Ext(e.Name()) == ".tmp" {
|
||||
tmpPath := filepath.Join(dir, e.Name())
|
||||
if info, err := e.Info(); err == nil {
|
||||
if time.Since(info.ModTime()) > time.Hour {
|
||||
os.Remove(tmpPath)
|
||||
log.Debug().Str("file", tmpPath).Msg("Cleaned stale temp file")
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
@@ -5,10 +5,8 @@ import (
|
||||
_ "embed"
|
||||
)
|
||||
|
||||
// Cross-encoder model and tokenizer files - embedded for all platforms
|
||||
// Tokenizer file - embedded for all platforms (small, not in LFS).
|
||||
// The ONNX model is downloaded at runtime to ~/.claude-mnemonic/models/.
|
||||
//
|
||||
//go:embed assets/model.onnx
|
||||
var crossEncoderModelData []byte
|
||||
|
||||
//go:embed assets/tokenizer.json
|
||||
var crossEncoderTokenizerData []byte
|
||||
|
||||
Binary file not shown.
@@ -9,6 +9,7 @@ import (
|
||||
"sort"
|
||||
"sync"
|
||||
|
||||
"github.com/lukaszraczylo/claude-mnemonic/internal/models"
|
||||
"github.com/rs/zerolog/log"
|
||||
"github.com/sugarme/tokenizer"
|
||||
"github.com/sugarme/tokenizer/pretrained"
|
||||
@@ -26,6 +27,13 @@ const (
|
||||
DefaultCandidateLimit = 100
|
||||
// DefaultResultLimit is the default number of results to return after reranking
|
||||
DefaultResultLimit = 10
|
||||
|
||||
// CrossEncoderAssetName is the local filename for the cross-encoder model.
|
||||
CrossEncoderAssetName = "ms-marco-MiniLM-L6-v2-model.onnx"
|
||||
// CrossEncoderAssetURL is the Hugging Face URL for the cross-encoder ONNX model.
|
||||
CrossEncoderAssetURL = "https://huggingface.co/cross-encoder/ms-marco-MiniLM-L6-v2/resolve/main/onnx/model.onnx"
|
||||
// CrossEncoderModelSHA256 is the expected SHA-256 of the cross-encoder model file.
|
||||
CrossEncoderModelSHA256 = "5d3e70fd0c9ff14b9b5169a51e957b7a9c74897afd0a35ce4bd318150c1d4d4a"
|
||||
)
|
||||
|
||||
// Candidate represents a search result candidate for reranking.
|
||||
@@ -91,12 +99,18 @@ func NewService(cfg Config) (*Service, error) {
|
||||
Stride: 0,
|
||||
})
|
||||
|
||||
// Download model on first use (cached to ~/.claude-mnemonic/models/)
|
||||
modelPath, err := models.EnsureModel(CrossEncoderAssetName, CrossEncoderAssetURL, CrossEncoderModelSHA256)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("ensure cross-encoder model: %w", err)
|
||||
}
|
||||
|
||||
// Cross-encoder outputs a single logit for relevance scoring
|
||||
inputNames := []string{"input_ids", "attention_mask", "token_type_ids"}
|
||||
outputNames := []string{"logits"}
|
||||
|
||||
session, err := ort.NewDynamicAdvancedSessionWithONNXData(
|
||||
crossEncoderModelData,
|
||||
session, err := ort.NewDynamicAdvancedSession(
|
||||
modelPath,
|
||||
inputNames,
|
||||
outputNames,
|
||||
nil,
|
||||
|
||||
Reference in New Issue
Block a user