mirror of
https://github.com/lukaszraczylo/claude-mnemonic.git
synced 2026-06-05 23:03:55 +00:00
4f4b4ac70f
- [x] Add language-specific chunkers with AST parsing (Go, Python, TypeScript) - [x] Implement chunking manager to dispatch files to appropriate chunkers - [x] Integrate code chunks into vector sync for semantic search - [x] Add tree-sitter dependency for Python/TypeScript parsing - [x] Reorder struct fields for consistency across codebase - [x] Rename error variables to follow Go conventions (err → unmarshalErr, etc.) - [x] Add code chunk metadata to vector documents (language, symbol name, line ranges) - [x] Update worker service to initialize chunking pipeline with all three languages
381 lines
10 KiB
Go
381 lines
10 KiB
Go
// Package reranking provides cross-encoder reranking for search results.
|
|
// Uses MS-MARCO MiniLM L6 v2 cross-encoder model for relevance scoring.
|
|
package reranking
|
|
|
|
import (
|
|
"bytes"
|
|
"fmt"
|
|
"math"
|
|
"sort"
|
|
"sync"
|
|
|
|
"github.com/rs/zerolog/log"
|
|
"github.com/sugarme/tokenizer"
|
|
"github.com/sugarme/tokenizer/pretrained"
|
|
ort "github.com/yalue/onnxruntime_go"
|
|
)
|
|
|
|
const (
|
|
// ModelName is the human-readable name for the cross-encoder model
|
|
ModelName = "ms-marco-MiniLM-L6-v2"
|
|
// ModelVersion is the short version string for identification
|
|
ModelVersion = "msmarco-v2"
|
|
// MaxSequenceLength is the maximum combined query+document token length
|
|
MaxSequenceLength = 512
|
|
// DefaultCandidateLimit is the default number of candidates to rerank
|
|
DefaultCandidateLimit = 100
|
|
// DefaultResultLimit is the default number of results to return after reranking
|
|
DefaultResultLimit = 10
|
|
)
|
|
|
|
// Candidate represents a search result candidate for reranking.
|
|
type Candidate struct {
|
|
Metadata map[string]any
|
|
RerankInfo map[string]float64
|
|
ID string
|
|
Content string
|
|
Score float64
|
|
}
|
|
|
|
// RerankResult represents a reranked search result.
|
|
type RerankResult struct {
|
|
Metadata map[string]any
|
|
ID string
|
|
Content string
|
|
OriginalScore float64
|
|
RerankScore float64
|
|
CombinedScore float64
|
|
OriginalRank int
|
|
RerankRank int
|
|
RankImprovement int
|
|
}
|
|
|
|
// Service provides cross-encoder reranking functionality.
|
|
type Service struct {
|
|
tk *tokenizer.Tokenizer
|
|
session *ort.DynamicAdvancedSession
|
|
mu sync.Mutex
|
|
|
|
// Weight for combining scores: combined = alpha*rerank + (1-alpha)*original
|
|
// Default 0.7 favors cross-encoder score
|
|
Alpha float64
|
|
}
|
|
|
|
// Config holds configuration for the reranking service.
|
|
type Config struct {
|
|
// Alpha is the weight for combining scores (0.0-1.0)
|
|
// Higher values favor cross-encoder scores, lower values favor bi-encoder scores
|
|
Alpha float64
|
|
}
|
|
|
|
// DefaultConfig returns sensible defaults for reranking.
|
|
func DefaultConfig() Config {
|
|
return Config{
|
|
Alpha: 0.7, // Favor cross-encoder by default
|
|
}
|
|
}
|
|
|
|
// NewService creates a new cross-encoder reranking service.
|
|
// Note: ONNX runtime must be initialized before calling this (via embedding.NewService).
|
|
func NewService(cfg Config) (*Service, error) {
|
|
// Load tokenizer from embedded data
|
|
tk, err := pretrained.FromReader(bytes.NewReader(crossEncoderTokenizerData))
|
|
if err != nil {
|
|
return nil, fmt.Errorf("load cross-encoder tokenizer: %w", err)
|
|
}
|
|
|
|
// Configure tokenizer for sequence classification (pairs)
|
|
tk.WithTruncation(&tokenizer.TruncationParams{
|
|
MaxLength: MaxSequenceLength,
|
|
Strategy: tokenizer.LongestFirst,
|
|
Stride: 0,
|
|
})
|
|
|
|
// Cross-encoder outputs a single logit for relevance scoring
|
|
inputNames := []string{"input_ids", "attention_mask", "token_type_ids"}
|
|
outputNames := []string{"logits"}
|
|
|
|
session, err := ort.NewDynamicAdvancedSessionWithONNXData(
|
|
crossEncoderModelData,
|
|
inputNames,
|
|
outputNames,
|
|
nil,
|
|
)
|
|
if err != nil {
|
|
return nil, fmt.Errorf("create cross-encoder ONNX session: %w", err)
|
|
}
|
|
|
|
alpha := cfg.Alpha
|
|
if alpha <= 0 || alpha > 1 {
|
|
alpha = 0.7
|
|
}
|
|
|
|
return &Service{
|
|
tk: tk,
|
|
session: session,
|
|
Alpha: alpha,
|
|
}, nil
|
|
}
|
|
|
|
// Rerank reranks candidates using the cross-encoder model.
|
|
// Takes a query and list of candidates, returns reranked results.
|
|
func (s *Service) Rerank(query string, candidates []Candidate, limit int) ([]RerankResult, error) {
|
|
if len(candidates) == 0 {
|
|
return nil, nil
|
|
}
|
|
|
|
if limit <= 0 {
|
|
limit = DefaultResultLimit
|
|
}
|
|
|
|
s.mu.Lock()
|
|
defer s.mu.Unlock()
|
|
|
|
// Score all query-document pairs
|
|
scores, err := s.scoreAll(query, candidates)
|
|
if err != nil {
|
|
return nil, fmt.Errorf("score candidates: %w", err)
|
|
}
|
|
|
|
// Build results with combined scores
|
|
results := make([]RerankResult, len(candidates))
|
|
for i, c := range candidates {
|
|
// Normalize cross-encoder score to 0-1 range using sigmoid
|
|
normalizedRerank := sigmoid(scores[i])
|
|
|
|
results[i] = RerankResult{
|
|
ID: c.ID,
|
|
Content: c.Content,
|
|
OriginalScore: c.Score,
|
|
RerankScore: normalizedRerank,
|
|
CombinedScore: s.Alpha*normalizedRerank + (1-s.Alpha)*c.Score,
|
|
Metadata: c.Metadata,
|
|
OriginalRank: i + 1,
|
|
}
|
|
}
|
|
|
|
// Sort by combined score (descending)
|
|
sort.Slice(results, func(i, j int) bool {
|
|
return results[i].CombinedScore > results[j].CombinedScore
|
|
})
|
|
|
|
// Assign rerank positions and calculate improvement
|
|
for i := range results {
|
|
results[i].RerankRank = i + 1
|
|
results[i].RankImprovement = results[i].OriginalRank - results[i].RerankRank
|
|
}
|
|
|
|
// Apply limit
|
|
if len(results) > limit {
|
|
results = results[:limit]
|
|
}
|
|
|
|
log.Debug().
|
|
Int("candidates", len(candidates)).
|
|
Int("returned", len(results)).
|
|
Float64("alpha", s.Alpha).
|
|
Msg("Cross-encoder reranking completed")
|
|
|
|
return results, nil
|
|
}
|
|
|
|
// RerankByScore reranks candidates and returns sorted by pure cross-encoder score.
|
|
// Useful when you want to completely replace bi-encoder ranking.
|
|
func (s *Service) RerankByScore(query string, candidates []Candidate, limit int) ([]RerankResult, error) {
|
|
if len(candidates) == 0 {
|
|
return nil, nil
|
|
}
|
|
|
|
if limit <= 0 {
|
|
limit = DefaultResultLimit
|
|
}
|
|
|
|
s.mu.Lock()
|
|
defer s.mu.Unlock()
|
|
|
|
scores, err := s.scoreAll(query, candidates)
|
|
if err != nil {
|
|
return nil, fmt.Errorf("score candidates: %w", err)
|
|
}
|
|
|
|
results := make([]RerankResult, len(candidates))
|
|
for i, c := range candidates {
|
|
normalizedRerank := sigmoid(scores[i])
|
|
results[i] = RerankResult{
|
|
ID: c.ID,
|
|
Content: c.Content,
|
|
OriginalScore: c.Score,
|
|
RerankScore: normalizedRerank,
|
|
CombinedScore: normalizedRerank, // Use pure rerank score
|
|
Metadata: c.Metadata,
|
|
OriginalRank: i + 1,
|
|
}
|
|
}
|
|
|
|
// Sort by rerank score only
|
|
sort.Slice(results, func(i, j int) bool {
|
|
return results[i].RerankScore > results[j].RerankScore
|
|
})
|
|
|
|
for i := range results {
|
|
results[i].RerankRank = i + 1
|
|
results[i].RankImprovement = results[i].OriginalRank - results[i].RerankRank
|
|
}
|
|
|
|
if len(results) > limit {
|
|
results = results[:limit]
|
|
}
|
|
|
|
return results, nil
|
|
}
|
|
|
|
// scoreAll scores all query-document pairs using the cross-encoder.
|
|
// Returns raw logits (before sigmoid normalization).
|
|
func (s *Service) scoreAll(query string, candidates []Candidate) ([]float64, error) {
|
|
batchSize := len(candidates)
|
|
|
|
// Tokenize all query-document pairs
|
|
pairs := make([]tokenizer.EncodeInput, batchSize)
|
|
for i, c := range candidates {
|
|
// Cross-encoder takes query and document as a pair
|
|
pairs[i] = tokenizer.NewDualEncodeInput(
|
|
tokenizer.NewRawInputSequence(query),
|
|
tokenizer.NewRawInputSequence(c.Content),
|
|
)
|
|
}
|
|
|
|
encodings, err := s.tk.EncodeBatch(pairs, true)
|
|
if err != nil {
|
|
return nil, fmt.Errorf("tokenize pairs: %w", err)
|
|
}
|
|
|
|
// Find max sequence length
|
|
seqLength := 0
|
|
for _, enc := range encodings {
|
|
if len(enc.Ids) > seqLength {
|
|
seqLength = len(enc.Ids)
|
|
}
|
|
}
|
|
if seqLength > MaxSequenceLength {
|
|
seqLength = MaxSequenceLength
|
|
}
|
|
|
|
inputShape := ort.NewShape(int64(batchSize), int64(seqLength))
|
|
|
|
// Create input tensors
|
|
inputIdsData := make([]int64, batchSize*seqLength)
|
|
attentionMaskData := make([]int64, batchSize*seqLength)
|
|
tokenTypeIdsData := make([]int64, batchSize*seqLength)
|
|
|
|
for b := 0; b < batchSize; b++ {
|
|
copyLen := len(encodings[b].Ids)
|
|
if copyLen > seqLength {
|
|
copyLen = seqLength
|
|
}
|
|
for i := 0; i < copyLen; i++ {
|
|
inputIdsData[b*seqLength+i] = int64(encodings[b].Ids[i])
|
|
}
|
|
|
|
copyLen = len(encodings[b].AttentionMask)
|
|
if copyLen > seqLength {
|
|
copyLen = seqLength
|
|
}
|
|
for i := 0; i < copyLen; i++ {
|
|
attentionMaskData[b*seqLength+i] = int64(encodings[b].AttentionMask[i])
|
|
}
|
|
|
|
copyLen = len(encodings[b].TypeIds)
|
|
if copyLen > seqLength {
|
|
copyLen = seqLength
|
|
}
|
|
for i := 0; i < copyLen; i++ {
|
|
tokenTypeIdsData[b*seqLength+i] = int64(encodings[b].TypeIds[i])
|
|
}
|
|
}
|
|
|
|
inputIdsTensor, err := ort.NewTensor(inputShape, inputIdsData)
|
|
if err != nil {
|
|
return nil, fmt.Errorf("create input_ids tensor: %w", err)
|
|
}
|
|
defer func() { _ = inputIdsTensor.Destroy() }()
|
|
|
|
attentionMaskTensor, err := ort.NewTensor(inputShape, attentionMaskData)
|
|
if err != nil {
|
|
return nil, fmt.Errorf("create attention_mask tensor: %w", err)
|
|
}
|
|
defer func() { _ = attentionMaskTensor.Destroy() }()
|
|
|
|
tokenTypeIdsTensor, err := ort.NewTensor(inputShape, tokenTypeIdsData)
|
|
if err != nil {
|
|
return nil, fmt.Errorf("create token_type_ids tensor: %w", err)
|
|
}
|
|
defer func() { _ = tokenTypeIdsTensor.Destroy() }()
|
|
|
|
// Cross-encoder outputs [batch, 1] logits
|
|
outputShape := ort.NewShape(int64(batchSize), 1)
|
|
outputTensor, err := ort.NewEmptyTensor[float32](outputShape)
|
|
if err != nil {
|
|
return nil, fmt.Errorf("create output tensor: %w", err)
|
|
}
|
|
defer func() { _ = outputTensor.Destroy() }()
|
|
|
|
// Run inference
|
|
inputTensors := []ort.Value{inputIdsTensor, attentionMaskTensor, tokenTypeIdsTensor}
|
|
outputTensors := []ort.Value{outputTensor}
|
|
|
|
if err := s.session.Run(inputTensors, outputTensors); err != nil {
|
|
return nil, fmt.Errorf("run cross-encoder inference: %w", err)
|
|
}
|
|
|
|
// Extract scores
|
|
flatOutput := outputTensor.GetData()
|
|
scores := make([]float64, batchSize)
|
|
for i := 0; i < batchSize; i++ {
|
|
scores[i] = float64(flatOutput[i])
|
|
}
|
|
|
|
return scores, nil
|
|
}
|
|
|
|
// Score scores a single query-document pair.
|
|
// Returns the raw cross-encoder logit and normalized score.
|
|
func (s *Service) Score(query, document string) (rawScore, normalizedScore float64, err error) {
|
|
s.mu.Lock()
|
|
defer s.mu.Unlock()
|
|
|
|
scores, err := s.scoreAll(query, []Candidate{{Content: document}})
|
|
if err != nil {
|
|
return 0, 0, err
|
|
}
|
|
|
|
rawScore = scores[0]
|
|
normalizedScore = sigmoid(rawScore)
|
|
return rawScore, normalizedScore, nil
|
|
}
|
|
|
|
// Close releases model resources.
|
|
func (s *Service) Close() error {
|
|
s.mu.Lock()
|
|
defer s.mu.Unlock()
|
|
|
|
if s.session != nil {
|
|
if err := s.session.Destroy(); err != nil {
|
|
return fmt.Errorf("destroy cross-encoder session: %w", err)
|
|
}
|
|
s.session = nil
|
|
}
|
|
|
|
return nil
|
|
}
|
|
|
|
// sigmoid applies the sigmoid function to normalize scores to 0-1 range.
|
|
func sigmoid(x float64) float64 {
|
|
if x > 20 {
|
|
return 1.0
|
|
}
|
|
if x < -20 {
|
|
return 0.0
|
|
}
|
|
return 1.0 / (1.0 + math.Exp(-x))
|
|
}
|