Merge commit 'd803bfe2b1fe7f5e219e50ac20d6801a0a58ac75' as 'vendor/ruvector'

This commit is contained in:
ruv
2026-02-28 14:39:40 -05:00
7854 changed files with 3522914 additions and 0 deletions

View File

@@ -0,0 +1,452 @@
//! BM25 (Best Matching 25) scoring implementation
//!
//! Provides proper BM25 scoring with:
//! - Document length normalization
//! - IDF weighting across corpus
//! - Term frequency saturation
//! - Configurable k1 and b parameters
//!
//! Unlike PostgreSQL's ts_rank, this is a proper BM25 implementation.
use parking_lot::RwLock;
use serde::{Deserialize, Serialize};
use std::collections::HashMap;
use std::sync::Arc;
/// Default BM25 k1 parameter (term frequency saturation)
pub const DEFAULT_K1: f32 = 1.2;
/// Default BM25 b parameter (length normalization)
pub const DEFAULT_B: f32 = 0.75;
/// Corpus statistics for BM25 scoring
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct CorpusStats {
/// Average document length in the corpus
pub avg_doc_length: f32,
/// Total number of documents
pub doc_count: u64,
/// Total number of terms across all documents
pub total_terms: u64,
/// Last update timestamp (Unix epoch seconds)
pub last_update: i64,
}
impl Default for CorpusStats {
fn default() -> Self {
Self {
avg_doc_length: 0.0,
doc_count: 0,
total_terms: 0,
last_update: 0,
}
}
}
/// BM25 configuration parameters
#[derive(Debug, Clone, Copy, Serialize, Deserialize)]
pub struct BM25Config {
/// Term frequency saturation parameter (default: 1.2)
/// Higher values give more weight to term frequency
pub k1: f32,
/// Length normalization parameter (default: 0.75)
/// 0 = no length normalization, 1 = full normalization
pub b: f32,
}
impl Default for BM25Config {
fn default() -> Self {
Self {
k1: DEFAULT_K1,
b: DEFAULT_B,
}
}
}
impl BM25Config {
/// Create a new BM25 configuration
pub fn new(k1: f32, b: f32) -> Self {
Self {
k1: k1.max(0.0),
b: b.clamp(0.0, 1.0),
}
}
}
/// Term frequency information for a document
#[derive(Debug, Clone)]
pub struct TermFrequencies {
/// Term -> frequency map
pub frequencies: HashMap<String, u32>,
/// Total terms in document
pub doc_length: u32,
}
impl TermFrequencies {
/// Create from term frequency map
pub fn new(frequencies: HashMap<String, u32>) -> Self {
let doc_length = frequencies.values().sum();
Self {
frequencies,
doc_length,
}
}
/// Get term frequency for a specific term
pub fn get(&self, term: &str) -> Option<u32> {
self.frequencies.get(term).copied()
}
}
/// Document information for BM25 scoring
pub struct Document<'a> {
/// Term frequencies in the document
pub term_freqs: &'a TermFrequencies,
}
impl<'a> Document<'a> {
/// Create a new document wrapper
pub fn new(term_freqs: &'a TermFrequencies) -> Self {
Self { term_freqs }
}
/// Get term frequency for a term
pub fn term_freq(&self, term: &str) -> Option<u32> {
self.term_freqs.get(term)
}
/// Get document length (total terms)
pub fn term_count(&self) -> u32 {
self.term_freqs.doc_length
}
}
/// BM25 scorer with corpus statistics and IDF caching
pub struct BM25Scorer {
/// Configuration parameters
config: BM25Config,
/// Corpus statistics
corpus_stats: CorpusStats,
/// Cached IDF values (term -> IDF score)
idf_cache: Arc<RwLock<HashMap<String, f32>>>,
/// Document frequency cache (term -> doc count containing term)
df_cache: Arc<RwLock<HashMap<String, u64>>>,
}
impl BM25Scorer {
/// Create a new BM25 scorer with default config
pub fn new(corpus_stats: CorpusStats) -> Self {
Self::with_config(corpus_stats, BM25Config::default())
}
/// Create a new BM25 scorer with custom config
pub fn with_config(corpus_stats: CorpusStats, config: BM25Config) -> Self {
Self {
config,
corpus_stats,
idf_cache: Arc::new(RwLock::new(HashMap::new())),
df_cache: Arc::new(RwLock::new(HashMap::new())),
}
}
/// Update corpus statistics
pub fn update_corpus_stats(&mut self, stats: CorpusStats) {
self.corpus_stats = stats;
// Clear caches when stats change
self.idf_cache.write().clear();
}
/// Set document frequency for a term (used during index building)
pub fn set_doc_freq(&self, term: &str, doc_freq: u64) {
self.df_cache.write().insert(term.to_string(), doc_freq);
// Invalidate IDF cache for this term
self.idf_cache.write().remove(term);
}
/// Compute IDF (Inverse Document Frequency) for a term
///
/// Uses the BM25 IDF formula:
/// IDF(t) = ln((N - df(t) + 0.5) / (df(t) + 0.5) + 1)
///
/// where:
/// - N = total documents in corpus
/// - df(t) = number of documents containing term t
pub fn idf(&self, term: &str) -> f32 {
// Check cache first
if let Some(&cached) = self.idf_cache.read().get(term) {
return cached;
}
// Get document frequency
let df = self.df_cache.read().get(term).copied().unwrap_or(0);
// Compute IDF using BM25 formula
let n = self.corpus_stats.doc_count as f32;
let df_f = df as f32;
// Prevent division by zero and handle edge cases
let idf = if df == 0 {
// Term not in corpus - use max IDF
(n + 0.5).ln()
} else {
((n - df_f + 0.5) / (df_f + 0.5) + 1.0).ln()
};
// Cache the result
self.idf_cache.write().insert(term.to_string(), idf);
idf
}
/// Compute IDF with known document frequency (bypasses cache lookup)
pub fn idf_with_df(&self, doc_freq: u64) -> f32 {
let n = self.corpus_stats.doc_count as f32;
let df = doc_freq as f32;
if doc_freq == 0 {
(n + 0.5).ln()
} else {
((n - df + 0.5) / (df + 0.5) + 1.0).ln()
}
}
/// Score a document for a query
///
/// BM25 formula:
/// score(D, Q) = sum over t in Q of: IDF(t) * (tf(t,D) * (k1 + 1)) / (tf(t,D) + k1 * (1 - b + b * |D|/avgdl))
///
/// where:
/// - tf(t,D) = term frequency of t in document D
/// - |D| = document length
/// - avgdl = average document length
/// - k1 = term saturation parameter
/// - b = length normalization parameter
pub fn score(&self, doc: &Document, query_terms: &[String]) -> f32 {
let doc_len = doc.term_count() as f32;
let avg_doc_len = self.corpus_stats.avg_doc_length.max(1.0);
// Length normalization factor
let len_norm = 1.0 - self.config.b + self.config.b * (doc_len / avg_doc_len);
query_terms
.iter()
.filter_map(|term| {
let tf = doc.term_freq(term)? as f32;
let idf = self.idf(term);
// BM25 term score
let numerator = tf * (self.config.k1 + 1.0);
let denominator = tf + self.config.k1 * len_norm;
Some(idf * numerator / denominator)
})
.sum()
}
/// Score a document with pre-computed term frequencies and document frequencies
///
/// This is the optimized version for batch scoring where IDF values are known.
pub fn score_with_freqs(
&self,
term_freqs: &[(String, u32, u64)], // (term, tf in doc, df in corpus)
doc_length: u32,
) -> f32 {
let doc_len = doc_length as f32;
let avg_doc_len = self.corpus_stats.avg_doc_length.max(1.0);
let len_norm = 1.0 - self.config.b + self.config.b * (doc_len / avg_doc_len);
term_freqs
.iter()
.map(|(_, tf, df)| {
let tf = *tf as f32;
let idf = self.idf_with_df(*df);
let numerator = tf * (self.config.k1 + 1.0);
let denominator = tf + self.config.k1 * len_norm;
idf * numerator / denominator
})
.sum()
}
/// Batch score multiple documents for the same query
pub fn score_batch<'a>(
&self,
docs: impl Iterator<Item = &'a Document<'a>>,
query_terms: &[String],
) -> Vec<f32> {
docs.map(|doc| self.score(doc, query_terms)).collect()
}
/// Get current configuration
pub fn config(&self) -> &BM25Config {
&self.config
}
/// Get corpus statistics
pub fn corpus_stats(&self) -> &CorpusStats {
&self.corpus_stats
}
/// Clear IDF cache
pub fn clear_cache(&self) {
self.idf_cache.write().clear();
self.df_cache.write().clear();
}
}
/// Simple query tokenizer for BM25
///
/// Note: In production, this should use PostgreSQL's text search configuration
/// for proper stemming, stopword removal, etc.
pub fn tokenize_query(text: &str) -> Vec<String> {
text.to_lowercase()
.split_whitespace()
.filter(|s| s.len() > 1) // Skip single characters
.map(|s| s.trim_matches(|c: char| !c.is_alphanumeric()).to_string())
.filter(|s| !s.is_empty())
.collect()
}
/// Parse tsvector-style string to term frequencies
pub fn parse_tsvector(tsvector_str: &str) -> HashMap<String, u32> {
let mut frequencies = HashMap::new();
// Format: 'term1':1,2,3 'term2':4,5
for part in tsvector_str.split_whitespace() {
if let Some(quote_end) = part.find("':") {
if part.starts_with('\'') {
let term = &part[1..quote_end];
let positions = &part[quote_end + 2..];
// Count positions as frequency
let freq = positions.split(',').count() as u32;
frequencies.insert(term.to_string(), freq.max(1));
}
} else if part.starts_with('\'') && part.ends_with('\'') {
// Term without positions
let term = &part[1..part.len() - 1];
frequencies.insert(term.to_string(), 1);
}
}
frequencies
}
#[cfg(test)]
mod tests {
use super::*;
fn create_test_scorer() -> BM25Scorer {
let stats = CorpusStats {
avg_doc_length: 100.0,
doc_count: 1000,
total_terms: 100000,
last_update: 0,
};
BM25Scorer::new(stats)
}
#[test]
fn test_idf_common_term() {
let scorer = create_test_scorer();
scorer.set_doc_freq("the", 900); // Very common term
let idf = scorer.idf("the");
assert!(idf > 0.0, "IDF should be positive");
assert!(idf < 1.0, "IDF for common term should be low");
}
#[test]
fn test_idf_rare_term() {
let scorer = create_test_scorer();
scorer.set_doc_freq("xyzzy", 5); // Rare term
let idf = scorer.idf("xyzzy");
assert!(idf > 4.0, "IDF for rare term should be high");
}
#[test]
fn test_idf_unknown_term() {
let scorer = create_test_scorer();
let idf = scorer.idf("unknown_term_xyz");
assert!(idf > 5.0, "IDF for unknown term should be maximum");
}
#[test]
fn test_bm25_score() {
let scorer = create_test_scorer();
scorer.set_doc_freq("database", 100);
scorer.set_doc_freq("query", 50);
let mut freqs = HashMap::new();
freqs.insert("database".to_string(), 3);
freqs.insert("query".to_string(), 2);
freqs.insert("other".to_string(), 5);
let term_freqs = TermFrequencies::new(freqs);
let doc = Document::new(&term_freqs);
let query_terms = vec!["database".to_string(), "query".to_string()];
let score = scorer.score(&doc, &query_terms);
assert!(score > 0.0, "Score should be positive");
}
#[test]
fn test_length_normalization() {
let scorer = create_test_scorer();
scorer.set_doc_freq("test", 100);
// Short document (50 terms)
let mut short_freqs = HashMap::new();
short_freqs.insert("test".to_string(), 2);
for i in 0..48 {
short_freqs.insert(format!("filler{}", i), 1);
}
let short_tf = TermFrequencies::new(short_freqs);
let short_doc = Document::new(&short_tf);
// Long document (200 terms)
let mut long_freqs = HashMap::new();
long_freqs.insert("test".to_string(), 2);
for i in 0..198 {
long_freqs.insert(format!("filler{}", i), 1);
}
let long_tf = TermFrequencies::new(long_freqs);
let long_doc = Document::new(&long_tf);
let query_terms = vec!["test".to_string()];
let short_score = scorer.score(&short_doc, &query_terms);
let long_score = scorer.score(&long_doc, &query_terms);
// Short doc should score higher (same tf, less length penalty)
assert!(
short_score > long_score,
"Short doc should score higher than long doc with same TF"
);
}
#[test]
fn test_tokenize_query() {
let tokens = tokenize_query("Hello World! Database Query.");
assert_eq!(tokens, vec!["hello", "world", "database", "query"]);
}
#[test]
fn test_parse_tsvector() {
let tsvector = "'database':1,3,5 'query':2,4";
let freqs = parse_tsvector(tsvector);
assert_eq!(freqs.get("database"), Some(&3));
assert_eq!(freqs.get("query"), Some(&2));
}
#[test]
fn test_config_clamping() {
let config = BM25Config::new(-1.0, 1.5);
assert_eq!(config.k1, 0.0, "k1 should be clamped to 0");
assert_eq!(config.b, 1.0, "b should be clamped to 1");
}
}

View File

@@ -0,0 +1,522 @@
//! Hybrid Query Executor
//!
//! Executes vector and keyword search branches, optionally in parallel,
//! and fuses results using the configured algorithm.
use serde::{Deserialize, Serialize};
use std::collections::HashMap;
use super::bm25::{tokenize_query, BM25Scorer, CorpusStats, Document, TermFrequencies};
use super::fusion::{
fuse_results, learned_fusion, DocId, FusedResult, FusionConfig, FusionMethod, FusionModel,
};
/// Hybrid search query
#[derive(Debug, Clone)]
pub struct HybridQuery {
/// Query text for keyword search
pub text: String,
/// Query embedding for vector search
pub embedding: Vec<f32>,
/// Number of final results to return
pub k: usize,
/// Number of results to prefetch from each branch
pub prefetch_k: usize,
/// Fusion configuration
pub fusion_config: FusionConfig,
/// Optional filter expression
pub filter: Option<String>,
}
impl HybridQuery {
/// Create a new hybrid query
pub fn new(text: String, embedding: Vec<f32>, k: usize) -> Self {
Self {
text,
embedding,
k,
prefetch_k: k * 10, // Default prefetch 10x final k
fusion_config: FusionConfig::default(),
filter: None,
}
}
/// Set fusion method
pub fn with_fusion(mut self, method: FusionMethod) -> Self {
self.fusion_config.method = method;
self
}
/// Set fusion alpha (for linear fusion)
pub fn with_alpha(mut self, alpha: f32) -> Self {
self.fusion_config.alpha = alpha;
self
}
/// Set RRF k constant
pub fn with_rrf_k(mut self, rrf_k: usize) -> Self {
self.fusion_config.rrf_k = rrf_k;
self
}
/// Set prefetch size
pub fn with_prefetch(mut self, prefetch_k: usize) -> Self {
self.prefetch_k = prefetch_k;
self
}
/// Set filter expression
pub fn with_filter(mut self, filter: String) -> Self {
self.filter = Some(filter);
self
}
}
/// Execution strategy for hybrid search
#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
pub enum HybridStrategy {
/// Execute both branches fully
Full,
/// Pre-filter using keyword/metadata, then vector search on subset
PreFilter,
/// Execute hybrid search, then apply filter
PostFilter,
/// Vector search only (degraded mode)
VectorOnly,
/// Keyword search only (degraded mode)
KeywordOnly,
}
impl Default for HybridStrategy {
fn default() -> Self {
HybridStrategy::Full
}
}
/// Choose execution strategy based on query characteristics
pub fn choose_strategy(
filter_selectivity: Option<f32>,
corpus_size: u64,
has_filter: bool,
) -> HybridStrategy {
if !has_filter {
return HybridStrategy::Full;
}
match filter_selectivity {
Some(sel) if sel < 0.01 => {
// Very selective filter: pre-filter first
HybridStrategy::PreFilter
}
Some(sel) if sel < 0.1 && corpus_size > 1_000_000 => {
// Moderately selective on large corpus: post-filter
HybridStrategy::PostFilter
}
_ => HybridStrategy::Full,
}
}
/// Result from a single search branch
#[derive(Debug, Clone)]
pub struct BranchResults {
/// Document IDs and scores
pub results: Vec<(DocId, f32)>,
/// Execution time in milliseconds
pub latency_ms: f64,
/// Number of candidates evaluated
pub candidates_evaluated: usize,
}
impl BranchResults {
/// Create empty results
pub fn empty() -> Self {
Self {
results: Vec::new(),
latency_ms: 0.0,
candidates_evaluated: 0,
}
}
}
/// Hybrid search result with detailed scores
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct HybridResult {
/// Document ID
pub id: DocId,
/// Final hybrid score
pub hybrid_score: f32,
/// Vector similarity score (1 - distance for cosine)
pub vector_score: Option<f32>,
/// BM25 keyword score
pub keyword_score: Option<f32>,
/// Rank in vector results (None if not present)
pub vector_rank: Option<usize>,
/// Rank in keyword results (None if not present)
pub keyword_rank: Option<usize>,
}
impl From<FusedResult> for HybridResult {
fn from(fused: FusedResult) -> Self {
Self {
id: fused.doc_id,
hybrid_score: fused.hybrid_score,
vector_score: fused.vector_score,
keyword_score: fused.keyword_score,
vector_rank: None,
keyword_rank: None,
}
}
}
/// Execution statistics for hybrid search
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct ExecutionStats {
/// Total execution time in milliseconds
pub total_latency_ms: f64,
/// Vector branch latency
pub vector_latency_ms: f64,
/// Keyword branch latency
pub keyword_latency_ms: f64,
/// Fusion latency
pub fusion_latency_ms: f64,
/// Strategy used
pub strategy: String,
/// Number of vector candidates
pub vector_candidates: usize,
/// Number of keyword candidates
pub keyword_candidates: usize,
/// Final result count
pub result_count: usize,
}
/// Hybrid Query Executor
///
/// Coordinates vector and keyword search branches, handles parallel execution,
/// and manages score fusion.
pub struct HybridExecutor {
/// BM25 scorer for keyword scoring
bm25_scorer: BM25Scorer,
/// Fusion model for learned fusion
fusion_model: FusionModel,
/// Whether to use parallel execution
parallel_enabled: bool,
/// Default prefetch multiplier
prefetch_multiplier: usize,
}
impl HybridExecutor {
/// Create a new hybrid executor
pub fn new(corpus_stats: CorpusStats) -> Self {
Self {
bm25_scorer: BM25Scorer::new(corpus_stats),
fusion_model: FusionModel::default(),
parallel_enabled: true,
prefetch_multiplier: 10,
}
}
/// Update corpus statistics
pub fn update_corpus_stats(&mut self, stats: CorpusStats) {
self.bm25_scorer.update_corpus_stats(stats);
}
/// Set whether to use parallel execution
pub fn set_parallel(&mut self, enabled: bool) {
self.parallel_enabled = enabled;
}
/// Set prefetch multiplier
pub fn set_prefetch_multiplier(&mut self, multiplier: usize) {
self.prefetch_multiplier = multiplier;
}
/// Execute hybrid search
///
/// This is the main entry point for hybrid search. In the PostgreSQL extension,
/// this would call into the actual vector index and tsvector search.
pub fn execute(
&self,
query: &HybridQuery,
vector_search_fn: impl FnOnce(&[f32], usize) -> BranchResults,
keyword_search_fn: impl FnOnce(&str, usize) -> BranchResults,
) -> (Vec<HybridResult>, ExecutionStats) {
let start = std::time::Instant::now();
// Execute both branches
let (vector_results, keyword_results) = if self.parallel_enabled {
// In async context, use tokio::join!
// For sync, execute sequentially (rayon could parallelize)
let v_start = std::time::Instant::now();
let vector = vector_search_fn(&query.embedding, query.prefetch_k);
let v_elapsed = v_start.elapsed().as_secs_f64() * 1000.0;
let k_start = std::time::Instant::now();
let keyword = keyword_search_fn(&query.text, query.prefetch_k);
let k_elapsed = k_start.elapsed().as_secs_f64() * 1000.0;
(
BranchResults {
latency_ms: v_elapsed,
..vector
},
BranchResults {
latency_ms: k_elapsed,
..keyword
},
)
} else {
let v_start = std::time::Instant::now();
let vector = vector_search_fn(&query.embedding, query.prefetch_k);
let v_elapsed = v_start.elapsed().as_secs_f64() * 1000.0;
let k_start = std::time::Instant::now();
let keyword = keyword_search_fn(&query.text, query.prefetch_k);
let k_elapsed = k_start.elapsed().as_secs_f64() * 1000.0;
(
BranchResults {
latency_ms: v_elapsed,
..vector
},
BranchResults {
latency_ms: k_elapsed,
..keyword
},
)
};
// Fuse results
let fusion_start = std::time::Instant::now();
let fused = self.fuse(&query, &vector_results.results, &keyword_results.results);
let fusion_elapsed = fusion_start.elapsed().as_secs_f64() * 1000.0;
// Add rank information
let vector_ranks: HashMap<DocId, usize> = vector_results
.results
.iter()
.enumerate()
.map(|(i, (id, _))| (*id, i))
.collect();
let keyword_ranks: HashMap<DocId, usize> = keyword_results
.results
.iter()
.enumerate()
.map(|(i, (id, _))| (*id, i))
.collect();
let results: Vec<HybridResult> = fused
.into_iter()
.take(query.k)
.map(|f| {
let mut result = HybridResult::from(f);
result.vector_rank = vector_ranks.get(&result.id).copied();
result.keyword_rank = keyword_ranks.get(&result.id).copied();
result
})
.collect();
let total_elapsed = start.elapsed().as_secs_f64() * 1000.0;
let stats = ExecutionStats {
total_latency_ms: total_elapsed,
vector_latency_ms: vector_results.latency_ms,
keyword_latency_ms: keyword_results.latency_ms,
fusion_latency_ms: fusion_elapsed,
strategy: "full".to_string(),
vector_candidates: vector_results.candidates_evaluated,
keyword_candidates: keyword_results.candidates_evaluated,
result_count: results.len(),
};
(results, stats)
}
/// Fuse vector and keyword results
fn fuse(
&self,
query: &HybridQuery,
vector_results: &[(DocId, f32)],
keyword_results: &[(DocId, f32)],
) -> Vec<FusedResult> {
match query.fusion_config.method {
FusionMethod::Learned => {
// Use learned fusion with query features
let query_terms = tokenize_query(&query.text);
let avg_idf = self.compute_avg_idf(&query_terms);
learned_fusion(
&query.embedding,
&query_terms,
vector_results,
keyword_results,
&self.fusion_model,
avg_idf,
query.prefetch_k,
)
}
_ => {
// Use standard fusion
fuse_results(
vector_results,
keyword_results,
&query.fusion_config,
query.prefetch_k,
)
}
}
}
/// Compute average IDF for query terms
fn compute_avg_idf(&self, terms: &[String]) -> f32 {
if terms.is_empty() {
return 0.0;
}
let total_idf: f32 = terms.iter().map(|t| self.bm25_scorer.idf(t)).sum();
total_idf / terms.len() as f32
}
/// Score documents using BM25
pub fn bm25_score(&self, term_freqs: &TermFrequencies, query_terms: &[String]) -> f32 {
let doc = Document::new(term_freqs);
self.bm25_scorer.score(&doc, query_terms)
}
/// Set document frequency for a term (for BM25 IDF calculation)
pub fn set_doc_freq(&self, term: &str, doc_freq: u64) {
self.bm25_scorer.set_doc_freq(term, doc_freq);
}
/// Get current corpus statistics
pub fn corpus_stats(&self) -> &CorpusStats {
self.bm25_scorer.corpus_stats()
}
}
/// Async hybrid execution using tokio
#[cfg(feature = "tokio")]
pub mod async_executor {
use super::*;
use std::future::Future;
/// Execute hybrid search with parallel branches
pub async fn parallel_hybrid<VF, KF, VFut, KFut>(
query: &HybridQuery,
vector_search: VF,
keyword_search: KF,
fusion_config: &FusionConfig,
) -> Vec<FusedResult>
where
VF: FnOnce(&[f32], usize) -> VFut,
KF: FnOnce(&str, usize) -> KFut,
VFut: Future<Output = BranchResults>,
KFut: Future<Output = BranchResults>,
{
let (vector_results, keyword_results) = tokio::join!(
vector_search(&query.embedding, query.prefetch_k),
keyword_search(&query.text, query.prefetch_k),
);
fuse_results(
&vector_results.results,
&keyword_results.results,
fusion_config,
query.k,
)
}
}
#[cfg(test)]
mod tests {
use super::*;
fn mock_vector_search(_embedding: &[f32], k: usize) -> BranchResults {
BranchResults {
results: (0..k.min(5))
.map(|i| (i as DocId + 1, 0.1 * (i + 1) as f32))
.collect(),
latency_ms: 1.0,
candidates_evaluated: 100,
}
}
fn mock_keyword_search(_text: &str, k: usize) -> BranchResults {
BranchResults {
results: (0..k.min(5))
.map(|i| ((i as DocId + 3), 10.0 - i as f32))
.collect(),
latency_ms: 0.5,
candidates_evaluated: 50,
}
}
#[test]
fn test_hybrid_query_builder() {
let query = HybridQuery::new("test query".into(), vec![0.1, 0.2, 0.3], 10)
.with_fusion(FusionMethod::Linear)
.with_alpha(0.7)
.with_prefetch(100)
.with_filter("category = 'docs'".into());
assert_eq!(query.k, 10);
assert_eq!(query.prefetch_k, 100);
assert_eq!(query.fusion_config.method, FusionMethod::Linear);
assert!((query.fusion_config.alpha - 0.7).abs() < 0.01);
assert!(query.filter.is_some());
}
#[test]
fn test_hybrid_executor() {
let stats = CorpusStats {
avg_doc_length: 100.0,
doc_count: 1000,
total_terms: 100000,
last_update: 0,
};
let executor = HybridExecutor::new(stats);
let query = HybridQuery::new("database query".into(), vec![0.1; 128], 5);
let (results, exec_stats) =
executor.execute(&query, mock_vector_search, mock_keyword_search);
assert!(!results.is_empty());
assert!(results.len() <= 5);
assert!(exec_stats.total_latency_ms > 0.0);
}
#[test]
fn test_strategy_selection() {
// No filter -> Full
assert_eq!(choose_strategy(None, 10000, false), HybridStrategy::Full);
// Very selective filter -> PreFilter
assert_eq!(
choose_strategy(Some(0.005), 1000000, true),
HybridStrategy::PreFilter
);
// Moderate selectivity, large corpus -> PostFilter
assert_eq!(
choose_strategy(Some(0.05), 5000000, true),
HybridStrategy::PostFilter
);
}
#[test]
fn test_execution_stats() {
let stats = CorpusStats::default();
let executor = HybridExecutor::new(stats);
let query = HybridQuery::new("test".into(), vec![0.1; 16], 5);
let (_, exec_stats) = executor.execute(&query, mock_vector_search, mock_keyword_search);
assert!(exec_stats.vector_latency_ms >= 0.0);
assert!(exec_stats.keyword_latency_ms >= 0.0);
assert!(exec_stats.fusion_latency_ms >= 0.0);
assert!(exec_stats.total_latency_ms >= exec_stats.fusion_latency_ms);
}
}

View File

@@ -0,0 +1,604 @@
//! Fusion algorithms for combining vector and keyword search results
//!
//! Provides:
//! - RRF (Reciprocal Rank Fusion) - default, robust
//! - Linear blend - simple weighted combination
//! - Learned fusion - query-adaptive weights
use serde::{Deserialize, Serialize};
use std::collections::HashMap;
/// Document ID type (matches with database row IDs)
pub type DocId = i64;
/// Default RRF constant
pub const DEFAULT_RRF_K: usize = 60;
/// Default alpha for linear fusion (0.5 = equal weight)
pub const DEFAULT_ALPHA: f32 = 0.5;
/// Fusion method selection
#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
#[serde(rename_all = "lowercase")]
pub enum FusionMethod {
/// Reciprocal Rank Fusion (default)
Rrf,
/// Linear weighted combination
Linear,
/// Learned/adaptive fusion based on query features
Learned,
}
impl Default for FusionMethod {
fn default() -> Self {
FusionMethod::Rrf
}
}
impl std::str::FromStr for FusionMethod {
type Err = String;
fn from_str(s: &str) -> Result<Self, Self::Err> {
match s.to_lowercase().as_str() {
"rrf" | "reciprocal" | "reciprocal_rank" => Ok(FusionMethod::Rrf),
"linear" | "blend" | "weighted" => Ok(FusionMethod::Linear),
"learned" | "adaptive" | "auto" => Ok(FusionMethod::Learned),
_ => Err(format!("Unknown fusion method: {}", s)),
}
}
}
/// Fusion configuration
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct FusionConfig {
/// Fusion method to use
pub method: FusionMethod,
/// RRF constant (typically 60)
pub rrf_k: usize,
/// Alpha for linear fusion (0 = all keyword, 1 = all vector)
pub alpha: f32,
}
impl Default for FusionConfig {
fn default() -> Self {
Self {
method: FusionMethod::Rrf,
rrf_k: DEFAULT_RRF_K,
alpha: DEFAULT_ALPHA,
}
}
}
/// Search result from a single branch (vector or keyword)
#[derive(Debug, Clone)]
pub struct BranchResult {
/// Document ID
pub doc_id: DocId,
/// Original score from the branch
pub score: f32,
/// Rank in the result set (0-indexed)
pub rank: usize,
}
/// Fused search result combining both branches
#[derive(Debug, Clone)]
pub struct FusedResult {
/// Document ID
pub doc_id: DocId,
/// Final fused score
pub hybrid_score: f32,
/// Original vector score (if present)
pub vector_score: Option<f32>,
/// Original keyword score (if present)
pub keyword_score: Option<f32>,
}
/// Reciprocal Rank Fusion (RRF)
///
/// RRF score = sum over sources: 1 / (k + rank)
///
/// Properties:
/// - Robust to different score scales
/// - No need for score calibration
/// - Works well with partial overlap
///
/// Reference: "Reciprocal Rank Fusion outperforms Condorcet and individual Rank Learning Methods"
pub fn rrf_fusion(
vector_results: &[(DocId, f32)], // (id, distance - lower is better)
keyword_results: &[(DocId, f32)], // (id, BM25 score - higher is better)
k: usize,
limit: usize,
) -> Vec<FusedResult> {
let mut scores: HashMap<DocId, (f32, Option<f32>, Option<f32>)> = HashMap::new();
// Vector ranking (lower distance = higher rank, so already sorted best first)
for (rank, (doc_id, distance)) in vector_results.iter().enumerate() {
let rrf_score = 1.0 / (k + rank + 1) as f32;
let entry = scores.entry(*doc_id).or_insert((0.0, None, None));
entry.0 += rrf_score;
entry.1 = Some(*distance);
}
// Keyword ranking (higher BM25 = higher rank, already sorted best first)
for (rank, (doc_id, bm25_score)) in keyword_results.iter().enumerate() {
let rrf_score = 1.0 / (k + rank + 1) as f32;
let entry = scores.entry(*doc_id).or_insert((0.0, None, None));
entry.0 += rrf_score;
entry.2 = Some(*bm25_score);
}
// Sort by fused score (descending)
let mut results: Vec<FusedResult> = scores
.into_iter()
.map(
|(doc_id, (hybrid_score, vector_score, keyword_score))| FusedResult {
doc_id,
hybrid_score,
vector_score,
keyword_score,
},
)
.collect();
results.sort_by(|a, b| {
b.hybrid_score
.partial_cmp(&a.hybrid_score)
.unwrap_or(std::cmp::Ordering::Equal)
});
results.truncate(limit);
results
}
/// Normalize vector distances to similarity scores [0, 1]
///
/// Converts distance (lower = better) to similarity (higher = better)
fn normalize_to_similarity(results: &[(DocId, f32)]) -> Vec<(DocId, f32)> {
if results.is_empty() {
return Vec::new();
}
// Find min/max distances
let (min_dist, max_dist) = results
.iter()
.fold((f32::MAX, f32::MIN), |(min, max), (_, d)| {
(min.min(*d), max.max(*d))
});
let range = (max_dist - min_dist).max(1e-6);
results
.iter()
.map(|(id, dist)| {
// Convert distance to similarity: 1 - normalized_distance
let similarity = 1.0 - (dist - min_dist) / range;
(*id, similarity)
})
.collect()
}
/// Min-max normalize scores to [0, 1]
fn min_max_normalize(results: &[(DocId, f32)]) -> Vec<(DocId, f32)> {
if results.is_empty() {
return Vec::new();
}
let (min_score, max_score) = results
.iter()
.fold((f32::MAX, f32::MIN), |(min, max), (_, s)| {
(min.min(*s), max.max(*s))
});
let range = (max_score - min_score).max(1e-6);
results
.iter()
.map(|(id, score)| {
let normalized = (score - min_score) / range;
(*id, normalized)
})
.collect()
}
/// Linear fusion with alpha blending
///
/// score = alpha * vector_similarity + (1 - alpha) * keyword_score
///
/// Note: Scores must be normalized before fusion
pub fn linear_fusion(
vector_results: &[(DocId, f32)], // (id, distance)
keyword_results: &[(DocId, f32)], // (id, BM25 score)
alpha: f32,
limit: usize,
) -> Vec<FusedResult> {
// Normalize vector scores (distance -> similarity)
let vec_scores: HashMap<DocId, f32> = normalize_to_similarity(vector_results)
.into_iter()
.collect();
// Normalize keyword scores to [0, 1]
let kw_scores: HashMap<DocId, f32> = min_max_normalize(keyword_results).into_iter().collect();
// Combine scores
let mut combined: HashMap<DocId, (f32, Option<f32>, Option<f32>)> = HashMap::new();
// Add vector results
for (doc_id, sim) in &vec_scores {
let entry = combined.entry(*doc_id).or_insert((0.0, None, None));
entry.0 += alpha * sim;
// Store original distance
if let Some((_, dist)) = vector_results.iter().find(|(id, _)| id == doc_id) {
entry.1 = Some(*dist);
}
}
// Add keyword results
for (doc_id, norm_score) in &kw_scores {
let entry = combined.entry(*doc_id).or_insert((0.0, None, None));
entry.0 += (1.0 - alpha) * norm_score;
// Store original BM25 score
if let Some((_, score)) = keyword_results.iter().find(|(id, _)| id == doc_id) {
entry.2 = Some(*score);
}
}
// Sort by fused score
let mut results: Vec<FusedResult> = combined
.into_iter()
.map(
|(doc_id, (hybrid_score, vector_score, keyword_score))| FusedResult {
doc_id,
hybrid_score,
vector_score,
keyword_score,
},
)
.collect();
results.sort_by(|a, b| {
b.hybrid_score
.partial_cmp(&a.hybrid_score)
.unwrap_or(std::cmp::Ordering::Equal)
});
results.truncate(limit);
results
}
/// Query features for learned fusion
#[derive(Debug, Clone)]
pub struct QueryFeatures {
/// L2 norm of query embedding
pub embedding_norm: f32,
/// Number of query terms
pub term_count: usize,
/// Average IDF of query terms
pub avg_term_idf: f32,
/// Whether query appears to need exact matching
pub has_exact_match: bool,
/// Classified query type
pub query_type: QueryType,
}
/// Query type classification for learned fusion
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum QueryType {
/// Navigational query (looking for specific entity)
Navigational,
/// Informational query (seeking information)
Informational,
/// Transactional query (action-oriented)
Transactional,
/// Unknown/mixed
Unknown,
}
/// Simple fusion model for learned/adaptive fusion
///
/// In production, this would be a trained ML model (e.g., GNN, logistic regression)
pub struct FusionModel {
/// Default alpha when model can't make prediction
pub default_alpha: f32,
/// Weight for embedding norm
pub norm_weight: f32,
/// Weight for term count
pub term_weight: f32,
/// Weight for avg IDF
pub idf_weight: f32,
/// Bias for exact match queries (favor keyword)
pub exact_match_bias: f32,
}
impl Default for FusionModel {
fn default() -> Self {
Self {
default_alpha: 0.5,
norm_weight: 0.1,
term_weight: -0.05, // More terms -> slight keyword preference
idf_weight: 0.15, // Rare terms -> vector preference
exact_match_bias: -0.2, // Exact match -> keyword preference
}
}
}
impl FusionModel {
/// Predict optimal alpha for a query
pub fn predict_alpha(&self, features: &QueryFeatures) -> f32 {
let mut alpha = self.default_alpha;
// Adjust based on embedding norm (high norm -> more distinctive)
alpha += self.norm_weight * (features.embedding_norm - 1.0).clamp(-1.0, 1.0);
// Adjust based on term count
alpha += self.term_weight * (features.term_count as f32 - 3.0).clamp(-3.0, 3.0);
// Adjust based on avg IDF (high IDF = rare terms, favor vector)
alpha += self.idf_weight * (features.avg_term_idf - 3.0).clamp(-3.0, 3.0);
// Adjust for exact match intent
if features.has_exact_match {
alpha += self.exact_match_bias;
}
// Adjust based on query type
match features.query_type {
QueryType::Navigational => alpha -= 0.15, // Favor keyword
QueryType::Informational => alpha += 0.1, // Favor vector
QueryType::Transactional => alpha -= 0.05,
QueryType::Unknown => {}
}
// Clamp to valid range
alpha.clamp(0.0, 1.0)
}
}
/// Learned fusion using query characteristics
pub fn learned_fusion(
query_embedding: &[f32],
query_terms: &[String],
vector_results: &[(DocId, f32)],
keyword_results: &[(DocId, f32)],
model: &FusionModel,
avg_term_idf: f32, // Pre-computed average IDF
limit: usize,
) -> Vec<FusedResult> {
// Compute query features
let embedding_norm = l2_norm(query_embedding);
let has_exact_match = detect_exact_match_intent(query_terms);
let query_type = classify_query_type(query_terms);
let features = QueryFeatures {
embedding_norm,
term_count: query_terms.len(),
avg_term_idf,
has_exact_match,
query_type,
};
// Predict optimal alpha
let alpha = model.predict_alpha(&features);
// Use linear fusion with predicted alpha
linear_fusion(vector_results, keyword_results, alpha, limit)
}
/// Compute L2 norm of a vector
fn l2_norm(v: &[f32]) -> f32 {
v.iter().map(|x| x * x).sum::<f32>().sqrt()
}
/// Detect if query seems to need exact matching
fn detect_exact_match_intent(terms: &[String]) -> bool {
// Heuristics for exact match intent:
// - Quoted phrases (handled upstream)
// - Product codes, error codes, IDs
// - Very short queries (1-2 terms)
if terms.len() <= 2 {
return true;
}
terms.iter().any(|t| {
// Looks like a code/ID
t.chars().any(|c| c.is_numeric()) && t.len() >= 3 && t.len() <= 20
})
}
/// Classify query type based on terms
fn classify_query_type(terms: &[String]) -> QueryType {
let terms_lower: Vec<String> = terms.iter().map(|t| t.to_lowercase()).collect();
// Navigational indicators
let nav_indicators = ["website", "login", "home", "official", "download"];
if terms_lower
.iter()
.any(|t| nav_indicators.contains(&t.as_str()))
{
return QueryType::Navigational;
}
// Transactional indicators
let trans_indicators = ["buy", "purchase", "order", "price", "cheap", "best", "deal"];
if terms_lower
.iter()
.any(|t| trans_indicators.contains(&t.as_str()))
{
return QueryType::Transactional;
}
// Informational indicators
let info_indicators = [
"how", "what", "why", "when", "where", "guide", "tutorial", "explain",
];
if terms_lower
.iter()
.any(|t| info_indicators.contains(&t.as_str()))
{
return QueryType::Informational;
}
QueryType::Unknown
}
/// Fuse results using the specified method
pub fn fuse_results(
vector_results: &[(DocId, f32)],
keyword_results: &[(DocId, f32)],
config: &FusionConfig,
limit: usize,
) -> Vec<FusedResult> {
match config.method {
FusionMethod::Rrf => rrf_fusion(vector_results, keyword_results, config.rrf_k, limit),
FusionMethod::Linear => linear_fusion(vector_results, keyword_results, config.alpha, limit),
FusionMethod::Learned => {
// Learned fusion requires additional context
// Fall back to RRF if no model is available
rrf_fusion(vector_results, keyword_results, config.rrf_k, limit)
}
}
}
#[cfg(test)]
mod tests {
use super::*;
fn sample_vector_results() -> Vec<(DocId, f32)> {
vec![
(1, 0.1), // Best (lowest distance)
(2, 0.2),
(3, 0.3),
(4, 0.5),
(5, 0.8),
]
}
fn sample_keyword_results() -> Vec<(DocId, f32)> {
vec![
(3, 8.5), // Best (highest BM25)
(1, 7.2),
(6, 5.0),
(2, 3.5),
(7, 2.0),
]
}
#[test]
fn test_rrf_fusion() {
let vector = sample_vector_results();
let keyword = sample_keyword_results();
let results = rrf_fusion(&vector, &keyword, 60, 5);
assert!(!results.is_empty());
// Doc 1 and 3 should be near top (they appear in both)
let top_ids: Vec<DocId> = results.iter().map(|r| r.doc_id).collect();
assert!(top_ids.contains(&1) || top_ids.contains(&3));
}
#[test]
fn test_linear_fusion_alpha_1() {
let vector = sample_vector_results();
let keyword = sample_keyword_results();
// Alpha = 1.0 means only vector
let results = linear_fusion(&vector, &keyword, 1.0, 5);
assert!(!results.is_empty());
// With alpha=1, vector-only docs should dominate
let first = &results[0];
assert!(first.vector_score.is_some());
}
#[test]
fn test_linear_fusion_alpha_0() {
let vector = sample_vector_results();
let keyword = sample_keyword_results();
// Alpha = 0.0 means only keyword
let results = linear_fusion(&vector, &keyword, 0.0, 5);
assert!(!results.is_empty());
// With alpha=0, keyword-only docs can appear
let first = &results[0];
assert!(first.keyword_score.is_some());
}
#[test]
fn test_normalization() {
let results = vec![(1, 0.1), (2, 0.5), (3, 0.9)];
let normalized = normalize_to_similarity(&results);
assert_eq!(normalized.len(), 3);
// Lowest distance should have highest similarity
let (id1, sim1) = normalized[0];
assert_eq!(id1, 1);
assert!((sim1 - 1.0).abs() < 0.01, "Best should be ~1.0");
}
#[test]
fn test_fusion_model() {
let model = FusionModel::default();
// Short navigational query
let features = QueryFeatures {
embedding_norm: 1.0,
term_count: 2,
avg_term_idf: 2.0,
has_exact_match: true,
query_type: QueryType::Navigational,
};
let alpha = model.predict_alpha(&features);
assert!(
alpha < 0.5,
"Navigational query should favor keyword (alpha < 0.5)"
);
// Long informational query
let features2 = QueryFeatures {
embedding_norm: 1.2,
term_count: 6,
avg_term_idf: 5.0,
has_exact_match: false,
query_type: QueryType::Informational,
};
let alpha2 = model.predict_alpha(&features2);
assert!(
alpha2 > 0.4,
"Informational query with rare terms should favor vector"
);
}
#[test]
fn test_query_type_classification() {
let nav = classify_query_type(&["github".into(), "login".into()]);
assert_eq!(nav, QueryType::Navigational);
let info = classify_query_type(&["how".into(), "to".into(), "cook".into(), "pasta".into()]);
assert_eq!(info, QueryType::Informational);
let trans = classify_query_type(&["buy".into(), "laptop".into()]);
assert_eq!(trans, QueryType::Transactional);
}
#[test]
fn test_exact_match_detection() {
assert!(detect_exact_match_intent(&["ERR001".into()]));
assert!(detect_exact_match_intent(&["SKU12345".into()]));
assert!(!detect_exact_match_intent(&[
"database".into(),
"connection".into(),
"error".into(),
"handling".into()
]));
}
#[test]
fn test_empty_results() {
let results = rrf_fusion(&[], &[], 60, 10);
assert!(results.is_empty());
let results2 = linear_fusion(&[], &[], 0.5, 10);
assert!(results2.is_empty());
}
}

View File

@@ -0,0 +1,533 @@
//! Hybrid Search (BM25 + Vector) for RuVector Postgres
//!
//! Provides combined keyword and semantic vector search with multiple fusion strategies.
//!
//! ## Features
//!
//! - **BM25 Scoring**: Proper BM25 implementation with document length normalization
//! - **Fusion Algorithms**: RRF (default), Linear blend, Learned/adaptive
//! - **Parallel Execution**: Vector and keyword branches can run concurrently
//! - **Registry System**: Track hybrid-enabled collections with per-collection settings
//!
//! ## SQL Interface
//!
//! ```sql
//! -- Register a collection for hybrid search
//! SELECT ruvector_register_hybrid(
//! collection := 'documents',
//! vector_column := 'embedding',
//! fts_column := 'fts',
//! text_column := 'content'
//! );
//!
//! -- Perform hybrid search
//! SELECT * FROM ruvector_hybrid_search(
//! 'documents',
//! query_text := 'database query optimization',
//! query_vector := $embedding,
//! k := 10,
//! fusion := 'rrf'
//! );
//! ```
pub mod bm25;
pub mod executor;
pub mod fusion;
pub mod registry;
// Re-exports
pub use bm25::{tokenize_query, BM25Config, BM25Scorer, CorpusStats, Document, TermFrequencies};
pub use executor::{
choose_strategy, BranchResults, ExecutionStats, HybridExecutor, HybridQuery, HybridResult,
HybridStrategy,
};
pub use fusion::{
fuse_results, learned_fusion, linear_fusion, rrf_fusion, DocId, FusedResult, FusionConfig,
FusionMethod, FusionModel, DEFAULT_ALPHA, DEFAULT_RRF_K,
};
pub use registry::{
get_registry, HybridCollectionConfig, HybridConfigUpdate, HybridRegistry, RegistryError,
HYBRID_REGISTRY,
};
use pgrx::prelude::*;
// ============================================================================
// SQL Functions
// ============================================================================
/// Register a collection for hybrid search
///
/// Creates the necessary metadata and computes initial corpus statistics.
///
/// # Arguments
/// * `collection` - Table name (optionally schema-qualified)
/// * `vector_column` - Name of the vector column
/// * `fts_column` - Name of the tsvector column
/// * `text_column` - Name of the original text column (for BM25 stats)
///
/// # Returns
/// JSON object with registration details
#[pg_extern]
fn ruvector_register_hybrid(
collection: &str,
vector_column: &str,
fts_column: &str,
text_column: &str,
) -> pgrx::JsonB {
// Parse collection name
let (schema, table) = parse_collection_name(collection);
// For now, use a simple hash as collection ID
// In production, this would query ruvector.collections table
let collection_id = collection
.bytes()
.fold(0i32, |acc, b| acc.wrapping_add(b as i32));
// Check if already registered
let registry = get_registry();
if registry.is_registered(collection_id) {
return pgrx::JsonB(serde_json::json!({
"success": false,
"error": format!("Collection '{}' is already registered for hybrid search", collection),
"collection_id": collection_id
}));
}
// Create configuration
let mut config = HybridCollectionConfig::new(
collection_id,
table.to_string(),
vector_column.to_string(),
fts_column.to_string(),
text_column.to_string(),
);
config.schema_name = schema.to_string();
// Register
match registry.register(config) {
Ok(_) => pgrx::JsonB(serde_json::json!({
"success": true,
"collection_id": collection_id,
"collection": collection,
"vector_column": vector_column,
"fts_column": fts_column,
"text_column": text_column,
"message": "Collection registered for hybrid search. Run ruvector_hybrid_update_stats() to compute corpus statistics."
})),
Err(e) => pgrx::JsonB(serde_json::json!({
"success": false,
"error": e.to_string()
})),
}
}
/// Update BM25 corpus statistics for a hybrid collection
///
/// Computes average document length, document count, and term frequencies.
/// Should be run periodically or after bulk inserts.
#[pg_extern]
fn ruvector_hybrid_update_stats(collection: &str) -> pgrx::JsonB {
let (schema, table) = parse_collection_name(collection);
let qualified_name = format!("{}.{}", schema, table);
let registry = get_registry();
let config = match registry.get_by_name(&qualified_name) {
Some(c) => c,
None => {
return pgrx::JsonB(serde_json::json!({
"success": false,
"error": format!("Collection '{}' is not registered for hybrid search", collection)
}));
}
};
// In the actual extension, we would run SQL to compute stats:
// SELECT AVG(LENGTH(text_column)), COUNT(*) FROM table
// For now, we return a placeholder indicating the function works
let now = std::time::SystemTime::now()
.duration_since(std::time::UNIX_EPOCH)
.unwrap()
.as_secs() as i64;
let stats = CorpusStats {
avg_doc_length: config.corpus_stats.avg_doc_length,
doc_count: config.corpus_stats.doc_count,
total_terms: config.corpus_stats.total_terms,
last_update: now,
};
match registry.update_stats(config.collection_id, stats) {
Ok(_) => pgrx::JsonB(serde_json::json!({
"success": true,
"collection": collection,
"message": "Stats update initiated. In production, this would compute actual corpus statistics.",
"note": "Use Spi::run to execute SQL for actual stats computation"
})),
Err(e) => pgrx::JsonB(serde_json::json!({
"success": false,
"error": e.to_string()
})),
}
}
/// Configure hybrid search settings for a collection
///
/// # Arguments
/// * `collection` - Collection name
/// * `config` - JSON configuration object
///
/// # Example Configuration
/// ```json
/// {
/// "default_fusion": "rrf",
/// "default_alpha": 0.5,
/// "rrf_k": 60,
/// "prefetch_k": 100,
/// "bm25_k1": 1.2,
/// "bm25_b": 0.75,
/// "stats_refresh_interval": "1 hour",
/// "parallel_enabled": true
/// }
/// ```
#[pg_extern]
fn ruvector_hybrid_configure(collection: &str, config: pgrx::JsonB) -> pgrx::JsonB {
let (schema, table) = parse_collection_name(collection);
let qualified_name = format!("{}.{}", schema, table);
let registry = get_registry();
let mut existing_config = match registry.get_by_name(&qualified_name) {
Some(c) => c,
None => {
return pgrx::JsonB(serde_json::json!({
"success": false,
"error": format!("Collection '{}' is not registered for hybrid search", collection)
}));
}
};
// Parse and apply updates
let update: HybridConfigUpdate = match serde_json::from_value(config.0.clone()) {
Ok(u) => u,
Err(e) => {
return pgrx::JsonB(serde_json::json!({
"success": false,
"error": format!("Invalid configuration: {}", e)
}));
}
};
if let Err(e) = update.apply(&mut existing_config) {
return pgrx::JsonB(serde_json::json!({
"success": false,
"error": e.to_string()
}));
}
match registry.update(existing_config.clone()) {
Ok(_) => pgrx::JsonB(serde_json::json!({
"success": true,
"collection": collection,
"config": {
"fusion_method": format!("{:?}", existing_config.fusion_config.method),
"alpha": existing_config.fusion_config.alpha,
"rrf_k": existing_config.fusion_config.rrf_k,
"prefetch_k": existing_config.prefetch_k,
"bm25_k1": existing_config.bm25_config.k1,
"bm25_b": existing_config.bm25_config.b,
"stats_refresh_interval": existing_config.stats_refresh_interval,
"parallel_enabled": existing_config.parallel_enabled
}
})),
Err(e) => pgrx::JsonB(serde_json::json!({
"success": false,
"error": e.to_string()
})),
}
}
/// Perform hybrid search combining vector and keyword search
///
/// # Arguments
/// * `collection` - Table name
/// * `query_text` - Text query for keyword search
/// * `query_vector` - Vector for semantic search
/// * `k` - Number of results to return
/// * `fusion` - Fusion method ("rrf", "linear", "learned")
/// * `alpha` - Alpha for linear fusion (0-1, higher favors vector)
///
/// # Returns
/// Table of results with id, content, vector_score, keyword_score, hybrid_score
#[pg_extern]
fn ruvector_hybrid_search(
collection: &str,
query_text: &str,
query_vector: Vec<f32>,
k: i32,
fusion: default!(Option<&str>, "NULL"),
alpha: default!(Option<f32>, "NULL"),
) -> pgrx::JsonB {
let k = k.max(1) as usize;
let (schema, table) = parse_collection_name(collection);
let qualified_name = format!("{}.{}", schema, table);
let registry = get_registry();
let config = match registry.get_by_name(&qualified_name) {
Some(c) => c,
None => {
return pgrx::JsonB(serde_json::json!({
"success": false,
"error": format!("Collection '{}' is not registered for hybrid search. Run ruvector_register_hybrid first.", collection)
}));
}
};
// Build fusion config
let mut fusion_config = config.fusion_config.clone();
if let Some(method) = fusion {
if let Ok(m) = method.parse::<FusionMethod>() {
fusion_config.method = m;
}
}
if let Some(a) = alpha {
fusion_config.alpha = a.clamp(0.0, 1.0);
}
// Build query
let query = HybridQuery {
text: query_text.to_string(),
embedding: query_vector,
k,
prefetch_k: config.prefetch_k.max(k * 2),
fusion_config,
filter: None,
};
// Create executor
let executor = HybridExecutor::new(config.corpus_stats.clone());
// In the actual extension, these would execute real searches via SPI
// For now, return a demonstration response
let mock_vector_results: Vec<(DocId, f32)> = (1..=k.min(10) as i64)
.map(|i| (i, 0.1 * i as f32))
.collect();
let mock_keyword_results: Vec<(DocId, f32)> = (1..=k.min(10) as i64)
.map(|i| ((k as i64 + 1 - i), 10.0 / i as f32))
.collect();
// Execute fusion
let (results, stats) = executor.execute(
&query,
|_, k| BranchResults {
results: mock_vector_results.clone().into_iter().take(k).collect(),
latency_ms: 1.0,
candidates_evaluated: 100,
},
|_, k| BranchResults {
results: mock_keyword_results.clone().into_iter().take(k).collect(),
latency_ms: 0.5,
candidates_evaluated: 50,
},
);
// Format results
let result_json: Vec<serde_json::Value> = results
.iter()
.enumerate()
.map(|(i, r)| {
serde_json::json!({
"rank": i + 1,
"id": r.id,
"hybrid_score": r.hybrid_score,
"vector_score": r.vector_score,
"keyword_score": r.keyword_score,
"vector_rank": r.vector_rank,
"keyword_rank": r.keyword_rank
})
})
.collect();
pgrx::JsonB(serde_json::json!({
"success": true,
"collection": collection,
"query": {
"text": query_text,
"vector_dims": query.embedding.len(),
"k": k,
"fusion": format!("{:?}", query.fusion_config.method),
"alpha": query.fusion_config.alpha
},
"results": result_json,
"stats": {
"total_latency_ms": stats.total_latency_ms,
"vector_latency_ms": stats.vector_latency_ms,
"keyword_latency_ms": stats.keyword_latency_ms,
"fusion_latency_ms": stats.fusion_latency_ms,
"result_count": stats.result_count
},
"note": "This is a demonstration. In production, actual vector/keyword searches would be executed via SPI."
}))
}
/// Get hybrid search statistics for a collection
#[pg_extern]
fn ruvector_hybrid_stats(collection: &str) -> pgrx::JsonB {
let (schema, table) = parse_collection_name(collection);
let qualified_name = format!("{}.{}", schema, table);
let registry = get_registry();
match registry.get_by_name(&qualified_name) {
Some(config) => pgrx::JsonB(serde_json::json!({
"collection": collection,
"corpus_stats": {
"avg_doc_length": config.corpus_stats.avg_doc_length,
"doc_count": config.corpus_stats.doc_count,
"total_terms": config.corpus_stats.total_terms,
"last_update": config.corpus_stats.last_update
},
"bm25_config": {
"k1": config.bm25_config.k1,
"b": config.bm25_config.b
},
"fusion_config": {
"method": format!("{:?}", config.fusion_config.method),
"alpha": config.fusion_config.alpha,
"rrf_k": config.fusion_config.rrf_k
},
"settings": {
"prefetch_k": config.prefetch_k,
"parallel_enabled": config.parallel_enabled,
"stats_refresh_interval": config.stats_refresh_interval
},
"metadata": {
"vector_column": config.vector_column,
"fts_column": config.fts_column,
"text_column": config.text_column,
"created_at": config.created_at,
"updated_at": config.updated_at
}
})),
None => pgrx::JsonB(serde_json::json!({
"error": format!("Collection '{}' is not registered for hybrid search", collection)
})),
}
}
/// Compute hybrid score from vector distance and keyword score
///
/// Utility function for manual hybrid scoring in queries.
#[pg_extern(immutable, parallel_safe)]
fn ruvector_hybrid_score(
vector_distance: f32,
keyword_score: f32,
alpha: default!(Option<f32>, "0.5"),
) -> f32 {
let alpha = alpha.unwrap_or(0.5).clamp(0.0, 1.0);
// Convert distance to similarity (assuming cosine distance in [0, 2])
let vector_similarity = 1.0 - (vector_distance / 2.0).clamp(0.0, 1.0);
// Simple linear blend (normalized keyword scores assumed)
alpha * vector_similarity + (1.0 - alpha) * keyword_score
}
/// List all collections registered for hybrid search
#[pg_extern]
fn ruvector_hybrid_list() -> pgrx::JsonB {
let registry = get_registry();
let collections: Vec<serde_json::Value> = registry
.list()
.iter()
.map(|c| {
serde_json::json!({
"collection_id": c.collection_id,
"name": c.qualified_name(),
"vector_column": c.vector_column,
"fts_column": c.fts_column,
"fusion_method": format!("{:?}", c.fusion_config.method),
"doc_count": c.corpus_stats.doc_count,
"needs_refresh": c.needs_stats_refresh()
})
})
.collect();
pgrx::JsonB(serde_json::json!({
"count": collections.len(),
"collections": collections
}))
}
// ============================================================================
// Helper Functions
// ============================================================================
/// Parse collection name into schema and table
fn parse_collection_name(name: &str) -> (&str, &str) {
if let Some((schema, table)) = name.split_once('.') {
(schema, table)
} else {
("public", name)
}
}
// ============================================================================
// Tests
// ============================================================================
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_parse_collection_name() {
let (schema, table) = parse_collection_name("documents");
assert_eq!(schema, "public");
assert_eq!(table, "documents");
let (schema, table) = parse_collection_name("myschema.mytable");
assert_eq!(schema, "myschema");
assert_eq!(table, "mytable");
}
#[test]
fn test_module_exports() {
// Verify all expected types are accessible
let _ = BM25Config::default();
let _ = FusionConfig::default();
let _ = CorpusStats::default();
let stats = CorpusStats {
avg_doc_length: 100.0,
doc_count: 1000,
total_terms: 100000,
last_update: 0,
};
let _ = BM25Scorer::new(stats.clone());
let _ = HybridExecutor::new(stats);
}
#[test]
fn test_registry_workflow() {
let registry = HybridRegistry::new();
// Register
let config = HybridCollectionConfig::new(
1,
"test_table".to_string(),
"embedding".to_string(),
"fts".to_string(),
"content".to_string(),
);
registry.register(config).unwrap();
// Get
let retrieved = registry.get(1).unwrap();
assert_eq!(retrieved.table_name, "test_table");
// List
let list = registry.list();
assert_eq!(list.len(), 1);
}
}

View File

@@ -0,0 +1,651 @@
//! Hybrid Collections Registry
//!
//! Tracks collections with hybrid search enabled and stores:
//! - BM25 corpus statistics
//! - Per-collection fusion settings
//! - Column mappings for vector and FTS
use parking_lot::RwLock;
use serde::{Deserialize, Serialize};
use std::collections::HashMap;
use std::sync::Arc;
use super::bm25::{BM25Config, CorpusStats};
use super::fusion::FusionConfig;
#[cfg(test)]
use super::fusion::FusionMethod;
/// Hybrid collection configuration
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct HybridCollectionConfig {
/// Collection ID (from ruvector.collections table)
pub collection_id: i32,
/// Table name
pub table_name: String,
/// Schema name (default: public)
pub schema_name: String,
/// Vector column name
pub vector_column: String,
/// FTS tsvector column name
pub fts_column: String,
/// Original text column name (for BM25 stats)
pub text_column: String,
/// Primary key column name
pub pk_column: String,
/// BM25 configuration
pub bm25_config: BM25Config,
/// Fusion configuration
pub fusion_config: FusionConfig,
/// Corpus statistics
pub corpus_stats: CorpusStats,
/// Prefetch size for each branch
pub prefetch_k: usize,
/// Stats refresh interval in seconds
pub stats_refresh_interval: i64,
/// Enable parallel branch execution
pub parallel_enabled: bool,
/// Created timestamp (Unix epoch)
pub created_at: i64,
/// Last modified timestamp
pub updated_at: i64,
}
impl HybridCollectionConfig {
/// Create a new hybrid collection configuration
pub fn new(
collection_id: i32,
table_name: String,
vector_column: String,
fts_column: String,
text_column: String,
) -> Self {
let now = std::time::SystemTime::now()
.duration_since(std::time::UNIX_EPOCH)
.unwrap()
.as_secs() as i64;
Self {
collection_id,
table_name,
schema_name: "public".to_string(),
vector_column,
fts_column,
text_column,
pk_column: "id".to_string(),
bm25_config: BM25Config::default(),
fusion_config: FusionConfig::default(),
corpus_stats: CorpusStats::default(),
prefetch_k: 100,
stats_refresh_interval: 3600, // 1 hour
parallel_enabled: true,
created_at: now,
updated_at: now,
}
}
/// Get fully qualified table name
pub fn qualified_name(&self) -> String {
format!("{}.{}", self.schema_name, self.table_name)
}
/// Check if stats need refresh
pub fn needs_stats_refresh(&self) -> bool {
let now = std::time::SystemTime::now()
.duration_since(std::time::UNIX_EPOCH)
.unwrap()
.as_secs() as i64;
now - self.corpus_stats.last_update > self.stats_refresh_interval
}
/// Update corpus statistics
pub fn update_stats(&mut self, stats: CorpusStats) {
self.corpus_stats = stats;
self.updated_at = std::time::SystemTime::now()
.duration_since(std::time::UNIX_EPOCH)
.unwrap()
.as_secs() as i64;
}
}
/// Registry entry for a hybrid collection
#[derive(Debug)]
struct RegistryEntry {
/// Configuration
config: HybridCollectionConfig,
/// Cached IDF values (term -> idf)
idf_cache: HashMap<String, f32>,
/// Document frequency cache (term -> doc count)
df_cache: HashMap<String, u64>,
}
impl RegistryEntry {
fn new(config: HybridCollectionConfig) -> Self {
Self {
config,
idf_cache: HashMap::new(),
df_cache: HashMap::new(),
}
}
}
/// Hybrid Collections Registry
///
/// Global registry for hybrid-enabled collections.
/// In the PostgreSQL extension, this is backed by the ruvector.hybrid_collections table.
pub struct HybridRegistry {
/// Collections by ID
collections_by_id: RwLock<HashMap<i32, RegistryEntry>>,
/// Collections by name (schema.table -> id)
collections_by_name: RwLock<HashMap<String, i32>>,
}
impl HybridRegistry {
/// Create a new registry
pub fn new() -> Self {
Self {
collections_by_id: RwLock::new(HashMap::new()),
collections_by_name: RwLock::new(HashMap::new()),
}
}
/// Register a collection for hybrid search
pub fn register(&self, config: HybridCollectionConfig) -> Result<(), RegistryError> {
let qualified_name = config.qualified_name();
let collection_id = config.collection_id;
// Check for duplicates
{
let by_name = self.collections_by_name.read();
if by_name.contains_key(&qualified_name) {
return Err(RegistryError::AlreadyRegistered(qualified_name));
}
}
// Insert into both maps
let entry = RegistryEntry::new(config);
self.collections_by_id.write().insert(collection_id, entry);
self.collections_by_name
.write()
.insert(qualified_name, collection_id);
Ok(())
}
/// Unregister a collection
pub fn unregister(&self, collection_id: i32) -> Result<(), RegistryError> {
let entry = self.collections_by_id.write().remove(&collection_id);
if let Some(entry) = entry {
let qualified_name = entry.config.qualified_name();
self.collections_by_name.write().remove(&qualified_name);
Ok(())
} else {
Err(RegistryError::NotFound(collection_id.to_string()))
}
}
/// Get collection by ID
pub fn get(&self, collection_id: i32) -> Option<HybridCollectionConfig> {
self.collections_by_id
.read()
.get(&collection_id)
.map(|e| e.config.clone())
}
/// Get collection by name
pub fn get_by_name(&self, name: &str) -> Option<HybridCollectionConfig> {
let collection_id = self.collections_by_name.read().get(name).copied()?;
self.get(collection_id)
}
/// Update collection configuration
pub fn update(&self, config: HybridCollectionConfig) -> Result<(), RegistryError> {
let collection_id = config.collection_id;
let mut by_id = self.collections_by_id.write();
if let Some(entry) = by_id.get_mut(&collection_id) {
entry.config = config;
Ok(())
} else {
Err(RegistryError::NotFound(collection_id.to_string()))
}
}
/// Update corpus statistics for a collection
pub fn update_stats(
&self,
collection_id: i32,
stats: CorpusStats,
) -> Result<(), RegistryError> {
let mut by_id = self.collections_by_id.write();
if let Some(entry) = by_id.get_mut(&collection_id) {
entry.config.update_stats(stats);
// Clear caches when stats change
entry.idf_cache.clear();
entry.df_cache.clear();
Ok(())
} else {
Err(RegistryError::NotFound(collection_id.to_string()))
}
}
/// Set document frequency for a term in a collection
pub fn set_doc_freq(
&self,
collection_id: i32,
term: &str,
doc_freq: u64,
) -> Result<(), RegistryError> {
let mut by_id = self.collections_by_id.write();
if let Some(entry) = by_id.get_mut(&collection_id) {
entry.df_cache.insert(term.to_string(), doc_freq);
// Invalidate IDF cache for this term
entry.idf_cache.remove(term);
Ok(())
} else {
Err(RegistryError::NotFound(collection_id.to_string()))
}
}
/// Get IDF for a term, computing if not cached
pub fn get_idf(&self, collection_id: i32, term: &str) -> Option<f32> {
let mut by_id = self.collections_by_id.write();
let entry = by_id.get_mut(&collection_id)?;
// Check cache
if let Some(&idf) = entry.idf_cache.get(term) {
return Some(idf);
}
// Compute IDF
let df = entry.df_cache.get(term).copied().unwrap_or(0);
let n = entry.config.corpus_stats.doc_count as f32;
let df_f = df as f32;
let idf = if df == 0 {
(n + 0.5).ln()
} else {
((n - df_f + 0.5) / (df_f + 0.5) + 1.0).ln()
};
// Cache and return
entry.idf_cache.insert(term.to_string(), idf);
Some(idf)
}
/// List all registered collections
pub fn list(&self) -> Vec<HybridCollectionConfig> {
self.collections_by_id
.read()
.values()
.map(|e| e.config.clone())
.collect()
}
/// Check if a collection is registered
pub fn is_registered(&self, collection_id: i32) -> bool {
self.collections_by_id.read().contains_key(&collection_id)
}
/// Get collections needing stats refresh
pub fn collections_needing_refresh(&self) -> Vec<i32> {
self.collections_by_id
.read()
.iter()
.filter(|(_, e)| e.config.needs_stats_refresh())
.map(|(id, _)| *id)
.collect()
}
/// Clear all caches
pub fn clear_caches(&self) {
let mut by_id = self.collections_by_id.write();
for entry in by_id.values_mut() {
entry.idf_cache.clear();
entry.df_cache.clear();
}
}
}
impl Default for HybridRegistry {
fn default() -> Self {
Self::new()
}
}
/// Registry error types
#[derive(Debug, Clone)]
pub enum RegistryError {
/// Collection already registered
AlreadyRegistered(String),
/// Collection not found
NotFound(String),
/// Invalid configuration
InvalidConfig(String),
/// Database error
DatabaseError(String),
}
impl std::fmt::Display for RegistryError {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self {
RegistryError::AlreadyRegistered(name) => {
write!(
f,
"Collection '{}' is already registered for hybrid search",
name
)
}
RegistryError::NotFound(name) => {
write!(f, "Hybrid collection '{}' not found", name)
}
RegistryError::InvalidConfig(msg) => {
write!(f, "Invalid hybrid configuration: {}", msg)
}
RegistryError::DatabaseError(msg) => {
write!(f, "Database error: {}", msg)
}
}
}
}
impl std::error::Error for RegistryError {}
// Global registry instance
lazy_static::lazy_static! {
/// Global hybrid collections registry
pub static ref HYBRID_REGISTRY: Arc<HybridRegistry> = Arc::new(HybridRegistry::new());
}
/// Get the global hybrid registry
pub fn get_registry() -> Arc<HybridRegistry> {
HYBRID_REGISTRY.clone()
}
/// Configuration update from JSONB
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct HybridConfigUpdate {
/// New fusion method
#[serde(default, skip_serializing_if = "Option::is_none")]
pub default_fusion: Option<String>,
/// New alpha value
#[serde(default, skip_serializing_if = "Option::is_none")]
pub default_alpha: Option<f32>,
/// New RRF k value
#[serde(default, skip_serializing_if = "Option::is_none")]
pub rrf_k: Option<usize>,
/// New prefetch k value
#[serde(default, skip_serializing_if = "Option::is_none")]
pub prefetch_k: Option<usize>,
/// BM25 k1 parameter
#[serde(default, skip_serializing_if = "Option::is_none")]
pub bm25_k1: Option<f32>,
/// BM25 b parameter
#[serde(default, skip_serializing_if = "Option::is_none")]
pub bm25_b: Option<f32>,
/// Stats refresh interval (e.g., "1 hour", "30 minutes")
#[serde(default, skip_serializing_if = "Option::is_none")]
pub stats_refresh_interval: Option<String>,
/// Enable parallel execution
#[serde(default, skip_serializing_if = "Option::is_none")]
pub parallel_enabled: Option<bool>,
}
impl HybridConfigUpdate {
/// Apply updates to a configuration
pub fn apply(&self, config: &mut HybridCollectionConfig) -> Result<(), RegistryError> {
if let Some(ref fusion) = self.default_fusion {
config.fusion_config.method = fusion
.parse()
.map_err(|e: String| RegistryError::InvalidConfig(e))?;
}
if let Some(alpha) = self.default_alpha {
if !(0.0..=1.0).contains(&alpha) {
return Err(RegistryError::InvalidConfig(
"alpha must be between 0 and 1".into(),
));
}
config.fusion_config.alpha = alpha;
}
if let Some(rrf_k) = self.rrf_k {
if rrf_k == 0 {
return Err(RegistryError::InvalidConfig(
"rrf_k must be positive".into(),
));
}
config.fusion_config.rrf_k = rrf_k;
}
if let Some(prefetch_k) = self.prefetch_k {
if prefetch_k == 0 {
return Err(RegistryError::InvalidConfig(
"prefetch_k must be positive".into(),
));
}
config.prefetch_k = prefetch_k;
}
if let Some(k1) = self.bm25_k1 {
config.bm25_config.k1 = k1.max(0.0);
}
if let Some(b) = self.bm25_b {
config.bm25_config.b = b.clamp(0.0, 1.0);
}
if let Some(ref interval) = self.stats_refresh_interval {
config.stats_refresh_interval = parse_interval(interval)?;
}
if let Some(parallel) = self.parallel_enabled {
config.parallel_enabled = parallel;
}
config.updated_at = std::time::SystemTime::now()
.duration_since(std::time::UNIX_EPOCH)
.unwrap()
.as_secs() as i64;
Ok(())
}
}
/// Parse interval string to seconds
fn parse_interval(s: &str) -> Result<i64, RegistryError> {
let s = s.trim().to_lowercase();
// Try common formats
if let Some(hours) = s.strip_suffix(" hour").or_else(|| s.strip_suffix(" hours")) {
return hours
.trim()
.parse::<i64>()
.map(|h| h * 3600)
.map_err(|_| RegistryError::InvalidConfig(format!("Invalid interval: {}", s)));
}
if let Some(mins) = s
.strip_suffix(" minute")
.or_else(|| s.strip_suffix(" minutes"))
{
return mins
.trim()
.parse::<i64>()
.map(|m| m * 60)
.map_err(|_| RegistryError::InvalidConfig(format!("Invalid interval: {}", s)));
}
if let Some(secs) = s
.strip_suffix(" second")
.or_else(|| s.strip_suffix(" seconds"))
{
return secs
.trim()
.parse::<i64>()
.map_err(|_| RegistryError::InvalidConfig(format!("Invalid interval: {}", s)));
}
// Try as plain seconds
s.parse::<i64>()
.map_err(|_| RegistryError::InvalidConfig(format!("Invalid interval: {}", s)))
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_registry_register_get() {
let registry = HybridRegistry::new();
let config = HybridCollectionConfig::new(
1,
"documents".to_string(),
"embedding".to_string(),
"fts".to_string(),
"content".to_string(),
);
registry.register(config.clone()).unwrap();
let retrieved = registry.get(1).unwrap();
assert_eq!(retrieved.table_name, "documents");
assert_eq!(retrieved.vector_column, "embedding");
}
#[test]
fn test_registry_duplicate() {
let registry = HybridRegistry::new();
let config = HybridCollectionConfig::new(
1,
"documents".to_string(),
"embedding".to_string(),
"fts".to_string(),
"content".to_string(),
);
registry.register(config.clone()).unwrap();
let result = registry.register(config);
assert!(matches!(result, Err(RegistryError::AlreadyRegistered(_))));
}
#[test]
fn test_registry_get_by_name() {
let registry = HybridRegistry::new();
let config = HybridCollectionConfig::new(
42,
"my_table".to_string(),
"vec".to_string(),
"tsv".to_string(),
"text".to_string(),
);
registry.register(config).unwrap();
let retrieved = registry.get_by_name("public.my_table").unwrap();
assert_eq!(retrieved.collection_id, 42);
}
#[test]
fn test_registry_update_stats() {
let registry = HybridRegistry::new();
let config = HybridCollectionConfig::new(
1,
"test".to_string(),
"vec".to_string(),
"fts".to_string(),
"text".to_string(),
);
registry.register(config).unwrap();
let new_stats = CorpusStats {
avg_doc_length: 150.0,
doc_count: 5000,
total_terms: 500000,
last_update: 12345,
};
registry.update_stats(1, new_stats).unwrap();
let updated = registry.get(1).unwrap();
assert!((updated.corpus_stats.avg_doc_length - 150.0).abs() < 0.01);
assert_eq!(updated.corpus_stats.doc_count, 5000);
}
#[test]
fn test_config_update() {
let mut config = HybridCollectionConfig::new(
1,
"test".to_string(),
"vec".to_string(),
"fts".to_string(),
"text".to_string(),
);
let update = HybridConfigUpdate {
default_fusion: Some("linear".to_string()),
default_alpha: Some(0.7),
rrf_k: Some(40),
prefetch_k: Some(200),
bm25_k1: Some(1.5),
bm25_b: Some(0.8),
stats_refresh_interval: Some("2 hours".to_string()),
parallel_enabled: Some(false),
};
update.apply(&mut config).unwrap();
assert_eq!(config.fusion_config.method, FusionMethod::Linear);
assert!((config.fusion_config.alpha - 0.7).abs() < 0.01);
assert_eq!(config.fusion_config.rrf_k, 40);
assert_eq!(config.prefetch_k, 200);
assert!((config.bm25_config.k1 - 1.5).abs() < 0.01);
assert!((config.bm25_config.b - 0.8).abs() < 0.01);
assert_eq!(config.stats_refresh_interval, 7200);
assert!(!config.parallel_enabled);
}
#[test]
fn test_parse_interval() {
assert_eq!(parse_interval("1 hour").unwrap(), 3600);
assert_eq!(parse_interval("2 hours").unwrap(), 7200);
assert_eq!(parse_interval("30 minutes").unwrap(), 1800);
assert_eq!(parse_interval("60 seconds").unwrap(), 60);
assert_eq!(parse_interval("120").unwrap(), 120);
}
#[test]
fn test_needs_refresh() {
let mut config = HybridCollectionConfig::new(
1,
"test".to_string(),
"vec".to_string(),
"fts".to_string(),
"text".to_string(),
);
// Fresh stats should not need refresh
config.corpus_stats.last_update = std::time::SystemTime::now()
.duration_since(std::time::UNIX_EPOCH)
.unwrap()
.as_secs() as i64;
config.stats_refresh_interval = 3600;
assert!(!config.needs_stats_refresh());
// Old stats should need refresh
config.corpus_stats.last_update -= 7200;
assert!(config.needs_stats_refresh());
}
}

View File

@@ -0,0 +1,806 @@
//! Comprehensive tests for the hybrid search module
//!
//! Tests cover:
//! - BM25 scoring correctness
//! - Fusion algorithm behavior
//! - Executor integration
//! - Registry operations
#[cfg(test)]
mod bm25_tests {
use crate::hybrid::bm25::*;
use std::collections::HashMap;
/// Create a test scorer with known corpus statistics
fn test_scorer() -> BM25Scorer {
let stats = CorpusStats {
avg_doc_length: 100.0,
doc_count: 10000,
total_terms: 1000000,
last_update: 0,
};
BM25Scorer::new(stats)
}
#[test]
fn test_bm25_idf_formula() {
let scorer = test_scorer();
// Set known document frequencies
scorer.set_doc_freq("common", 5000); // 50% of docs
scorer.set_doc_freq("rare", 10); // 0.1% of docs
scorer.set_doc_freq("unique", 1); // 0.01% of docs
let idf_common = scorer.idf("common");
let idf_rare = scorer.idf("rare");
let idf_unique = scorer.idf("unique");
// IDF should increase as term rarity increases
assert!(idf_common < idf_rare, "Common term should have lower IDF");
assert!(idf_rare < idf_unique, "Rare term should have lower IDF than unique");
// Verify approximate values using BM25 formula
// IDF = ln((N - df + 0.5) / (df + 0.5) + 1)
// For common (df=5000, N=10000): ln((10000-5000+0.5)/(5000+0.5)+1) ~= ln(2) ~= 0.69
assert!((idf_common - 0.69).abs() < 0.1, "IDF common: {}", idf_common);
}
#[test]
fn test_bm25_score_single_term() {
let scorer = test_scorer();
scorer.set_doc_freq("test", 1000); // 10% of docs
let mut freqs = HashMap::new();
freqs.insert("test".to_string(), 5); // Term appears 5 times
let term_freqs = TermFrequencies::new(freqs);
let doc = Document::new(&term_freqs);
let score = scorer.score(&doc, &["test".to_string()]);
assert!(score > 0.0, "Score should be positive");
}
#[test]
fn test_bm25_score_multiple_terms() {
let scorer = test_scorer();
scorer.set_doc_freq("database", 500);
scorer.set_doc_freq("query", 300);
scorer.set_doc_freq("optimization", 100);
let mut freqs = HashMap::new();
freqs.insert("database".to_string(), 3);
freqs.insert("query".to_string(), 2);
freqs.insert("optimization".to_string(), 1);
let term_freqs = TermFrequencies::new(freqs);
let doc = Document::new(&term_freqs);
let query_terms = vec![
"database".to_string(),
"query".to_string(),
"optimization".to_string(),
];
let score = scorer.score(&doc, &query_terms);
assert!(score > 0.0);
// Score with subset should be lower
let partial_score = scorer.score(&doc, &["database".to_string()]);
assert!(partial_score < score, "Partial match should score lower");
}
#[test]
fn test_bm25_length_normalization() {
let scorer = test_scorer();
scorer.set_doc_freq("keyword", 1000);
// Create two documents with same TF but different lengths
// Short doc: 50 terms, avg is 100
let mut short_freqs = HashMap::new();
short_freqs.insert("keyword".to_string(), 2);
for i in 0..48 {
short_freqs.insert(format!("other{}", i), 1);
}
let short_tf = TermFrequencies::new(short_freqs);
let short_doc = Document::new(&short_tf);
// Long doc: 200 terms, avg is 100
let mut long_freqs = HashMap::new();
long_freqs.insert("keyword".to_string(), 2);
for i in 0..198 {
long_freqs.insert(format!("other{}", i), 1);
}
let long_tf = TermFrequencies::new(long_freqs);
let long_doc = Document::new(&long_tf);
let query = vec!["keyword".to_string()];
let short_score = scorer.score(&short_doc, &query);
let long_score = scorer.score(&long_doc, &query);
// Short doc should score higher due to length normalization
assert!(short_score > long_score,
"Short doc ({}) should score higher than long doc ({})",
short_score, long_score
);
}
#[test]
fn test_bm25_tf_saturation() {
let scorer = test_scorer();
scorer.set_doc_freq("term", 500);
// Document with low TF
let mut low_tf_freqs = HashMap::new();
low_tf_freqs.insert("term".to_string(), 1);
let low_tf = TermFrequencies::new(low_tf_freqs);
let low_doc = Document::new(&low_tf);
// Document with high TF
let mut high_tf_freqs = HashMap::new();
high_tf_freqs.insert("term".to_string(), 100);
let high_tf = TermFrequencies::new(high_tf_freqs);
let high_doc = Document::new(&high_tf);
let query = vec!["term".to_string()];
let low_score = scorer.score(&low_doc, &query);
let high_score = scorer.score(&high_doc, &query);
// High TF should score higher, but not 100x higher (saturation)
assert!(high_score > low_score);
assert!(high_score < low_score * 10.0,
"TF saturation should prevent linear scaling: {} vs {}",
high_score, low_score
);
}
#[test]
fn test_bm25_config_params() {
let stats = CorpusStats {
avg_doc_length: 100.0,
doc_count: 1000,
total_terms: 100000,
last_update: 0,
};
// High k1 = more weight to term frequency
let high_k1 = BM25Config::new(2.0, 0.75);
let scorer_high_k1 = BM25Scorer::with_config(stats.clone(), high_k1);
// Low k1 = less weight to term frequency
let low_k1 = BM25Config::new(0.5, 0.75);
let scorer_low_k1 = BM25Scorer::with_config(stats, low_k1);
scorer_high_k1.set_doc_freq("test", 100);
scorer_low_k1.set_doc_freq("test", 100);
let mut freqs = HashMap::new();
freqs.insert("test".to_string(), 10);
let tf = TermFrequencies::new(freqs);
let doc = Document::new(&tf);
let query = vec!["test".to_string()];
let score_high = scorer_high_k1.score(&doc, &query);
let score_low = scorer_low_k1.score(&doc, &query);
// Different k1 should produce different scores
assert!((score_high - score_low).abs() > 0.1,
"k1 should affect scoring: {} vs {}", score_high, score_low
);
}
#[test]
fn test_tokenize_query() {
let tokens = tokenize_query("Hello, World! This is a TEST.");
assert_eq!(tokens, vec!["hello", "world", "this", "is", "test"]);
let tokens2 = tokenize_query("database-query optimization");
assert!(tokens2.contains(&"database-query".to_string()) || tokens2.contains(&"database".to_string()));
}
#[test]
fn test_parse_tsvector() {
let tsvector = "'databas':1,4,7 'queri':2,5 'optim':3";
let freqs = parse_tsvector(tsvector);
assert_eq!(freqs.get("databas"), Some(&3));
assert_eq!(freqs.get("queri"), Some(&2));
assert_eq!(freqs.get("optim"), Some(&1));
}
}
#[cfg(test)]
mod fusion_tests {
use crate::hybrid::fusion::*;
fn sample_vector_results() -> Vec<(DocId, f32)> {
// Lower distance = better
vec![
(1, 0.1),
(2, 0.15),
(3, 0.25),
(4, 0.4),
(5, 0.6),
]
}
fn sample_keyword_results() -> Vec<(DocId, f32)> {
// Higher BM25 = better
vec![
(3, 9.5),
(6, 8.0),
(1, 7.2),
(7, 5.0),
(2, 3.5),
]
}
#[test]
fn test_rrf_basic() {
let vector = sample_vector_results();
let keyword = sample_keyword_results();
let results = rrf_fusion(&vector, &keyword, 60, 10);
assert!(!results.is_empty());
// Doc 1 and 3 appear in both, should rank highly
let top_3_ids: Vec<DocId> = results.iter().take(3).map(|r| r.doc_id).collect();
assert!(top_3_ids.contains(&1) || top_3_ids.contains(&3));
}
#[test]
fn test_rrf_k_parameter() {
let vector = sample_vector_results();
let keyword = sample_keyword_results();
// Low k = top ranks matter more
let results_low_k = rrf_fusion(&vector, &keyword, 10, 5);
// High k = ranks matter less
let results_high_k = rrf_fusion(&vector, &keyword, 100, 5);
// Both should produce results
assert!(!results_low_k.is_empty());
assert!(!results_high_k.is_empty());
// Order might differ due to k
let order_low: Vec<DocId> = results_low_k.iter().map(|r| r.doc_id).collect();
let order_high: Vec<DocId> = results_high_k.iter().map(|r| r.doc_id).collect();
// At least verify both have same elements (possibly different order)
for id in &order_low {
assert!(order_high.contains(id) || order_low.len() > order_high.len());
}
}
#[test]
fn test_linear_fusion_alpha() {
let vector = sample_vector_results();
let keyword = sample_keyword_results();
// Alpha = 1.0 means only vector
let results_vector_only = linear_fusion(&vector, &keyword, 1.0, 5);
// Alpha = 0.0 means only keyword
let results_keyword_only = linear_fusion(&vector, &keyword, 0.0, 5);
// With alpha=1, best vector result (id=1) should be top
assert_eq!(results_vector_only[0].doc_id, 1);
// With alpha=0, best keyword result (id=3) should be top
assert_eq!(results_keyword_only[0].doc_id, 3);
}
#[test]
fn test_linear_fusion_balanced() {
let vector = sample_vector_results();
let keyword = sample_keyword_results();
let results = linear_fusion(&vector, &keyword, 0.5, 5);
// All results should have both scores if they appeared in both
for r in &results {
assert!(r.hybrid_score > 0.0);
}
}
#[test]
fn test_fusion_preserves_scores() {
let vector = vec![(1, 0.1), (2, 0.2)];
let keyword = vec![(1, 5.0), (3, 4.0)];
let results = rrf_fusion(&vector, &keyword, 60, 10);
let doc1 = results.iter().find(|r| r.doc_id == 1).unwrap();
assert!(doc1.vector_score.is_some());
assert!(doc1.keyword_score.is_some());
let doc2 = results.iter().find(|r| r.doc_id == 2).unwrap();
assert!(doc2.vector_score.is_some());
assert!(doc2.keyword_score.is_none());
let doc3 = results.iter().find(|r| r.doc_id == 3).unwrap();
assert!(doc3.vector_score.is_none());
assert!(doc3.keyword_score.is_some());
}
#[test]
fn test_fusion_method_parse() {
assert_eq!("rrf".parse::<FusionMethod>().unwrap(), FusionMethod::Rrf);
assert_eq!("linear".parse::<FusionMethod>().unwrap(), FusionMethod::Linear);
assert_eq!("learned".parse::<FusionMethod>().unwrap(), FusionMethod::Learned);
assert!("invalid".parse::<FusionMethod>().is_err());
}
#[test]
fn test_query_type_classification() {
let nav = classify_query_type(&["github".into(), "login".into()]);
assert_eq!(nav, QueryType::Navigational);
let info = classify_query_type(&["how".into(), "to".into(), "build".into()]);
assert_eq!(info, QueryType::Informational);
let trans = classify_query_type(&["buy".into(), "cheap".into(), "laptop".into()]);
assert_eq!(trans, QueryType::Transactional);
}
#[test]
fn test_fusion_model() {
let model = FusionModel::default();
// Test navigational query (should favor keyword)
let nav_features = QueryFeatures {
embedding_norm: 1.0,
term_count: 2,
avg_term_idf: 2.0,
has_exact_match: true,
query_type: QueryType::Navigational,
};
let nav_alpha = model.predict_alpha(&nav_features);
assert!(nav_alpha < 0.5, "Nav query should favor keyword");
// Test informational query (should favor vector)
let info_features = QueryFeatures {
embedding_norm: 1.2,
term_count: 5,
avg_term_idf: 4.5,
has_exact_match: false,
query_type: QueryType::Informational,
};
let info_alpha = model.predict_alpha(&info_features);
assert!(info_alpha > 0.4, "Info query should favor vector");
}
}
#[cfg(test)]
mod executor_tests {
use crate::hybrid::executor::*;
use crate::hybrid::fusion::*;
use crate::hybrid::bm25::CorpusStats;
fn mock_corpus_stats() -> CorpusStats {
CorpusStats {
avg_doc_length: 150.0,
doc_count: 5000,
total_terms: 750000,
last_update: 0,
}
}
fn mock_vector_search(_embedding: &[f32], k: usize) -> BranchResults {
BranchResults {
results: (1..=k.min(10) as i64)
.map(|i| (i, 0.05 * i as f32))
.collect(),
latency_ms: 2.5,
candidates_evaluated: 500,
}
}
fn mock_keyword_search(_text: &str, k: usize) -> BranchResults {
BranchResults {
results: (1..=k.min(10) as i64)
.map(|i| (10 - i + 1, 12.0 - i as f32))
.collect(),
latency_ms: 1.2,
candidates_evaluated: 200,
}
}
#[test]
fn test_hybrid_query_builder() {
let query = HybridQuery::new("test query".into(), vec![0.1; 128], 10)
.with_fusion(FusionMethod::Linear)
.with_alpha(0.7)
.with_prefetch(200)
.with_rrf_k(40);
assert_eq!(query.k, 10);
assert_eq!(query.prefetch_k, 200);
assert_eq!(query.fusion_config.method, FusionMethod::Linear);
assert!((query.fusion_config.alpha - 0.7).abs() < 0.01);
assert_eq!(query.fusion_config.rrf_k, 40);
}
#[test]
fn test_executor_execute() {
let executor = HybridExecutor::new(mock_corpus_stats());
let query = HybridQuery::new(
"database optimization".into(),
vec![0.1; 64],
5,
);
let (results, stats) = executor.execute(&query, mock_vector_search, mock_keyword_search);
assert!(!results.is_empty());
assert!(results.len() <= 5);
// Check stats
assert!(stats.total_latency_ms > 0.0);
assert!(stats.vector_latency_ms > 0.0);
assert!(stats.keyword_latency_ms > 0.0);
assert!(stats.result_count <= 5);
}
#[test]
fn test_executor_with_different_fusion() {
let executor = HybridExecutor::new(mock_corpus_stats());
// RRF
let query_rrf = HybridQuery::new("test".into(), vec![0.1; 32], 5)
.with_fusion(FusionMethod::Rrf);
let (results_rrf, _) = executor.execute(&query_rrf, mock_vector_search, mock_keyword_search);
// Linear
let query_linear = HybridQuery::new("test".into(), vec![0.1; 32], 5)
.with_fusion(FusionMethod::Linear)
.with_alpha(0.5);
let (results_linear, _) = executor.execute(&query_linear, mock_vector_search, mock_keyword_search);
assert!(!results_rrf.is_empty());
assert!(!results_linear.is_empty());
}
#[test]
fn test_strategy_selection() {
// No filter
assert_eq!(choose_strategy(None, 10000, false), HybridStrategy::Full);
// Very selective filter
assert_eq!(choose_strategy(Some(0.005), 1000000, true), HybridStrategy::PreFilter);
// Moderate selectivity, large corpus
assert_eq!(choose_strategy(Some(0.05), 5000000, true), HybridStrategy::PostFilter);
// Low selectivity
assert_eq!(choose_strategy(Some(0.5), 10000, true), HybridStrategy::Full);
}
#[test]
fn test_result_has_ranks() {
let executor = HybridExecutor::new(mock_corpus_stats());
let query = HybridQuery::new("test".into(), vec![0.1; 16], 10);
let (results, _) = executor.execute(&query, mock_vector_search, mock_keyword_search);
// Check that rank information is populated
for r in &results {
// At least one rank should be present (doc appeared in at least one branch)
assert!(r.vector_rank.is_some() || r.keyword_rank.is_some());
}
}
}
#[cfg(test)]
mod registry_tests {
use crate::hybrid::registry::*;
use crate::hybrid::bm25::CorpusStats;
use crate::hybrid::fusion::FusionMethod;
#[test]
fn test_registry_lifecycle() {
let registry = HybridRegistry::new();
// Register
let config = HybridCollectionConfig::new(
1,
"test_collection".to_string(),
"embedding".to_string(),
"fts".to_string(),
"content".to_string(),
);
assert!(registry.register(config).is_ok());
// Get by ID
let retrieved = registry.get(1).unwrap();
assert_eq!(retrieved.table_name, "test_collection");
// Get by name
let by_name = registry.get_by_name("public.test_collection").unwrap();
assert_eq!(by_name.collection_id, 1);
// List
let list = registry.list();
assert_eq!(list.len(), 1);
// Unregister
assert!(registry.unregister(1).is_ok());
assert!(registry.get(1).is_none());
}
#[test]
fn test_registry_duplicate_prevention() {
let registry = HybridRegistry::new();
let config = HybridCollectionConfig::new(
1,
"unique_table".to_string(),
"vec".to_string(),
"fts".to_string(),
"text".to_string(),
);
registry.register(config.clone()).unwrap();
let result = registry.register(config);
assert!(matches!(result, Err(RegistryError::AlreadyRegistered(_))));
}
#[test]
fn test_registry_stats_update() {
let registry = HybridRegistry::new();
let config = HybridCollectionConfig::new(
42,
"stats_test".to_string(),
"v".to_string(),
"f".to_string(),
"t".to_string(),
);
registry.register(config).unwrap();
let new_stats = CorpusStats {
avg_doc_length: 200.0,
doc_count: 10000,
total_terms: 2000000,
last_update: 12345,
};
registry.update_stats(42, new_stats).unwrap();
let updated = registry.get(42).unwrap();
assert!((updated.corpus_stats.avg_doc_length - 200.0).abs() < 0.1);
assert_eq!(updated.corpus_stats.doc_count, 10000);
}
#[test]
fn test_config_update_apply() {
let mut config = HybridCollectionConfig::new(
1,
"test".to_string(),
"v".to_string(),
"f".to_string(),
"t".to_string(),
);
let update = HybridConfigUpdate {
default_fusion: Some("linear".to_string()),
default_alpha: Some(0.8),
rrf_k: Some(50),
prefetch_k: Some(150),
bm25_k1: Some(1.4),
bm25_b: Some(0.7),
stats_refresh_interval: Some("30 minutes".to_string()),
parallel_enabled: Some(false),
};
update.apply(&mut config).unwrap();
assert_eq!(config.fusion_config.method, FusionMethod::Linear);
assert!((config.fusion_config.alpha - 0.8).abs() < 0.01);
assert_eq!(config.fusion_config.rrf_k, 50);
assert_eq!(config.prefetch_k, 150);
assert!((config.bm25_config.k1 - 1.4).abs() < 0.01);
assert!((config.bm25_config.b - 0.7).abs() < 0.01);
assert_eq!(config.stats_refresh_interval, 1800);
assert!(!config.parallel_enabled);
}
#[test]
fn test_config_update_validation() {
let mut config = HybridCollectionConfig::new(
1,
"test".to_string(),
"v".to_string(),
"f".to_string(),
"t".to_string(),
);
// Invalid alpha
let invalid_alpha = HybridConfigUpdate {
default_alpha: Some(1.5),
..Default::default()
};
assert!(invalid_alpha.apply(&mut config).is_err());
// Invalid rrf_k
let invalid_rrf = HybridConfigUpdate {
rrf_k: Some(0),
..Default::default()
};
assert!(invalid_rrf.apply(&mut config).is_err());
}
#[test]
fn test_idf_caching() {
let registry = HybridRegistry::new();
let mut config = HybridCollectionConfig::new(
1,
"idf_test".to_string(),
"v".to_string(),
"f".to_string(),
"t".to_string(),
);
config.corpus_stats.doc_count = 1000;
registry.register(config).unwrap();
// Set doc freq
registry.set_doc_freq(1, "test_term", 100).unwrap();
// Get IDF (should compute and cache)
let idf1 = registry.get_idf(1, "test_term").unwrap();
assert!(idf1 > 0.0);
// Get again (should use cache)
let idf2 = registry.get_idf(1, "test_term").unwrap();
assert!((idf1 - idf2).abs() < 0.001);
}
#[test]
fn test_needs_refresh() {
let mut config = HybridCollectionConfig::new(
1,
"refresh_test".to_string(),
"v".to_string(),
"f".to_string(),
"t".to_string(),
);
// Set refresh interval to 1 hour
config.stats_refresh_interval = 3600;
// Fresh stats
config.corpus_stats.last_update = std::time::SystemTime::now()
.duration_since(std::time::UNIX_EPOCH)
.unwrap()
.as_secs() as i64;
assert!(!config.needs_stats_refresh());
// Stale stats (2 hours old)
config.corpus_stats.last_update -= 7200;
assert!(config.needs_stats_refresh());
}
}
#[cfg(test)]
mod integration_tests {
use crate::hybrid::*;
#[test]
fn test_end_to_end_workflow() {
// 1. Setup registry
let registry = HybridRegistry::new();
// 2. Register collection
let config = HybridCollectionConfig::new(
100,
"documents".to_string(),
"embedding".to_string(),
"fts".to_string(),
"content".to_string(),
);
registry.register(config).unwrap();
// 3. Update with corpus stats
let stats = CorpusStats {
avg_doc_length: 250.0,
doc_count: 50000,
total_terms: 12500000,
last_update: std::time::SystemTime::now()
.duration_since(std::time::UNIX_EPOCH)
.unwrap()
.as_secs() as i64,
};
registry.update_stats(100, stats).unwrap();
// 4. Configure hybrid settings
let config_update = HybridConfigUpdate {
default_fusion: Some("rrf".to_string()),
rrf_k: Some(60),
prefetch_k: Some(200),
..Default::default()
};
let mut updated_config = registry.get(100).unwrap();
config_update.apply(&mut updated_config).unwrap();
registry.update(updated_config.clone()).unwrap();
// 5. Create executor with updated config
let executor = HybridExecutor::new(updated_config.corpus_stats);
// 6. Execute query
let query = HybridQuery::new(
"machine learning optimization".to_string(),
vec![0.1; 768],
10,
)
.with_fusion(FusionMethod::Rrf)
.with_prefetch(200);
let mock_vector = |_: &[f32], k: usize| BranchResults {
results: (1..=k.min(20) as i64).map(|i| (i, i as f32 * 0.02)).collect(),
latency_ms: 3.0,
candidates_evaluated: 1000,
};
let mock_keyword = |_: &str, k: usize| BranchResults {
results: (1..=k.min(20) as i64).map(|i| (25 - i, 15.0 - i as f32 * 0.5)).collect(),
latency_ms: 1.5,
candidates_evaluated: 500,
};
let (results, stats) = executor.execute(&query, mock_vector, mock_keyword);
// 7. Verify results
assert!(!results.is_empty());
assert!(results.len() <= 10);
assert!(stats.total_latency_ms > 0.0);
// Top result should have high hybrid score
assert!(results[0].hybrid_score > 0.0);
}
#[test]
fn test_bm25_scorer_integration() {
let corpus_stats = CorpusStats {
avg_doc_length: 100.0,
doc_count: 1000,
total_terms: 100000,
last_update: 0,
};
let scorer = BM25Scorer::new(corpus_stats);
// Set up document frequencies
scorer.set_doc_freq("machine", 200);
scorer.set_doc_freq("learning", 150);
scorer.set_doc_freq("deep", 50);
// Create test document
let mut doc_freqs = std::collections::HashMap::new();
doc_freqs.insert("machine".to_string(), 3);
doc_freqs.insert("learning".to_string(), 2);
doc_freqs.insert("deep".to_string(), 1);
doc_freqs.insert("neural".to_string(), 2);
doc_freqs.insert("networks".to_string(), 2);
let term_freqs = TermFrequencies::new(doc_freqs);
let doc = Document::new(&term_freqs);
// Score with query
let query_terms = vec![
"machine".to_string(),
"learning".to_string(),
"deep".to_string(),
];
let score = scorer.score(&doc, &query_terms);
assert!(score > 0.0);
// "deep" is rarer, so query with just "deep" should have higher IDF contribution
let deep_idf = scorer.idf("deep");
let machine_idf = scorer.idf("machine");
assert!(deep_idf > machine_idf, "Rare term should have higher IDF");
}
}