Skip to content

Commit f4d9d0f

Browse files
authored
Merge pull request #2043 from dgageot/board/you-re-the-performance-expert-you-give-a-411a43fe
perf: optimize BM25 scoring strategy
2 parents d9f5036 + 690d138 commit f4d9d0f

File tree

1 file changed

+60
-54
lines changed

1 file changed

+60
-54
lines changed

pkg/rag/strategy/bm25.go

Lines changed: 60 additions & 54 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@ import (
99
"math"
1010
"os"
1111
"path/filepath"
12+
"slices"
1213
"strings"
1314
"sync"
1415
"time"
@@ -99,6 +100,10 @@ type BM25Strategy struct {
99100
b float64 // length normalization parameter (typically 0.75)
100101
avgDocLength float64 // average document length
101102
docCount int // total number of documents
103+
104+
// Tokenization helpers (built once per strategy instance)
105+
replacer *strings.Replacer
106+
stopwords map[string]bool
102107
}
103108

104109
// newBM25Strategy creates a new BM25-based retrieval strategy
@@ -120,6 +125,18 @@ func newBM25Strategy(name string, db *bm25DB, events chan<- types.Event, k1, b f
120125
shouldIgnore: shouldIgnore,
121126
k1: k1,
122127
b: b,
128+
replacer: strings.NewReplacer(
129+
".", " ", ",", " ", "!", " ", "?", " ",
130+
";", " ", ":", " ", "(", " ", ")", " ",
131+
"[", " ", "]", " ", "{", " ", "}", " ",
132+
"\"", " ", "'", " ", "\n", " ", "\t", " ",
133+
),
134+
stopwords: map[string]bool{
135+
"the": true, "a": true, "an": true, "and": true, "or": true,
136+
"but": true, "in": true, "on": true, "at": true, "to": true,
137+
"for": true, "of": true, "as": true, "by": true, "is": true,
138+
"was": true, "are": true, "were": true, "be": true, "been": true,
139+
},
123140
}
124141
}
125142

@@ -247,11 +264,7 @@ func (s *BM25Strategy) Query(ctx context.Context, query string, numResults int,
247264
return nil, errors.New("query contains no valid terms")
248265
}
249266

250-
// For BM25, we need to retrieve all documents and score them
251-
// In a production system, you'd use an inverted index for efficiency
252-
// For now, this is a simplified implementation
253-
254-
// Get all documents (in production, use inverted index to get only relevant docs)
267+
// Get all documents
255268
allDocs, err := s.getAllDocuments(ctx)
256269
if err != nil {
257270
return nil, fmt.Errorf("failed to retrieve documents: %w", err)
@@ -261,10 +274,33 @@ func (s *BM25Strategy) Query(ctx context.Context, query string, numResults int,
261274
return []database.SearchResult{}, nil
262275
}
263276

264-
// Score each document using BM25
277+
// Pre-tokenize all documents once: build term frequency maps and lengths.
278+
docTermFreqs := make([]map[string]int, len(allDocs))
279+
docLengths := make([]float64, len(allDocs))
280+
for i, doc := range allDocs {
281+
tokens := s.tokenize(doc.Content)
282+
tf := make(map[string]int, len(tokens))
283+
for _, term := range tokens {
284+
tf[term]++
285+
}
286+
docTermFreqs[i] = tf
287+
docLengths[i] = float64(len(tokens))
288+
}
289+
290+
// Pre-compute document frequency for each query term.
291+
df := make(map[string]int, len(queryTerms))
292+
for _, term := range queryTerms {
293+
for _, tf := range docTermFreqs {
294+
if tf[term] > 0 {
295+
df[term]++
296+
}
297+
}
298+
}
299+
300+
// Score each document.
265301
scores := make([]database.SearchResult, 0, len(allDocs))
266-
for _, doc := range allDocs {
267-
score := s.calculateBM25Score(queryTerms, doc, allDocs)
302+
for i, doc := range allDocs {
303+
score := s.calculateBM25Score(queryTerms, docTermFreqs[i], docLengths[i], df)
268304
if score >= threshold {
269305
scores = append(scores, database.SearchResult{
270306
Document: doc,
@@ -273,14 +309,10 @@ func (s *BM25Strategy) Query(ctx context.Context, query string, numResults int,
273309
}
274310
}
275311

276-
// Sort by score descending
277-
for i := 0; i < len(scores); i++ {
278-
for j := i + 1; j < len(scores); j++ {
279-
if scores[j].Similarity > scores[i].Similarity {
280-
scores[i], scores[j] = scores[j], scores[i]
281-
}
282-
}
283-
}
312+
// Sort by score descending.
313+
slices.SortFunc(scores, func(a, b database.SearchResult) int {
314+
return cmp.Compare(b.Similarity, a.Similarity)
315+
})
284316

285317
// Return top N results
286318
if len(scores) > numResults {
@@ -384,77 +416,51 @@ func (s *BM25Strategy) Close() error {
384416
// Helper methods
385417

386418
func (s *BM25Strategy) tokenize(text string) []string {
387-
// Simple tokenization: lowercase and split on whitespace/punctuation
388419
text = strings.ToLower(text)
389-
// Replace common punctuation with spaces
390-
replacer := strings.NewReplacer(
391-
".", " ", ",", " ", "!", " ", "?", " ",
392-
";", " ", ":", " ", "(", " ", ")", " ",
393-
"[", " ", "]", " ", "{", " ", "}", " ",
394-
"\"", " ", "'", " ", "\n", " ", "\t", " ",
395-
)
396-
text = replacer.Replace(text)
420+
text = s.replacer.Replace(text)
397421

398422
tokens := strings.Fields(text)
399423

400-
// Remove stopwords (simplified list)
401-
stopwords := map[string]bool{
402-
"the": true, "a": true, "an": true, "and": true, "or": true,
403-
"but": true, "in": true, "on": true, "at": true, "to": true,
404-
"for": true, "of": true, "as": true, "by": true, "is": true,
405-
"was": true, "are": true, "were": true, "be": true, "been": true,
406-
}
407-
408424
filtered := make([]string, 0, len(tokens))
409425
for _, token := range tokens {
410-
if len(token) > 2 && !stopwords[token] {
426+
if len(token) > 2 && !s.stopwords[token] {
411427
filtered = append(filtered, token)
412428
}
413429
}
414430

415431
return filtered
416432
}
417433

418-
func (s *BM25Strategy) calculateBM25Score(queryTerms []string, doc database.Document, allDocs []database.Document) float64 {
419-
docLength := float64(len(s.tokenize(doc.Content)))
434+
func (s *BM25Strategy) calculateBM25Score(queryTerms []string, docTermFreq map[string]int, docLength float64, df map[string]int) float64 {
420435
score := 0.0
421436

422-
docTerms := s.tokenize(doc.Content)
423-
docTermFreq := make(map[string]int)
424-
for _, term := range docTerms {
425-
docTermFreq[term]++
426-
}
427-
428437
for _, queryTerm := range queryTerms {
429438
// Term frequency in document
430439
tf := float64(docTermFreq[queryTerm])
431440
if tf == 0 {
432441
continue
433442
}
434443

435-
// Document frequency (number of documents containing the term)
436-
df := 0
437-
for _, d := range allDocs {
438-
if strings.Contains(strings.ToLower(d.Content), queryTerm) {
439-
df++
440-
}
441-
}
442-
443-
if df == 0 {
444+
// Document frequency (pre-computed)
445+
termDF := df[queryTerm]
446+
if termDF == 0 {
444447
continue
445448
}
446449

447450
// IDF calculation
448-
idf := math.Log((float64(s.docCount)-float64(df)+0.5)/(float64(df)+0.5) + 1.0)
451+
idf := math.Log((float64(s.docCount)-float64(termDF)+0.5)/(float64(termDF)+0.5) + 1.0)
449452

450453
// BM25 formula
451454
numerator := tf * (s.k1 + 1.0)
452-
denominator := tf + s.k1*(1.0-s.b+s.b*(docLength/s.avgDocLength))
455+
lengthRatio := 1.0
456+
if s.avgDocLength > 0 {
457+
lengthRatio = docLength / s.avgDocLength
458+
}
459+
denominator := tf + s.k1*(1.0-s.b+s.b*lengthRatio)
453460
score += idf * (numerator / denominator)
454461
}
455462

456463
// Normalize score to 0-1 range for consistency with vector similarity
457-
// This is a simple normalization; in production, you might use a different approach
458464
return math.Min(score/float64(len(queryTerms)), 1.0)
459465
}
460466

0 commit comments

Comments
 (0)