Files
claude-mnemonic/internal/reranking/service.go
lukaszraczylo 4f4b4ac70f feat(chunking): add AST-aware code chunking for Go, Python, TypeScript
- [x] Add language-specific chunkers with AST parsing (Go, Python, TypeScript)
- [x] Implement chunking manager to dispatch files to appropriate chunkers
- [x] Integrate code chunks into vector sync for semantic search
- [x] Add tree-sitter dependency for Python/TypeScript parsing
- [x] Reorder struct fields for consistency across codebase
- [x] Rename error variables to follow Go conventions (err → unmarshalErr, etc.)
- [x] Add code chunk metadata to vector documents (language, symbol name, line ranges)
- [x] Update worker service to initialize chunking pipeline with all three languages
2026-01-07 13:19:58 +00:00

381 lines
10 KiB
Go

// Package reranking provides cross-encoder reranking for search results.
// Uses MS-MARCO MiniLM L6 v2 cross-encoder model for relevance scoring.
package reranking
import (
"bytes"
"fmt"
"math"
"sort"
"sync"
"github.com/rs/zerolog/log"
"github.com/sugarme/tokenizer"
"github.com/sugarme/tokenizer/pretrained"
ort "github.com/yalue/onnxruntime_go"
)
const (
// ModelName is the human-readable name for the cross-encoder model
ModelName = "ms-marco-MiniLM-L6-v2"
// ModelVersion is the short version string for identification
ModelVersion = "msmarco-v2"
// MaxSequenceLength is the maximum combined query+document token length
MaxSequenceLength = 512
// DefaultCandidateLimit is the default number of candidates to rerank
DefaultCandidateLimit = 100
// DefaultResultLimit is the default number of results to return after reranking
DefaultResultLimit = 10
)
// Candidate represents a search result candidate for reranking.
type Candidate struct {
Metadata map[string]any
RerankInfo map[string]float64
ID string
Content string
Score float64
}
// RerankResult represents a reranked search result.
type RerankResult struct {
Metadata map[string]any
ID string
Content string
OriginalScore float64
RerankScore float64
CombinedScore float64
OriginalRank int
RerankRank int
RankImprovement int
}
// Service provides cross-encoder reranking functionality.
type Service struct {
tk *tokenizer.Tokenizer
session *ort.DynamicAdvancedSession
mu sync.Mutex
// Weight for combining scores: combined = alpha*rerank + (1-alpha)*original
// Default 0.7 favors cross-encoder score
Alpha float64
}
// Config holds configuration for the reranking service.
type Config struct {
// Alpha is the weight for combining scores (0.0-1.0)
// Higher values favor cross-encoder scores, lower values favor bi-encoder scores
Alpha float64
}
// DefaultConfig returns sensible defaults for reranking.
func DefaultConfig() Config {
return Config{
Alpha: 0.7, // Favor cross-encoder by default
}
}
// NewService creates a new cross-encoder reranking service.
// Note: ONNX runtime must be initialized before calling this (via embedding.NewService).
func NewService(cfg Config) (*Service, error) {
// Load tokenizer from embedded data
tk, err := pretrained.FromReader(bytes.NewReader(crossEncoderTokenizerData))
if err != nil {
return nil, fmt.Errorf("load cross-encoder tokenizer: %w", err)
}
// Configure tokenizer for sequence classification (pairs)
tk.WithTruncation(&tokenizer.TruncationParams{
MaxLength: MaxSequenceLength,
Strategy: tokenizer.LongestFirst,
Stride: 0,
})
// Cross-encoder outputs a single logit for relevance scoring
inputNames := []string{"input_ids", "attention_mask", "token_type_ids"}
outputNames := []string{"logits"}
session, err := ort.NewDynamicAdvancedSessionWithONNXData(
crossEncoderModelData,
inputNames,
outputNames,
nil,
)
if err != nil {
return nil, fmt.Errorf("create cross-encoder ONNX session: %w", err)
}
alpha := cfg.Alpha
if alpha <= 0 || alpha > 1 {
alpha = 0.7
}
return &Service{
tk: tk,
session: session,
Alpha: alpha,
}, nil
}
// Rerank reranks candidates using the cross-encoder model.
// Takes a query and list of candidates, returns reranked results.
func (s *Service) Rerank(query string, candidates []Candidate, limit int) ([]RerankResult, error) {
if len(candidates) == 0 {
return nil, nil
}
if limit <= 0 {
limit = DefaultResultLimit
}
s.mu.Lock()
defer s.mu.Unlock()
// Score all query-document pairs
scores, err := s.scoreAll(query, candidates)
if err != nil {
return nil, fmt.Errorf("score candidates: %w", err)
}
// Build results with combined scores
results := make([]RerankResult, len(candidates))
for i, c := range candidates {
// Normalize cross-encoder score to 0-1 range using sigmoid
normalizedRerank := sigmoid(scores[i])
results[i] = RerankResult{
ID: c.ID,
Content: c.Content,
OriginalScore: c.Score,
RerankScore: normalizedRerank,
CombinedScore: s.Alpha*normalizedRerank + (1-s.Alpha)*c.Score,
Metadata: c.Metadata,
OriginalRank: i + 1,
}
}
// Sort by combined score (descending)
sort.Slice(results, func(i, j int) bool {
return results[i].CombinedScore > results[j].CombinedScore
})
// Assign rerank positions and calculate improvement
for i := range results {
results[i].RerankRank = i + 1
results[i].RankImprovement = results[i].OriginalRank - results[i].RerankRank
}
// Apply limit
if len(results) > limit {
results = results[:limit]
}
log.Debug().
Int("candidates", len(candidates)).
Int("returned", len(results)).
Float64("alpha", s.Alpha).
Msg("Cross-encoder reranking completed")
return results, nil
}
// RerankByScore reranks candidates and returns sorted by pure cross-encoder score.
// Useful when you want to completely replace bi-encoder ranking.
func (s *Service) RerankByScore(query string, candidates []Candidate, limit int) ([]RerankResult, error) {
if len(candidates) == 0 {
return nil, nil
}
if limit <= 0 {
limit = DefaultResultLimit
}
s.mu.Lock()
defer s.mu.Unlock()
scores, err := s.scoreAll(query, candidates)
if err != nil {
return nil, fmt.Errorf("score candidates: %w", err)
}
results := make([]RerankResult, len(candidates))
for i, c := range candidates {
normalizedRerank := sigmoid(scores[i])
results[i] = RerankResult{
ID: c.ID,
Content: c.Content,
OriginalScore: c.Score,
RerankScore: normalizedRerank,
CombinedScore: normalizedRerank, // Use pure rerank score
Metadata: c.Metadata,
OriginalRank: i + 1,
}
}
// Sort by rerank score only
sort.Slice(results, func(i, j int) bool {
return results[i].RerankScore > results[j].RerankScore
})
for i := range results {
results[i].RerankRank = i + 1
results[i].RankImprovement = results[i].OriginalRank - results[i].RerankRank
}
if len(results) > limit {
results = results[:limit]
}
return results, nil
}
// scoreAll scores all query-document pairs using the cross-encoder.
// Returns raw logits (before sigmoid normalization).
func (s *Service) scoreAll(query string, candidates []Candidate) ([]float64, error) {
batchSize := len(candidates)
// Tokenize all query-document pairs
pairs := make([]tokenizer.EncodeInput, batchSize)
for i, c := range candidates {
// Cross-encoder takes query and document as a pair
pairs[i] = tokenizer.NewDualEncodeInput(
tokenizer.NewRawInputSequence(query),
tokenizer.NewRawInputSequence(c.Content),
)
}
encodings, err := s.tk.EncodeBatch(pairs, true)
if err != nil {
return nil, fmt.Errorf("tokenize pairs: %w", err)
}
// Find max sequence length
seqLength := 0
for _, enc := range encodings {
if len(enc.Ids) > seqLength {
seqLength = len(enc.Ids)
}
}
if seqLength > MaxSequenceLength {
seqLength = MaxSequenceLength
}
inputShape := ort.NewShape(int64(batchSize), int64(seqLength))
// Create input tensors
inputIdsData := make([]int64, batchSize*seqLength)
attentionMaskData := make([]int64, batchSize*seqLength)
tokenTypeIdsData := make([]int64, batchSize*seqLength)
for b := 0; b < batchSize; b++ {
copyLen := len(encodings[b].Ids)
if copyLen > seqLength {
copyLen = seqLength
}
for i := 0; i < copyLen; i++ {
inputIdsData[b*seqLength+i] = int64(encodings[b].Ids[i])
}
copyLen = len(encodings[b].AttentionMask)
if copyLen > seqLength {
copyLen = seqLength
}
for i := 0; i < copyLen; i++ {
attentionMaskData[b*seqLength+i] = int64(encodings[b].AttentionMask[i])
}
copyLen = len(encodings[b].TypeIds)
if copyLen > seqLength {
copyLen = seqLength
}
for i := 0; i < copyLen; i++ {
tokenTypeIdsData[b*seqLength+i] = int64(encodings[b].TypeIds[i])
}
}
inputIdsTensor, err := ort.NewTensor(inputShape, inputIdsData)
if err != nil {
return nil, fmt.Errorf("create input_ids tensor: %w", err)
}
defer func() { _ = inputIdsTensor.Destroy() }()
attentionMaskTensor, err := ort.NewTensor(inputShape, attentionMaskData)
if err != nil {
return nil, fmt.Errorf("create attention_mask tensor: %w", err)
}
defer func() { _ = attentionMaskTensor.Destroy() }()
tokenTypeIdsTensor, err := ort.NewTensor(inputShape, tokenTypeIdsData)
if err != nil {
return nil, fmt.Errorf("create token_type_ids tensor: %w", err)
}
defer func() { _ = tokenTypeIdsTensor.Destroy() }()
// Cross-encoder outputs [batch, 1] logits
outputShape := ort.NewShape(int64(batchSize), 1)
outputTensor, err := ort.NewEmptyTensor[float32](outputShape)
if err != nil {
return nil, fmt.Errorf("create output tensor: %w", err)
}
defer func() { _ = outputTensor.Destroy() }()
// Run inference
inputTensors := []ort.Value{inputIdsTensor, attentionMaskTensor, tokenTypeIdsTensor}
outputTensors := []ort.Value{outputTensor}
if err := s.session.Run(inputTensors, outputTensors); err != nil {
return nil, fmt.Errorf("run cross-encoder inference: %w", err)
}
// Extract scores
flatOutput := outputTensor.GetData()
scores := make([]float64, batchSize)
for i := 0; i < batchSize; i++ {
scores[i] = float64(flatOutput[i])
}
return scores, nil
}
// Score scores a single query-document pair.
// Returns the raw cross-encoder logit and normalized score.
func (s *Service) Score(query, document string) (rawScore, normalizedScore float64, err error) {
s.mu.Lock()
defer s.mu.Unlock()
scores, err := s.scoreAll(query, []Candidate{{Content: document}})
if err != nil {
return 0, 0, err
}
rawScore = scores[0]
normalizedScore = sigmoid(rawScore)
return rawScore, normalizedScore, nil
}
// Close releases model resources.
func (s *Service) Close() error {
s.mu.Lock()
defer s.mu.Unlock()
if s.session != nil {
if err := s.session.Destroy(); err != nil {
return fmt.Errorf("destroy cross-encoder session: %w", err)
}
s.session = nil
}
return nil
}
// sigmoid applies the sigmoid function to normalize scores to 0-1 range.
func sigmoid(x float64) float64 {
if x > 20 {
return 1.0
}
if x < -20 {
return 0.0
}
return 1.0 / (1.0 + math.Exp(-x))
}