Merge commit 'd803bfe2b1fe7f5e219e50ac20d6801a0a58ac75' as 'vendor/ruvector'
This commit is contained in:
452
vendor/ruvector/crates/ruvector-postgres/src/hybrid/bm25.rs
vendored
Normal file
452
vendor/ruvector/crates/ruvector-postgres/src/hybrid/bm25.rs
vendored
Normal file
@@ -0,0 +1,452 @@
|
||||
//! BM25 (Best Matching 25) scoring implementation
|
||||
//!
|
||||
//! Provides proper BM25 scoring with:
|
||||
//! - Document length normalization
|
||||
//! - IDF weighting across corpus
|
||||
//! - Term frequency saturation
|
||||
//! - Configurable k1 and b parameters
|
||||
//!
|
||||
//! Unlike PostgreSQL's ts_rank, this is a proper BM25 implementation.
|
||||
|
||||
use parking_lot::RwLock;
|
||||
use serde::{Deserialize, Serialize};
|
||||
use std::collections::HashMap;
|
||||
use std::sync::Arc;
|
||||
|
||||
/// Default BM25 k1 parameter (term frequency saturation)
|
||||
pub const DEFAULT_K1: f32 = 1.2;
|
||||
|
||||
/// Default BM25 b parameter (length normalization)
|
||||
pub const DEFAULT_B: f32 = 0.75;
|
||||
|
||||
/// Corpus statistics for BM25 scoring
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct CorpusStats {
|
||||
/// Average document length in the corpus
|
||||
pub avg_doc_length: f32,
|
||||
/// Total number of documents
|
||||
pub doc_count: u64,
|
||||
/// Total number of terms across all documents
|
||||
pub total_terms: u64,
|
||||
/// Last update timestamp (Unix epoch seconds)
|
||||
pub last_update: i64,
|
||||
}
|
||||
|
||||
impl Default for CorpusStats {
|
||||
fn default() -> Self {
|
||||
Self {
|
||||
avg_doc_length: 0.0,
|
||||
doc_count: 0,
|
||||
total_terms: 0,
|
||||
last_update: 0,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// BM25 configuration parameters
|
||||
#[derive(Debug, Clone, Copy, Serialize, Deserialize)]
|
||||
pub struct BM25Config {
|
||||
/// Term frequency saturation parameter (default: 1.2)
|
||||
/// Higher values give more weight to term frequency
|
||||
pub k1: f32,
|
||||
/// Length normalization parameter (default: 0.75)
|
||||
/// 0 = no length normalization, 1 = full normalization
|
||||
pub b: f32,
|
||||
}
|
||||
|
||||
impl Default for BM25Config {
|
||||
fn default() -> Self {
|
||||
Self {
|
||||
k1: DEFAULT_K1,
|
||||
b: DEFAULT_B,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl BM25Config {
|
||||
/// Create a new BM25 configuration
|
||||
pub fn new(k1: f32, b: f32) -> Self {
|
||||
Self {
|
||||
k1: k1.max(0.0),
|
||||
b: b.clamp(0.0, 1.0),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Term frequency information for a document
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct TermFrequencies {
|
||||
/// Term -> frequency map
|
||||
pub frequencies: HashMap<String, u32>,
|
||||
/// Total terms in document
|
||||
pub doc_length: u32,
|
||||
}
|
||||
|
||||
impl TermFrequencies {
|
||||
/// Create from term frequency map
|
||||
pub fn new(frequencies: HashMap<String, u32>) -> Self {
|
||||
let doc_length = frequencies.values().sum();
|
||||
Self {
|
||||
frequencies,
|
||||
doc_length,
|
||||
}
|
||||
}
|
||||
|
||||
/// Get term frequency for a specific term
|
||||
pub fn get(&self, term: &str) -> Option<u32> {
|
||||
self.frequencies.get(term).copied()
|
||||
}
|
||||
}
|
||||
|
||||
/// Document information for BM25 scoring
|
||||
pub struct Document<'a> {
|
||||
/// Term frequencies in the document
|
||||
pub term_freqs: &'a TermFrequencies,
|
||||
}
|
||||
|
||||
impl<'a> Document<'a> {
|
||||
/// Create a new document wrapper
|
||||
pub fn new(term_freqs: &'a TermFrequencies) -> Self {
|
||||
Self { term_freqs }
|
||||
}
|
||||
|
||||
/// Get term frequency for a term
|
||||
pub fn term_freq(&self, term: &str) -> Option<u32> {
|
||||
self.term_freqs.get(term)
|
||||
}
|
||||
|
||||
/// Get document length (total terms)
|
||||
pub fn term_count(&self) -> u32 {
|
||||
self.term_freqs.doc_length
|
||||
}
|
||||
}
|
||||
|
||||
/// BM25 scorer with corpus statistics and IDF caching
|
||||
pub struct BM25Scorer {
|
||||
/// Configuration parameters
|
||||
config: BM25Config,
|
||||
/// Corpus statistics
|
||||
corpus_stats: CorpusStats,
|
||||
/// Cached IDF values (term -> IDF score)
|
||||
idf_cache: Arc<RwLock<HashMap<String, f32>>>,
|
||||
/// Document frequency cache (term -> doc count containing term)
|
||||
df_cache: Arc<RwLock<HashMap<String, u64>>>,
|
||||
}
|
||||
|
||||
impl BM25Scorer {
|
||||
/// Create a new BM25 scorer with default config
|
||||
pub fn new(corpus_stats: CorpusStats) -> Self {
|
||||
Self::with_config(corpus_stats, BM25Config::default())
|
||||
}
|
||||
|
||||
/// Create a new BM25 scorer with custom config
|
||||
pub fn with_config(corpus_stats: CorpusStats, config: BM25Config) -> Self {
|
||||
Self {
|
||||
config,
|
||||
corpus_stats,
|
||||
idf_cache: Arc::new(RwLock::new(HashMap::new())),
|
||||
df_cache: Arc::new(RwLock::new(HashMap::new())),
|
||||
}
|
||||
}
|
||||
|
||||
/// Update corpus statistics
|
||||
pub fn update_corpus_stats(&mut self, stats: CorpusStats) {
|
||||
self.corpus_stats = stats;
|
||||
// Clear caches when stats change
|
||||
self.idf_cache.write().clear();
|
||||
}
|
||||
|
||||
/// Set document frequency for a term (used during index building)
|
||||
pub fn set_doc_freq(&self, term: &str, doc_freq: u64) {
|
||||
self.df_cache.write().insert(term.to_string(), doc_freq);
|
||||
// Invalidate IDF cache for this term
|
||||
self.idf_cache.write().remove(term);
|
||||
}
|
||||
|
||||
/// Compute IDF (Inverse Document Frequency) for a term
|
||||
///
|
||||
/// Uses the BM25 IDF formula:
|
||||
/// IDF(t) = ln((N - df(t) + 0.5) / (df(t) + 0.5) + 1)
|
||||
///
|
||||
/// where:
|
||||
/// - N = total documents in corpus
|
||||
/// - df(t) = number of documents containing term t
|
||||
pub fn idf(&self, term: &str) -> f32 {
|
||||
// Check cache first
|
||||
if let Some(&cached) = self.idf_cache.read().get(term) {
|
||||
return cached;
|
||||
}
|
||||
|
||||
// Get document frequency
|
||||
let df = self.df_cache.read().get(term).copied().unwrap_or(0);
|
||||
|
||||
// Compute IDF using BM25 formula
|
||||
let n = self.corpus_stats.doc_count as f32;
|
||||
let df_f = df as f32;
|
||||
|
||||
// Prevent division by zero and handle edge cases
|
||||
let idf = if df == 0 {
|
||||
// Term not in corpus - use max IDF
|
||||
(n + 0.5).ln()
|
||||
} else {
|
||||
((n - df_f + 0.5) / (df_f + 0.5) + 1.0).ln()
|
||||
};
|
||||
|
||||
// Cache the result
|
||||
self.idf_cache.write().insert(term.to_string(), idf);
|
||||
|
||||
idf
|
||||
}
|
||||
|
||||
/// Compute IDF with known document frequency (bypasses cache lookup)
|
||||
pub fn idf_with_df(&self, doc_freq: u64) -> f32 {
|
||||
let n = self.corpus_stats.doc_count as f32;
|
||||
let df = doc_freq as f32;
|
||||
|
||||
if doc_freq == 0 {
|
||||
(n + 0.5).ln()
|
||||
} else {
|
||||
((n - df + 0.5) / (df + 0.5) + 1.0).ln()
|
||||
}
|
||||
}
|
||||
|
||||
/// Score a document for a query
|
||||
///
|
||||
/// BM25 formula:
|
||||
/// score(D, Q) = sum over t in Q of: IDF(t) * (tf(t,D) * (k1 + 1)) / (tf(t,D) + k1 * (1 - b + b * |D|/avgdl))
|
||||
///
|
||||
/// where:
|
||||
/// - tf(t,D) = term frequency of t in document D
|
||||
/// - |D| = document length
|
||||
/// - avgdl = average document length
|
||||
/// - k1 = term saturation parameter
|
||||
/// - b = length normalization parameter
|
||||
pub fn score(&self, doc: &Document, query_terms: &[String]) -> f32 {
|
||||
let doc_len = doc.term_count() as f32;
|
||||
let avg_doc_len = self.corpus_stats.avg_doc_length.max(1.0);
|
||||
|
||||
// Length normalization factor
|
||||
let len_norm = 1.0 - self.config.b + self.config.b * (doc_len / avg_doc_len);
|
||||
|
||||
query_terms
|
||||
.iter()
|
||||
.filter_map(|term| {
|
||||
let tf = doc.term_freq(term)? as f32;
|
||||
let idf = self.idf(term);
|
||||
|
||||
// BM25 term score
|
||||
let numerator = tf * (self.config.k1 + 1.0);
|
||||
let denominator = tf + self.config.k1 * len_norm;
|
||||
|
||||
Some(idf * numerator / denominator)
|
||||
})
|
||||
.sum()
|
||||
}
|
||||
|
||||
/// Score a document with pre-computed term frequencies and document frequencies
|
||||
///
|
||||
/// This is the optimized version for batch scoring where IDF values are known.
|
||||
pub fn score_with_freqs(
|
||||
&self,
|
||||
term_freqs: &[(String, u32, u64)], // (term, tf in doc, df in corpus)
|
||||
doc_length: u32,
|
||||
) -> f32 {
|
||||
let doc_len = doc_length as f32;
|
||||
let avg_doc_len = self.corpus_stats.avg_doc_length.max(1.0);
|
||||
|
||||
let len_norm = 1.0 - self.config.b + self.config.b * (doc_len / avg_doc_len);
|
||||
|
||||
term_freqs
|
||||
.iter()
|
||||
.map(|(_, tf, df)| {
|
||||
let tf = *tf as f32;
|
||||
let idf = self.idf_with_df(*df);
|
||||
|
||||
let numerator = tf * (self.config.k1 + 1.0);
|
||||
let denominator = tf + self.config.k1 * len_norm;
|
||||
|
||||
idf * numerator / denominator
|
||||
})
|
||||
.sum()
|
||||
}
|
||||
|
||||
/// Batch score multiple documents for the same query
|
||||
pub fn score_batch<'a>(
|
||||
&self,
|
||||
docs: impl Iterator<Item = &'a Document<'a>>,
|
||||
query_terms: &[String],
|
||||
) -> Vec<f32> {
|
||||
docs.map(|doc| self.score(doc, query_terms)).collect()
|
||||
}
|
||||
|
||||
/// Get current configuration
|
||||
pub fn config(&self) -> &BM25Config {
|
||||
&self.config
|
||||
}
|
||||
|
||||
/// Get corpus statistics
|
||||
pub fn corpus_stats(&self) -> &CorpusStats {
|
||||
&self.corpus_stats
|
||||
}
|
||||
|
||||
/// Clear IDF cache
|
||||
pub fn clear_cache(&self) {
|
||||
self.idf_cache.write().clear();
|
||||
self.df_cache.write().clear();
|
||||
}
|
||||
}
|
||||
|
||||
/// Simple query tokenizer for BM25
|
||||
///
|
||||
/// Note: In production, this should use PostgreSQL's text search configuration
|
||||
/// for proper stemming, stopword removal, etc.
|
||||
pub fn tokenize_query(text: &str) -> Vec<String> {
|
||||
text.to_lowercase()
|
||||
.split_whitespace()
|
||||
.filter(|s| s.len() > 1) // Skip single characters
|
||||
.map(|s| s.trim_matches(|c: char| !c.is_alphanumeric()).to_string())
|
||||
.filter(|s| !s.is_empty())
|
||||
.collect()
|
||||
}
|
||||
|
||||
/// Parse tsvector-style string to term frequencies
|
||||
pub fn parse_tsvector(tsvector_str: &str) -> HashMap<String, u32> {
|
||||
let mut frequencies = HashMap::new();
|
||||
|
||||
// Format: 'term1':1,2,3 'term2':4,5
|
||||
for part in tsvector_str.split_whitespace() {
|
||||
if let Some(quote_end) = part.find("':") {
|
||||
if part.starts_with('\'') {
|
||||
let term = &part[1..quote_end];
|
||||
let positions = &part[quote_end + 2..];
|
||||
// Count positions as frequency
|
||||
let freq = positions.split(',').count() as u32;
|
||||
frequencies.insert(term.to_string(), freq.max(1));
|
||||
}
|
||||
} else if part.starts_with('\'') && part.ends_with('\'') {
|
||||
// Term without positions
|
||||
let term = &part[1..part.len() - 1];
|
||||
frequencies.insert(term.to_string(), 1);
|
||||
}
|
||||
}
|
||||
|
||||
frequencies
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
fn create_test_scorer() -> BM25Scorer {
|
||||
let stats = CorpusStats {
|
||||
avg_doc_length: 100.0,
|
||||
doc_count: 1000,
|
||||
total_terms: 100000,
|
||||
last_update: 0,
|
||||
};
|
||||
BM25Scorer::new(stats)
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_idf_common_term() {
|
||||
let scorer = create_test_scorer();
|
||||
scorer.set_doc_freq("the", 900); // Very common term
|
||||
|
||||
let idf = scorer.idf("the");
|
||||
assert!(idf > 0.0, "IDF should be positive");
|
||||
assert!(idf < 1.0, "IDF for common term should be low");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_idf_rare_term() {
|
||||
let scorer = create_test_scorer();
|
||||
scorer.set_doc_freq("xyzzy", 5); // Rare term
|
||||
|
||||
let idf = scorer.idf("xyzzy");
|
||||
assert!(idf > 4.0, "IDF for rare term should be high");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_idf_unknown_term() {
|
||||
let scorer = create_test_scorer();
|
||||
|
||||
let idf = scorer.idf("unknown_term_xyz");
|
||||
assert!(idf > 5.0, "IDF for unknown term should be maximum");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_bm25_score() {
|
||||
let scorer = create_test_scorer();
|
||||
scorer.set_doc_freq("database", 100);
|
||||
scorer.set_doc_freq("query", 50);
|
||||
|
||||
let mut freqs = HashMap::new();
|
||||
freqs.insert("database".to_string(), 3);
|
||||
freqs.insert("query".to_string(), 2);
|
||||
freqs.insert("other".to_string(), 5);
|
||||
|
||||
let term_freqs = TermFrequencies::new(freqs);
|
||||
let doc = Document::new(&term_freqs);
|
||||
|
||||
let query_terms = vec!["database".to_string(), "query".to_string()];
|
||||
let score = scorer.score(&doc, &query_terms);
|
||||
|
||||
assert!(score > 0.0, "Score should be positive");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_length_normalization() {
|
||||
let scorer = create_test_scorer();
|
||||
scorer.set_doc_freq("test", 100);
|
||||
|
||||
// Short document (50 terms)
|
||||
let mut short_freqs = HashMap::new();
|
||||
short_freqs.insert("test".to_string(), 2);
|
||||
for i in 0..48 {
|
||||
short_freqs.insert(format!("filler{}", i), 1);
|
||||
}
|
||||
let short_tf = TermFrequencies::new(short_freqs);
|
||||
let short_doc = Document::new(&short_tf);
|
||||
|
||||
// Long document (200 terms)
|
||||
let mut long_freqs = HashMap::new();
|
||||
long_freqs.insert("test".to_string(), 2);
|
||||
for i in 0..198 {
|
||||
long_freqs.insert(format!("filler{}", i), 1);
|
||||
}
|
||||
let long_tf = TermFrequencies::new(long_freqs);
|
||||
let long_doc = Document::new(&long_tf);
|
||||
|
||||
let query_terms = vec!["test".to_string()];
|
||||
let short_score = scorer.score(&short_doc, &query_terms);
|
||||
let long_score = scorer.score(&long_doc, &query_terms);
|
||||
|
||||
// Short doc should score higher (same tf, less length penalty)
|
||||
assert!(
|
||||
short_score > long_score,
|
||||
"Short doc should score higher than long doc with same TF"
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_tokenize_query() {
|
||||
let tokens = tokenize_query("Hello World! Database Query.");
|
||||
assert_eq!(tokens, vec!["hello", "world", "database", "query"]);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_parse_tsvector() {
|
||||
let tsvector = "'database':1,3,5 'query':2,4";
|
||||
let freqs = parse_tsvector(tsvector);
|
||||
|
||||
assert_eq!(freqs.get("database"), Some(&3));
|
||||
assert_eq!(freqs.get("query"), Some(&2));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_config_clamping() {
|
||||
let config = BM25Config::new(-1.0, 1.5);
|
||||
assert_eq!(config.k1, 0.0, "k1 should be clamped to 0");
|
||||
assert_eq!(config.b, 1.0, "b should be clamped to 1");
|
||||
}
|
||||
}
|
||||
522
vendor/ruvector/crates/ruvector-postgres/src/hybrid/executor.rs
vendored
Normal file
522
vendor/ruvector/crates/ruvector-postgres/src/hybrid/executor.rs
vendored
Normal file
@@ -0,0 +1,522 @@
|
||||
//! Hybrid Query Executor
|
||||
//!
|
||||
//! Executes vector and keyword search branches, optionally in parallel,
|
||||
//! and fuses results using the configured algorithm.
|
||||
|
||||
use serde::{Deserialize, Serialize};
|
||||
use std::collections::HashMap;
|
||||
|
||||
use super::bm25::{tokenize_query, BM25Scorer, CorpusStats, Document, TermFrequencies};
|
||||
use super::fusion::{
|
||||
fuse_results, learned_fusion, DocId, FusedResult, FusionConfig, FusionMethod, FusionModel,
|
||||
};
|
||||
|
||||
/// Hybrid search query
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct HybridQuery {
|
||||
/// Query text for keyword search
|
||||
pub text: String,
|
||||
/// Query embedding for vector search
|
||||
pub embedding: Vec<f32>,
|
||||
/// Number of final results to return
|
||||
pub k: usize,
|
||||
/// Number of results to prefetch from each branch
|
||||
pub prefetch_k: usize,
|
||||
/// Fusion configuration
|
||||
pub fusion_config: FusionConfig,
|
||||
/// Optional filter expression
|
||||
pub filter: Option<String>,
|
||||
}
|
||||
|
||||
impl HybridQuery {
|
||||
/// Create a new hybrid query
|
||||
pub fn new(text: String, embedding: Vec<f32>, k: usize) -> Self {
|
||||
Self {
|
||||
text,
|
||||
embedding,
|
||||
k,
|
||||
prefetch_k: k * 10, // Default prefetch 10x final k
|
||||
fusion_config: FusionConfig::default(),
|
||||
filter: None,
|
||||
}
|
||||
}
|
||||
|
||||
/// Set fusion method
|
||||
pub fn with_fusion(mut self, method: FusionMethod) -> Self {
|
||||
self.fusion_config.method = method;
|
||||
self
|
||||
}
|
||||
|
||||
/// Set fusion alpha (for linear fusion)
|
||||
pub fn with_alpha(mut self, alpha: f32) -> Self {
|
||||
self.fusion_config.alpha = alpha;
|
||||
self
|
||||
}
|
||||
|
||||
/// Set RRF k constant
|
||||
pub fn with_rrf_k(mut self, rrf_k: usize) -> Self {
|
||||
self.fusion_config.rrf_k = rrf_k;
|
||||
self
|
||||
}
|
||||
|
||||
/// Set prefetch size
|
||||
pub fn with_prefetch(mut self, prefetch_k: usize) -> Self {
|
||||
self.prefetch_k = prefetch_k;
|
||||
self
|
||||
}
|
||||
|
||||
/// Set filter expression
|
||||
pub fn with_filter(mut self, filter: String) -> Self {
|
||||
self.filter = Some(filter);
|
||||
self
|
||||
}
|
||||
}
|
||||
|
||||
/// Execution strategy for hybrid search
|
||||
#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
|
||||
pub enum HybridStrategy {
|
||||
/// Execute both branches fully
|
||||
Full,
|
||||
/// Pre-filter using keyword/metadata, then vector search on subset
|
||||
PreFilter,
|
||||
/// Execute hybrid search, then apply filter
|
||||
PostFilter,
|
||||
/// Vector search only (degraded mode)
|
||||
VectorOnly,
|
||||
/// Keyword search only (degraded mode)
|
||||
KeywordOnly,
|
||||
}
|
||||
|
||||
impl Default for HybridStrategy {
|
||||
fn default() -> Self {
|
||||
HybridStrategy::Full
|
||||
}
|
||||
}
|
||||
|
||||
/// Choose execution strategy based on query characteristics
|
||||
pub fn choose_strategy(
|
||||
filter_selectivity: Option<f32>,
|
||||
corpus_size: u64,
|
||||
has_filter: bool,
|
||||
) -> HybridStrategy {
|
||||
if !has_filter {
|
||||
return HybridStrategy::Full;
|
||||
}
|
||||
|
||||
match filter_selectivity {
|
||||
Some(sel) if sel < 0.01 => {
|
||||
// Very selective filter: pre-filter first
|
||||
HybridStrategy::PreFilter
|
||||
}
|
||||
Some(sel) if sel < 0.1 && corpus_size > 1_000_000 => {
|
||||
// Moderately selective on large corpus: post-filter
|
||||
HybridStrategy::PostFilter
|
||||
}
|
||||
_ => HybridStrategy::Full,
|
||||
}
|
||||
}
|
||||
|
||||
/// Result from a single search branch
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct BranchResults {
|
||||
/// Document IDs and scores
|
||||
pub results: Vec<(DocId, f32)>,
|
||||
/// Execution time in milliseconds
|
||||
pub latency_ms: f64,
|
||||
/// Number of candidates evaluated
|
||||
pub candidates_evaluated: usize,
|
||||
}
|
||||
|
||||
impl BranchResults {
|
||||
/// Create empty results
|
||||
pub fn empty() -> Self {
|
||||
Self {
|
||||
results: Vec::new(),
|
||||
latency_ms: 0.0,
|
||||
candidates_evaluated: 0,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Hybrid search result with detailed scores
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct HybridResult {
|
||||
/// Document ID
|
||||
pub id: DocId,
|
||||
/// Final hybrid score
|
||||
pub hybrid_score: f32,
|
||||
/// Vector similarity score (1 - distance for cosine)
|
||||
pub vector_score: Option<f32>,
|
||||
/// BM25 keyword score
|
||||
pub keyword_score: Option<f32>,
|
||||
/// Rank in vector results (None if not present)
|
||||
pub vector_rank: Option<usize>,
|
||||
/// Rank in keyword results (None if not present)
|
||||
pub keyword_rank: Option<usize>,
|
||||
}
|
||||
|
||||
impl From<FusedResult> for HybridResult {
|
||||
fn from(fused: FusedResult) -> Self {
|
||||
Self {
|
||||
id: fused.doc_id,
|
||||
hybrid_score: fused.hybrid_score,
|
||||
vector_score: fused.vector_score,
|
||||
keyword_score: fused.keyword_score,
|
||||
vector_rank: None,
|
||||
keyword_rank: None,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Execution statistics for hybrid search
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct ExecutionStats {
|
||||
/// Total execution time in milliseconds
|
||||
pub total_latency_ms: f64,
|
||||
/// Vector branch latency
|
||||
pub vector_latency_ms: f64,
|
||||
/// Keyword branch latency
|
||||
pub keyword_latency_ms: f64,
|
||||
/// Fusion latency
|
||||
pub fusion_latency_ms: f64,
|
||||
/// Strategy used
|
||||
pub strategy: String,
|
||||
/// Number of vector candidates
|
||||
pub vector_candidates: usize,
|
||||
/// Number of keyword candidates
|
||||
pub keyword_candidates: usize,
|
||||
/// Final result count
|
||||
pub result_count: usize,
|
||||
}
|
||||
|
||||
/// Hybrid Query Executor
|
||||
///
|
||||
/// Coordinates vector and keyword search branches, handles parallel execution,
|
||||
/// and manages score fusion.
|
||||
pub struct HybridExecutor {
|
||||
/// BM25 scorer for keyword scoring
|
||||
bm25_scorer: BM25Scorer,
|
||||
/// Fusion model for learned fusion
|
||||
fusion_model: FusionModel,
|
||||
/// Whether to use parallel execution
|
||||
parallel_enabled: bool,
|
||||
/// Default prefetch multiplier
|
||||
prefetch_multiplier: usize,
|
||||
}
|
||||
|
||||
impl HybridExecutor {
|
||||
/// Create a new hybrid executor
|
||||
pub fn new(corpus_stats: CorpusStats) -> Self {
|
||||
Self {
|
||||
bm25_scorer: BM25Scorer::new(corpus_stats),
|
||||
fusion_model: FusionModel::default(),
|
||||
parallel_enabled: true,
|
||||
prefetch_multiplier: 10,
|
||||
}
|
||||
}
|
||||
|
||||
/// Update corpus statistics
|
||||
pub fn update_corpus_stats(&mut self, stats: CorpusStats) {
|
||||
self.bm25_scorer.update_corpus_stats(stats);
|
||||
}
|
||||
|
||||
/// Set whether to use parallel execution
|
||||
pub fn set_parallel(&mut self, enabled: bool) {
|
||||
self.parallel_enabled = enabled;
|
||||
}
|
||||
|
||||
/// Set prefetch multiplier
|
||||
pub fn set_prefetch_multiplier(&mut self, multiplier: usize) {
|
||||
self.prefetch_multiplier = multiplier;
|
||||
}
|
||||
|
||||
/// Execute hybrid search
|
||||
///
|
||||
/// This is the main entry point for hybrid search. In the PostgreSQL extension,
|
||||
/// this would call into the actual vector index and tsvector search.
|
||||
pub fn execute(
|
||||
&self,
|
||||
query: &HybridQuery,
|
||||
vector_search_fn: impl FnOnce(&[f32], usize) -> BranchResults,
|
||||
keyword_search_fn: impl FnOnce(&str, usize) -> BranchResults,
|
||||
) -> (Vec<HybridResult>, ExecutionStats) {
|
||||
let start = std::time::Instant::now();
|
||||
|
||||
// Execute both branches
|
||||
let (vector_results, keyword_results) = if self.parallel_enabled {
|
||||
// In async context, use tokio::join!
|
||||
// For sync, execute sequentially (rayon could parallelize)
|
||||
let v_start = std::time::Instant::now();
|
||||
let vector = vector_search_fn(&query.embedding, query.prefetch_k);
|
||||
let v_elapsed = v_start.elapsed().as_secs_f64() * 1000.0;
|
||||
|
||||
let k_start = std::time::Instant::now();
|
||||
let keyword = keyword_search_fn(&query.text, query.prefetch_k);
|
||||
let k_elapsed = k_start.elapsed().as_secs_f64() * 1000.0;
|
||||
|
||||
(
|
||||
BranchResults {
|
||||
latency_ms: v_elapsed,
|
||||
..vector
|
||||
},
|
||||
BranchResults {
|
||||
latency_ms: k_elapsed,
|
||||
..keyword
|
||||
},
|
||||
)
|
||||
} else {
|
||||
let v_start = std::time::Instant::now();
|
||||
let vector = vector_search_fn(&query.embedding, query.prefetch_k);
|
||||
let v_elapsed = v_start.elapsed().as_secs_f64() * 1000.0;
|
||||
|
||||
let k_start = std::time::Instant::now();
|
||||
let keyword = keyword_search_fn(&query.text, query.prefetch_k);
|
||||
let k_elapsed = k_start.elapsed().as_secs_f64() * 1000.0;
|
||||
|
||||
(
|
||||
BranchResults {
|
||||
latency_ms: v_elapsed,
|
||||
..vector
|
||||
},
|
||||
BranchResults {
|
||||
latency_ms: k_elapsed,
|
||||
..keyword
|
||||
},
|
||||
)
|
||||
};
|
||||
|
||||
// Fuse results
|
||||
let fusion_start = std::time::Instant::now();
|
||||
let fused = self.fuse(&query, &vector_results.results, &keyword_results.results);
|
||||
let fusion_elapsed = fusion_start.elapsed().as_secs_f64() * 1000.0;
|
||||
|
||||
// Add rank information
|
||||
let vector_ranks: HashMap<DocId, usize> = vector_results
|
||||
.results
|
||||
.iter()
|
||||
.enumerate()
|
||||
.map(|(i, (id, _))| (*id, i))
|
||||
.collect();
|
||||
|
||||
let keyword_ranks: HashMap<DocId, usize> = keyword_results
|
||||
.results
|
||||
.iter()
|
||||
.enumerate()
|
||||
.map(|(i, (id, _))| (*id, i))
|
||||
.collect();
|
||||
|
||||
let results: Vec<HybridResult> = fused
|
||||
.into_iter()
|
||||
.take(query.k)
|
||||
.map(|f| {
|
||||
let mut result = HybridResult::from(f);
|
||||
result.vector_rank = vector_ranks.get(&result.id).copied();
|
||||
result.keyword_rank = keyword_ranks.get(&result.id).copied();
|
||||
result
|
||||
})
|
||||
.collect();
|
||||
|
||||
let total_elapsed = start.elapsed().as_secs_f64() * 1000.0;
|
||||
|
||||
let stats = ExecutionStats {
|
||||
total_latency_ms: total_elapsed,
|
||||
vector_latency_ms: vector_results.latency_ms,
|
||||
keyword_latency_ms: keyword_results.latency_ms,
|
||||
fusion_latency_ms: fusion_elapsed,
|
||||
strategy: "full".to_string(),
|
||||
vector_candidates: vector_results.candidates_evaluated,
|
||||
keyword_candidates: keyword_results.candidates_evaluated,
|
||||
result_count: results.len(),
|
||||
};
|
||||
|
||||
(results, stats)
|
||||
}
|
||||
|
||||
/// Fuse vector and keyword results
|
||||
fn fuse(
|
||||
&self,
|
||||
query: &HybridQuery,
|
||||
vector_results: &[(DocId, f32)],
|
||||
keyword_results: &[(DocId, f32)],
|
||||
) -> Vec<FusedResult> {
|
||||
match query.fusion_config.method {
|
||||
FusionMethod::Learned => {
|
||||
// Use learned fusion with query features
|
||||
let query_terms = tokenize_query(&query.text);
|
||||
let avg_idf = self.compute_avg_idf(&query_terms);
|
||||
|
||||
learned_fusion(
|
||||
&query.embedding,
|
||||
&query_terms,
|
||||
vector_results,
|
||||
keyword_results,
|
||||
&self.fusion_model,
|
||||
avg_idf,
|
||||
query.prefetch_k,
|
||||
)
|
||||
}
|
||||
_ => {
|
||||
// Use standard fusion
|
||||
fuse_results(
|
||||
vector_results,
|
||||
keyword_results,
|
||||
&query.fusion_config,
|
||||
query.prefetch_k,
|
||||
)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Compute average IDF for query terms
|
||||
fn compute_avg_idf(&self, terms: &[String]) -> f32 {
|
||||
if terms.is_empty() {
|
||||
return 0.0;
|
||||
}
|
||||
|
||||
let total_idf: f32 = terms.iter().map(|t| self.bm25_scorer.idf(t)).sum();
|
||||
total_idf / terms.len() as f32
|
||||
}
|
||||
|
||||
/// Score documents using BM25
|
||||
pub fn bm25_score(&self, term_freqs: &TermFrequencies, query_terms: &[String]) -> f32 {
|
||||
let doc = Document::new(term_freqs);
|
||||
self.bm25_scorer.score(&doc, query_terms)
|
||||
}
|
||||
|
||||
/// Set document frequency for a term (for BM25 IDF calculation)
|
||||
pub fn set_doc_freq(&self, term: &str, doc_freq: u64) {
|
||||
self.bm25_scorer.set_doc_freq(term, doc_freq);
|
||||
}
|
||||
|
||||
/// Get current corpus statistics
|
||||
pub fn corpus_stats(&self) -> &CorpusStats {
|
||||
self.bm25_scorer.corpus_stats()
|
||||
}
|
||||
}
|
||||
|
||||
/// Async hybrid execution using tokio
|
||||
#[cfg(feature = "tokio")]
|
||||
pub mod async_executor {
|
||||
use super::*;
|
||||
use std::future::Future;
|
||||
|
||||
/// Execute hybrid search with parallel branches
|
||||
pub async fn parallel_hybrid<VF, KF, VFut, KFut>(
|
||||
query: &HybridQuery,
|
||||
vector_search: VF,
|
||||
keyword_search: KF,
|
||||
fusion_config: &FusionConfig,
|
||||
) -> Vec<FusedResult>
|
||||
where
|
||||
VF: FnOnce(&[f32], usize) -> VFut,
|
||||
KF: FnOnce(&str, usize) -> KFut,
|
||||
VFut: Future<Output = BranchResults>,
|
||||
KFut: Future<Output = BranchResults>,
|
||||
{
|
||||
let (vector_results, keyword_results) = tokio::join!(
|
||||
vector_search(&query.embedding, query.prefetch_k),
|
||||
keyword_search(&query.text, query.prefetch_k),
|
||||
);
|
||||
|
||||
fuse_results(
|
||||
&vector_results.results,
|
||||
&keyword_results.results,
|
||||
fusion_config,
|
||||
query.k,
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
fn mock_vector_search(_embedding: &[f32], k: usize) -> BranchResults {
|
||||
BranchResults {
|
||||
results: (0..k.min(5))
|
||||
.map(|i| (i as DocId + 1, 0.1 * (i + 1) as f32))
|
||||
.collect(),
|
||||
latency_ms: 1.0,
|
||||
candidates_evaluated: 100,
|
||||
}
|
||||
}
|
||||
|
||||
fn mock_keyword_search(_text: &str, k: usize) -> BranchResults {
|
||||
BranchResults {
|
||||
results: (0..k.min(5))
|
||||
.map(|i| ((i as DocId + 3), 10.0 - i as f32))
|
||||
.collect(),
|
||||
latency_ms: 0.5,
|
||||
candidates_evaluated: 50,
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_hybrid_query_builder() {
|
||||
let query = HybridQuery::new("test query".into(), vec![0.1, 0.2, 0.3], 10)
|
||||
.with_fusion(FusionMethod::Linear)
|
||||
.with_alpha(0.7)
|
||||
.with_prefetch(100)
|
||||
.with_filter("category = 'docs'".into());
|
||||
|
||||
assert_eq!(query.k, 10);
|
||||
assert_eq!(query.prefetch_k, 100);
|
||||
assert_eq!(query.fusion_config.method, FusionMethod::Linear);
|
||||
assert!((query.fusion_config.alpha - 0.7).abs() < 0.01);
|
||||
assert!(query.filter.is_some());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_hybrid_executor() {
|
||||
let stats = CorpusStats {
|
||||
avg_doc_length: 100.0,
|
||||
doc_count: 1000,
|
||||
total_terms: 100000,
|
||||
last_update: 0,
|
||||
};
|
||||
|
||||
let executor = HybridExecutor::new(stats);
|
||||
|
||||
let query = HybridQuery::new("database query".into(), vec![0.1; 128], 5);
|
||||
|
||||
let (results, exec_stats) =
|
||||
executor.execute(&query, mock_vector_search, mock_keyword_search);
|
||||
|
||||
assert!(!results.is_empty());
|
||||
assert!(results.len() <= 5);
|
||||
assert!(exec_stats.total_latency_ms > 0.0);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_strategy_selection() {
|
||||
// No filter -> Full
|
||||
assert_eq!(choose_strategy(None, 10000, false), HybridStrategy::Full);
|
||||
|
||||
// Very selective filter -> PreFilter
|
||||
assert_eq!(
|
||||
choose_strategy(Some(0.005), 1000000, true),
|
||||
HybridStrategy::PreFilter
|
||||
);
|
||||
|
||||
// Moderate selectivity, large corpus -> PostFilter
|
||||
assert_eq!(
|
||||
choose_strategy(Some(0.05), 5000000, true),
|
||||
HybridStrategy::PostFilter
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_execution_stats() {
|
||||
let stats = CorpusStats::default();
|
||||
let executor = HybridExecutor::new(stats);
|
||||
|
||||
let query = HybridQuery::new("test".into(), vec![0.1; 16], 5);
|
||||
|
||||
let (_, exec_stats) = executor.execute(&query, mock_vector_search, mock_keyword_search);
|
||||
|
||||
assert!(exec_stats.vector_latency_ms >= 0.0);
|
||||
assert!(exec_stats.keyword_latency_ms >= 0.0);
|
||||
assert!(exec_stats.fusion_latency_ms >= 0.0);
|
||||
assert!(exec_stats.total_latency_ms >= exec_stats.fusion_latency_ms);
|
||||
}
|
||||
}
|
||||
604
vendor/ruvector/crates/ruvector-postgres/src/hybrid/fusion.rs
vendored
Normal file
604
vendor/ruvector/crates/ruvector-postgres/src/hybrid/fusion.rs
vendored
Normal file
@@ -0,0 +1,604 @@
|
||||
//! Fusion algorithms for combining vector and keyword search results
|
||||
//!
|
||||
//! Provides:
|
||||
//! - RRF (Reciprocal Rank Fusion) - default, robust
|
||||
//! - Linear blend - simple weighted combination
|
||||
//! - Learned fusion - query-adaptive weights
|
||||
|
||||
use serde::{Deserialize, Serialize};
|
||||
use std::collections::HashMap;
|
||||
|
||||
/// Document ID type (matches with database row IDs)
|
||||
pub type DocId = i64;
|
||||
|
||||
/// Default RRF constant
|
||||
pub const DEFAULT_RRF_K: usize = 60;
|
||||
|
||||
/// Default alpha for linear fusion (0.5 = equal weight)
|
||||
pub const DEFAULT_ALPHA: f32 = 0.5;
|
||||
|
||||
/// Fusion method selection
|
||||
#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
|
||||
#[serde(rename_all = "lowercase")]
|
||||
pub enum FusionMethod {
|
||||
/// Reciprocal Rank Fusion (default)
|
||||
Rrf,
|
||||
/// Linear weighted combination
|
||||
Linear,
|
||||
/// Learned/adaptive fusion based on query features
|
||||
Learned,
|
||||
}
|
||||
|
||||
impl Default for FusionMethod {
|
||||
fn default() -> Self {
|
||||
FusionMethod::Rrf
|
||||
}
|
||||
}
|
||||
|
||||
impl std::str::FromStr for FusionMethod {
|
||||
type Err = String;
|
||||
|
||||
fn from_str(s: &str) -> Result<Self, Self::Err> {
|
||||
match s.to_lowercase().as_str() {
|
||||
"rrf" | "reciprocal" | "reciprocal_rank" => Ok(FusionMethod::Rrf),
|
||||
"linear" | "blend" | "weighted" => Ok(FusionMethod::Linear),
|
||||
"learned" | "adaptive" | "auto" => Ok(FusionMethod::Learned),
|
||||
_ => Err(format!("Unknown fusion method: {}", s)),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Fusion configuration
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct FusionConfig {
|
||||
/// Fusion method to use
|
||||
pub method: FusionMethod,
|
||||
/// RRF constant (typically 60)
|
||||
pub rrf_k: usize,
|
||||
/// Alpha for linear fusion (0 = all keyword, 1 = all vector)
|
||||
pub alpha: f32,
|
||||
}
|
||||
|
||||
impl Default for FusionConfig {
|
||||
fn default() -> Self {
|
||||
Self {
|
||||
method: FusionMethod::Rrf,
|
||||
rrf_k: DEFAULT_RRF_K,
|
||||
alpha: DEFAULT_ALPHA,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Search result from a single branch (vector or keyword)
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct BranchResult {
|
||||
/// Document ID
|
||||
pub doc_id: DocId,
|
||||
/// Original score from the branch
|
||||
pub score: f32,
|
||||
/// Rank in the result set (0-indexed)
|
||||
pub rank: usize,
|
||||
}
|
||||
|
||||
/// Fused search result combining both branches
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct FusedResult {
|
||||
/// Document ID
|
||||
pub doc_id: DocId,
|
||||
/// Final fused score
|
||||
pub hybrid_score: f32,
|
||||
/// Original vector score (if present)
|
||||
pub vector_score: Option<f32>,
|
||||
/// Original keyword score (if present)
|
||||
pub keyword_score: Option<f32>,
|
||||
}
|
||||
|
||||
/// Reciprocal Rank Fusion (RRF)
|
||||
///
|
||||
/// RRF score = sum over sources: 1 / (k + rank)
|
||||
///
|
||||
/// Properties:
|
||||
/// - Robust to different score scales
|
||||
/// - No need for score calibration
|
||||
/// - Works well with partial overlap
|
||||
///
|
||||
/// Reference: "Reciprocal Rank Fusion outperforms Condorcet and individual Rank Learning Methods"
|
||||
pub fn rrf_fusion(
|
||||
vector_results: &[(DocId, f32)], // (id, distance - lower is better)
|
||||
keyword_results: &[(DocId, f32)], // (id, BM25 score - higher is better)
|
||||
k: usize,
|
||||
limit: usize,
|
||||
) -> Vec<FusedResult> {
|
||||
let mut scores: HashMap<DocId, (f32, Option<f32>, Option<f32>)> = HashMap::new();
|
||||
|
||||
// Vector ranking (lower distance = higher rank, so already sorted best first)
|
||||
for (rank, (doc_id, distance)) in vector_results.iter().enumerate() {
|
||||
let rrf_score = 1.0 / (k + rank + 1) as f32;
|
||||
let entry = scores.entry(*doc_id).or_insert((0.0, None, None));
|
||||
entry.0 += rrf_score;
|
||||
entry.1 = Some(*distance);
|
||||
}
|
||||
|
||||
// Keyword ranking (higher BM25 = higher rank, already sorted best first)
|
||||
for (rank, (doc_id, bm25_score)) in keyword_results.iter().enumerate() {
|
||||
let rrf_score = 1.0 / (k + rank + 1) as f32;
|
||||
let entry = scores.entry(*doc_id).or_insert((0.0, None, None));
|
||||
entry.0 += rrf_score;
|
||||
entry.2 = Some(*bm25_score);
|
||||
}
|
||||
|
||||
// Sort by fused score (descending)
|
||||
let mut results: Vec<FusedResult> = scores
|
||||
.into_iter()
|
||||
.map(
|
||||
|(doc_id, (hybrid_score, vector_score, keyword_score))| FusedResult {
|
||||
doc_id,
|
||||
hybrid_score,
|
||||
vector_score,
|
||||
keyword_score,
|
||||
},
|
||||
)
|
||||
.collect();
|
||||
|
||||
results.sort_by(|a, b| {
|
||||
b.hybrid_score
|
||||
.partial_cmp(&a.hybrid_score)
|
||||
.unwrap_or(std::cmp::Ordering::Equal)
|
||||
});
|
||||
results.truncate(limit);
|
||||
results
|
||||
}
|
||||
|
||||
/// Normalize vector distances to similarity scores [0, 1]
|
||||
///
|
||||
/// Converts distance (lower = better) to similarity (higher = better)
|
||||
fn normalize_to_similarity(results: &[(DocId, f32)]) -> Vec<(DocId, f32)> {
|
||||
if results.is_empty() {
|
||||
return Vec::new();
|
||||
}
|
||||
|
||||
// Find min/max distances
|
||||
let (min_dist, max_dist) = results
|
||||
.iter()
|
||||
.fold((f32::MAX, f32::MIN), |(min, max), (_, d)| {
|
||||
(min.min(*d), max.max(*d))
|
||||
});
|
||||
|
||||
let range = (max_dist - min_dist).max(1e-6);
|
||||
|
||||
results
|
||||
.iter()
|
||||
.map(|(id, dist)| {
|
||||
// Convert distance to similarity: 1 - normalized_distance
|
||||
let similarity = 1.0 - (dist - min_dist) / range;
|
||||
(*id, similarity)
|
||||
})
|
||||
.collect()
|
||||
}
|
||||
|
||||
/// Min-max normalize scores to [0, 1]
|
||||
fn min_max_normalize(results: &[(DocId, f32)]) -> Vec<(DocId, f32)> {
|
||||
if results.is_empty() {
|
||||
return Vec::new();
|
||||
}
|
||||
|
||||
let (min_score, max_score) = results
|
||||
.iter()
|
||||
.fold((f32::MAX, f32::MIN), |(min, max), (_, s)| {
|
||||
(min.min(*s), max.max(*s))
|
||||
});
|
||||
|
||||
let range = (max_score - min_score).max(1e-6);
|
||||
|
||||
results
|
||||
.iter()
|
||||
.map(|(id, score)| {
|
||||
let normalized = (score - min_score) / range;
|
||||
(*id, normalized)
|
||||
})
|
||||
.collect()
|
||||
}
|
||||
|
||||
/// Linear fusion with alpha blending
|
||||
///
|
||||
/// score = alpha * vector_similarity + (1 - alpha) * keyword_score
|
||||
///
|
||||
/// Note: Scores must be normalized before fusion
|
||||
pub fn linear_fusion(
|
||||
vector_results: &[(DocId, f32)], // (id, distance)
|
||||
keyword_results: &[(DocId, f32)], // (id, BM25 score)
|
||||
alpha: f32,
|
||||
limit: usize,
|
||||
) -> Vec<FusedResult> {
|
||||
// Normalize vector scores (distance -> similarity)
|
||||
let vec_scores: HashMap<DocId, f32> = normalize_to_similarity(vector_results)
|
||||
.into_iter()
|
||||
.collect();
|
||||
|
||||
// Normalize keyword scores to [0, 1]
|
||||
let kw_scores: HashMap<DocId, f32> = min_max_normalize(keyword_results).into_iter().collect();
|
||||
|
||||
// Combine scores
|
||||
let mut combined: HashMap<DocId, (f32, Option<f32>, Option<f32>)> = HashMap::new();
|
||||
|
||||
// Add vector results
|
||||
for (doc_id, sim) in &vec_scores {
|
||||
let entry = combined.entry(*doc_id).or_insert((0.0, None, None));
|
||||
entry.0 += alpha * sim;
|
||||
// Store original distance
|
||||
if let Some((_, dist)) = vector_results.iter().find(|(id, _)| id == doc_id) {
|
||||
entry.1 = Some(*dist);
|
||||
}
|
||||
}
|
||||
|
||||
// Add keyword results
|
||||
for (doc_id, norm_score) in &kw_scores {
|
||||
let entry = combined.entry(*doc_id).or_insert((0.0, None, None));
|
||||
entry.0 += (1.0 - alpha) * norm_score;
|
||||
// Store original BM25 score
|
||||
if let Some((_, score)) = keyword_results.iter().find(|(id, _)| id == doc_id) {
|
||||
entry.2 = Some(*score);
|
||||
}
|
||||
}
|
||||
|
||||
// Sort by fused score
|
||||
let mut results: Vec<FusedResult> = combined
|
||||
.into_iter()
|
||||
.map(
|
||||
|(doc_id, (hybrid_score, vector_score, keyword_score))| FusedResult {
|
||||
doc_id,
|
||||
hybrid_score,
|
||||
vector_score,
|
||||
keyword_score,
|
||||
},
|
||||
)
|
||||
.collect();
|
||||
|
||||
results.sort_by(|a, b| {
|
||||
b.hybrid_score
|
||||
.partial_cmp(&a.hybrid_score)
|
||||
.unwrap_or(std::cmp::Ordering::Equal)
|
||||
});
|
||||
results.truncate(limit);
|
||||
results
|
||||
}
|
||||
|
||||
/// Query features for learned fusion
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct QueryFeatures {
|
||||
/// L2 norm of query embedding
|
||||
pub embedding_norm: f32,
|
||||
/// Number of query terms
|
||||
pub term_count: usize,
|
||||
/// Average IDF of query terms
|
||||
pub avg_term_idf: f32,
|
||||
/// Whether query appears to need exact matching
|
||||
pub has_exact_match: bool,
|
||||
/// Classified query type
|
||||
pub query_type: QueryType,
|
||||
}
|
||||
|
||||
/// Query type classification for learned fusion
|
||||
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
|
||||
pub enum QueryType {
|
||||
/// Navigational query (looking for specific entity)
|
||||
Navigational,
|
||||
/// Informational query (seeking information)
|
||||
Informational,
|
||||
/// Transactional query (action-oriented)
|
||||
Transactional,
|
||||
/// Unknown/mixed
|
||||
Unknown,
|
||||
}
|
||||
|
||||
/// Simple fusion model for learned/adaptive fusion
|
||||
///
|
||||
/// In production, this would be a trained ML model (e.g., GNN, logistic regression)
|
||||
pub struct FusionModel {
|
||||
/// Default alpha when model can't make prediction
|
||||
pub default_alpha: f32,
|
||||
/// Weight for embedding norm
|
||||
pub norm_weight: f32,
|
||||
/// Weight for term count
|
||||
pub term_weight: f32,
|
||||
/// Weight for avg IDF
|
||||
pub idf_weight: f32,
|
||||
/// Bias for exact match queries (favor keyword)
|
||||
pub exact_match_bias: f32,
|
||||
}
|
||||
|
||||
impl Default for FusionModel {
|
||||
fn default() -> Self {
|
||||
Self {
|
||||
default_alpha: 0.5,
|
||||
norm_weight: 0.1,
|
||||
term_weight: -0.05, // More terms -> slight keyword preference
|
||||
idf_weight: 0.15, // Rare terms -> vector preference
|
||||
exact_match_bias: -0.2, // Exact match -> keyword preference
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl FusionModel {
|
||||
/// Predict optimal alpha for a query
|
||||
pub fn predict_alpha(&self, features: &QueryFeatures) -> f32 {
|
||||
let mut alpha = self.default_alpha;
|
||||
|
||||
// Adjust based on embedding norm (high norm -> more distinctive)
|
||||
alpha += self.norm_weight * (features.embedding_norm - 1.0).clamp(-1.0, 1.0);
|
||||
|
||||
// Adjust based on term count
|
||||
alpha += self.term_weight * (features.term_count as f32 - 3.0).clamp(-3.0, 3.0);
|
||||
|
||||
// Adjust based on avg IDF (high IDF = rare terms, favor vector)
|
||||
alpha += self.idf_weight * (features.avg_term_idf - 3.0).clamp(-3.0, 3.0);
|
||||
|
||||
// Adjust for exact match intent
|
||||
if features.has_exact_match {
|
||||
alpha += self.exact_match_bias;
|
||||
}
|
||||
|
||||
// Adjust based on query type
|
||||
match features.query_type {
|
||||
QueryType::Navigational => alpha -= 0.15, // Favor keyword
|
||||
QueryType::Informational => alpha += 0.1, // Favor vector
|
||||
QueryType::Transactional => alpha -= 0.05,
|
||||
QueryType::Unknown => {}
|
||||
}
|
||||
|
||||
// Clamp to valid range
|
||||
alpha.clamp(0.0, 1.0)
|
||||
}
|
||||
}
|
||||
|
||||
/// Learned fusion using query characteristics
|
||||
pub fn learned_fusion(
|
||||
query_embedding: &[f32],
|
||||
query_terms: &[String],
|
||||
vector_results: &[(DocId, f32)],
|
||||
keyword_results: &[(DocId, f32)],
|
||||
model: &FusionModel,
|
||||
avg_term_idf: f32, // Pre-computed average IDF
|
||||
limit: usize,
|
||||
) -> Vec<FusedResult> {
|
||||
// Compute query features
|
||||
let embedding_norm = l2_norm(query_embedding);
|
||||
let has_exact_match = detect_exact_match_intent(query_terms);
|
||||
let query_type = classify_query_type(query_terms);
|
||||
|
||||
let features = QueryFeatures {
|
||||
embedding_norm,
|
||||
term_count: query_terms.len(),
|
||||
avg_term_idf,
|
||||
has_exact_match,
|
||||
query_type,
|
||||
};
|
||||
|
||||
// Predict optimal alpha
|
||||
let alpha = model.predict_alpha(&features);
|
||||
|
||||
// Use linear fusion with predicted alpha
|
||||
linear_fusion(vector_results, keyword_results, alpha, limit)
|
||||
}
|
||||
|
||||
/// Compute L2 norm of a vector
|
||||
fn l2_norm(v: &[f32]) -> f32 {
|
||||
v.iter().map(|x| x * x).sum::<f32>().sqrt()
|
||||
}
|
||||
|
||||
/// Detect if query seems to need exact matching
|
||||
fn detect_exact_match_intent(terms: &[String]) -> bool {
|
||||
// Heuristics for exact match intent:
|
||||
// - Quoted phrases (handled upstream)
|
||||
// - Product codes, error codes, IDs
|
||||
// - Very short queries (1-2 terms)
|
||||
|
||||
if terms.len() <= 2 {
|
||||
return true;
|
||||
}
|
||||
|
||||
terms.iter().any(|t| {
|
||||
// Looks like a code/ID
|
||||
t.chars().any(|c| c.is_numeric()) && t.len() >= 3 && t.len() <= 20
|
||||
})
|
||||
}
|
||||
|
||||
/// Classify query type based on terms
|
||||
fn classify_query_type(terms: &[String]) -> QueryType {
|
||||
let terms_lower: Vec<String> = terms.iter().map(|t| t.to_lowercase()).collect();
|
||||
|
||||
// Navigational indicators
|
||||
let nav_indicators = ["website", "login", "home", "official", "download"];
|
||||
if terms_lower
|
||||
.iter()
|
||||
.any(|t| nav_indicators.contains(&t.as_str()))
|
||||
{
|
||||
return QueryType::Navigational;
|
||||
}
|
||||
|
||||
// Transactional indicators
|
||||
let trans_indicators = ["buy", "purchase", "order", "price", "cheap", "best", "deal"];
|
||||
if terms_lower
|
||||
.iter()
|
||||
.any(|t| trans_indicators.contains(&t.as_str()))
|
||||
{
|
||||
return QueryType::Transactional;
|
||||
}
|
||||
|
||||
// Informational indicators
|
||||
let info_indicators = [
|
||||
"how", "what", "why", "when", "where", "guide", "tutorial", "explain",
|
||||
];
|
||||
if terms_lower
|
||||
.iter()
|
||||
.any(|t| info_indicators.contains(&t.as_str()))
|
||||
{
|
||||
return QueryType::Informational;
|
||||
}
|
||||
|
||||
QueryType::Unknown
|
||||
}
|
||||
|
||||
/// Fuse results using the specified method
|
||||
pub fn fuse_results(
|
||||
vector_results: &[(DocId, f32)],
|
||||
keyword_results: &[(DocId, f32)],
|
||||
config: &FusionConfig,
|
||||
limit: usize,
|
||||
) -> Vec<FusedResult> {
|
||||
match config.method {
|
||||
FusionMethod::Rrf => rrf_fusion(vector_results, keyword_results, config.rrf_k, limit),
|
||||
FusionMethod::Linear => linear_fusion(vector_results, keyword_results, config.alpha, limit),
|
||||
FusionMethod::Learned => {
|
||||
// Learned fusion requires additional context
|
||||
// Fall back to RRF if no model is available
|
||||
rrf_fusion(vector_results, keyword_results, config.rrf_k, limit)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
fn sample_vector_results() -> Vec<(DocId, f32)> {
|
||||
vec![
|
||||
(1, 0.1), // Best (lowest distance)
|
||||
(2, 0.2),
|
||||
(3, 0.3),
|
||||
(4, 0.5),
|
||||
(5, 0.8),
|
||||
]
|
||||
}
|
||||
|
||||
fn sample_keyword_results() -> Vec<(DocId, f32)> {
|
||||
vec![
|
||||
(3, 8.5), // Best (highest BM25)
|
||||
(1, 7.2),
|
||||
(6, 5.0),
|
||||
(2, 3.5),
|
||||
(7, 2.0),
|
||||
]
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_rrf_fusion() {
|
||||
let vector = sample_vector_results();
|
||||
let keyword = sample_keyword_results();
|
||||
|
||||
let results = rrf_fusion(&vector, &keyword, 60, 5);
|
||||
|
||||
assert!(!results.is_empty());
|
||||
// Doc 1 and 3 should be near top (they appear in both)
|
||||
let top_ids: Vec<DocId> = results.iter().map(|r| r.doc_id).collect();
|
||||
assert!(top_ids.contains(&1) || top_ids.contains(&3));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_linear_fusion_alpha_1() {
|
||||
let vector = sample_vector_results();
|
||||
let keyword = sample_keyword_results();
|
||||
|
||||
// Alpha = 1.0 means only vector
|
||||
let results = linear_fusion(&vector, &keyword, 1.0, 5);
|
||||
|
||||
assert!(!results.is_empty());
|
||||
// With alpha=1, vector-only docs should dominate
|
||||
let first = &results[0];
|
||||
assert!(first.vector_score.is_some());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_linear_fusion_alpha_0() {
|
||||
let vector = sample_vector_results();
|
||||
let keyword = sample_keyword_results();
|
||||
|
||||
// Alpha = 0.0 means only keyword
|
||||
let results = linear_fusion(&vector, &keyword, 0.0, 5);
|
||||
|
||||
assert!(!results.is_empty());
|
||||
// With alpha=0, keyword-only docs can appear
|
||||
let first = &results[0];
|
||||
assert!(first.keyword_score.is_some());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_normalization() {
|
||||
let results = vec![(1, 0.1), (2, 0.5), (3, 0.9)];
|
||||
let normalized = normalize_to_similarity(&results);
|
||||
|
||||
assert_eq!(normalized.len(), 3);
|
||||
// Lowest distance should have highest similarity
|
||||
let (id1, sim1) = normalized[0];
|
||||
assert_eq!(id1, 1);
|
||||
assert!((sim1 - 1.0).abs() < 0.01, "Best should be ~1.0");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_fusion_model() {
|
||||
let model = FusionModel::default();
|
||||
|
||||
// Short navigational query
|
||||
let features = QueryFeatures {
|
||||
embedding_norm: 1.0,
|
||||
term_count: 2,
|
||||
avg_term_idf: 2.0,
|
||||
has_exact_match: true,
|
||||
query_type: QueryType::Navigational,
|
||||
};
|
||||
|
||||
let alpha = model.predict_alpha(&features);
|
||||
assert!(
|
||||
alpha < 0.5,
|
||||
"Navigational query should favor keyword (alpha < 0.5)"
|
||||
);
|
||||
|
||||
// Long informational query
|
||||
let features2 = QueryFeatures {
|
||||
embedding_norm: 1.2,
|
||||
term_count: 6,
|
||||
avg_term_idf: 5.0,
|
||||
has_exact_match: false,
|
||||
query_type: QueryType::Informational,
|
||||
};
|
||||
|
||||
let alpha2 = model.predict_alpha(&features2);
|
||||
assert!(
|
||||
alpha2 > 0.4,
|
||||
"Informational query with rare terms should favor vector"
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_query_type_classification() {
|
||||
let nav = classify_query_type(&["github".into(), "login".into()]);
|
||||
assert_eq!(nav, QueryType::Navigational);
|
||||
|
||||
let info = classify_query_type(&["how".into(), "to".into(), "cook".into(), "pasta".into()]);
|
||||
assert_eq!(info, QueryType::Informational);
|
||||
|
||||
let trans = classify_query_type(&["buy".into(), "laptop".into()]);
|
||||
assert_eq!(trans, QueryType::Transactional);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_exact_match_detection() {
|
||||
assert!(detect_exact_match_intent(&["ERR001".into()]));
|
||||
assert!(detect_exact_match_intent(&["SKU12345".into()]));
|
||||
assert!(!detect_exact_match_intent(&[
|
||||
"database".into(),
|
||||
"connection".into(),
|
||||
"error".into(),
|
||||
"handling".into()
|
||||
]));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_empty_results() {
|
||||
let results = rrf_fusion(&[], &[], 60, 10);
|
||||
assert!(results.is_empty());
|
||||
|
||||
let results2 = linear_fusion(&[], &[], 0.5, 10);
|
||||
assert!(results2.is_empty());
|
||||
}
|
||||
}
|
||||
533
vendor/ruvector/crates/ruvector-postgres/src/hybrid/mod.rs
vendored
Normal file
533
vendor/ruvector/crates/ruvector-postgres/src/hybrid/mod.rs
vendored
Normal file
@@ -0,0 +1,533 @@
|
||||
//! Hybrid Search (BM25 + Vector) for RuVector Postgres
|
||||
//!
|
||||
//! Provides combined keyword and semantic vector search with multiple fusion strategies.
|
||||
//!
|
||||
//! ## Features
|
||||
//!
|
||||
//! - **BM25 Scoring**: Proper BM25 implementation with document length normalization
|
||||
//! - **Fusion Algorithms**: RRF (default), Linear blend, Learned/adaptive
|
||||
//! - **Parallel Execution**: Vector and keyword branches can run concurrently
|
||||
//! - **Registry System**: Track hybrid-enabled collections with per-collection settings
|
||||
//!
|
||||
//! ## SQL Interface
|
||||
//!
|
||||
//! ```sql
|
||||
//! -- Register a collection for hybrid search
|
||||
//! SELECT ruvector_register_hybrid(
|
||||
//! collection := 'documents',
|
||||
//! vector_column := 'embedding',
|
||||
//! fts_column := 'fts',
|
||||
//! text_column := 'content'
|
||||
//! );
|
||||
//!
|
||||
//! -- Perform hybrid search
|
||||
//! SELECT * FROM ruvector_hybrid_search(
|
||||
//! 'documents',
|
||||
//! query_text := 'database query optimization',
|
||||
//! query_vector := $embedding,
|
||||
//! k := 10,
|
||||
//! fusion := 'rrf'
|
||||
//! );
|
||||
//! ```
|
||||
|
||||
pub mod bm25;
|
||||
pub mod executor;
|
||||
pub mod fusion;
|
||||
pub mod registry;
|
||||
|
||||
// Re-exports
|
||||
pub use bm25::{tokenize_query, BM25Config, BM25Scorer, CorpusStats, Document, TermFrequencies};
|
||||
pub use executor::{
|
||||
choose_strategy, BranchResults, ExecutionStats, HybridExecutor, HybridQuery, HybridResult,
|
||||
HybridStrategy,
|
||||
};
|
||||
pub use fusion::{
|
||||
fuse_results, learned_fusion, linear_fusion, rrf_fusion, DocId, FusedResult, FusionConfig,
|
||||
FusionMethod, FusionModel, DEFAULT_ALPHA, DEFAULT_RRF_K,
|
||||
};
|
||||
pub use registry::{
|
||||
get_registry, HybridCollectionConfig, HybridConfigUpdate, HybridRegistry, RegistryError,
|
||||
HYBRID_REGISTRY,
|
||||
};
|
||||
|
||||
use pgrx::prelude::*;
|
||||
|
||||
// ============================================================================
|
||||
// SQL Functions
|
||||
// ============================================================================
|
||||
|
||||
/// Register a collection for hybrid search
|
||||
///
|
||||
/// Creates the necessary metadata and computes initial corpus statistics.
|
||||
///
|
||||
/// # Arguments
|
||||
/// * `collection` - Table name (optionally schema-qualified)
|
||||
/// * `vector_column` - Name of the vector column
|
||||
/// * `fts_column` - Name of the tsvector column
|
||||
/// * `text_column` - Name of the original text column (for BM25 stats)
|
||||
///
|
||||
/// # Returns
|
||||
/// JSON object with registration details
|
||||
#[pg_extern]
|
||||
fn ruvector_register_hybrid(
|
||||
collection: &str,
|
||||
vector_column: &str,
|
||||
fts_column: &str,
|
||||
text_column: &str,
|
||||
) -> pgrx::JsonB {
|
||||
// Parse collection name
|
||||
let (schema, table) = parse_collection_name(collection);
|
||||
|
||||
// For now, use a simple hash as collection ID
|
||||
// In production, this would query ruvector.collections table
|
||||
let collection_id = collection
|
||||
.bytes()
|
||||
.fold(0i32, |acc, b| acc.wrapping_add(b as i32));
|
||||
|
||||
// Check if already registered
|
||||
let registry = get_registry();
|
||||
if registry.is_registered(collection_id) {
|
||||
return pgrx::JsonB(serde_json::json!({
|
||||
"success": false,
|
||||
"error": format!("Collection '{}' is already registered for hybrid search", collection),
|
||||
"collection_id": collection_id
|
||||
}));
|
||||
}
|
||||
|
||||
// Create configuration
|
||||
let mut config = HybridCollectionConfig::new(
|
||||
collection_id,
|
||||
table.to_string(),
|
||||
vector_column.to_string(),
|
||||
fts_column.to_string(),
|
||||
text_column.to_string(),
|
||||
);
|
||||
config.schema_name = schema.to_string();
|
||||
|
||||
// Register
|
||||
match registry.register(config) {
|
||||
Ok(_) => pgrx::JsonB(serde_json::json!({
|
||||
"success": true,
|
||||
"collection_id": collection_id,
|
||||
"collection": collection,
|
||||
"vector_column": vector_column,
|
||||
"fts_column": fts_column,
|
||||
"text_column": text_column,
|
||||
"message": "Collection registered for hybrid search. Run ruvector_hybrid_update_stats() to compute corpus statistics."
|
||||
})),
|
||||
Err(e) => pgrx::JsonB(serde_json::json!({
|
||||
"success": false,
|
||||
"error": e.to_string()
|
||||
})),
|
||||
}
|
||||
}
|
||||
|
||||
/// Update BM25 corpus statistics for a hybrid collection
|
||||
///
|
||||
/// Computes average document length, document count, and term frequencies.
|
||||
/// Should be run periodically or after bulk inserts.
|
||||
#[pg_extern]
|
||||
fn ruvector_hybrid_update_stats(collection: &str) -> pgrx::JsonB {
|
||||
let (schema, table) = parse_collection_name(collection);
|
||||
let qualified_name = format!("{}.{}", schema, table);
|
||||
|
||||
let registry = get_registry();
|
||||
let config = match registry.get_by_name(&qualified_name) {
|
||||
Some(c) => c,
|
||||
None => {
|
||||
return pgrx::JsonB(serde_json::json!({
|
||||
"success": false,
|
||||
"error": format!("Collection '{}' is not registered for hybrid search", collection)
|
||||
}));
|
||||
}
|
||||
};
|
||||
|
||||
// In the actual extension, we would run SQL to compute stats:
|
||||
// SELECT AVG(LENGTH(text_column)), COUNT(*) FROM table
|
||||
// For now, we return a placeholder indicating the function works
|
||||
|
||||
let now = std::time::SystemTime::now()
|
||||
.duration_since(std::time::UNIX_EPOCH)
|
||||
.unwrap()
|
||||
.as_secs() as i64;
|
||||
|
||||
let stats = CorpusStats {
|
||||
avg_doc_length: config.corpus_stats.avg_doc_length,
|
||||
doc_count: config.corpus_stats.doc_count,
|
||||
total_terms: config.corpus_stats.total_terms,
|
||||
last_update: now,
|
||||
};
|
||||
|
||||
match registry.update_stats(config.collection_id, stats) {
|
||||
Ok(_) => pgrx::JsonB(serde_json::json!({
|
||||
"success": true,
|
||||
"collection": collection,
|
||||
"message": "Stats update initiated. In production, this would compute actual corpus statistics.",
|
||||
"note": "Use Spi::run to execute SQL for actual stats computation"
|
||||
})),
|
||||
Err(e) => pgrx::JsonB(serde_json::json!({
|
||||
"success": false,
|
||||
"error": e.to_string()
|
||||
})),
|
||||
}
|
||||
}
|
||||
|
||||
/// Configure hybrid search settings for a collection
|
||||
///
|
||||
/// # Arguments
|
||||
/// * `collection` - Collection name
|
||||
/// * `config` - JSON configuration object
|
||||
///
|
||||
/// # Example Configuration
|
||||
/// ```json
|
||||
/// {
|
||||
/// "default_fusion": "rrf",
|
||||
/// "default_alpha": 0.5,
|
||||
/// "rrf_k": 60,
|
||||
/// "prefetch_k": 100,
|
||||
/// "bm25_k1": 1.2,
|
||||
/// "bm25_b": 0.75,
|
||||
/// "stats_refresh_interval": "1 hour",
|
||||
/// "parallel_enabled": true
|
||||
/// }
|
||||
/// ```
|
||||
#[pg_extern]
|
||||
fn ruvector_hybrid_configure(collection: &str, config: pgrx::JsonB) -> pgrx::JsonB {
|
||||
let (schema, table) = parse_collection_name(collection);
|
||||
let qualified_name = format!("{}.{}", schema, table);
|
||||
|
||||
let registry = get_registry();
|
||||
let mut existing_config = match registry.get_by_name(&qualified_name) {
|
||||
Some(c) => c,
|
||||
None => {
|
||||
return pgrx::JsonB(serde_json::json!({
|
||||
"success": false,
|
||||
"error": format!("Collection '{}' is not registered for hybrid search", collection)
|
||||
}));
|
||||
}
|
||||
};
|
||||
|
||||
// Parse and apply updates
|
||||
let update: HybridConfigUpdate = match serde_json::from_value(config.0.clone()) {
|
||||
Ok(u) => u,
|
||||
Err(e) => {
|
||||
return pgrx::JsonB(serde_json::json!({
|
||||
"success": false,
|
||||
"error": format!("Invalid configuration: {}", e)
|
||||
}));
|
||||
}
|
||||
};
|
||||
|
||||
if let Err(e) = update.apply(&mut existing_config) {
|
||||
return pgrx::JsonB(serde_json::json!({
|
||||
"success": false,
|
||||
"error": e.to_string()
|
||||
}));
|
||||
}
|
||||
|
||||
match registry.update(existing_config.clone()) {
|
||||
Ok(_) => pgrx::JsonB(serde_json::json!({
|
||||
"success": true,
|
||||
"collection": collection,
|
||||
"config": {
|
||||
"fusion_method": format!("{:?}", existing_config.fusion_config.method),
|
||||
"alpha": existing_config.fusion_config.alpha,
|
||||
"rrf_k": existing_config.fusion_config.rrf_k,
|
||||
"prefetch_k": existing_config.prefetch_k,
|
||||
"bm25_k1": existing_config.bm25_config.k1,
|
||||
"bm25_b": existing_config.bm25_config.b,
|
||||
"stats_refresh_interval": existing_config.stats_refresh_interval,
|
||||
"parallel_enabled": existing_config.parallel_enabled
|
||||
}
|
||||
})),
|
||||
Err(e) => pgrx::JsonB(serde_json::json!({
|
||||
"success": false,
|
||||
"error": e.to_string()
|
||||
})),
|
||||
}
|
||||
}
|
||||
|
||||
/// Perform hybrid search combining vector and keyword search
|
||||
///
|
||||
/// # Arguments
|
||||
/// * `collection` - Table name
|
||||
/// * `query_text` - Text query for keyword search
|
||||
/// * `query_vector` - Vector for semantic search
|
||||
/// * `k` - Number of results to return
|
||||
/// * `fusion` - Fusion method ("rrf", "linear", "learned")
|
||||
/// * `alpha` - Alpha for linear fusion (0-1, higher favors vector)
|
||||
///
|
||||
/// # Returns
|
||||
/// Table of results with id, content, vector_score, keyword_score, hybrid_score
|
||||
#[pg_extern]
|
||||
fn ruvector_hybrid_search(
|
||||
collection: &str,
|
||||
query_text: &str,
|
||||
query_vector: Vec<f32>,
|
||||
k: i32,
|
||||
fusion: default!(Option<&str>, "NULL"),
|
||||
alpha: default!(Option<f32>, "NULL"),
|
||||
) -> pgrx::JsonB {
|
||||
let k = k.max(1) as usize;
|
||||
let (schema, table) = parse_collection_name(collection);
|
||||
let qualified_name = format!("{}.{}", schema, table);
|
||||
|
||||
let registry = get_registry();
|
||||
let config = match registry.get_by_name(&qualified_name) {
|
||||
Some(c) => c,
|
||||
None => {
|
||||
return pgrx::JsonB(serde_json::json!({
|
||||
"success": false,
|
||||
"error": format!("Collection '{}' is not registered for hybrid search. Run ruvector_register_hybrid first.", collection)
|
||||
}));
|
||||
}
|
||||
};
|
||||
|
||||
// Build fusion config
|
||||
let mut fusion_config = config.fusion_config.clone();
|
||||
if let Some(method) = fusion {
|
||||
if let Ok(m) = method.parse::<FusionMethod>() {
|
||||
fusion_config.method = m;
|
||||
}
|
||||
}
|
||||
if let Some(a) = alpha {
|
||||
fusion_config.alpha = a.clamp(0.0, 1.0);
|
||||
}
|
||||
|
||||
// Build query
|
||||
let query = HybridQuery {
|
||||
text: query_text.to_string(),
|
||||
embedding: query_vector,
|
||||
k,
|
||||
prefetch_k: config.prefetch_k.max(k * 2),
|
||||
fusion_config,
|
||||
filter: None,
|
||||
};
|
||||
|
||||
// Create executor
|
||||
let executor = HybridExecutor::new(config.corpus_stats.clone());
|
||||
|
||||
// In the actual extension, these would execute real searches via SPI
|
||||
// For now, return a demonstration response
|
||||
let mock_vector_results: Vec<(DocId, f32)> = (1..=k.min(10) as i64)
|
||||
.map(|i| (i, 0.1 * i as f32))
|
||||
.collect();
|
||||
|
||||
let mock_keyword_results: Vec<(DocId, f32)> = (1..=k.min(10) as i64)
|
||||
.map(|i| ((k as i64 + 1 - i), 10.0 / i as f32))
|
||||
.collect();
|
||||
|
||||
// Execute fusion
|
||||
let (results, stats) = executor.execute(
|
||||
&query,
|
||||
|_, k| BranchResults {
|
||||
results: mock_vector_results.clone().into_iter().take(k).collect(),
|
||||
latency_ms: 1.0,
|
||||
candidates_evaluated: 100,
|
||||
},
|
||||
|_, k| BranchResults {
|
||||
results: mock_keyword_results.clone().into_iter().take(k).collect(),
|
||||
latency_ms: 0.5,
|
||||
candidates_evaluated: 50,
|
||||
},
|
||||
);
|
||||
|
||||
// Format results
|
||||
let result_json: Vec<serde_json::Value> = results
|
||||
.iter()
|
||||
.enumerate()
|
||||
.map(|(i, r)| {
|
||||
serde_json::json!({
|
||||
"rank": i + 1,
|
||||
"id": r.id,
|
||||
"hybrid_score": r.hybrid_score,
|
||||
"vector_score": r.vector_score,
|
||||
"keyword_score": r.keyword_score,
|
||||
"vector_rank": r.vector_rank,
|
||||
"keyword_rank": r.keyword_rank
|
||||
})
|
||||
})
|
||||
.collect();
|
||||
|
||||
pgrx::JsonB(serde_json::json!({
|
||||
"success": true,
|
||||
"collection": collection,
|
||||
"query": {
|
||||
"text": query_text,
|
||||
"vector_dims": query.embedding.len(),
|
||||
"k": k,
|
||||
"fusion": format!("{:?}", query.fusion_config.method),
|
||||
"alpha": query.fusion_config.alpha
|
||||
},
|
||||
"results": result_json,
|
||||
"stats": {
|
||||
"total_latency_ms": stats.total_latency_ms,
|
||||
"vector_latency_ms": stats.vector_latency_ms,
|
||||
"keyword_latency_ms": stats.keyword_latency_ms,
|
||||
"fusion_latency_ms": stats.fusion_latency_ms,
|
||||
"result_count": stats.result_count
|
||||
},
|
||||
"note": "This is a demonstration. In production, actual vector/keyword searches would be executed via SPI."
|
||||
}))
|
||||
}
|
||||
|
||||
/// Get hybrid search statistics for a collection
|
||||
#[pg_extern]
|
||||
fn ruvector_hybrid_stats(collection: &str) -> pgrx::JsonB {
|
||||
let (schema, table) = parse_collection_name(collection);
|
||||
let qualified_name = format!("{}.{}", schema, table);
|
||||
|
||||
let registry = get_registry();
|
||||
match registry.get_by_name(&qualified_name) {
|
||||
Some(config) => pgrx::JsonB(serde_json::json!({
|
||||
"collection": collection,
|
||||
"corpus_stats": {
|
||||
"avg_doc_length": config.corpus_stats.avg_doc_length,
|
||||
"doc_count": config.corpus_stats.doc_count,
|
||||
"total_terms": config.corpus_stats.total_terms,
|
||||
"last_update": config.corpus_stats.last_update
|
||||
},
|
||||
"bm25_config": {
|
||||
"k1": config.bm25_config.k1,
|
||||
"b": config.bm25_config.b
|
||||
},
|
||||
"fusion_config": {
|
||||
"method": format!("{:?}", config.fusion_config.method),
|
||||
"alpha": config.fusion_config.alpha,
|
||||
"rrf_k": config.fusion_config.rrf_k
|
||||
},
|
||||
"settings": {
|
||||
"prefetch_k": config.prefetch_k,
|
||||
"parallel_enabled": config.parallel_enabled,
|
||||
"stats_refresh_interval": config.stats_refresh_interval
|
||||
},
|
||||
"metadata": {
|
||||
"vector_column": config.vector_column,
|
||||
"fts_column": config.fts_column,
|
||||
"text_column": config.text_column,
|
||||
"created_at": config.created_at,
|
||||
"updated_at": config.updated_at
|
||||
}
|
||||
})),
|
||||
None => pgrx::JsonB(serde_json::json!({
|
||||
"error": format!("Collection '{}' is not registered for hybrid search", collection)
|
||||
})),
|
||||
}
|
||||
}
|
||||
|
||||
/// Compute hybrid score from vector distance and keyword score
|
||||
///
|
||||
/// Utility function for manual hybrid scoring in queries.
|
||||
#[pg_extern(immutable, parallel_safe)]
|
||||
fn ruvector_hybrid_score(
|
||||
vector_distance: f32,
|
||||
keyword_score: f32,
|
||||
alpha: default!(Option<f32>, "0.5"),
|
||||
) -> f32 {
|
||||
let alpha = alpha.unwrap_or(0.5).clamp(0.0, 1.0);
|
||||
|
||||
// Convert distance to similarity (assuming cosine distance in [0, 2])
|
||||
let vector_similarity = 1.0 - (vector_distance / 2.0).clamp(0.0, 1.0);
|
||||
|
||||
// Simple linear blend (normalized keyword scores assumed)
|
||||
alpha * vector_similarity + (1.0 - alpha) * keyword_score
|
||||
}
|
||||
|
||||
/// List all collections registered for hybrid search
|
||||
#[pg_extern]
|
||||
fn ruvector_hybrid_list() -> pgrx::JsonB {
|
||||
let registry = get_registry();
|
||||
let collections: Vec<serde_json::Value> = registry
|
||||
.list()
|
||||
.iter()
|
||||
.map(|c| {
|
||||
serde_json::json!({
|
||||
"collection_id": c.collection_id,
|
||||
"name": c.qualified_name(),
|
||||
"vector_column": c.vector_column,
|
||||
"fts_column": c.fts_column,
|
||||
"fusion_method": format!("{:?}", c.fusion_config.method),
|
||||
"doc_count": c.corpus_stats.doc_count,
|
||||
"needs_refresh": c.needs_stats_refresh()
|
||||
})
|
||||
})
|
||||
.collect();
|
||||
|
||||
pgrx::JsonB(serde_json::json!({
|
||||
"count": collections.len(),
|
||||
"collections": collections
|
||||
}))
|
||||
}
|
||||
|
||||
// ============================================================================
|
||||
// Helper Functions
|
||||
// ============================================================================
|
||||
|
||||
/// Parse collection name into schema and table
|
||||
fn parse_collection_name(name: &str) -> (&str, &str) {
|
||||
if let Some((schema, table)) = name.split_once('.') {
|
||||
(schema, table)
|
||||
} else {
|
||||
("public", name)
|
||||
}
|
||||
}
|
||||
|
||||
// ============================================================================
|
||||
// Tests
|
||||
// ============================================================================
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
#[test]
|
||||
fn test_parse_collection_name() {
|
||||
let (schema, table) = parse_collection_name("documents");
|
||||
assert_eq!(schema, "public");
|
||||
assert_eq!(table, "documents");
|
||||
|
||||
let (schema, table) = parse_collection_name("myschema.mytable");
|
||||
assert_eq!(schema, "myschema");
|
||||
assert_eq!(table, "mytable");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_module_exports() {
|
||||
// Verify all expected types are accessible
|
||||
let _ = BM25Config::default();
|
||||
let _ = FusionConfig::default();
|
||||
let _ = CorpusStats::default();
|
||||
|
||||
let stats = CorpusStats {
|
||||
avg_doc_length: 100.0,
|
||||
doc_count: 1000,
|
||||
total_terms: 100000,
|
||||
last_update: 0,
|
||||
};
|
||||
let _ = BM25Scorer::new(stats.clone());
|
||||
let _ = HybridExecutor::new(stats);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_registry_workflow() {
|
||||
let registry = HybridRegistry::new();
|
||||
|
||||
// Register
|
||||
let config = HybridCollectionConfig::new(
|
||||
1,
|
||||
"test_table".to_string(),
|
||||
"embedding".to_string(),
|
||||
"fts".to_string(),
|
||||
"content".to_string(),
|
||||
);
|
||||
registry.register(config).unwrap();
|
||||
|
||||
// Get
|
||||
let retrieved = registry.get(1).unwrap();
|
||||
assert_eq!(retrieved.table_name, "test_table");
|
||||
|
||||
// List
|
||||
let list = registry.list();
|
||||
assert_eq!(list.len(), 1);
|
||||
}
|
||||
}
|
||||
651
vendor/ruvector/crates/ruvector-postgres/src/hybrid/registry.rs
vendored
Normal file
651
vendor/ruvector/crates/ruvector-postgres/src/hybrid/registry.rs
vendored
Normal file
@@ -0,0 +1,651 @@
|
||||
//! Hybrid Collections Registry
|
||||
//!
|
||||
//! Tracks collections with hybrid search enabled and stores:
|
||||
//! - BM25 corpus statistics
|
||||
//! - Per-collection fusion settings
|
||||
//! - Column mappings for vector and FTS
|
||||
|
||||
use parking_lot::RwLock;
|
||||
use serde::{Deserialize, Serialize};
|
||||
use std::collections::HashMap;
|
||||
use std::sync::Arc;
|
||||
|
||||
use super::bm25::{BM25Config, CorpusStats};
|
||||
use super::fusion::FusionConfig;
|
||||
#[cfg(test)]
|
||||
use super::fusion::FusionMethod;
|
||||
|
||||
/// Hybrid collection configuration
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct HybridCollectionConfig {
|
||||
/// Collection ID (from ruvector.collections table)
|
||||
pub collection_id: i32,
|
||||
/// Table name
|
||||
pub table_name: String,
|
||||
/// Schema name (default: public)
|
||||
pub schema_name: String,
|
||||
/// Vector column name
|
||||
pub vector_column: String,
|
||||
/// FTS tsvector column name
|
||||
pub fts_column: String,
|
||||
/// Original text column name (for BM25 stats)
|
||||
pub text_column: String,
|
||||
/// Primary key column name
|
||||
pub pk_column: String,
|
||||
|
||||
/// BM25 configuration
|
||||
pub bm25_config: BM25Config,
|
||||
/// Fusion configuration
|
||||
pub fusion_config: FusionConfig,
|
||||
/// Corpus statistics
|
||||
pub corpus_stats: CorpusStats,
|
||||
|
||||
/// Prefetch size for each branch
|
||||
pub prefetch_k: usize,
|
||||
/// Stats refresh interval in seconds
|
||||
pub stats_refresh_interval: i64,
|
||||
/// Enable parallel branch execution
|
||||
pub parallel_enabled: bool,
|
||||
|
||||
/// Created timestamp (Unix epoch)
|
||||
pub created_at: i64,
|
||||
/// Last modified timestamp
|
||||
pub updated_at: i64,
|
||||
}
|
||||
|
||||
impl HybridCollectionConfig {
|
||||
/// Create a new hybrid collection configuration
|
||||
pub fn new(
|
||||
collection_id: i32,
|
||||
table_name: String,
|
||||
vector_column: String,
|
||||
fts_column: String,
|
||||
text_column: String,
|
||||
) -> Self {
|
||||
let now = std::time::SystemTime::now()
|
||||
.duration_since(std::time::UNIX_EPOCH)
|
||||
.unwrap()
|
||||
.as_secs() as i64;
|
||||
|
||||
Self {
|
||||
collection_id,
|
||||
table_name,
|
||||
schema_name: "public".to_string(),
|
||||
vector_column,
|
||||
fts_column,
|
||||
text_column,
|
||||
pk_column: "id".to_string(),
|
||||
bm25_config: BM25Config::default(),
|
||||
fusion_config: FusionConfig::default(),
|
||||
corpus_stats: CorpusStats::default(),
|
||||
prefetch_k: 100,
|
||||
stats_refresh_interval: 3600, // 1 hour
|
||||
parallel_enabled: true,
|
||||
created_at: now,
|
||||
updated_at: now,
|
||||
}
|
||||
}
|
||||
|
||||
/// Get fully qualified table name
|
||||
pub fn qualified_name(&self) -> String {
|
||||
format!("{}.{}", self.schema_name, self.table_name)
|
||||
}
|
||||
|
||||
/// Check if stats need refresh
|
||||
pub fn needs_stats_refresh(&self) -> bool {
|
||||
let now = std::time::SystemTime::now()
|
||||
.duration_since(std::time::UNIX_EPOCH)
|
||||
.unwrap()
|
||||
.as_secs() as i64;
|
||||
|
||||
now - self.corpus_stats.last_update > self.stats_refresh_interval
|
||||
}
|
||||
|
||||
/// Update corpus statistics
|
||||
pub fn update_stats(&mut self, stats: CorpusStats) {
|
||||
self.corpus_stats = stats;
|
||||
self.updated_at = std::time::SystemTime::now()
|
||||
.duration_since(std::time::UNIX_EPOCH)
|
||||
.unwrap()
|
||||
.as_secs() as i64;
|
||||
}
|
||||
}
|
||||
|
||||
/// Registry entry for a hybrid collection
|
||||
#[derive(Debug)]
|
||||
struct RegistryEntry {
|
||||
/// Configuration
|
||||
config: HybridCollectionConfig,
|
||||
/// Cached IDF values (term -> idf)
|
||||
idf_cache: HashMap<String, f32>,
|
||||
/// Document frequency cache (term -> doc count)
|
||||
df_cache: HashMap<String, u64>,
|
||||
}
|
||||
|
||||
impl RegistryEntry {
|
||||
fn new(config: HybridCollectionConfig) -> Self {
|
||||
Self {
|
||||
config,
|
||||
idf_cache: HashMap::new(),
|
||||
df_cache: HashMap::new(),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Hybrid Collections Registry
|
||||
///
|
||||
/// Global registry for hybrid-enabled collections.
|
||||
/// In the PostgreSQL extension, this is backed by the ruvector.hybrid_collections table.
|
||||
pub struct HybridRegistry {
|
||||
/// Collections by ID
|
||||
collections_by_id: RwLock<HashMap<i32, RegistryEntry>>,
|
||||
/// Collections by name (schema.table -> id)
|
||||
collections_by_name: RwLock<HashMap<String, i32>>,
|
||||
}
|
||||
|
||||
impl HybridRegistry {
|
||||
/// Create a new registry
|
||||
pub fn new() -> Self {
|
||||
Self {
|
||||
collections_by_id: RwLock::new(HashMap::new()),
|
||||
collections_by_name: RwLock::new(HashMap::new()),
|
||||
}
|
||||
}
|
||||
|
||||
/// Register a collection for hybrid search
|
||||
pub fn register(&self, config: HybridCollectionConfig) -> Result<(), RegistryError> {
|
||||
let qualified_name = config.qualified_name();
|
||||
let collection_id = config.collection_id;
|
||||
|
||||
// Check for duplicates
|
||||
{
|
||||
let by_name = self.collections_by_name.read();
|
||||
if by_name.contains_key(&qualified_name) {
|
||||
return Err(RegistryError::AlreadyRegistered(qualified_name));
|
||||
}
|
||||
}
|
||||
|
||||
// Insert into both maps
|
||||
let entry = RegistryEntry::new(config);
|
||||
|
||||
self.collections_by_id.write().insert(collection_id, entry);
|
||||
self.collections_by_name
|
||||
.write()
|
||||
.insert(qualified_name, collection_id);
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// Unregister a collection
|
||||
pub fn unregister(&self, collection_id: i32) -> Result<(), RegistryError> {
|
||||
let entry = self.collections_by_id.write().remove(&collection_id);
|
||||
|
||||
if let Some(entry) = entry {
|
||||
let qualified_name = entry.config.qualified_name();
|
||||
self.collections_by_name.write().remove(&qualified_name);
|
||||
Ok(())
|
||||
} else {
|
||||
Err(RegistryError::NotFound(collection_id.to_string()))
|
||||
}
|
||||
}
|
||||
|
||||
/// Get collection by ID
|
||||
pub fn get(&self, collection_id: i32) -> Option<HybridCollectionConfig> {
|
||||
self.collections_by_id
|
||||
.read()
|
||||
.get(&collection_id)
|
||||
.map(|e| e.config.clone())
|
||||
}
|
||||
|
||||
/// Get collection by name
|
||||
pub fn get_by_name(&self, name: &str) -> Option<HybridCollectionConfig> {
|
||||
let collection_id = self.collections_by_name.read().get(name).copied()?;
|
||||
self.get(collection_id)
|
||||
}
|
||||
|
||||
/// Update collection configuration
|
||||
pub fn update(&self, config: HybridCollectionConfig) -> Result<(), RegistryError> {
|
||||
let collection_id = config.collection_id;
|
||||
|
||||
let mut by_id = self.collections_by_id.write();
|
||||
if let Some(entry) = by_id.get_mut(&collection_id) {
|
||||
entry.config = config;
|
||||
Ok(())
|
||||
} else {
|
||||
Err(RegistryError::NotFound(collection_id.to_string()))
|
||||
}
|
||||
}
|
||||
|
||||
/// Update corpus statistics for a collection
|
||||
pub fn update_stats(
|
||||
&self,
|
||||
collection_id: i32,
|
||||
stats: CorpusStats,
|
||||
) -> Result<(), RegistryError> {
|
||||
let mut by_id = self.collections_by_id.write();
|
||||
if let Some(entry) = by_id.get_mut(&collection_id) {
|
||||
entry.config.update_stats(stats);
|
||||
// Clear caches when stats change
|
||||
entry.idf_cache.clear();
|
||||
entry.df_cache.clear();
|
||||
Ok(())
|
||||
} else {
|
||||
Err(RegistryError::NotFound(collection_id.to_string()))
|
||||
}
|
||||
}
|
||||
|
||||
/// Set document frequency for a term in a collection
|
||||
pub fn set_doc_freq(
|
||||
&self,
|
||||
collection_id: i32,
|
||||
term: &str,
|
||||
doc_freq: u64,
|
||||
) -> Result<(), RegistryError> {
|
||||
let mut by_id = self.collections_by_id.write();
|
||||
if let Some(entry) = by_id.get_mut(&collection_id) {
|
||||
entry.df_cache.insert(term.to_string(), doc_freq);
|
||||
// Invalidate IDF cache for this term
|
||||
entry.idf_cache.remove(term);
|
||||
Ok(())
|
||||
} else {
|
||||
Err(RegistryError::NotFound(collection_id.to_string()))
|
||||
}
|
||||
}
|
||||
|
||||
/// Get IDF for a term, computing if not cached
|
||||
pub fn get_idf(&self, collection_id: i32, term: &str) -> Option<f32> {
|
||||
let mut by_id = self.collections_by_id.write();
|
||||
let entry = by_id.get_mut(&collection_id)?;
|
||||
|
||||
// Check cache
|
||||
if let Some(&idf) = entry.idf_cache.get(term) {
|
||||
return Some(idf);
|
||||
}
|
||||
|
||||
// Compute IDF
|
||||
let df = entry.df_cache.get(term).copied().unwrap_or(0);
|
||||
let n = entry.config.corpus_stats.doc_count as f32;
|
||||
let df_f = df as f32;
|
||||
|
||||
let idf = if df == 0 {
|
||||
(n + 0.5).ln()
|
||||
} else {
|
||||
((n - df_f + 0.5) / (df_f + 0.5) + 1.0).ln()
|
||||
};
|
||||
|
||||
// Cache and return
|
||||
entry.idf_cache.insert(term.to_string(), idf);
|
||||
Some(idf)
|
||||
}
|
||||
|
||||
/// List all registered collections
|
||||
pub fn list(&self) -> Vec<HybridCollectionConfig> {
|
||||
self.collections_by_id
|
||||
.read()
|
||||
.values()
|
||||
.map(|e| e.config.clone())
|
||||
.collect()
|
||||
}
|
||||
|
||||
/// Check if a collection is registered
|
||||
pub fn is_registered(&self, collection_id: i32) -> bool {
|
||||
self.collections_by_id.read().contains_key(&collection_id)
|
||||
}
|
||||
|
||||
/// Get collections needing stats refresh
|
||||
pub fn collections_needing_refresh(&self) -> Vec<i32> {
|
||||
self.collections_by_id
|
||||
.read()
|
||||
.iter()
|
||||
.filter(|(_, e)| e.config.needs_stats_refresh())
|
||||
.map(|(id, _)| *id)
|
||||
.collect()
|
||||
}
|
||||
|
||||
/// Clear all caches
|
||||
pub fn clear_caches(&self) {
|
||||
let mut by_id = self.collections_by_id.write();
|
||||
for entry in by_id.values_mut() {
|
||||
entry.idf_cache.clear();
|
||||
entry.df_cache.clear();
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl Default for HybridRegistry {
|
||||
fn default() -> Self {
|
||||
Self::new()
|
||||
}
|
||||
}
|
||||
|
||||
/// Registry error types
|
||||
#[derive(Debug, Clone)]
|
||||
pub enum RegistryError {
|
||||
/// Collection already registered
|
||||
AlreadyRegistered(String),
|
||||
/// Collection not found
|
||||
NotFound(String),
|
||||
/// Invalid configuration
|
||||
InvalidConfig(String),
|
||||
/// Database error
|
||||
DatabaseError(String),
|
||||
}
|
||||
|
||||
impl std::fmt::Display for RegistryError {
|
||||
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
|
||||
match self {
|
||||
RegistryError::AlreadyRegistered(name) => {
|
||||
write!(
|
||||
f,
|
||||
"Collection '{}' is already registered for hybrid search",
|
||||
name
|
||||
)
|
||||
}
|
||||
RegistryError::NotFound(name) => {
|
||||
write!(f, "Hybrid collection '{}' not found", name)
|
||||
}
|
||||
RegistryError::InvalidConfig(msg) => {
|
||||
write!(f, "Invalid hybrid configuration: {}", msg)
|
||||
}
|
||||
RegistryError::DatabaseError(msg) => {
|
||||
write!(f, "Database error: {}", msg)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl std::error::Error for RegistryError {}
|
||||
|
||||
// Global registry instance
|
||||
lazy_static::lazy_static! {
|
||||
/// Global hybrid collections registry
|
||||
pub static ref HYBRID_REGISTRY: Arc<HybridRegistry> = Arc::new(HybridRegistry::new());
|
||||
}
|
||||
|
||||
/// Get the global hybrid registry
|
||||
pub fn get_registry() -> Arc<HybridRegistry> {
|
||||
HYBRID_REGISTRY.clone()
|
||||
}
|
||||
|
||||
/// Configuration update from JSONB
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct HybridConfigUpdate {
|
||||
/// New fusion method
|
||||
#[serde(default, skip_serializing_if = "Option::is_none")]
|
||||
pub default_fusion: Option<String>,
|
||||
/// New alpha value
|
||||
#[serde(default, skip_serializing_if = "Option::is_none")]
|
||||
pub default_alpha: Option<f32>,
|
||||
/// New RRF k value
|
||||
#[serde(default, skip_serializing_if = "Option::is_none")]
|
||||
pub rrf_k: Option<usize>,
|
||||
/// New prefetch k value
|
||||
#[serde(default, skip_serializing_if = "Option::is_none")]
|
||||
pub prefetch_k: Option<usize>,
|
||||
/// BM25 k1 parameter
|
||||
#[serde(default, skip_serializing_if = "Option::is_none")]
|
||||
pub bm25_k1: Option<f32>,
|
||||
/// BM25 b parameter
|
||||
#[serde(default, skip_serializing_if = "Option::is_none")]
|
||||
pub bm25_b: Option<f32>,
|
||||
/// Stats refresh interval (e.g., "1 hour", "30 minutes")
|
||||
#[serde(default, skip_serializing_if = "Option::is_none")]
|
||||
pub stats_refresh_interval: Option<String>,
|
||||
/// Enable parallel execution
|
||||
#[serde(default, skip_serializing_if = "Option::is_none")]
|
||||
pub parallel_enabled: Option<bool>,
|
||||
}
|
||||
|
||||
impl HybridConfigUpdate {
|
||||
/// Apply updates to a configuration
|
||||
pub fn apply(&self, config: &mut HybridCollectionConfig) -> Result<(), RegistryError> {
|
||||
if let Some(ref fusion) = self.default_fusion {
|
||||
config.fusion_config.method = fusion
|
||||
.parse()
|
||||
.map_err(|e: String| RegistryError::InvalidConfig(e))?;
|
||||
}
|
||||
|
||||
if let Some(alpha) = self.default_alpha {
|
||||
if !(0.0..=1.0).contains(&alpha) {
|
||||
return Err(RegistryError::InvalidConfig(
|
||||
"alpha must be between 0 and 1".into(),
|
||||
));
|
||||
}
|
||||
config.fusion_config.alpha = alpha;
|
||||
}
|
||||
|
||||
if let Some(rrf_k) = self.rrf_k {
|
||||
if rrf_k == 0 {
|
||||
return Err(RegistryError::InvalidConfig(
|
||||
"rrf_k must be positive".into(),
|
||||
));
|
||||
}
|
||||
config.fusion_config.rrf_k = rrf_k;
|
||||
}
|
||||
|
||||
if let Some(prefetch_k) = self.prefetch_k {
|
||||
if prefetch_k == 0 {
|
||||
return Err(RegistryError::InvalidConfig(
|
||||
"prefetch_k must be positive".into(),
|
||||
));
|
||||
}
|
||||
config.prefetch_k = prefetch_k;
|
||||
}
|
||||
|
||||
if let Some(k1) = self.bm25_k1 {
|
||||
config.bm25_config.k1 = k1.max(0.0);
|
||||
}
|
||||
|
||||
if let Some(b) = self.bm25_b {
|
||||
config.bm25_config.b = b.clamp(0.0, 1.0);
|
||||
}
|
||||
|
||||
if let Some(ref interval) = self.stats_refresh_interval {
|
||||
config.stats_refresh_interval = parse_interval(interval)?;
|
||||
}
|
||||
|
||||
if let Some(parallel) = self.parallel_enabled {
|
||||
config.parallel_enabled = parallel;
|
||||
}
|
||||
|
||||
config.updated_at = std::time::SystemTime::now()
|
||||
.duration_since(std::time::UNIX_EPOCH)
|
||||
.unwrap()
|
||||
.as_secs() as i64;
|
||||
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
|
||||
/// Parse interval string to seconds
|
||||
fn parse_interval(s: &str) -> Result<i64, RegistryError> {
|
||||
let s = s.trim().to_lowercase();
|
||||
|
||||
// Try common formats
|
||||
if let Some(hours) = s.strip_suffix(" hour").or_else(|| s.strip_suffix(" hours")) {
|
||||
return hours
|
||||
.trim()
|
||||
.parse::<i64>()
|
||||
.map(|h| h * 3600)
|
||||
.map_err(|_| RegistryError::InvalidConfig(format!("Invalid interval: {}", s)));
|
||||
}
|
||||
|
||||
if let Some(mins) = s
|
||||
.strip_suffix(" minute")
|
||||
.or_else(|| s.strip_suffix(" minutes"))
|
||||
{
|
||||
return mins
|
||||
.trim()
|
||||
.parse::<i64>()
|
||||
.map(|m| m * 60)
|
||||
.map_err(|_| RegistryError::InvalidConfig(format!("Invalid interval: {}", s)));
|
||||
}
|
||||
|
||||
if let Some(secs) = s
|
||||
.strip_suffix(" second")
|
||||
.or_else(|| s.strip_suffix(" seconds"))
|
||||
{
|
||||
return secs
|
||||
.trim()
|
||||
.parse::<i64>()
|
||||
.map_err(|_| RegistryError::InvalidConfig(format!("Invalid interval: {}", s)));
|
||||
}
|
||||
|
||||
// Try as plain seconds
|
||||
s.parse::<i64>()
|
||||
.map_err(|_| RegistryError::InvalidConfig(format!("Invalid interval: {}", s)))
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
#[test]
|
||||
fn test_registry_register_get() {
|
||||
let registry = HybridRegistry::new();
|
||||
|
||||
let config = HybridCollectionConfig::new(
|
||||
1,
|
||||
"documents".to_string(),
|
||||
"embedding".to_string(),
|
||||
"fts".to_string(),
|
||||
"content".to_string(),
|
||||
);
|
||||
|
||||
registry.register(config.clone()).unwrap();
|
||||
|
||||
let retrieved = registry.get(1).unwrap();
|
||||
assert_eq!(retrieved.table_name, "documents");
|
||||
assert_eq!(retrieved.vector_column, "embedding");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_registry_duplicate() {
|
||||
let registry = HybridRegistry::new();
|
||||
|
||||
let config = HybridCollectionConfig::new(
|
||||
1,
|
||||
"documents".to_string(),
|
||||
"embedding".to_string(),
|
||||
"fts".to_string(),
|
||||
"content".to_string(),
|
||||
);
|
||||
|
||||
registry.register(config.clone()).unwrap();
|
||||
let result = registry.register(config);
|
||||
|
||||
assert!(matches!(result, Err(RegistryError::AlreadyRegistered(_))));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_registry_get_by_name() {
|
||||
let registry = HybridRegistry::new();
|
||||
|
||||
let config = HybridCollectionConfig::new(
|
||||
42,
|
||||
"my_table".to_string(),
|
||||
"vec".to_string(),
|
||||
"tsv".to_string(),
|
||||
"text".to_string(),
|
||||
);
|
||||
|
||||
registry.register(config).unwrap();
|
||||
|
||||
let retrieved = registry.get_by_name("public.my_table").unwrap();
|
||||
assert_eq!(retrieved.collection_id, 42);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_registry_update_stats() {
|
||||
let registry = HybridRegistry::new();
|
||||
|
||||
let config = HybridCollectionConfig::new(
|
||||
1,
|
||||
"test".to_string(),
|
||||
"vec".to_string(),
|
||||
"fts".to_string(),
|
||||
"text".to_string(),
|
||||
);
|
||||
|
||||
registry.register(config).unwrap();
|
||||
|
||||
let new_stats = CorpusStats {
|
||||
avg_doc_length: 150.0,
|
||||
doc_count: 5000,
|
||||
total_terms: 500000,
|
||||
last_update: 12345,
|
||||
};
|
||||
|
||||
registry.update_stats(1, new_stats).unwrap();
|
||||
|
||||
let updated = registry.get(1).unwrap();
|
||||
assert!((updated.corpus_stats.avg_doc_length - 150.0).abs() < 0.01);
|
||||
assert_eq!(updated.corpus_stats.doc_count, 5000);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_config_update() {
|
||||
let mut config = HybridCollectionConfig::new(
|
||||
1,
|
||||
"test".to_string(),
|
||||
"vec".to_string(),
|
||||
"fts".to_string(),
|
||||
"text".to_string(),
|
||||
);
|
||||
|
||||
let update = HybridConfigUpdate {
|
||||
default_fusion: Some("linear".to_string()),
|
||||
default_alpha: Some(0.7),
|
||||
rrf_k: Some(40),
|
||||
prefetch_k: Some(200),
|
||||
bm25_k1: Some(1.5),
|
||||
bm25_b: Some(0.8),
|
||||
stats_refresh_interval: Some("2 hours".to_string()),
|
||||
parallel_enabled: Some(false),
|
||||
};
|
||||
|
||||
update.apply(&mut config).unwrap();
|
||||
|
||||
assert_eq!(config.fusion_config.method, FusionMethod::Linear);
|
||||
assert!((config.fusion_config.alpha - 0.7).abs() < 0.01);
|
||||
assert_eq!(config.fusion_config.rrf_k, 40);
|
||||
assert_eq!(config.prefetch_k, 200);
|
||||
assert!((config.bm25_config.k1 - 1.5).abs() < 0.01);
|
||||
assert!((config.bm25_config.b - 0.8).abs() < 0.01);
|
||||
assert_eq!(config.stats_refresh_interval, 7200);
|
||||
assert!(!config.parallel_enabled);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_parse_interval() {
|
||||
assert_eq!(parse_interval("1 hour").unwrap(), 3600);
|
||||
assert_eq!(parse_interval("2 hours").unwrap(), 7200);
|
||||
assert_eq!(parse_interval("30 minutes").unwrap(), 1800);
|
||||
assert_eq!(parse_interval("60 seconds").unwrap(), 60);
|
||||
assert_eq!(parse_interval("120").unwrap(), 120);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_needs_refresh() {
|
||||
let mut config = HybridCollectionConfig::new(
|
||||
1,
|
||||
"test".to_string(),
|
||||
"vec".to_string(),
|
||||
"fts".to_string(),
|
||||
"text".to_string(),
|
||||
);
|
||||
|
||||
// Fresh stats should not need refresh
|
||||
config.corpus_stats.last_update = std::time::SystemTime::now()
|
||||
.duration_since(std::time::UNIX_EPOCH)
|
||||
.unwrap()
|
||||
.as_secs() as i64;
|
||||
config.stats_refresh_interval = 3600;
|
||||
|
||||
assert!(!config.needs_stats_refresh());
|
||||
|
||||
// Old stats should need refresh
|
||||
config.corpus_stats.last_update -= 7200;
|
||||
assert!(config.needs_stats_refresh());
|
||||
}
|
||||
}
|
||||
806
vendor/ruvector/crates/ruvector-postgres/src/hybrid/tests.rs
vendored
Normal file
806
vendor/ruvector/crates/ruvector-postgres/src/hybrid/tests.rs
vendored
Normal file
@@ -0,0 +1,806 @@
|
||||
//! Comprehensive tests for the hybrid search module
|
||||
//!
|
||||
//! Tests cover:
|
||||
//! - BM25 scoring correctness
|
||||
//! - Fusion algorithm behavior
|
||||
//! - Executor integration
|
||||
//! - Registry operations
|
||||
|
||||
#[cfg(test)]
|
||||
mod bm25_tests {
|
||||
use crate::hybrid::bm25::*;
|
||||
use std::collections::HashMap;
|
||||
|
||||
/// Create a test scorer with known corpus statistics
|
||||
fn test_scorer() -> BM25Scorer {
|
||||
let stats = CorpusStats {
|
||||
avg_doc_length: 100.0,
|
||||
doc_count: 10000,
|
||||
total_terms: 1000000,
|
||||
last_update: 0,
|
||||
};
|
||||
BM25Scorer::new(stats)
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_bm25_idf_formula() {
|
||||
let scorer = test_scorer();
|
||||
|
||||
// Set known document frequencies
|
||||
scorer.set_doc_freq("common", 5000); // 50% of docs
|
||||
scorer.set_doc_freq("rare", 10); // 0.1% of docs
|
||||
scorer.set_doc_freq("unique", 1); // 0.01% of docs
|
||||
|
||||
let idf_common = scorer.idf("common");
|
||||
let idf_rare = scorer.idf("rare");
|
||||
let idf_unique = scorer.idf("unique");
|
||||
|
||||
// IDF should increase as term rarity increases
|
||||
assert!(idf_common < idf_rare, "Common term should have lower IDF");
|
||||
assert!(idf_rare < idf_unique, "Rare term should have lower IDF than unique");
|
||||
|
||||
// Verify approximate values using BM25 formula
|
||||
// IDF = ln((N - df + 0.5) / (df + 0.5) + 1)
|
||||
// For common (df=5000, N=10000): ln((10000-5000+0.5)/(5000+0.5)+1) ~= ln(2) ~= 0.69
|
||||
assert!((idf_common - 0.69).abs() < 0.1, "IDF common: {}", idf_common);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_bm25_score_single_term() {
|
||||
let scorer = test_scorer();
|
||||
scorer.set_doc_freq("test", 1000); // 10% of docs
|
||||
|
||||
let mut freqs = HashMap::new();
|
||||
freqs.insert("test".to_string(), 5); // Term appears 5 times
|
||||
let term_freqs = TermFrequencies::new(freqs);
|
||||
let doc = Document::new(&term_freqs);
|
||||
|
||||
let score = scorer.score(&doc, &["test".to_string()]);
|
||||
|
||||
assert!(score > 0.0, "Score should be positive");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_bm25_score_multiple_terms() {
|
||||
let scorer = test_scorer();
|
||||
scorer.set_doc_freq("database", 500);
|
||||
scorer.set_doc_freq("query", 300);
|
||||
scorer.set_doc_freq("optimization", 100);
|
||||
|
||||
let mut freqs = HashMap::new();
|
||||
freqs.insert("database".to_string(), 3);
|
||||
freqs.insert("query".to_string(), 2);
|
||||
freqs.insert("optimization".to_string(), 1);
|
||||
let term_freqs = TermFrequencies::new(freqs);
|
||||
let doc = Document::new(&term_freqs);
|
||||
|
||||
let query_terms = vec![
|
||||
"database".to_string(),
|
||||
"query".to_string(),
|
||||
"optimization".to_string(),
|
||||
];
|
||||
|
||||
let score = scorer.score(&doc, &query_terms);
|
||||
assert!(score > 0.0);
|
||||
|
||||
// Score with subset should be lower
|
||||
let partial_score = scorer.score(&doc, &["database".to_string()]);
|
||||
assert!(partial_score < score, "Partial match should score lower");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_bm25_length_normalization() {
|
||||
let scorer = test_scorer();
|
||||
scorer.set_doc_freq("keyword", 1000);
|
||||
|
||||
// Create two documents with same TF but different lengths
|
||||
// Short doc: 50 terms, avg is 100
|
||||
let mut short_freqs = HashMap::new();
|
||||
short_freqs.insert("keyword".to_string(), 2);
|
||||
for i in 0..48 {
|
||||
short_freqs.insert(format!("other{}", i), 1);
|
||||
}
|
||||
let short_tf = TermFrequencies::new(short_freqs);
|
||||
let short_doc = Document::new(&short_tf);
|
||||
|
||||
// Long doc: 200 terms, avg is 100
|
||||
let mut long_freqs = HashMap::new();
|
||||
long_freqs.insert("keyword".to_string(), 2);
|
||||
for i in 0..198 {
|
||||
long_freqs.insert(format!("other{}", i), 1);
|
||||
}
|
||||
let long_tf = TermFrequencies::new(long_freqs);
|
||||
let long_doc = Document::new(&long_tf);
|
||||
|
||||
let query = vec!["keyword".to_string()];
|
||||
let short_score = scorer.score(&short_doc, &query);
|
||||
let long_score = scorer.score(&long_doc, &query);
|
||||
|
||||
// Short doc should score higher due to length normalization
|
||||
assert!(short_score > long_score,
|
||||
"Short doc ({}) should score higher than long doc ({})",
|
||||
short_score, long_score
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_bm25_tf_saturation() {
|
||||
let scorer = test_scorer();
|
||||
scorer.set_doc_freq("term", 500);
|
||||
|
||||
// Document with low TF
|
||||
let mut low_tf_freqs = HashMap::new();
|
||||
low_tf_freqs.insert("term".to_string(), 1);
|
||||
let low_tf = TermFrequencies::new(low_tf_freqs);
|
||||
let low_doc = Document::new(&low_tf);
|
||||
|
||||
// Document with high TF
|
||||
let mut high_tf_freqs = HashMap::new();
|
||||
high_tf_freqs.insert("term".to_string(), 100);
|
||||
let high_tf = TermFrequencies::new(high_tf_freqs);
|
||||
let high_doc = Document::new(&high_tf);
|
||||
|
||||
let query = vec!["term".to_string()];
|
||||
let low_score = scorer.score(&low_doc, &query);
|
||||
let high_score = scorer.score(&high_doc, &query);
|
||||
|
||||
// High TF should score higher, but not 100x higher (saturation)
|
||||
assert!(high_score > low_score);
|
||||
assert!(high_score < low_score * 10.0,
|
||||
"TF saturation should prevent linear scaling: {} vs {}",
|
||||
high_score, low_score
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_bm25_config_params() {
|
||||
let stats = CorpusStats {
|
||||
avg_doc_length: 100.0,
|
||||
doc_count: 1000,
|
||||
total_terms: 100000,
|
||||
last_update: 0,
|
||||
};
|
||||
|
||||
// High k1 = more weight to term frequency
|
||||
let high_k1 = BM25Config::new(2.0, 0.75);
|
||||
let scorer_high_k1 = BM25Scorer::with_config(stats.clone(), high_k1);
|
||||
|
||||
// Low k1 = less weight to term frequency
|
||||
let low_k1 = BM25Config::new(0.5, 0.75);
|
||||
let scorer_low_k1 = BM25Scorer::with_config(stats, low_k1);
|
||||
|
||||
scorer_high_k1.set_doc_freq("test", 100);
|
||||
scorer_low_k1.set_doc_freq("test", 100);
|
||||
|
||||
let mut freqs = HashMap::new();
|
||||
freqs.insert("test".to_string(), 10);
|
||||
let tf = TermFrequencies::new(freqs);
|
||||
let doc = Document::new(&tf);
|
||||
let query = vec!["test".to_string()];
|
||||
|
||||
let score_high = scorer_high_k1.score(&doc, &query);
|
||||
let score_low = scorer_low_k1.score(&doc, &query);
|
||||
|
||||
// Different k1 should produce different scores
|
||||
assert!((score_high - score_low).abs() > 0.1,
|
||||
"k1 should affect scoring: {} vs {}", score_high, score_low
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_tokenize_query() {
|
||||
let tokens = tokenize_query("Hello, World! This is a TEST.");
|
||||
assert_eq!(tokens, vec!["hello", "world", "this", "is", "test"]);
|
||||
|
||||
let tokens2 = tokenize_query("database-query optimization");
|
||||
assert!(tokens2.contains(&"database-query".to_string()) || tokens2.contains(&"database".to_string()));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_parse_tsvector() {
|
||||
let tsvector = "'databas':1,4,7 'queri':2,5 'optim':3";
|
||||
let freqs = parse_tsvector(tsvector);
|
||||
|
||||
assert_eq!(freqs.get("databas"), Some(&3));
|
||||
assert_eq!(freqs.get("queri"), Some(&2));
|
||||
assert_eq!(freqs.get("optim"), Some(&1));
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod fusion_tests {
|
||||
use crate::hybrid::fusion::*;
|
||||
|
||||
fn sample_vector_results() -> Vec<(DocId, f32)> {
|
||||
// Lower distance = better
|
||||
vec![
|
||||
(1, 0.1),
|
||||
(2, 0.15),
|
||||
(3, 0.25),
|
||||
(4, 0.4),
|
||||
(5, 0.6),
|
||||
]
|
||||
}
|
||||
|
||||
fn sample_keyword_results() -> Vec<(DocId, f32)> {
|
||||
// Higher BM25 = better
|
||||
vec![
|
||||
(3, 9.5),
|
||||
(6, 8.0),
|
||||
(1, 7.2),
|
||||
(7, 5.0),
|
||||
(2, 3.5),
|
||||
]
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_rrf_basic() {
|
||||
let vector = sample_vector_results();
|
||||
let keyword = sample_keyword_results();
|
||||
|
||||
let results = rrf_fusion(&vector, &keyword, 60, 10);
|
||||
|
||||
assert!(!results.is_empty());
|
||||
// Doc 1 and 3 appear in both, should rank highly
|
||||
let top_3_ids: Vec<DocId> = results.iter().take(3).map(|r| r.doc_id).collect();
|
||||
assert!(top_3_ids.contains(&1) || top_3_ids.contains(&3));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_rrf_k_parameter() {
|
||||
let vector = sample_vector_results();
|
||||
let keyword = sample_keyword_results();
|
||||
|
||||
// Low k = top ranks matter more
|
||||
let results_low_k = rrf_fusion(&vector, &keyword, 10, 5);
|
||||
|
||||
// High k = ranks matter less
|
||||
let results_high_k = rrf_fusion(&vector, &keyword, 100, 5);
|
||||
|
||||
// Both should produce results
|
||||
assert!(!results_low_k.is_empty());
|
||||
assert!(!results_high_k.is_empty());
|
||||
|
||||
// Order might differ due to k
|
||||
let order_low: Vec<DocId> = results_low_k.iter().map(|r| r.doc_id).collect();
|
||||
let order_high: Vec<DocId> = results_high_k.iter().map(|r| r.doc_id).collect();
|
||||
|
||||
// At least verify both have same elements (possibly different order)
|
||||
for id in &order_low {
|
||||
assert!(order_high.contains(id) || order_low.len() > order_high.len());
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_linear_fusion_alpha() {
|
||||
let vector = sample_vector_results();
|
||||
let keyword = sample_keyword_results();
|
||||
|
||||
// Alpha = 1.0 means only vector
|
||||
let results_vector_only = linear_fusion(&vector, &keyword, 1.0, 5);
|
||||
// Alpha = 0.0 means only keyword
|
||||
let results_keyword_only = linear_fusion(&vector, &keyword, 0.0, 5);
|
||||
|
||||
// With alpha=1, best vector result (id=1) should be top
|
||||
assert_eq!(results_vector_only[0].doc_id, 1);
|
||||
|
||||
// With alpha=0, best keyword result (id=3) should be top
|
||||
assert_eq!(results_keyword_only[0].doc_id, 3);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_linear_fusion_balanced() {
|
||||
let vector = sample_vector_results();
|
||||
let keyword = sample_keyword_results();
|
||||
|
||||
let results = linear_fusion(&vector, &keyword, 0.5, 5);
|
||||
|
||||
// All results should have both scores if they appeared in both
|
||||
for r in &results {
|
||||
assert!(r.hybrid_score > 0.0);
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_fusion_preserves_scores() {
|
||||
let vector = vec![(1, 0.1), (2, 0.2)];
|
||||
let keyword = vec![(1, 5.0), (3, 4.0)];
|
||||
|
||||
let results = rrf_fusion(&vector, &keyword, 60, 10);
|
||||
|
||||
let doc1 = results.iter().find(|r| r.doc_id == 1).unwrap();
|
||||
assert!(doc1.vector_score.is_some());
|
||||
assert!(doc1.keyword_score.is_some());
|
||||
|
||||
let doc2 = results.iter().find(|r| r.doc_id == 2).unwrap();
|
||||
assert!(doc2.vector_score.is_some());
|
||||
assert!(doc2.keyword_score.is_none());
|
||||
|
||||
let doc3 = results.iter().find(|r| r.doc_id == 3).unwrap();
|
||||
assert!(doc3.vector_score.is_none());
|
||||
assert!(doc3.keyword_score.is_some());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_fusion_method_parse() {
|
||||
assert_eq!("rrf".parse::<FusionMethod>().unwrap(), FusionMethod::Rrf);
|
||||
assert_eq!("linear".parse::<FusionMethod>().unwrap(), FusionMethod::Linear);
|
||||
assert_eq!("learned".parse::<FusionMethod>().unwrap(), FusionMethod::Learned);
|
||||
assert!("invalid".parse::<FusionMethod>().is_err());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_query_type_classification() {
|
||||
let nav = classify_query_type(&["github".into(), "login".into()]);
|
||||
assert_eq!(nav, QueryType::Navigational);
|
||||
|
||||
let info = classify_query_type(&["how".into(), "to".into(), "build".into()]);
|
||||
assert_eq!(info, QueryType::Informational);
|
||||
|
||||
let trans = classify_query_type(&["buy".into(), "cheap".into(), "laptop".into()]);
|
||||
assert_eq!(trans, QueryType::Transactional);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_fusion_model() {
|
||||
let model = FusionModel::default();
|
||||
|
||||
// Test navigational query (should favor keyword)
|
||||
let nav_features = QueryFeatures {
|
||||
embedding_norm: 1.0,
|
||||
term_count: 2,
|
||||
avg_term_idf: 2.0,
|
||||
has_exact_match: true,
|
||||
query_type: QueryType::Navigational,
|
||||
};
|
||||
let nav_alpha = model.predict_alpha(&nav_features);
|
||||
assert!(nav_alpha < 0.5, "Nav query should favor keyword");
|
||||
|
||||
// Test informational query (should favor vector)
|
||||
let info_features = QueryFeatures {
|
||||
embedding_norm: 1.2,
|
||||
term_count: 5,
|
||||
avg_term_idf: 4.5,
|
||||
has_exact_match: false,
|
||||
query_type: QueryType::Informational,
|
||||
};
|
||||
let info_alpha = model.predict_alpha(&info_features);
|
||||
assert!(info_alpha > 0.4, "Info query should favor vector");
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod executor_tests {
|
||||
use crate::hybrid::executor::*;
|
||||
use crate::hybrid::fusion::*;
|
||||
use crate::hybrid::bm25::CorpusStats;
|
||||
|
||||
fn mock_corpus_stats() -> CorpusStats {
|
||||
CorpusStats {
|
||||
avg_doc_length: 150.0,
|
||||
doc_count: 5000,
|
||||
total_terms: 750000,
|
||||
last_update: 0,
|
||||
}
|
||||
}
|
||||
|
||||
fn mock_vector_search(_embedding: &[f32], k: usize) -> BranchResults {
|
||||
BranchResults {
|
||||
results: (1..=k.min(10) as i64)
|
||||
.map(|i| (i, 0.05 * i as f32))
|
||||
.collect(),
|
||||
latency_ms: 2.5,
|
||||
candidates_evaluated: 500,
|
||||
}
|
||||
}
|
||||
|
||||
fn mock_keyword_search(_text: &str, k: usize) -> BranchResults {
|
||||
BranchResults {
|
||||
results: (1..=k.min(10) as i64)
|
||||
.map(|i| (10 - i + 1, 12.0 - i as f32))
|
||||
.collect(),
|
||||
latency_ms: 1.2,
|
||||
candidates_evaluated: 200,
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_hybrid_query_builder() {
|
||||
let query = HybridQuery::new("test query".into(), vec![0.1; 128], 10)
|
||||
.with_fusion(FusionMethod::Linear)
|
||||
.with_alpha(0.7)
|
||||
.with_prefetch(200)
|
||||
.with_rrf_k(40);
|
||||
|
||||
assert_eq!(query.k, 10);
|
||||
assert_eq!(query.prefetch_k, 200);
|
||||
assert_eq!(query.fusion_config.method, FusionMethod::Linear);
|
||||
assert!((query.fusion_config.alpha - 0.7).abs() < 0.01);
|
||||
assert_eq!(query.fusion_config.rrf_k, 40);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_executor_execute() {
|
||||
let executor = HybridExecutor::new(mock_corpus_stats());
|
||||
|
||||
let query = HybridQuery::new(
|
||||
"database optimization".into(),
|
||||
vec![0.1; 64],
|
||||
5,
|
||||
);
|
||||
|
||||
let (results, stats) = executor.execute(&query, mock_vector_search, mock_keyword_search);
|
||||
|
||||
assert!(!results.is_empty());
|
||||
assert!(results.len() <= 5);
|
||||
|
||||
// Check stats
|
||||
assert!(stats.total_latency_ms > 0.0);
|
||||
assert!(stats.vector_latency_ms > 0.0);
|
||||
assert!(stats.keyword_latency_ms > 0.0);
|
||||
assert!(stats.result_count <= 5);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_executor_with_different_fusion() {
|
||||
let executor = HybridExecutor::new(mock_corpus_stats());
|
||||
|
||||
// RRF
|
||||
let query_rrf = HybridQuery::new("test".into(), vec![0.1; 32], 5)
|
||||
.with_fusion(FusionMethod::Rrf);
|
||||
let (results_rrf, _) = executor.execute(&query_rrf, mock_vector_search, mock_keyword_search);
|
||||
|
||||
// Linear
|
||||
let query_linear = HybridQuery::new("test".into(), vec![0.1; 32], 5)
|
||||
.with_fusion(FusionMethod::Linear)
|
||||
.with_alpha(0.5);
|
||||
let (results_linear, _) = executor.execute(&query_linear, mock_vector_search, mock_keyword_search);
|
||||
|
||||
assert!(!results_rrf.is_empty());
|
||||
assert!(!results_linear.is_empty());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_strategy_selection() {
|
||||
// No filter
|
||||
assert_eq!(choose_strategy(None, 10000, false), HybridStrategy::Full);
|
||||
|
||||
// Very selective filter
|
||||
assert_eq!(choose_strategy(Some(0.005), 1000000, true), HybridStrategy::PreFilter);
|
||||
|
||||
// Moderate selectivity, large corpus
|
||||
assert_eq!(choose_strategy(Some(0.05), 5000000, true), HybridStrategy::PostFilter);
|
||||
|
||||
// Low selectivity
|
||||
assert_eq!(choose_strategy(Some(0.5), 10000, true), HybridStrategy::Full);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_result_has_ranks() {
|
||||
let executor = HybridExecutor::new(mock_corpus_stats());
|
||||
|
||||
let query = HybridQuery::new("test".into(), vec![0.1; 16], 10);
|
||||
let (results, _) = executor.execute(&query, mock_vector_search, mock_keyword_search);
|
||||
|
||||
// Check that rank information is populated
|
||||
for r in &results {
|
||||
// At least one rank should be present (doc appeared in at least one branch)
|
||||
assert!(r.vector_rank.is_some() || r.keyword_rank.is_some());
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod registry_tests {
|
||||
use crate::hybrid::registry::*;
|
||||
use crate::hybrid::bm25::CorpusStats;
|
||||
use crate::hybrid::fusion::FusionMethod;
|
||||
|
||||
#[test]
|
||||
fn test_registry_lifecycle() {
|
||||
let registry = HybridRegistry::new();
|
||||
|
||||
// Register
|
||||
let config = HybridCollectionConfig::new(
|
||||
1,
|
||||
"test_collection".to_string(),
|
||||
"embedding".to_string(),
|
||||
"fts".to_string(),
|
||||
"content".to_string(),
|
||||
);
|
||||
assert!(registry.register(config).is_ok());
|
||||
|
||||
// Get by ID
|
||||
let retrieved = registry.get(1).unwrap();
|
||||
assert_eq!(retrieved.table_name, "test_collection");
|
||||
|
||||
// Get by name
|
||||
let by_name = registry.get_by_name("public.test_collection").unwrap();
|
||||
assert_eq!(by_name.collection_id, 1);
|
||||
|
||||
// List
|
||||
let list = registry.list();
|
||||
assert_eq!(list.len(), 1);
|
||||
|
||||
// Unregister
|
||||
assert!(registry.unregister(1).is_ok());
|
||||
assert!(registry.get(1).is_none());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_registry_duplicate_prevention() {
|
||||
let registry = HybridRegistry::new();
|
||||
|
||||
let config = HybridCollectionConfig::new(
|
||||
1,
|
||||
"unique_table".to_string(),
|
||||
"vec".to_string(),
|
||||
"fts".to_string(),
|
||||
"text".to_string(),
|
||||
);
|
||||
|
||||
registry.register(config.clone()).unwrap();
|
||||
let result = registry.register(config);
|
||||
|
||||
assert!(matches!(result, Err(RegistryError::AlreadyRegistered(_))));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_registry_stats_update() {
|
||||
let registry = HybridRegistry::new();
|
||||
|
||||
let config = HybridCollectionConfig::new(
|
||||
42,
|
||||
"stats_test".to_string(),
|
||||
"v".to_string(),
|
||||
"f".to_string(),
|
||||
"t".to_string(),
|
||||
);
|
||||
registry.register(config).unwrap();
|
||||
|
||||
let new_stats = CorpusStats {
|
||||
avg_doc_length: 200.0,
|
||||
doc_count: 10000,
|
||||
total_terms: 2000000,
|
||||
last_update: 12345,
|
||||
};
|
||||
|
||||
registry.update_stats(42, new_stats).unwrap();
|
||||
|
||||
let updated = registry.get(42).unwrap();
|
||||
assert!((updated.corpus_stats.avg_doc_length - 200.0).abs() < 0.1);
|
||||
assert_eq!(updated.corpus_stats.doc_count, 10000);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_config_update_apply() {
|
||||
let mut config = HybridCollectionConfig::new(
|
||||
1,
|
||||
"test".to_string(),
|
||||
"v".to_string(),
|
||||
"f".to_string(),
|
||||
"t".to_string(),
|
||||
);
|
||||
|
||||
let update = HybridConfigUpdate {
|
||||
default_fusion: Some("linear".to_string()),
|
||||
default_alpha: Some(0.8),
|
||||
rrf_k: Some(50),
|
||||
prefetch_k: Some(150),
|
||||
bm25_k1: Some(1.4),
|
||||
bm25_b: Some(0.7),
|
||||
stats_refresh_interval: Some("30 minutes".to_string()),
|
||||
parallel_enabled: Some(false),
|
||||
};
|
||||
|
||||
update.apply(&mut config).unwrap();
|
||||
|
||||
assert_eq!(config.fusion_config.method, FusionMethod::Linear);
|
||||
assert!((config.fusion_config.alpha - 0.8).abs() < 0.01);
|
||||
assert_eq!(config.fusion_config.rrf_k, 50);
|
||||
assert_eq!(config.prefetch_k, 150);
|
||||
assert!((config.bm25_config.k1 - 1.4).abs() < 0.01);
|
||||
assert!((config.bm25_config.b - 0.7).abs() < 0.01);
|
||||
assert_eq!(config.stats_refresh_interval, 1800);
|
||||
assert!(!config.parallel_enabled);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_config_update_validation() {
|
||||
let mut config = HybridCollectionConfig::new(
|
||||
1,
|
||||
"test".to_string(),
|
||||
"v".to_string(),
|
||||
"f".to_string(),
|
||||
"t".to_string(),
|
||||
);
|
||||
|
||||
// Invalid alpha
|
||||
let invalid_alpha = HybridConfigUpdate {
|
||||
default_alpha: Some(1.5),
|
||||
..Default::default()
|
||||
};
|
||||
assert!(invalid_alpha.apply(&mut config).is_err());
|
||||
|
||||
// Invalid rrf_k
|
||||
let invalid_rrf = HybridConfigUpdate {
|
||||
rrf_k: Some(0),
|
||||
..Default::default()
|
||||
};
|
||||
assert!(invalid_rrf.apply(&mut config).is_err());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_idf_caching() {
|
||||
let registry = HybridRegistry::new();
|
||||
|
||||
let mut config = HybridCollectionConfig::new(
|
||||
1,
|
||||
"idf_test".to_string(),
|
||||
"v".to_string(),
|
||||
"f".to_string(),
|
||||
"t".to_string(),
|
||||
);
|
||||
config.corpus_stats.doc_count = 1000;
|
||||
registry.register(config).unwrap();
|
||||
|
||||
// Set doc freq
|
||||
registry.set_doc_freq(1, "test_term", 100).unwrap();
|
||||
|
||||
// Get IDF (should compute and cache)
|
||||
let idf1 = registry.get_idf(1, "test_term").unwrap();
|
||||
assert!(idf1 > 0.0);
|
||||
|
||||
// Get again (should use cache)
|
||||
let idf2 = registry.get_idf(1, "test_term").unwrap();
|
||||
assert!((idf1 - idf2).abs() < 0.001);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_needs_refresh() {
|
||||
let mut config = HybridCollectionConfig::new(
|
||||
1,
|
||||
"refresh_test".to_string(),
|
||||
"v".to_string(),
|
||||
"f".to_string(),
|
||||
"t".to_string(),
|
||||
);
|
||||
|
||||
// Set refresh interval to 1 hour
|
||||
config.stats_refresh_interval = 3600;
|
||||
|
||||
// Fresh stats
|
||||
config.corpus_stats.last_update = std::time::SystemTime::now()
|
||||
.duration_since(std::time::UNIX_EPOCH)
|
||||
.unwrap()
|
||||
.as_secs() as i64;
|
||||
|
||||
assert!(!config.needs_stats_refresh());
|
||||
|
||||
// Stale stats (2 hours old)
|
||||
config.corpus_stats.last_update -= 7200;
|
||||
assert!(config.needs_stats_refresh());
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod integration_tests {
|
||||
use crate::hybrid::*;
|
||||
|
||||
#[test]
|
||||
fn test_end_to_end_workflow() {
|
||||
// 1. Setup registry
|
||||
let registry = HybridRegistry::new();
|
||||
|
||||
// 2. Register collection
|
||||
let config = HybridCollectionConfig::new(
|
||||
100,
|
||||
"documents".to_string(),
|
||||
"embedding".to_string(),
|
||||
"fts".to_string(),
|
||||
"content".to_string(),
|
||||
);
|
||||
registry.register(config).unwrap();
|
||||
|
||||
// 3. Update with corpus stats
|
||||
let stats = CorpusStats {
|
||||
avg_doc_length: 250.0,
|
||||
doc_count: 50000,
|
||||
total_terms: 12500000,
|
||||
last_update: std::time::SystemTime::now()
|
||||
.duration_since(std::time::UNIX_EPOCH)
|
||||
.unwrap()
|
||||
.as_secs() as i64,
|
||||
};
|
||||
registry.update_stats(100, stats).unwrap();
|
||||
|
||||
// 4. Configure hybrid settings
|
||||
let config_update = HybridConfigUpdate {
|
||||
default_fusion: Some("rrf".to_string()),
|
||||
rrf_k: Some(60),
|
||||
prefetch_k: Some(200),
|
||||
..Default::default()
|
||||
};
|
||||
|
||||
let mut updated_config = registry.get(100).unwrap();
|
||||
config_update.apply(&mut updated_config).unwrap();
|
||||
registry.update(updated_config.clone()).unwrap();
|
||||
|
||||
// 5. Create executor with updated config
|
||||
let executor = HybridExecutor::new(updated_config.corpus_stats);
|
||||
|
||||
// 6. Execute query
|
||||
let query = HybridQuery::new(
|
||||
"machine learning optimization".to_string(),
|
||||
vec![0.1; 768],
|
||||
10,
|
||||
)
|
||||
.with_fusion(FusionMethod::Rrf)
|
||||
.with_prefetch(200);
|
||||
|
||||
let mock_vector = |_: &[f32], k: usize| BranchResults {
|
||||
results: (1..=k.min(20) as i64).map(|i| (i, i as f32 * 0.02)).collect(),
|
||||
latency_ms: 3.0,
|
||||
candidates_evaluated: 1000,
|
||||
};
|
||||
|
||||
let mock_keyword = |_: &str, k: usize| BranchResults {
|
||||
results: (1..=k.min(20) as i64).map(|i| (25 - i, 15.0 - i as f32 * 0.5)).collect(),
|
||||
latency_ms: 1.5,
|
||||
candidates_evaluated: 500,
|
||||
};
|
||||
|
||||
let (results, stats) = executor.execute(&query, mock_vector, mock_keyword);
|
||||
|
||||
// 7. Verify results
|
||||
assert!(!results.is_empty());
|
||||
assert!(results.len() <= 10);
|
||||
assert!(stats.total_latency_ms > 0.0);
|
||||
|
||||
// Top result should have high hybrid score
|
||||
assert!(results[0].hybrid_score > 0.0);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_bm25_scorer_integration() {
|
||||
let corpus_stats = CorpusStats {
|
||||
avg_doc_length: 100.0,
|
||||
doc_count: 1000,
|
||||
total_terms: 100000,
|
||||
last_update: 0,
|
||||
};
|
||||
|
||||
let scorer = BM25Scorer::new(corpus_stats);
|
||||
|
||||
// Set up document frequencies
|
||||
scorer.set_doc_freq("machine", 200);
|
||||
scorer.set_doc_freq("learning", 150);
|
||||
scorer.set_doc_freq("deep", 50);
|
||||
|
||||
// Create test document
|
||||
let mut doc_freqs = std::collections::HashMap::new();
|
||||
doc_freqs.insert("machine".to_string(), 3);
|
||||
doc_freqs.insert("learning".to_string(), 2);
|
||||
doc_freqs.insert("deep".to_string(), 1);
|
||||
doc_freqs.insert("neural".to_string(), 2);
|
||||
doc_freqs.insert("networks".to_string(), 2);
|
||||
|
||||
let term_freqs = TermFrequencies::new(doc_freqs);
|
||||
let doc = Document::new(&term_freqs);
|
||||
|
||||
// Score with query
|
||||
let query_terms = vec![
|
||||
"machine".to_string(),
|
||||
"learning".to_string(),
|
||||
"deep".to_string(),
|
||||
];
|
||||
|
||||
let score = scorer.score(&doc, &query_terms);
|
||||
assert!(score > 0.0);
|
||||
|
||||
// "deep" is rarer, so query with just "deep" should have higher IDF contribution
|
||||
let deep_idf = scorer.idf("deep");
|
||||
let machine_idf = scorer.idf("machine");
|
||||
assert!(deep_idf > machine_idf, "Rare term should have higher IDF");
|
||||
}
|
||||
}
|
||||
Reference in New Issue
Block a user