Squashed 'vendor/ruvector/' content from commit b64c2172

git-subtree-dir: vendor/ruvector
git-subtree-split: b64c21726f2bb37286d9ee36a7869fef60cc6900
This commit is contained in:
ruv
2026-02-28 14:39:40 -05:00
commit d803bfe2b1
7854 changed files with 3522914 additions and 0 deletions

View File

@@ -0,0 +1,545 @@
//! # Hypergraph Support for N-ary Relationships
//!
//! Implements hypergraph structures for representing complex multi-entity relationships
//! beyond traditional pairwise similarity. Based on HyperGraphRAG (NeurIPS 2025) architecture.
use crate::error::{Result, RuvectorError};
use crate::types::{DistanceMetric, VectorId};
use serde::{Deserialize, Serialize};
use std::collections::{HashMap, HashSet};
use std::time::{SystemTime, UNIX_EPOCH};
use uuid::Uuid;
/// Hyperedge connecting multiple vectors with description and embedding
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct Hyperedge {
/// Unique identifier for the hyperedge
pub id: String,
/// Vector IDs connected by this hyperedge
pub nodes: Vec<VectorId>,
/// Natural language description of the relationship
pub description: String,
/// Embedding of the hyperedge description
pub embedding: Vec<f32>,
/// Confidence weight (0.0-1.0)
pub confidence: f32,
/// Optional metadata
pub metadata: HashMap<String, String>,
}
/// Temporal hyperedge with time attributes
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct TemporalHyperedge {
/// Base hyperedge
pub hyperedge: Hyperedge,
/// Creation timestamp (Unix epoch seconds)
pub timestamp: u64,
/// Optional expiration timestamp
pub expires_at: Option<u64>,
/// Temporal context (hourly, daily, monthly)
pub granularity: TemporalGranularity,
}
/// Temporal granularity for indexing
#[derive(Debug, Clone, Copy, Serialize, Deserialize)]
pub enum TemporalGranularity {
Hourly,
Daily,
Monthly,
Yearly,
}
impl Hyperedge {
/// Create a new hyperedge
pub fn new(
nodes: Vec<VectorId>,
description: String,
embedding: Vec<f32>,
confidence: f32,
) -> Self {
Self {
id: Uuid::new_v4().to_string(),
nodes,
description,
embedding,
confidence: confidence.clamp(0.0, 1.0),
metadata: HashMap::new(),
}
}
/// Get hyperedge order (number of nodes)
pub fn order(&self) -> usize {
self.nodes.len()
}
/// Check if hyperedge contains a specific node
pub fn contains_node(&self, node: &VectorId) -> bool {
self.nodes.contains(node)
}
}
impl TemporalHyperedge {
/// Create a new temporal hyperedge with current timestamp
pub fn new(hyperedge: Hyperedge, granularity: TemporalGranularity) -> Self {
let timestamp = SystemTime::now()
.duration_since(UNIX_EPOCH)
.unwrap()
.as_secs();
Self {
hyperedge,
timestamp,
expires_at: None,
granularity,
}
}
/// Check if hyperedge is expired
pub fn is_expired(&self) -> bool {
if let Some(expires_at) = self.expires_at {
let now = SystemTime::now()
.duration_since(UNIX_EPOCH)
.unwrap()
.as_secs();
now > expires_at
} else {
false
}
}
/// Get time bucket for indexing
pub fn time_bucket(&self) -> u64 {
match self.granularity {
TemporalGranularity::Hourly => self.timestamp / 3600,
TemporalGranularity::Daily => self.timestamp / 86400,
TemporalGranularity::Monthly => self.timestamp / (86400 * 30),
TemporalGranularity::Yearly => self.timestamp / (86400 * 365),
}
}
}
/// Hypergraph index with bipartite graph storage
pub struct HypergraphIndex {
/// Entity nodes
entities: HashMap<VectorId, Vec<f32>>,
/// Hyperedges
hyperedges: HashMap<String, Hyperedge>,
/// Temporal hyperedges indexed by time bucket
temporal_index: HashMap<u64, Vec<String>>,
/// Bipartite graph: entity -> hyperedge IDs
entity_to_hyperedges: HashMap<VectorId, HashSet<String>>,
/// Bipartite graph: hyperedge -> entity IDs
hyperedge_to_entities: HashMap<String, HashSet<VectorId>>,
/// Distance metric for embeddings
distance_metric: DistanceMetric,
}
impl HypergraphIndex {
/// Create a new hypergraph index
pub fn new(distance_metric: DistanceMetric) -> Self {
Self {
entities: HashMap::new(),
hyperedges: HashMap::new(),
temporal_index: HashMap::new(),
entity_to_hyperedges: HashMap::new(),
hyperedge_to_entities: HashMap::new(),
distance_metric,
}
}
/// Add an entity node
pub fn add_entity(&mut self, id: VectorId, embedding: Vec<f32>) {
self.entities.insert(id.clone(), embedding);
self.entity_to_hyperedges.entry(id).or_default();
}
/// Add a hyperedge
pub fn add_hyperedge(&mut self, hyperedge: Hyperedge) -> Result<()> {
let edge_id = hyperedge.id.clone();
// Verify all nodes exist
for node in &hyperedge.nodes {
if !self.entities.contains_key(node) {
return Err(RuvectorError::InvalidInput(format!(
"Entity {} not found in hypergraph",
node
)));
}
}
// Update bipartite graph
for node in &hyperedge.nodes {
self.entity_to_hyperedges
.entry(node.clone())
.or_default()
.insert(edge_id.clone());
}
let nodes_set: HashSet<VectorId> = hyperedge.nodes.iter().cloned().collect();
self.hyperedge_to_entities
.insert(edge_id.clone(), nodes_set);
self.hyperedges.insert(edge_id, hyperedge);
Ok(())
}
/// Add a temporal hyperedge
pub fn add_temporal_hyperedge(&mut self, temporal_edge: TemporalHyperedge) -> Result<()> {
let bucket = temporal_edge.time_bucket();
let edge_id = temporal_edge.hyperedge.id.clone();
self.add_hyperedge(temporal_edge.hyperedge)?;
self.temporal_index.entry(bucket).or_default().push(edge_id);
Ok(())
}
/// Search hyperedges by embedding similarity
pub fn search_hyperedges(&self, query_embedding: &[f32], k: usize) -> Vec<(String, f32)> {
let mut results: Vec<(String, f32)> = self
.hyperedges
.iter()
.map(|(id, edge)| {
let distance = self.compute_distance(query_embedding, &edge.embedding);
(id.clone(), distance)
})
.collect();
results.sort_by(|a, b| a.1.partial_cmp(&b.1).unwrap());
results.truncate(k);
results
}
/// Get k-hop neighbors in hypergraph
/// Returns all nodes reachable within k hops from the start node
pub fn k_hop_neighbors(&self, start_node: VectorId, k: usize) -> HashSet<VectorId> {
let mut visited = HashSet::new();
let mut current_layer = HashSet::new();
current_layer.insert(start_node.clone());
visited.insert(start_node); // Start node is at distance 0
for _hop in 0..k {
let mut next_layer = HashSet::new();
for node in current_layer.iter() {
// Get all hyperedges containing this node
if let Some(hyperedges) = self.entity_to_hyperedges.get(node) {
for edge_id in hyperedges {
// Get all nodes in this hyperedge
if let Some(nodes) = self.hyperedge_to_entities.get(edge_id) {
for neighbor in nodes.iter() {
if !visited.contains(neighbor) {
visited.insert(neighbor.clone());
next_layer.insert(neighbor.clone());
}
}
}
}
}
}
if next_layer.is_empty() {
break;
}
current_layer = next_layer;
}
visited
}
/// Query temporal hyperedges in a time range
pub fn query_temporal_range(&self, start_bucket: u64, end_bucket: u64) -> Vec<String> {
let mut results = Vec::new();
for bucket in start_bucket..=end_bucket {
if let Some(edges) = self.temporal_index.get(&bucket) {
results.extend(edges.iter().cloned());
}
}
results
}
/// Get hyperedge by ID
pub fn get_hyperedge(&self, id: &str) -> Option<&Hyperedge> {
self.hyperedges.get(id)
}
/// Get statistics
pub fn stats(&self) -> HypergraphStats {
let total_edges = self.hyperedges.len();
let total_entities = self.entities.len();
let avg_degree = if total_entities > 0 {
self.entity_to_hyperedges
.values()
.map(|edges| edges.len())
.sum::<usize>() as f32
/ total_entities as f32
} else {
0.0
};
HypergraphStats {
total_entities,
total_hyperedges: total_edges,
avg_entity_degree: avg_degree,
}
}
fn compute_distance(&self, a: &[f32], b: &[f32]) -> f32 {
crate::distance::distance(a, b, self.distance_metric).unwrap_or(f32::MAX)
}
}
/// Hypergraph statistics
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct HypergraphStats {
pub total_entities: usize,
pub total_hyperedges: usize,
pub avg_entity_degree: f32,
}
/// Causal hypergraph memory for agent reasoning
pub struct CausalMemory {
/// Hypergraph index
index: HypergraphIndex,
/// Causal relationship tracking: (cause_id, effect_id) -> success_count
causal_counts: HashMap<(VectorId, VectorId), u32>,
/// Action latencies: action_id -> avg_latency_ms
latencies: HashMap<VectorId, f32>,
/// Utility function weights
alpha: f32, // similarity weight
beta: f32, // causal uplift weight
gamma: f32, // latency penalty weight
}
impl CausalMemory {
/// Create a new causal memory with default utility weights
pub fn new(distance_metric: DistanceMetric) -> Self {
Self {
index: HypergraphIndex::new(distance_metric),
causal_counts: HashMap::new(),
latencies: HashMap::new(),
alpha: 0.7,
beta: 0.2,
gamma: 0.1,
}
}
/// Set custom utility function weights
pub fn with_weights(mut self, alpha: f32, beta: f32, gamma: f32) -> Self {
self.alpha = alpha;
self.beta = beta;
self.gamma = gamma;
self
}
/// Add a causal relationship
pub fn add_causal_edge(
&mut self,
cause: VectorId,
effect: VectorId,
context: Vec<VectorId>,
description: String,
embedding: Vec<f32>,
latency_ms: f32,
) -> Result<()> {
// Create hyperedge connecting cause, effect, and context
let mut nodes = vec![cause.clone(), effect.clone()];
nodes.extend(context);
let hyperedge = Hyperedge::new(nodes, description, embedding, 1.0);
self.index.add_hyperedge(hyperedge)?;
// Update causal counts
*self
.causal_counts
.entry((cause.clone(), effect.clone()))
.or_insert(0) += 1;
// Update latency
let entry = self.latencies.entry(cause).or_insert(0.0);
*entry = (*entry + latency_ms) / 2.0; // Running average
Ok(())
}
/// Query with utility function: U = α·similarity + β·causal_uplift - γ·latency
pub fn query_with_utility(
&self,
query_embedding: &[f32],
action_id: VectorId,
k: usize,
) -> Vec<(String, f32)> {
let mut results: Vec<(String, f32)> = self
.index
.hyperedges
.iter()
.filter(|(_, edge)| edge.contains_node(&action_id))
.map(|(id, edge)| {
let similarity = 1.0
- self
.index
.compute_distance(query_embedding, &edge.embedding);
let causal_uplift = self.compute_causal_uplift(&edge.nodes);
let latency = self.latencies.get(&action_id).copied().unwrap_or(0.0);
let utility = self.alpha * similarity + self.beta * causal_uplift
- self.gamma * (latency / 1000.0); // Normalize latency to 0-1 range
(id.clone(), utility)
})
.collect();
results.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap()); // Sort by utility descending
results.truncate(k);
results
}
fn compute_causal_uplift(&self, nodes: &[VectorId]) -> f32 {
if nodes.len() < 2 {
return 0.0;
}
// Compute average causal strength for pairs in this hyperedge
let mut total_uplift = 0.0;
let mut count = 0;
for i in 0..nodes.len() - 1 {
for j in i + 1..nodes.len() {
if let Some(&success_count) = self
.causal_counts
.get(&(nodes[i].clone(), nodes[j].clone()))
{
total_uplift += (success_count as f32).ln_1p(); // Log scale
count += 1;
}
}
}
if count > 0 {
total_uplift / count as f32
} else {
0.0
}
}
/// Get hypergraph index
pub fn index(&self) -> &HypergraphIndex {
&self.index
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_hyperedge_creation() {
let nodes = vec!["1".to_string(), "2".to_string(), "3".to_string()];
let desc = "Test relationship".to_string();
let embedding = vec![0.1, 0.2, 0.3];
let edge = Hyperedge::new(nodes, desc, embedding, 0.95);
assert_eq!(edge.order(), 3);
assert!(edge.contains_node(&"1".to_string()));
assert!(!edge.contains_node(&"4".to_string()));
assert_eq!(edge.confidence, 0.95);
}
#[test]
fn test_temporal_hyperedge() {
let nodes = vec!["1".to_string(), "2".to_string()];
let desc = "Temporal relationship".to_string();
let embedding = vec![0.1, 0.2];
let edge = Hyperedge::new(nodes, desc, embedding, 1.0);
let temporal = TemporalHyperedge::new(edge, TemporalGranularity::Hourly);
assert!(!temporal.is_expired());
assert!(temporal.time_bucket() > 0);
}
#[test]
fn test_hypergraph_index() {
let mut index = HypergraphIndex::new(DistanceMetric::Cosine);
// Add entities
index.add_entity("1".to_string(), vec![1.0, 0.0, 0.0]);
index.add_entity("2".to_string(), vec![0.0, 1.0, 0.0]);
index.add_entity("3".to_string(), vec![0.0, 0.0, 1.0]);
// Add hyperedge
let edge = Hyperedge::new(
vec!["1".to_string(), "2".to_string(), "3".to_string()],
"Triple relationship".to_string(),
vec![0.5, 0.5, 0.5],
0.9,
);
index.add_hyperedge(edge).unwrap();
let stats = index.stats();
assert_eq!(stats.total_entities, 3);
assert_eq!(stats.total_hyperedges, 1);
}
#[test]
fn test_k_hop_neighbors() {
let mut index = HypergraphIndex::new(DistanceMetric::Cosine);
// Create a small hypergraph
index.add_entity("1".to_string(), vec![1.0]);
index.add_entity("2".to_string(), vec![1.0]);
index.add_entity("3".to_string(), vec![1.0]);
index.add_entity("4".to_string(), vec![1.0]);
let edge1 = Hyperedge::new(
vec!["1".to_string(), "2".to_string()],
"e1".to_string(),
vec![1.0],
1.0,
);
let edge2 = Hyperedge::new(
vec!["2".to_string(), "3".to_string()],
"e2".to_string(),
vec![1.0],
1.0,
);
let edge3 = Hyperedge::new(
vec!["3".to_string(), "4".to_string()],
"e3".to_string(),
vec![1.0],
1.0,
);
index.add_hyperedge(edge1).unwrap();
index.add_hyperedge(edge2).unwrap();
index.add_hyperedge(edge3).unwrap();
let neighbors = index.k_hop_neighbors("1".to_string(), 2);
assert!(neighbors.contains(&"1".to_string()));
assert!(neighbors.contains(&"2".to_string()));
assert!(neighbors.contains(&"3".to_string()));
}
#[test]
fn test_causal_memory() {
let mut memory = CausalMemory::new(DistanceMetric::Cosine);
memory.index.add_entity("1".to_string(), vec![1.0, 0.0]);
memory.index.add_entity("2".to_string(), vec![0.0, 1.0]);
memory
.add_causal_edge(
"1".to_string(),
"2".to_string(),
vec![],
"Action 1 causes effect 2".to_string(),
vec![0.5, 0.5],
100.0,
)
.unwrap();
let results = memory.query_with_utility(&[0.6, 0.4], "1".to_string(), 5);
assert!(!results.is_empty());
}
}

View File

@@ -0,0 +1,441 @@
//! # Learned Index Structures
//!
//! Experimental learned indexes using neural networks to approximate data distribution.
//! Based on Recursive Model Index (RMI) concept with bounded error correction.
use crate::error::{Result, RuvectorError};
use crate::types::VectorId;
use serde::{Deserialize, Serialize};
use std::collections::HashMap;
/// Trait for learned index structures
pub trait LearnedIndex {
/// Predict position for a key
fn predict(&self, key: &[f32]) -> Result<usize>;
/// Insert a key-value pair
fn insert(&mut self, key: Vec<f32>, value: VectorId) -> Result<()>;
/// Search for a key
fn search(&self, key: &[f32]) -> Result<Option<VectorId>>;
/// Get index statistics
fn stats(&self) -> IndexStats;
}
/// Statistics for learned indexes
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct IndexStats {
pub total_entries: usize,
pub model_size_bytes: usize,
pub avg_error: f32,
pub max_error: usize,
}
/// Simple linear model for CDF approximation
#[derive(Debug, Clone, Serialize, Deserialize)]
struct LinearModel {
weights: Vec<f32>,
bias: f32,
}
impl LinearModel {
fn new(dimensions: usize) -> Self {
Self {
weights: vec![0.0; dimensions],
bias: 0.0,
}
}
fn predict(&self, input: &[f32]) -> f32 {
let mut result = self.bias;
for (w, x) in self.weights.iter().zip(input.iter()) {
result += w * x;
}
result.max(0.0)
}
fn train_simple(&mut self, data: &[(Vec<f32>, usize)]) {
if data.is_empty() {
return;
}
// Simple least squares approximation
let n = data.len() as f32;
let dim = self.weights.len();
// Reset weights
self.weights.fill(0.0);
self.bias = 0.0;
// Compute means
let mut mean_x = vec![0.0; dim];
let mut mean_y = 0.0;
for (x, y) in data {
for (i, &val) in x.iter().enumerate() {
mean_x[i] += val;
}
mean_y += *y as f32;
}
for val in mean_x.iter_mut() {
*val /= n;
}
mean_y /= n;
// Simple linear regression for first dimension
if dim > 0 {
let mut numerator = 0.0;
let mut denominator = 0.0;
for (x, y) in data {
let x_diff = x[0] - mean_x[0];
let y_diff = *y as f32 - mean_y;
numerator += x_diff * y_diff;
denominator += x_diff * x_diff;
}
if denominator.abs() > 1e-10 {
self.weights[0] = numerator / denominator;
}
self.bias = mean_y - self.weights[0] * mean_x[0];
}
}
}
/// Recursive Model Index (RMI)
/// Multi-stage neural models making coarse-then-fine predictions
pub struct RecursiveModelIndex {
/// Root model for coarse prediction
root_model: LinearModel,
/// Second-level models for fine prediction
leaf_models: Vec<LinearModel>,
/// Sorted data with error correction
data: Vec<(Vec<f32>, VectorId)>,
/// Error bounds for binary search fallback
max_error: usize,
/// Dimensions of vectors
dimensions: usize,
}
impl RecursiveModelIndex {
/// Create a new RMI with specified number of leaf models
pub fn new(dimensions: usize, num_leaf_models: usize) -> Self {
let leaf_models = (0..num_leaf_models)
.map(|_| LinearModel::new(dimensions))
.collect();
Self {
root_model: LinearModel::new(dimensions),
leaf_models,
data: Vec::new(),
max_error: 100,
dimensions,
}
}
/// Build the index from data
pub fn build(&mut self, mut data: Vec<(Vec<f32>, VectorId)>) -> Result<()> {
if data.is_empty() {
return Err(RuvectorError::InvalidInput(
"Cannot build index from empty data".into(),
));
}
if data[0].0.is_empty() {
return Err(RuvectorError::InvalidInput(
"Cannot build index from vectors with zero dimensions".into(),
));
}
if self.leaf_models.is_empty() {
return Err(RuvectorError::InvalidInput(
"Cannot build index with zero leaf models".into(),
));
}
// Sort data by first dimension (simple heuristic)
data.sort_by(|a, b| {
a.0[0]
.partial_cmp(&b.0[0])
.unwrap_or(std::cmp::Ordering::Equal)
});
let n = data.len();
// Train root model to predict leaf model index
let root_training_data: Vec<(Vec<f32>, usize)> = data
.iter()
.enumerate()
.map(|(i, (key, _))| {
let leaf_idx = (i * self.leaf_models.len()) / n;
(key.clone(), leaf_idx)
})
.collect();
self.root_model.train_simple(&root_training_data);
// Train each leaf model
let num_leaf_models = self.leaf_models.len();
let chunk_size = n / num_leaf_models;
for (i, model) in self.leaf_models.iter_mut().enumerate() {
let start = i * chunk_size;
let end = if i == num_leaf_models - 1 {
n
} else {
(i + 1) * chunk_size
};
if start < n {
let leaf_data: Vec<(Vec<f32>, usize)> = data[start..end.min(n)]
.iter()
.enumerate()
.map(|(j, (key, _))| (key.clone(), start + j))
.collect();
model.train_simple(&leaf_data);
}
}
self.data = data;
Ok(())
}
}
impl LearnedIndex for RecursiveModelIndex {
fn predict(&self, key: &[f32]) -> Result<usize> {
if key.len() != self.dimensions {
return Err(RuvectorError::InvalidInput(
"Key dimensions mismatch".into(),
));
}
if self.leaf_models.is_empty() {
return Err(RuvectorError::InvalidInput(
"Index not built: no leaf models available".into(),
));
}
if self.data.is_empty() {
return Err(RuvectorError::InvalidInput(
"Index not built: no data available".into(),
));
}
// Root model predicts leaf model
let leaf_idx = self.root_model.predict(key) as usize;
let leaf_idx = leaf_idx.min(self.leaf_models.len() - 1);
// Leaf model predicts position
let pos = self.leaf_models[leaf_idx].predict(key) as usize;
let pos = pos.min(self.data.len().saturating_sub(1));
Ok(pos)
}
fn insert(&mut self, key: Vec<f32>, value: VectorId) -> Result<()> {
// For simplicity, append and mark for rebuild
// Production implementation would use incremental updates
self.data.push((key, value));
Ok(())
}
fn search(&self, key: &[f32]) -> Result<Option<VectorId>> {
if self.data.is_empty() {
return Ok(None);
}
let predicted_pos = self.predict(key)?;
// Binary search around predicted position with error bound
let start = predicted_pos.saturating_sub(self.max_error);
let end = (predicted_pos + self.max_error).min(self.data.len());
for i in start..end {
if self.data[i].0 == key {
return Ok(Some(self.data[i].1.clone()));
}
}
Ok(None)
}
fn stats(&self) -> IndexStats {
let model_size = std::mem::size_of_val(&self.root_model)
+ self.leaf_models.len() * std::mem::size_of::<LinearModel>();
// Compute average prediction error
let mut total_error = 0.0;
let mut max_error = 0;
for (i, (key, _)) in self.data.iter().enumerate() {
if let Ok(pred_pos) = self.predict(key) {
let error = i.abs_diff(pred_pos);
total_error += error as f32;
max_error = max_error.max(error);
}
}
let avg_error = if !self.data.is_empty() {
total_error / self.data.len() as f32
} else {
0.0
};
IndexStats {
total_entries: self.data.len(),
model_size_bytes: model_size,
avg_error,
max_error,
}
}
}
/// Hybrid index combining learned index for static data with HNSW for dynamic updates
pub struct HybridIndex {
/// Learned index for static segment
learned: RecursiveModelIndex,
/// Dynamic updates buffer
dynamic_buffer: HashMap<Vec<u8>, VectorId>,
/// Threshold for rebuilding learned index
rebuild_threshold: usize,
}
impl HybridIndex {
/// Create a new hybrid index
pub fn new(dimensions: usize, num_leaf_models: usize, rebuild_threshold: usize) -> Self {
Self {
learned: RecursiveModelIndex::new(dimensions, num_leaf_models),
dynamic_buffer: HashMap::new(),
rebuild_threshold,
}
}
/// Build the learned portion from static data
pub fn build_static(&mut self, data: Vec<(Vec<f32>, VectorId)>) -> Result<()> {
self.learned.build(data)
}
/// Check if rebuild is needed
pub fn needs_rebuild(&self) -> bool {
self.dynamic_buffer.len() >= self.rebuild_threshold
}
/// Rebuild learned index incorporating dynamic updates
pub fn rebuild(&mut self) -> Result<()> {
let mut all_data: Vec<(Vec<f32>, VectorId)> = self.learned.data.clone();
for (key_bytes, value) in &self.dynamic_buffer {
let (key, _): (Vec<f32>, usize) =
bincode::decode_from_slice(key_bytes, bincode::config::standard())
.map_err(|e| RuvectorError::SerializationError(e.to_string()))?;
all_data.push((key, value.clone()));
}
self.learned.build(all_data)?;
self.dynamic_buffer.clear();
Ok(())
}
fn serialize_key(key: &[f32]) -> Vec<u8> {
bincode::encode_to_vec(key, bincode::config::standard()).unwrap_or_default()
}
}
impl LearnedIndex for HybridIndex {
fn predict(&self, key: &[f32]) -> Result<usize> {
self.learned.predict(key)
}
fn insert(&mut self, key: Vec<f32>, value: VectorId) -> Result<()> {
let key_bytes = Self::serialize_key(&key);
self.dynamic_buffer.insert(key_bytes, value);
Ok(())
}
fn search(&self, key: &[f32]) -> Result<Option<VectorId>> {
// Check dynamic buffer first
let key_bytes = Self::serialize_key(key);
if let Some(value) = self.dynamic_buffer.get(&key_bytes) {
return Ok(Some(value.clone()));
}
// Fall back to learned index
self.learned.search(key)
}
fn stats(&self) -> IndexStats {
let mut stats = self.learned.stats();
stats.total_entries += self.dynamic_buffer.len();
stats
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_linear_model() {
let mut model = LinearModel::new(2);
let data = vec![
(vec![0.0, 0.0], 0),
(vec![1.0, 1.0], 10),
(vec![2.0, 2.0], 20),
];
model.train_simple(&data);
let pred = model.predict(&[1.5, 1.5]);
assert!(pred >= 0.0 && pred <= 30.0);
}
#[test]
fn test_rmi_build() {
let mut rmi = RecursiveModelIndex::new(2, 4);
let data: Vec<(Vec<f32>, VectorId)> = (0..100)
.map(|i| {
let x = i as f32 / 100.0;
(vec![x, x * x], i.to_string())
})
.collect();
rmi.build(data).unwrap();
let stats = rmi.stats();
assert_eq!(stats.total_entries, 100);
assert!(stats.avg_error < 50.0); // Should have reasonable error
}
#[test]
fn test_rmi_search() {
let mut rmi = RecursiveModelIndex::new(1, 2);
let data = vec![
(vec![0.0], "0".to_string()),
(vec![0.5], "1".to_string()),
(vec![1.0], "2".to_string()),
];
rmi.build(data).unwrap();
let result = rmi.search(&[0.5]).unwrap();
assert_eq!(result, Some("1".to_string()));
}
#[test]
fn test_hybrid_index() {
let mut hybrid = HybridIndex::new(1, 2, 10);
let static_data = vec![(vec![0.0], "0".to_string()), (vec![1.0], "1".to_string())];
hybrid.build_static(static_data).unwrap();
// Add dynamic updates
hybrid.insert(vec![2.0], "2".to_string()).unwrap();
assert_eq!(hybrid.search(&[2.0]).unwrap(), Some("2".to_string()));
assert_eq!(hybrid.search(&[0.0]).unwrap(), Some("0".to_string()));
}
}

View File

@@ -0,0 +1,17 @@
//! # Advanced Techniques
//!
//! This module contains experimental and advanced features for next-generation vector search:
//! - **Hypergraphs**: n-ary relationships beyond pairwise similarity
//! - **Learned Indexes**: Neural network-based index structures
//! - **Neural Hashing**: Similarity-preserving binary projections
//! - **Topological Data Analysis**: Embedding quality assessment
pub mod hypergraph;
pub mod learned_index;
pub mod neural_hash;
pub mod tda;
pub use hypergraph::{CausalMemory, Hyperedge, HypergraphIndex, TemporalHyperedge};
pub use learned_index::{HybridIndex, LearnedIndex, RecursiveModelIndex};
pub use neural_hash::{DeepHashEmbedding, NeuralHash};
pub use tda::{EmbeddingQuality, TopologicalAnalyzer};

View File

@@ -0,0 +1,427 @@
//! # Neural Hash Functions
//!
//! Learn similarity-preserving binary projections for extreme compression.
//! Achieves 32-128x compression with 90-95% recall preservation.
use crate::types::VectorId;
use ndarray::{Array1, Array2};
use rand::Rng;
use serde::{Deserialize, Serialize};
use std::collections::HashMap;
/// Neural hash function for similarity-preserving binary codes
pub trait NeuralHash {
/// Encode a vector to binary code
fn encode(&self, vector: &[f32]) -> Vec<u8>;
/// Compute Hamming distance between two codes
fn hamming_distance(&self, code_a: &[u8], code_b: &[u8]) -> u32;
/// Estimate similarity from Hamming distance
fn estimate_similarity(&self, hamming_dist: u32, code_bits: usize) -> f32;
}
/// Deep hash embedding with learned projections
#[derive(Clone, Serialize, Deserialize)]
pub struct DeepHashEmbedding {
/// Projection matrices for each layer
projections: Vec<Array2<f32>>,
/// Biases for each layer
biases: Vec<Array1<f32>>,
/// Number of output bits
output_bits: usize,
/// Input dimensions
input_dims: usize,
}
impl DeepHashEmbedding {
/// Create a new deep hash embedding
pub fn new(input_dims: usize, hidden_dims: Vec<usize>, output_bits: usize) -> Self {
let mut rng = rand::thread_rng();
let mut projections = Vec::new();
let mut biases = Vec::new();
let mut layer_dims = vec![input_dims];
layer_dims.extend(&hidden_dims);
layer_dims.push(output_bits);
// Initialize random projections (Xavier initialization)
for i in 0..layer_dims.len() - 1 {
let in_dim = layer_dims[i];
let out_dim = layer_dims[i + 1];
let scale = (2.0 / (in_dim + out_dim) as f32).sqrt();
let proj = Array2::from_shape_fn((out_dim, in_dim), |_| {
rng.gen::<f32>() * 2.0 * scale - scale
});
let bias = Array1::zeros(out_dim);
projections.push(proj);
biases.push(bias);
}
Self {
projections,
biases,
output_bits,
input_dims,
}
}
/// Forward pass through the network
fn forward(&self, input: &[f32]) -> Vec<f32> {
let mut activations = Array1::from_vec(input.to_vec());
for (proj, bias) in self.projections.iter().zip(self.biases.iter()) {
// Linear layer: y = Wx + b
activations = proj.dot(&activations) + bias;
// ReLU activation (except last layer)
if proj.nrows() != self.output_bits {
activations.mapv_inplace(|x| x.max(0.0));
}
}
activations.to_vec()
}
/// Train on pairs of similar/dissimilar examples
pub fn train(
&mut self,
positive_pairs: &[(Vec<f32>, Vec<f32>)],
negative_pairs: &[(Vec<f32>, Vec<f32>)],
learning_rate: f32,
epochs: usize,
) {
// Simplified training with contrastive loss
// Production would use proper backpropagation
for _ in 0..epochs {
// Positive pairs should have small Hamming distance
for (a, b) in positive_pairs {
let code_a = self.encode(a);
let code_b = self.encode(b);
let dist = self.hamming_distance(&code_a, &code_b);
// If distance is too large, update towards similarity
if dist as f32 > self.output_bits as f32 * 0.3 {
self.update_weights(a, b, learning_rate, true);
}
}
// Negative pairs should have large Hamming distance
for (a, b) in negative_pairs {
let code_a = self.encode(a);
let code_b = self.encode(b);
let dist = self.hamming_distance(&code_a, &code_b);
// If distance is too small, update towards dissimilarity
if (dist as f32) < self.output_bits as f32 * 0.6 {
self.update_weights(a, b, learning_rate, false);
}
}
}
}
fn update_weights(&mut self, a: &[f32], b: &[f32], lr: f32, attract: bool) {
// Simplified gradient update (production would use proper autodiff)
let direction = if attract { 1.0 } else { -1.0 };
// Update only the last layer for simplicity
if let Some(last_proj) = self.projections.last_mut() {
let a_arr = Array1::from_vec(a.to_vec());
let b_arr = Array1::from_vec(b.to_vec());
for i in 0..last_proj.nrows() {
for j in 0..last_proj.ncols() {
let grad = direction * lr * (a_arr[j] - b_arr[j]);
last_proj[[i, j]] += grad * 0.001; // Small update
}
}
}
}
/// Get dimensions
pub fn dimensions(&self) -> (usize, usize) {
(self.input_dims, self.output_bits)
}
}
impl NeuralHash for DeepHashEmbedding {
fn encode(&self, vector: &[f32]) -> Vec<u8> {
if vector.len() != self.input_dims {
return vec![0; self.output_bits.div_ceil(8)];
}
let logits = self.forward(vector);
// Threshold at 0 to get binary codes
let mut bits = vec![0u8; self.output_bits.div_ceil(8)];
for (i, &logit) in logits.iter().enumerate() {
if logit > 0.0 {
let byte_idx = i / 8;
let bit_idx = i % 8;
bits[byte_idx] |= 1 << bit_idx;
}
}
bits
}
fn hamming_distance(&self, code_a: &[u8], code_b: &[u8]) -> u32 {
code_a
.iter()
.zip(code_b.iter())
.map(|(a, b)| (a ^ b).count_ones())
.sum()
}
fn estimate_similarity(&self, hamming_dist: u32, code_bits: usize) -> f32 {
// Convert Hamming distance to approximate cosine similarity
let normalized_dist = hamming_dist as f32 / code_bits as f32;
1.0 - 2.0 * normalized_dist
}
}
/// Simple LSH (Locality Sensitive Hashing) baseline
#[derive(Clone, Serialize, Deserialize)]
pub struct SimpleLSH {
/// Random projection vectors
projections: Array2<f32>,
/// Number of hash bits
num_bits: usize,
}
impl SimpleLSH {
/// Create a new LSH with random projections
pub fn new(input_dims: usize, num_bits: usize) -> Self {
let mut rng = rand::thread_rng();
// Random Gaussian projections
let projections =
Array2::from_shape_fn((num_bits, input_dims), |_| rng.gen::<f32>() * 2.0 - 1.0);
Self {
projections,
num_bits,
}
}
}
impl NeuralHash for SimpleLSH {
fn encode(&self, vector: &[f32]) -> Vec<u8> {
let input = Array1::from_vec(vector.to_vec());
let projections = self.projections.dot(&input);
let mut bits = vec![0u8; self.num_bits.div_ceil(8)];
for (i, &val) in projections.iter().enumerate() {
if val > 0.0 {
let byte_idx = i / 8;
let bit_idx = i % 8;
bits[byte_idx] |= 1 << bit_idx;
}
}
bits
}
fn hamming_distance(&self, code_a: &[u8], code_b: &[u8]) -> u32 {
code_a
.iter()
.zip(code_b.iter())
.map(|(a, b)| (a ^ b).count_ones())
.sum()
}
fn estimate_similarity(&self, hamming_dist: u32, code_bits: usize) -> f32 {
let normalized_dist = hamming_dist as f32 / code_bits as f32;
1.0 - 2.0 * normalized_dist
}
}
/// Hash index for fast approximate nearest neighbor search
pub struct HashIndex<H: NeuralHash + Clone> {
/// Hash function
hasher: H,
/// Hash tables: binary code -> list of vector IDs
tables: HashMap<Vec<u8>, Vec<VectorId>>,
/// Original vectors for verification
vectors: HashMap<VectorId, Vec<f32>>,
/// Code bits
code_bits: usize,
}
impl<H: NeuralHash + Clone> HashIndex<H> {
/// Create a new hash index
pub fn new(hasher: H, code_bits: usize) -> Self {
Self {
hasher,
tables: HashMap::new(),
vectors: HashMap::new(),
code_bits,
}
}
/// Insert a vector
pub fn insert(&mut self, id: VectorId, vector: Vec<f32>) {
let code = self.hasher.encode(&vector);
self.tables.entry(code).or_default().push(id.clone());
self.vectors.insert(id, vector);
}
/// Search for approximate nearest neighbors
pub fn search(&self, query: &[f32], k: usize, max_hamming: u32) -> Vec<(VectorId, f32)> {
let query_code = self.hasher.encode(query);
let mut candidates = Vec::new();
// Find all vectors within Hamming distance threshold
for (code, ids) in &self.tables {
let hamming = self.hasher.hamming_distance(&query_code, code);
if hamming <= max_hamming {
for id in ids {
if let Some(vec) = self.vectors.get(id) {
let similarity = cosine_similarity(query, vec);
candidates.push((id.clone(), similarity));
}
}
}
}
// Sort by similarity and return top-k
candidates.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap());
candidates.truncate(k);
candidates
}
/// Get compression ratio
pub fn compression_ratio(&self) -> f32 {
if self.vectors.is_empty() {
return 0.0;
}
let original_size: usize = self
.vectors
.values()
.map(|v| v.len() * std::mem::size_of::<f32>())
.sum();
let compressed_size = self.tables.len() * self.code_bits.div_ceil(8);
original_size as f32 / compressed_size as f32
}
/// Get statistics
pub fn stats(&self) -> HashIndexStats {
let buckets = self.tables.len();
let total_vectors = self.vectors.len();
let avg_bucket_size = if buckets > 0 {
total_vectors as f32 / buckets as f32
} else {
0.0
};
HashIndexStats {
total_vectors,
num_buckets: buckets,
avg_bucket_size,
compression_ratio: self.compression_ratio(),
}
}
}
/// Hash index statistics
#[derive(Debug, Clone)]
pub struct HashIndexStats {
pub total_vectors: usize,
pub num_buckets: usize,
pub avg_bucket_size: f32,
pub compression_ratio: f32,
}
fn cosine_similarity(a: &[f32], b: &[f32]) -> f32 {
let dot: f32 = a.iter().zip(b.iter()).map(|(x, y)| x * y).sum();
let norm_a: f32 = a.iter().map(|x| x * x).sum::<f32>().sqrt();
let norm_b: f32 = b.iter().map(|x| x * x).sum::<f32>().sqrt();
if norm_a > 0.0 && norm_b > 0.0 {
dot / (norm_a * norm_b)
} else {
0.0
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_deep_hash_encoding() {
let hash = DeepHashEmbedding::new(4, vec![8], 16);
let vector = vec![0.1, 0.2, 0.3, 0.4];
let code = hash.encode(&vector);
assert_eq!(code.len(), 2); // 16 bits = 2 bytes
}
#[test]
fn test_hamming_distance() {
let hash = DeepHashEmbedding::new(2, vec![], 8);
let code_a = vec![0b10101010];
let code_b = vec![0b11001100];
let dist = hash.hamming_distance(&code_a, &code_b);
assert_eq!(dist, 4); // 4 bits differ
}
#[test]
fn test_lsh_encoding() {
let lsh = SimpleLSH::new(4, 16);
let vector = vec![1.0, 2.0, 3.0, 4.0];
let code = lsh.encode(&vector);
assert_eq!(code.len(), 2);
// Same vector should produce same code
let code2 = lsh.encode(&vector);
assert_eq!(code, code2);
}
#[test]
fn test_hash_index() {
let lsh = SimpleLSH::new(3, 8);
let mut index = HashIndex::new(lsh, 8);
// Insert vectors
index.insert("0".to_string(), vec![1.0, 0.0, 0.0]);
index.insert("1".to_string(), vec![0.9, 0.1, 0.0]);
index.insert("2".to_string(), vec![0.0, 1.0, 0.0]);
// Search
let results = index.search(&[1.0, 0.0, 0.0], 2, 4);
assert!(!results.is_empty());
let stats = index.stats();
assert_eq!(stats.total_vectors, 3);
}
#[test]
fn test_compression_ratio() {
let lsh = SimpleLSH::new(128, 32); // 128D -> 32 bits
let mut index = HashIndex::new(lsh, 32);
for i in 0..10 {
let vec: Vec<f32> = (0..128).map(|j| (i + j) as f32 / 128.0).collect();
index.insert(i.to_string(), vec);
}
let ratio = index.compression_ratio();
assert!(ratio > 1.0); // Should have compression
}
}

View File

@@ -0,0 +1,496 @@
//! # Topological Data Analysis (TDA)
//!
//! Basic topological analysis for embedding quality assessment.
//! Detects mode collapse, degeneracy, and topological structure.
use crate::error::{Result, RuvectorError};
use ndarray::Array2;
use serde::{Deserialize, Serialize};
/// Topological analyzer for embeddings
pub struct TopologicalAnalyzer {
/// k for k-nearest neighbors graph
k_neighbors: usize,
/// Distance threshold for edge creation
epsilon: f32,
}
impl TopologicalAnalyzer {
/// Create a new topological analyzer
pub fn new(k_neighbors: usize, epsilon: f32) -> Self {
Self {
k_neighbors,
epsilon,
}
}
/// Analyze embedding quality
pub fn analyze(&self, embeddings: &[Vec<f32>]) -> Result<EmbeddingQuality> {
if embeddings.is_empty() {
return Err(RuvectorError::InvalidInput("Empty embeddings".into()));
}
let n = embeddings.len();
let dim = embeddings[0].len();
// Build k-NN graph
let graph = self.build_knn_graph(embeddings);
// Compute topological features
let connected_components = self.count_connected_components(&graph, n);
let clustering_coefficient = self.compute_clustering_coefficient(&graph);
let degree_stats = self.compute_degree_statistics(&graph, n);
// Detect mode collapse
let mode_collapse_score = self.detect_mode_collapse(embeddings);
// Compute embedding spread
let spread = self.compute_spread(embeddings);
// Detect degeneracy (vectors collapsing to a lower-dimensional manifold)
let degeneracy_score = self.detect_degeneracy(embeddings);
// Compute persistence features (simplified)
let persistence_score = self.compute_persistence(&graph, embeddings);
// Overall quality score (0-1, higher is better)
let quality_score = self.compute_quality_score(
mode_collapse_score,
degeneracy_score,
connected_components,
clustering_coefficient,
spread,
);
Ok(EmbeddingQuality {
dimensions: dim,
num_vectors: n,
connected_components,
clustering_coefficient,
avg_degree: degree_stats.0,
degree_std: degree_stats.1,
mode_collapse_score,
degeneracy_score,
spread,
persistence_score,
quality_score,
})
}
fn build_knn_graph(&self, embeddings: &[Vec<f32>]) -> Vec<Vec<usize>> {
let n = embeddings.len();
let mut graph = vec![Vec::new(); n];
for i in 0..n {
let mut distances: Vec<(usize, f32)> = (0..n)
.filter(|&j| i != j)
.map(|j| {
let dist = euclidean_distance(&embeddings[i], &embeddings[j]);
(j, dist)
})
.collect();
distances.sort_by(|a, b| a.1.partial_cmp(&b.1).unwrap());
// Add k nearest neighbors
for (j, dist) in distances.iter().take(self.k_neighbors) {
if *dist <= self.epsilon {
graph[i].push(*j);
}
}
}
graph
}
fn count_connected_components(&self, graph: &[Vec<usize>], n: usize) -> usize {
let mut visited = vec![false; n];
let mut components = 0;
for i in 0..n {
if !visited[i] {
components += 1;
self.dfs(i, graph, &mut visited);
}
}
components
}
#[allow(clippy::only_used_in_recursion)]
fn dfs(&self, node: usize, graph: &[Vec<usize>], visited: &mut [bool]) {
visited[node] = true;
for &neighbor in &graph[node] {
if !visited[neighbor] {
self.dfs(neighbor, graph, visited);
}
}
}
fn compute_clustering_coefficient(&self, graph: &[Vec<usize>]) -> f32 {
let mut total_coeff = 0.0;
let mut count = 0;
for neighbors in graph {
if neighbors.len() < 2 {
continue;
}
let k = neighbors.len();
let mut triangles = 0;
// Count triangles
for i in 0..k {
for j in i + 1..k {
let ni = neighbors[i];
let nj = neighbors[j];
if graph[ni].contains(&nj) {
triangles += 1;
}
}
}
let possible_triangles = k * (k - 1) / 2;
if possible_triangles > 0 {
total_coeff += triangles as f32 / possible_triangles as f32;
count += 1;
}
}
if count > 0 {
total_coeff / count as f32
} else {
0.0
}
}
fn compute_degree_statistics(&self, graph: &[Vec<usize>], n: usize) -> (f32, f32) {
let degrees: Vec<f32> = graph
.iter()
.map(|neighbors| neighbors.len() as f32)
.collect();
let avg = degrees.iter().sum::<f32>() / n as f32;
let variance = degrees.iter().map(|&d| (d - avg).powi(2)).sum::<f32>() / n as f32;
let std = variance.sqrt();
(avg, std)
}
fn detect_mode_collapse(&self, embeddings: &[Vec<f32>]) -> f32 {
// Compute pairwise distances
let n = embeddings.len();
let mut distances = Vec::new();
for i in 0..n {
for j in i + 1..n {
let dist = euclidean_distance(&embeddings[i], &embeddings[j]);
distances.push(dist);
}
}
if distances.is_empty() {
return 0.0;
}
// Compute coefficient of variation
let mean = distances.iter().sum::<f32>() / distances.len() as f32;
let variance =
distances.iter().map(|&d| (d - mean).powi(2)).sum::<f32>() / distances.len() as f32;
let std = variance.sqrt();
// High CV indicates good separation, low CV indicates collapse
let cv = if mean > 0.0 { std / mean } else { 0.0 };
// Normalize to 0-1, where 0 is collapsed, 1 is good
(cv * 2.0).min(1.0)
}
fn compute_spread(&self, embeddings: &[Vec<f32>]) -> f32 {
if embeddings.is_empty() {
return 0.0;
}
let dim = embeddings[0].len();
// Compute mean
let mut mean = vec![0.0; dim];
for emb in embeddings {
for (i, &val) in emb.iter().enumerate() {
mean[i] += val;
}
}
for val in mean.iter_mut() {
*val /= embeddings.len() as f32;
}
// Compute average distance from mean
let mut total_dist = 0.0;
for emb in embeddings {
let dist = euclidean_distance(emb, &mean);
total_dist += dist;
}
total_dist / embeddings.len() as f32
}
fn detect_degeneracy(&self, embeddings: &[Vec<f32>]) -> f32 {
if embeddings.is_empty() || embeddings[0].is_empty() {
return 1.0; // Fully degenerate
}
let n = embeddings.len();
let dim = embeddings[0].len();
if n < dim {
return 0.0; // Cannot determine
}
// Compute covariance matrix
let cov = self.compute_covariance_matrix(embeddings);
// Estimate rank by counting significant singular values
let singular_values = self.approximate_singular_values(&cov);
let significant = singular_values.iter().filter(|&&sv| sv > 1e-6).count();
// Degeneracy score: 0 = full rank, 1 = rank-1 (collapsed)
1.0 - (significant as f32 / dim as f32)
}
fn compute_covariance_matrix(&self, embeddings: &[Vec<f32>]) -> Array2<f32> {
let n = embeddings.len();
let dim = embeddings[0].len();
// Compute mean
let mut mean = vec![0.0; dim];
for emb in embeddings {
for (i, &val) in emb.iter().enumerate() {
mean[i] += val;
}
}
for val in mean.iter_mut() {
*val /= n as f32;
}
// Compute covariance
let mut cov = Array2::zeros((dim, dim));
for emb in embeddings {
for i in 0..dim {
for j in 0..dim {
cov[[i, j]] += (emb[i] - mean[i]) * (emb[j] - mean[j]);
}
}
}
cov.mapv(|x| x / (n - 1) as f32);
cov
}
fn approximate_singular_values(&self, matrix: &Array2<f32>) -> Vec<f32> {
// Power iteration for largest singular values (simplified)
let dim = matrix.nrows();
let mut values = Vec::new();
// Just return diagonal for approximation
for i in 0..dim {
values.push(matrix[[i, i]].abs());
}
values.sort_by(|a, b| b.partial_cmp(a).unwrap());
values
}
fn compute_persistence(&self, _graph: &[Vec<usize>], embeddings: &[Vec<f32>]) -> f32 {
// Simplified persistence: measure how graph structure changes with distance threshold
let scales = vec![0.1, 0.5, 1.0, 2.0, 5.0];
let mut component_counts = Vec::new();
for &scale in &scales {
let scaled_analyzer = TopologicalAnalyzer::new(self.k_neighbors, scale);
let scaled_graph = scaled_analyzer.build_knn_graph(embeddings);
let components =
scaled_analyzer.count_connected_components(&scaled_graph, embeddings.len());
component_counts.push(components);
}
// Persistence is the variation in component count across scales
let max_components = *component_counts.iter().max().unwrap_or(&1);
let min_components = *component_counts.iter().min().unwrap_or(&1);
(max_components - min_components) as f32 / max_components as f32
}
fn compute_quality_score(
&self,
mode_collapse: f32,
degeneracy: f32,
components: usize,
clustering: f32,
spread: f32,
) -> f32 {
// Weighted combination of metrics
let collapse_score = mode_collapse; // Higher is better
let degeneracy_score = 1.0 - degeneracy; // Lower degeneracy is better
let component_score = if components == 1 { 1.0 } else { 0.5 }; // Single component is good
let clustering_score = clustering; // Higher clustering is good
let spread_score = (spread / 10.0).min(1.0); // Reasonable spread
(collapse_score * 0.3
+ degeneracy_score * 0.3
+ component_score * 0.2
+ clustering_score * 0.1
+ spread_score * 0.1)
.clamp(0.0, 1.0)
}
}
/// Embedding quality metrics
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct EmbeddingQuality {
/// Embedding dimensions
pub dimensions: usize,
/// Number of vectors
pub num_vectors: usize,
/// Number of connected components
pub connected_components: usize,
/// Clustering coefficient (0-1)
pub clustering_coefficient: f32,
/// Average node degree
pub avg_degree: f32,
/// Degree standard deviation
pub degree_std: f32,
/// Mode collapse score (0=collapsed, 1=good)
pub mode_collapse_score: f32,
/// Degeneracy score (0=full rank, 1=degenerate)
pub degeneracy_score: f32,
/// Average spread from centroid
pub spread: f32,
/// Topological persistence score
pub persistence_score: f32,
/// Overall quality (0-1, higher is better)
pub quality_score: f32,
}
impl EmbeddingQuality {
/// Check if embeddings show signs of mode collapse
pub fn has_mode_collapse(&self) -> bool {
self.mode_collapse_score < 0.3
}
/// Check if embeddings are degenerate
pub fn is_degenerate(&self) -> bool {
self.degeneracy_score > 0.7
}
/// Check if embeddings are well-structured
pub fn is_good_quality(&self) -> bool {
self.quality_score > 0.7
}
/// Get quality assessment
pub fn assessment(&self) -> &str {
if self.quality_score > 0.8 {
"Excellent"
} else if self.quality_score > 0.6 {
"Good"
} else if self.quality_score > 0.4 {
"Fair"
} else {
"Poor"
}
}
}
fn euclidean_distance(a: &[f32], b: &[f32]) -> f32 {
a.iter()
.zip(b.iter())
.map(|(x, y)| (x - y).powi(2))
.sum::<f32>()
.sqrt()
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_embedding_analysis() {
let analyzer = TopologicalAnalyzer::new(3, 5.0);
// Create well-separated embeddings
let embeddings = vec![
vec![0.0, 0.0],
vec![0.1, 0.1],
vec![0.2, 0.2],
vec![5.0, 5.0],
vec![5.1, 5.1],
];
let quality = analyzer.analyze(&embeddings).unwrap();
assert_eq!(quality.dimensions, 2);
assert_eq!(quality.num_vectors, 5);
assert!(quality.quality_score > 0.0);
}
#[test]
fn test_mode_collapse_detection() {
let analyzer = TopologicalAnalyzer::new(2, 10.0);
// Well-separated embeddings (high CV should give high score)
let good = vec![vec![0.0, 0.0], vec![5.0, 5.0], vec![10.0, 10.0]];
let score_good = analyzer.detect_mode_collapse(&good);
// Collapsed embeddings (all identical, CV = 0)
let collapsed = vec![vec![1.0, 1.0], vec![1.0, 1.0], vec![1.0, 1.0]];
let score_collapsed = analyzer.detect_mode_collapse(&collapsed);
// Identical vectors should have score 0 (distances all same = CV 0)
assert_eq!(score_collapsed, 0.0);
// Well-separated should have higher score
assert!(score_good > score_collapsed);
}
#[test]
fn test_connected_components() {
let analyzer = TopologicalAnalyzer::new(1, 1.0);
// Two separate clusters
let embeddings = vec![
vec![0.0, 0.0],
vec![0.5, 0.5],
vec![10.0, 10.0],
vec![10.5, 10.5],
];
let graph = analyzer.build_knn_graph(&embeddings);
let components = analyzer.count_connected_components(&graph, embeddings.len());
assert!(components >= 2); // Should have at least 2 components
}
#[test]
fn test_quality_assessment() {
let quality = EmbeddingQuality {
dimensions: 128,
num_vectors: 1000,
connected_components: 1,
clustering_coefficient: 0.6,
avg_degree: 5.0,
degree_std: 1.2,
mode_collapse_score: 0.8,
degeneracy_score: 0.2,
spread: 3.5,
persistence_score: 0.4,
quality_score: 0.75,
};
assert!(!quality.has_mode_collapse());
assert!(!quality.is_degenerate());
assert!(quality.is_good_quality());
assert_eq!(quality.assessment(), "Good");
}
}

View File

@@ -0,0 +1,23 @@
//! Advanced Features for Ruvector
//!
//! This module provides advanced vector database capabilities:
//! - Enhanced Product Quantization with precomputed lookup tables
//! - Filtered Search with automatic strategy selection
//! - MMR (Maximal Marginal Relevance) for diversity
//! - Hybrid Search combining vector and keyword matching
//! - Conformal Prediction for uncertainty quantification
pub mod conformal_prediction;
pub mod filtered_search;
pub mod hybrid_search;
pub mod mmr;
pub mod product_quantization;
// Re-exports
pub use conformal_prediction::{
ConformalConfig, ConformalPredictor, NonconformityMeasure, PredictionSet,
};
pub use filtered_search::{FilterExpression, FilterStrategy, FilteredSearch};
pub use hybrid_search::{HybridConfig, HybridSearch, NormalizationStrategy, BM25};
pub use mmr::{MMRConfig, MMRSearch};
pub use product_quantization::{EnhancedPQ, LookupTable, PQConfig};

View File

@@ -0,0 +1,503 @@
//! Conformal Prediction for Uncertainty Quantification
//!
//! Implements conformal prediction to provide statistically valid uncertainty estimates
//! and prediction sets with guaranteed coverage (1-α).
use crate::error::{Result, RuvectorError};
use crate::types::{SearchResult, VectorId};
use serde::{Deserialize, Serialize};
/// Configuration for conformal prediction
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct ConformalConfig {
/// Significance level (alpha) - typically 0.05 or 0.10
pub alpha: f32,
/// Size of calibration set (as fraction of total data)
pub calibration_fraction: f32,
/// Non-conformity measure type
pub nonconformity_measure: NonconformityMeasure,
}
impl Default for ConformalConfig {
fn default() -> Self {
Self {
alpha: 0.1, // 90% coverage
calibration_fraction: 0.2,
nonconformity_measure: NonconformityMeasure::Distance,
}
}
}
/// Type of non-conformity measure
#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
pub enum NonconformityMeasure {
/// Use distance score as non-conformity
Distance,
/// Use inverse rank as non-conformity
InverseRank,
/// Use normalized distance (distance / avg_distance)
NormalizedDistance,
}
/// Prediction set with conformal guarantees
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct PredictionSet {
/// Results in the prediction set
pub results: Vec<SearchResult>,
/// Conformal threshold used
pub threshold: f32,
/// Confidence level (1 - alpha)
pub confidence: f32,
/// Coverage guarantee
pub coverage_guarantee: f32,
}
/// Conformal predictor for vector search
#[derive(Debug, Clone)]
pub struct ConformalPredictor {
/// Configuration
pub config: ConformalConfig,
/// Calibration set: non-conformity scores
pub calibration_scores: Vec<f32>,
/// Conformal threshold (quantile of calibration scores)
pub threshold: Option<f32>,
}
impl ConformalPredictor {
/// Create a new conformal predictor
pub fn new(config: ConformalConfig) -> Result<Self> {
if !(0.0..=1.0).contains(&config.alpha) {
return Err(RuvectorError::InvalidParameter(format!(
"Alpha must be in [0, 1], got {}",
config.alpha
)));
}
if !(0.0..=1.0).contains(&config.calibration_fraction) {
return Err(RuvectorError::InvalidParameter(format!(
"Calibration fraction must be in [0, 1], got {}",
config.calibration_fraction
)));
}
Ok(Self {
config,
calibration_scores: Vec::new(),
threshold: None,
})
}
/// Calibrate on a set of validation examples
///
/// # Arguments
/// * `validation_queries` - Query vectors for calibration
/// * `true_neighbors` - Ground truth neighbors for each query
/// * `search_fn` - Function to perform search
pub fn calibrate<F>(
&mut self,
validation_queries: &[Vec<f32>],
true_neighbors: &[Vec<VectorId>],
search_fn: F,
) -> Result<()>
where
F: Fn(&[f32], usize) -> Result<Vec<SearchResult>>,
{
if validation_queries.len() != true_neighbors.len() {
return Err(RuvectorError::InvalidParameter(
"Number of queries must match number of true neighbor sets".to_string(),
));
}
if validation_queries.is_empty() {
return Err(RuvectorError::InvalidParameter(
"Calibration set cannot be empty".to_string(),
));
}
let mut all_scores = Vec::new();
// Compute non-conformity scores for calibration set
for (query, true_ids) in validation_queries.iter().zip(true_neighbors) {
// Search for neighbors
let results = search_fn(query, 100)?; // Fetch more results
// Compute non-conformity scores for true neighbors
for true_id in true_ids {
let score = self.compute_nonconformity_score(&results, true_id)?;
all_scores.push(score);
}
}
self.calibration_scores = all_scores;
// Compute threshold as (1 - alpha) quantile
self.compute_threshold()?;
Ok(())
}
/// Compute conformal threshold from calibration scores
fn compute_threshold(&mut self) -> Result<()> {
if self.calibration_scores.is_empty() {
return Err(RuvectorError::InvalidParameter(
"No calibration scores available".to_string(),
));
}
let mut sorted_scores = self.calibration_scores.clone();
sorted_scores.sort_by(|a, b| a.partial_cmp(b).unwrap());
// Compute (1 - alpha) quantile
let n = sorted_scores.len();
let quantile_index = ((1.0 - self.config.alpha) * (n as f32 + 1.0)).ceil() as usize;
let quantile_index = quantile_index.min(n - 1);
self.threshold = Some(sorted_scores[quantile_index]);
Ok(())
}
/// Compute non-conformity score for a specific result
fn compute_nonconformity_score(
&self,
results: &[SearchResult],
target_id: &VectorId,
) -> Result<f32> {
match self.config.nonconformity_measure {
NonconformityMeasure::Distance => {
// Use distance score directly
results
.iter()
.find(|r| &r.id == target_id)
.map(|r| r.score)
.ok_or_else(|| {
RuvectorError::VectorNotFound(format!(
"Target {} not in results",
target_id
))
})
}
NonconformityMeasure::InverseRank => {
// Use inverse rank: 1 / (rank + 1)
let rank = results
.iter()
.position(|r| &r.id == target_id)
.ok_or_else(|| {
RuvectorError::VectorNotFound(format!(
"Target {} not in results",
target_id
))
})?;
Ok(1.0 / (rank as f32 + 1.0))
}
NonconformityMeasure::NormalizedDistance => {
// Normalize by average distance
let target_score = results
.iter()
.find(|r| &r.id == target_id)
.map(|r| r.score)
.ok_or_else(|| {
RuvectorError::VectorNotFound(format!(
"Target {} not in results",
target_id
))
})?;
// Guard against empty results
if results.is_empty() {
return Ok(target_score);
}
let avg_score = results.iter().map(|r| r.score).sum::<f32>() / results.len() as f32;
Ok(if avg_score > 0.0 {
target_score / avg_score
} else {
target_score
})
}
}
}
/// Make prediction with conformal guarantee
///
/// # Arguments
/// * `query` - Query vector
/// * `search_fn` - Function to perform search
///
/// # Returns
/// Prediction set with coverage guarantee
pub fn predict<F>(&self, query: &[f32], search_fn: F) -> Result<PredictionSet>
where
F: Fn(&[f32], usize) -> Result<Vec<SearchResult>>,
{
let threshold = self.threshold.ok_or_else(|| {
RuvectorError::InvalidParameter("Predictor not calibrated yet".to_string())
})?;
// Perform search with large k
let results = search_fn(query, 1000)?;
// Select results based on non-conformity threshold
let prediction_set: Vec<SearchResult> = match self.config.nonconformity_measure {
NonconformityMeasure::Distance => {
// Include all results with distance <= threshold
results
.into_iter()
.filter(|r| r.score <= threshold)
.collect()
}
NonconformityMeasure::InverseRank => {
// Include top-k results where k is determined by threshold
let k = (1.0 / threshold).ceil() as usize;
results.into_iter().take(k).collect()
}
NonconformityMeasure::NormalizedDistance => {
// Guard against empty results
if results.is_empty() {
return Ok(PredictionSet {
results: vec![],
threshold,
confidence: 1.0 - self.config.alpha,
coverage_guarantee: 1.0 - self.config.alpha,
});
}
let avg_score = results.iter().map(|r| r.score).sum::<f32>() / results.len() as f32;
let adjusted_threshold = threshold * avg_score;
results
.into_iter()
.filter(|r| r.score <= adjusted_threshold)
.collect()
}
};
Ok(PredictionSet {
results: prediction_set,
threshold,
confidence: 1.0 - self.config.alpha,
coverage_guarantee: 1.0 - self.config.alpha,
})
}
/// Compute adaptive top-k based on uncertainty
///
/// # Arguments
/// * `query` - Query vector
/// * `search_fn` - Function to perform search
///
/// # Returns
/// Number of results to return based on uncertainty
pub fn adaptive_top_k<F>(&self, query: &[f32], search_fn: F) -> Result<usize>
where
F: Fn(&[f32], usize) -> Result<Vec<SearchResult>>,
{
let prediction_set = self.predict(query, search_fn)?;
Ok(prediction_set.results.len())
}
/// Get calibration statistics
pub fn get_statistics(&self) -> Option<CalibrationStats> {
if self.calibration_scores.is_empty() {
return None;
}
let n = self.calibration_scores.len() as f32;
let mean = self.calibration_scores.iter().sum::<f32>() / n;
let variance = self
.calibration_scores
.iter()
.map(|&s| (s - mean).powi(2))
.sum::<f32>()
/ n;
let std = variance.sqrt();
let mut sorted = self.calibration_scores.clone();
sorted.sort_by(|a, b| a.partial_cmp(b).unwrap());
Some(CalibrationStats {
num_samples: self.calibration_scores.len(),
mean,
std,
min: sorted.first().copied().unwrap(),
max: sorted.last().copied().unwrap(),
median: sorted[sorted.len() / 2],
threshold: self.threshold,
})
}
}
/// Calibration statistics
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct CalibrationStats {
/// Number of calibration samples
pub num_samples: usize,
/// Mean non-conformity score
pub mean: f32,
/// Standard deviation
pub std: f32,
/// Minimum score
pub min: f32,
/// Maximum score
pub max: f32,
/// Median score
pub median: f32,
/// Conformal threshold
pub threshold: Option<f32>,
}
#[cfg(test)]
mod tests {
use super::*;
fn create_search_result(id: &str, score: f32) -> SearchResult {
SearchResult {
id: id.to_string(),
score,
vector: Some(vec![0.0; 10]),
metadata: None,
}
}
fn mock_search_fn(_query: &[f32], k: usize) -> Result<Vec<SearchResult>> {
Ok((0..k)
.map(|i| create_search_result(&format!("doc_{}", i), i as f32 * 0.1))
.collect())
}
#[test]
fn test_conformal_config_validation() {
let config = ConformalConfig {
alpha: 0.1,
..Default::default()
};
assert!(ConformalPredictor::new(config).is_ok());
let invalid_config = ConformalConfig {
alpha: 1.5,
..Default::default()
};
assert!(ConformalPredictor::new(invalid_config).is_err());
}
#[test]
fn test_conformal_calibration() {
let config = ConformalConfig::default();
let mut predictor = ConformalPredictor::new(config).unwrap();
// Create calibration data
let queries = vec![vec![1.0; 10], vec![2.0; 10], vec![3.0; 10]];
let true_neighbors = vec![
vec!["doc_0".to_string(), "doc_1".to_string()],
vec!["doc_0".to_string()],
vec!["doc_1".to_string(), "doc_2".to_string()],
];
predictor
.calibrate(&queries, &true_neighbors, mock_search_fn)
.unwrap();
assert!(!predictor.calibration_scores.is_empty());
assert!(predictor.threshold.is_some());
}
#[test]
fn test_conformal_prediction() {
let config = ConformalConfig {
alpha: 0.1,
calibration_fraction: 0.2,
nonconformity_measure: NonconformityMeasure::Distance,
};
let mut predictor = ConformalPredictor::new(config).unwrap();
// Calibrate
let queries = vec![vec![1.0; 10], vec![2.0; 10]];
let true_neighbors = vec![vec!["doc_0".to_string()], vec!["doc_1".to_string()]];
predictor
.calibrate(&queries, &true_neighbors, mock_search_fn)
.unwrap();
// Make prediction
let query = vec![1.5; 10];
let prediction_set = predictor.predict(&query, mock_search_fn).unwrap();
assert!(!prediction_set.results.is_empty());
assert_eq!(prediction_set.confidence, 0.9);
assert!(prediction_set.threshold > 0.0);
}
#[test]
fn test_nonconformity_distance() {
let config = ConformalConfig {
nonconformity_measure: NonconformityMeasure::Distance,
..Default::default()
};
let predictor = ConformalPredictor::new(config).unwrap();
let results = vec![
create_search_result("doc_0", 0.1),
create_search_result("doc_1", 0.3),
create_search_result("doc_2", 0.5),
];
let score = predictor
.compute_nonconformity_score(&results, &"doc_1".to_string())
.unwrap();
assert!((score - 0.3).abs() < 0.01);
}
#[test]
fn test_nonconformity_inverse_rank() {
let config = ConformalConfig {
nonconformity_measure: NonconformityMeasure::InverseRank,
..Default::default()
};
let predictor = ConformalPredictor::new(config).unwrap();
let results = vec![
create_search_result("doc_0", 0.1),
create_search_result("doc_1", 0.3),
create_search_result("doc_2", 0.5),
];
let score = predictor
.compute_nonconformity_score(&results, &"doc_1".to_string())
.unwrap();
assert!((score - 0.5).abs() < 0.01); // 1 / (1 + 1) = 0.5
}
#[test]
fn test_calibration_stats() {
let config = ConformalConfig::default();
let mut predictor = ConformalPredictor::new(config).unwrap();
predictor.calibration_scores = vec![0.1, 0.2, 0.3, 0.4, 0.5];
predictor.threshold = Some(0.4);
let stats = predictor.get_statistics().unwrap();
assert_eq!(stats.num_samples, 5);
assert!((stats.mean - 0.3).abs() < 0.01);
assert!((stats.min - 0.1).abs() < 0.01);
assert!((stats.max - 0.5).abs() < 0.01);
}
#[test]
fn test_adaptive_top_k() {
let config = ConformalConfig::default();
let mut predictor = ConformalPredictor::new(config).unwrap();
// Calibrate
let queries = vec![vec![1.0; 10], vec![2.0; 10]];
let true_neighbors = vec![vec!["doc_0".to_string()], vec!["doc_1".to_string()]];
predictor
.calibrate(&queries, &true_neighbors, mock_search_fn)
.unwrap();
// Test adaptive k
let query = vec![1.5; 10];
let k = predictor.adaptive_top_k(&query, mock_search_fn).unwrap();
assert!(k > 0);
}
}

View File

@@ -0,0 +1,363 @@
//! Filtered Search with Automatic Strategy Selection
//!
//! Supports two filtering strategies:
//! - Pre-filtering: Apply metadata filters before graph traversal
//! - Post-filtering: Traverse graph then apply filters
//! - Automatic strategy selection based on filter selectivity
use crate::error::Result;
use crate::types::{SearchResult, VectorId};
use serde::{Deserialize, Serialize};
use std::collections::HashMap;
/// Filter strategy selection
#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
pub enum FilterStrategy {
/// Apply filters before search (efficient for highly selective filters)
PreFilter,
/// Apply filters after search (efficient for low selectivity)
PostFilter,
/// Automatically select strategy based on estimated selectivity
Auto,
}
/// Filter expression for metadata
#[derive(Debug, Clone, Serialize, Deserialize)]
pub enum FilterExpression {
/// Equality check: field == value
Eq(String, serde_json::Value),
/// Not equal: field != value
Ne(String, serde_json::Value),
/// Greater than: field > value
Gt(String, serde_json::Value),
/// Greater than or equal: field >= value
Gte(String, serde_json::Value),
/// Less than: field < value
Lt(String, serde_json::Value),
/// Less than or equal: field <= value
Lte(String, serde_json::Value),
/// In list: field in [values]
In(String, Vec<serde_json::Value>),
/// Not in list: field not in [values]
NotIn(String, Vec<serde_json::Value>),
/// Range check: min <= field <= max
Range(String, serde_json::Value, serde_json::Value),
/// Logical AND
And(Vec<FilterExpression>),
/// Logical OR
Or(Vec<FilterExpression>),
/// Logical NOT
Not(Box<FilterExpression>),
}
impl FilterExpression {
/// Evaluate filter against metadata
pub fn evaluate(&self, metadata: &HashMap<String, serde_json::Value>) -> bool {
match self {
FilterExpression::Eq(field, value) => metadata.get(field) == Some(value),
FilterExpression::Ne(field, value) => metadata.get(field) != Some(value),
FilterExpression::Gt(field, value) => {
if let Some(field_value) = metadata.get(field) {
compare_values(field_value, value) > 0
} else {
false
}
}
FilterExpression::Gte(field, value) => {
if let Some(field_value) = metadata.get(field) {
compare_values(field_value, value) >= 0
} else {
false
}
}
FilterExpression::Lt(field, value) => {
if let Some(field_value) = metadata.get(field) {
compare_values(field_value, value) < 0
} else {
false
}
}
FilterExpression::Lte(field, value) => {
if let Some(field_value) = metadata.get(field) {
compare_values(field_value, value) <= 0
} else {
false
}
}
FilterExpression::In(field, values) => {
if let Some(field_value) = metadata.get(field) {
values.contains(field_value)
} else {
false
}
}
FilterExpression::NotIn(field, values) => {
if let Some(field_value) = metadata.get(field) {
!values.contains(field_value)
} else {
true
}
}
FilterExpression::Range(field, min, max) => {
if let Some(field_value) = metadata.get(field) {
compare_values(field_value, min) >= 0 && compare_values(field_value, max) <= 0
} else {
false
}
}
FilterExpression::And(exprs) => exprs.iter().all(|e| e.evaluate(metadata)),
FilterExpression::Or(exprs) => exprs.iter().any(|e| e.evaluate(metadata)),
FilterExpression::Not(expr) => !expr.evaluate(metadata),
}
}
/// Estimate selectivity of filter (0.0 = very selective, 1.0 = not selective)
#[allow(clippy::only_used_in_recursion)]
pub fn estimate_selectivity(&self, total_vectors: usize) -> f32 {
match self {
FilterExpression::Eq(_, _) => 0.1, // Equality is typically selective
FilterExpression::Ne(_, _) => 0.9, // Not equal is less selective
FilterExpression::In(_, values) => (values.len() as f32) / 100.0,
FilterExpression::NotIn(_, values) => 1.0 - (values.len() as f32) / 100.0,
FilterExpression::Range(_, _, _) => 0.3, // Ranges are moderately selective
FilterExpression::Gt(_, _) | FilterExpression::Gte(_, _) => 0.5,
FilterExpression::Lt(_, _) | FilterExpression::Lte(_, _) => 0.5,
FilterExpression::And(exprs) => {
// AND is more selective (multiply selectivities)
exprs
.iter()
.map(|e| e.estimate_selectivity(total_vectors))
.product()
}
FilterExpression::Or(exprs) => {
// OR is less selective (sum selectivities, capped at 1.0)
exprs
.iter()
.map(|e| e.estimate_selectivity(total_vectors))
.sum::<f32>()
.min(1.0)
}
FilterExpression::Not(expr) => 1.0 - expr.estimate_selectivity(total_vectors),
}
}
}
/// Filtered search implementation
#[derive(Debug, Clone)]
pub struct FilteredSearch {
/// Filter expression
pub filter: FilterExpression,
/// Strategy for applying filter
pub strategy: FilterStrategy,
/// Metadata store: id -> metadata
pub metadata_store: HashMap<VectorId, HashMap<String, serde_json::Value>>,
}
impl FilteredSearch {
/// Create a new filtered search instance
pub fn new(
filter: FilterExpression,
strategy: FilterStrategy,
metadata_store: HashMap<VectorId, HashMap<String, serde_json::Value>>,
) -> Self {
Self {
filter,
strategy,
metadata_store,
}
}
/// Automatically select strategy based on filter selectivity
pub fn auto_select_strategy(&self) -> FilterStrategy {
let selectivity = self.filter.estimate_selectivity(self.metadata_store.len());
// If filter is highly selective (< 20%), use pre-filtering
// Otherwise use post-filtering
if selectivity < 0.2 {
FilterStrategy::PreFilter
} else {
FilterStrategy::PostFilter
}
}
/// Get list of vector IDs that pass the filter (for pre-filtering)
pub fn get_filtered_ids(&self) -> Vec<VectorId> {
self.metadata_store
.iter()
.filter(|(_, metadata)| self.filter.evaluate(metadata))
.map(|(id, _)| id.clone())
.collect()
}
/// Apply filter to search results (for post-filtering)
pub fn filter_results(&self, results: Vec<SearchResult>) -> Vec<SearchResult> {
results
.into_iter()
.filter(|result| {
if let Some(metadata) = result.metadata.as_ref() {
self.filter.evaluate(metadata)
} else {
false
}
})
.collect()
}
/// Apply filtered search with automatic strategy selection
pub fn search<F>(&self, query: &[f32], k: usize, search_fn: F) -> Result<Vec<SearchResult>>
where
F: Fn(&[f32], usize, Option<&[VectorId]>) -> Result<Vec<SearchResult>>,
{
let strategy = match self.strategy {
FilterStrategy::Auto => self.auto_select_strategy(),
other => other,
};
match strategy {
FilterStrategy::PreFilter => {
// Get filtered IDs first
let filtered_ids = self.get_filtered_ids();
if filtered_ids.is_empty() {
return Ok(Vec::new());
}
// Search only within filtered IDs
// We may need to fetch more results to get k after filtering
let fetch_k = (k as f32 * 1.5).ceil() as usize;
search_fn(query, fetch_k, Some(&filtered_ids))
}
FilterStrategy::PostFilter => {
// Search first, then filter
// Fetch more results to ensure we get k after filtering
let fetch_k = (k as f32 * 2.0).ceil() as usize;
let results = search_fn(query, fetch_k, None)?;
// Apply filter
let filtered = self.filter_results(results);
// Return top-k
Ok(filtered.into_iter().take(k).collect())
}
FilterStrategy::Auto => unreachable!(),
}
}
}
// Helper function to compare JSON values
fn compare_values(a: &serde_json::Value, b: &serde_json::Value) -> i32 {
use serde_json::Value;
match (a, b) {
(Value::Number(a), Value::Number(b)) => {
let a_f64 = a.as_f64().unwrap_or(0.0);
let b_f64 = b.as_f64().unwrap_or(0.0);
if a_f64 < b_f64 {
-1
} else if a_f64 > b_f64 {
1
} else {
0
}
}
(Value::String(a), Value::String(b)) => a.cmp(b) as i32,
(Value::Bool(a), Value::Bool(b)) => a.cmp(b) as i32,
_ => 0,
}
}
#[cfg(test)]
mod tests {
use super::*;
use serde_json::json;
#[test]
fn test_filter_eq() {
let mut metadata = HashMap::new();
metadata.insert("category".to_string(), json!("electronics"));
let filter = FilterExpression::Eq("category".to_string(), json!("electronics"));
assert!(filter.evaluate(&metadata));
let filter = FilterExpression::Eq("category".to_string(), json!("books"));
assert!(!filter.evaluate(&metadata));
}
#[test]
fn test_filter_range() {
let mut metadata = HashMap::new();
metadata.insert("price".to_string(), json!(50.0));
let filter = FilterExpression::Range("price".to_string(), json!(10.0), json!(100.0));
assert!(filter.evaluate(&metadata));
let filter = FilterExpression::Range("price".to_string(), json!(60.0), json!(100.0));
assert!(!filter.evaluate(&metadata));
}
#[test]
fn test_filter_and() {
let mut metadata = HashMap::new();
metadata.insert("category".to_string(), json!("electronics"));
metadata.insert("price".to_string(), json!(50.0));
let filter = FilterExpression::And(vec![
FilterExpression::Eq("category".to_string(), json!("electronics")),
FilterExpression::Lt("price".to_string(), json!(100.0)),
]);
assert!(filter.evaluate(&metadata));
}
#[test]
fn test_filter_or() {
let mut metadata = HashMap::new();
metadata.insert("category".to_string(), json!("electronics"));
let filter = FilterExpression::Or(vec![
FilterExpression::Eq("category".to_string(), json!("books")),
FilterExpression::Eq("category".to_string(), json!("electronics")),
]);
assert!(filter.evaluate(&metadata));
}
#[test]
fn test_filter_in() {
let mut metadata = HashMap::new();
metadata.insert("tag".to_string(), json!("popular"));
let filter = FilterExpression::In(
"tag".to_string(),
vec![json!("popular"), json!("trending"), json!("new")],
);
assert!(filter.evaluate(&metadata));
}
#[test]
fn test_selectivity_estimation() {
let filter_eq = FilterExpression::Eq("field".to_string(), json!("value"));
assert!(filter_eq.estimate_selectivity(1000) < 0.5);
let filter_ne = FilterExpression::Ne("field".to_string(), json!("value"));
assert!(filter_ne.estimate_selectivity(1000) > 0.5);
}
#[test]
fn test_auto_strategy_selection() {
let mut metadata_store = HashMap::new();
for i in 0..100 {
let mut metadata = HashMap::new();
metadata.insert("id".to_string(), json!(i));
metadata_store.insert(format!("vec_{}", i), metadata);
}
// Highly selective filter should choose pre-filter
let filter = FilterExpression::Eq("id".to_string(), json!(42));
let search = FilteredSearch::new(filter, FilterStrategy::Auto, metadata_store.clone());
assert_eq!(search.auto_select_strategy(), FilterStrategy::PreFilter);
// Less selective filter should choose post-filter
let filter = FilterExpression::Gte("id".to_string(), json!(0));
let search = FilteredSearch::new(filter, FilterStrategy::Auto, metadata_store);
assert_eq!(search.auto_select_strategy(), FilterStrategy::PostFilter);
}
}

View File

@@ -0,0 +1,444 @@
//! Hybrid Search: Combining Vector Similarity and Keyword Matching
//!
//! Implements hybrid search by combining:
//! - Vector similarity search (semantic)
//! - BM25 keyword matching (lexical)
//! - Weighted combination of scores
use crate::error::Result;
use crate::types::{SearchResult, VectorId};
use serde::{Deserialize, Serialize};
use std::collections::{HashMap, HashSet};
/// Configuration for hybrid search
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct HybridConfig {
/// Weight for vector similarity (alpha)
pub vector_weight: f32,
/// Weight for keyword matching (beta)
pub keyword_weight: f32,
/// Normalization strategy
pub normalization: NormalizationStrategy,
}
impl Default for HybridConfig {
fn default() -> Self {
Self {
vector_weight: 0.7,
keyword_weight: 0.3,
normalization: NormalizationStrategy::MinMax,
}
}
}
/// Score normalization strategy
#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
pub enum NormalizationStrategy {
/// Min-max normalization: (x - min) / (max - min)
MinMax,
/// Z-score normalization: (x - mean) / std
ZScore,
/// No normalization
None,
}
/// Simple BM25 implementation for keyword matching
#[derive(Debug, Clone)]
pub struct BM25 {
/// IDF scores for terms
pub idf: HashMap<String, f32>,
/// Average document length
pub avg_doc_len: f32,
/// Document lengths
pub doc_lengths: HashMap<VectorId, usize>,
/// Inverted index: term -> set of doc IDs
pub inverted_index: HashMap<String, HashSet<VectorId>>,
/// BM25 parameters
pub k1: f32,
pub b: f32,
}
impl BM25 {
/// Create a new BM25 instance
pub fn new(k1: f32, b: f32) -> Self {
Self {
idf: HashMap::new(),
avg_doc_len: 0.0,
doc_lengths: HashMap::new(),
inverted_index: HashMap::new(),
k1,
b,
}
}
/// Index a document
pub fn index_document(&mut self, doc_id: VectorId, text: &str) {
let terms = tokenize(text);
self.doc_lengths.insert(doc_id.clone(), terms.len());
// Update inverted index
for term in terms {
self.inverted_index
.entry(term)
.or_default()
.insert(doc_id.clone());
}
// Update average document length
let total_len: usize = self.doc_lengths.values().sum();
self.avg_doc_len = total_len as f32 / self.doc_lengths.len() as f32;
}
/// Build IDF scores after indexing all documents
pub fn build_idf(&mut self) {
let num_docs = self.doc_lengths.len() as f32;
for (term, doc_set) in &self.inverted_index {
let doc_freq = doc_set.len() as f32;
let idf = ((num_docs - doc_freq + 0.5) / (doc_freq + 0.5) + 1.0).ln();
self.idf.insert(term.clone(), idf);
}
}
/// Compute BM25 score for a query against a document
pub fn score(&self, query: &str, doc_id: &VectorId, doc_text: &str) -> f32 {
let query_terms = tokenize(query);
let doc_terms = tokenize(doc_text);
let doc_len = self.doc_lengths.get(doc_id).copied().unwrap_or(0) as f32;
// Count term frequencies in document
let mut term_freq: HashMap<String, f32> = HashMap::new();
for term in doc_terms {
*term_freq.entry(term).or_insert(0.0) += 1.0;
}
// Calculate BM25 score
let mut score = 0.0;
for term in query_terms {
let idf = self.idf.get(&term).copied().unwrap_or(0.0);
let tf = term_freq.get(&term).copied().unwrap_or(0.0);
let numerator = tf * (self.k1 + 1.0);
let denominator = tf + self.k1 * (1.0 - self.b + self.b * (doc_len / self.avg_doc_len));
score += idf * (numerator / denominator);
}
score
}
/// Get all documents containing at least one query term
pub fn get_candidate_docs(&self, query: &str) -> HashSet<VectorId> {
let query_terms = tokenize(query);
let mut candidates = HashSet::new();
for term in query_terms {
if let Some(doc_set) = self.inverted_index.get(&term) {
candidates.extend(doc_set.iter().cloned());
}
}
candidates
}
}
/// Hybrid search combining vector and keyword matching
#[derive(Debug, Clone)]
pub struct HybridSearch {
/// Configuration
pub config: HybridConfig,
/// BM25 index for keyword matching
pub bm25: BM25,
/// Document texts for BM25 scoring
pub doc_texts: HashMap<VectorId, String>,
}
impl HybridSearch {
/// Create a new hybrid search instance
pub fn new(config: HybridConfig) -> Self {
Self {
config,
bm25: BM25::new(1.5, 0.75), // Standard BM25 parameters
doc_texts: HashMap::new(),
}
}
/// Index a document with both vector and text
pub fn index_document(&mut self, doc_id: VectorId, text: String) {
self.bm25.index_document(doc_id.clone(), &text);
self.doc_texts.insert(doc_id, text);
}
/// Finalize indexing (build IDF scores)
pub fn finalize_indexing(&mut self) {
self.bm25.build_idf();
}
/// Perform hybrid search
///
/// # Arguments
/// * `query_vector` - Query vector for semantic search
/// * `query_text` - Query text for keyword matching
/// * `k` - Number of results to return
/// * `vector_search_fn` - Function to perform vector similarity search
///
/// # Returns
/// Combined and reranked search results
pub fn search<F>(
&self,
query_vector: &[f32],
query_text: &str,
k: usize,
vector_search_fn: F,
) -> Result<Vec<SearchResult>>
where
F: Fn(&[f32], usize) -> Result<Vec<SearchResult>>,
{
// Get vector similarity results
let vector_results = vector_search_fn(query_vector, k * 2)?;
// Get keyword matching candidates
let keyword_candidates = self.bm25.get_candidate_docs(query_text);
// Compute BM25 scores for all candidates
let mut bm25_scores: HashMap<VectorId, f32> = HashMap::new();
for doc_id in &keyword_candidates {
if let Some(doc_text) = self.doc_texts.get(doc_id) {
let score = self.bm25.score(query_text, doc_id, doc_text);
bm25_scores.insert(doc_id.clone(), score);
}
}
// Combine results
let mut combined_results: HashMap<VectorId, CombinedScore> = HashMap::new();
// Add vector results
for result in vector_results {
combined_results.insert(
result.id.clone(),
CombinedScore {
id: result.id.clone(),
vector_score: Some(result.score),
keyword_score: bm25_scores.get(&result.id).copied(),
vector: result.vector,
metadata: result.metadata,
},
);
}
// Add keyword-only results
for (doc_id, bm25_score) in bm25_scores {
combined_results
.entry(doc_id.clone())
.or_insert(CombinedScore {
id: doc_id,
vector_score: None,
keyword_score: Some(bm25_score),
vector: None,
metadata: None,
});
}
// Normalize and combine scores
let normalized_results =
self.normalize_and_combine(combined_results.into_values().collect())?;
// Sort by combined score (descending)
let mut sorted_results = normalized_results;
sorted_results.sort_by(|a, b| b.score.partial_cmp(&a.score).unwrap());
// Return top-k
Ok(sorted_results.into_iter().take(k).collect())
}
/// Normalize and combine scores
fn normalize_and_combine(&self, results: Vec<CombinedScore>) -> Result<Vec<SearchResult>> {
let mut vector_scores: Vec<f32> = results.iter().filter_map(|r| r.vector_score).collect();
let mut keyword_scores: Vec<f32> = results.iter().filter_map(|r| r.keyword_score).collect();
// Normalize scores
normalize_scores(&mut vector_scores, self.config.normalization);
normalize_scores(&mut keyword_scores, self.config.normalization);
// Create lookup maps
let mut vector_map: HashMap<VectorId, f32> = HashMap::new();
let mut keyword_map: HashMap<VectorId, f32> = HashMap::new();
for (result, &norm_score) in results.iter().zip(&vector_scores) {
if result.vector_score.is_some() {
vector_map.insert(result.id.clone(), norm_score);
}
}
for (result, &norm_score) in results.iter().zip(&keyword_scores) {
if result.keyword_score.is_some() {
keyword_map.insert(result.id.clone(), norm_score);
}
}
// Combine scores
let combined: Vec<SearchResult> = results
.into_iter()
.map(|result| {
let vector_norm = vector_map.get(&result.id).copied().unwrap_or(0.0);
let keyword_norm = keyword_map.get(&result.id).copied().unwrap_or(0.0);
let combined_score = self.config.vector_weight * vector_norm
+ self.config.keyword_weight * keyword_norm;
SearchResult {
id: result.id,
score: combined_score,
vector: result.vector,
metadata: result.metadata,
}
})
.collect();
Ok(combined)
}
}
/// Combined score holder
#[derive(Debug, Clone)]
struct CombinedScore {
id: VectorId,
vector_score: Option<f32>,
keyword_score: Option<f32>,
vector: Option<Vec<f32>>,
metadata: Option<HashMap<String, serde_json::Value>>,
}
// Helper functions
fn tokenize(text: &str) -> Vec<String> {
text.to_lowercase()
.split_whitespace()
.filter(|s| s.len() > 2) // Remove very short tokens
.map(|s| s.trim_matches(|c: char| !c.is_alphanumeric()).to_string())
.filter(|s| !s.is_empty())
.collect()
}
fn normalize_scores(scores: &mut [f32], strategy: NormalizationStrategy) {
if scores.is_empty() {
return;
}
match strategy {
NormalizationStrategy::MinMax => {
let min = scores.iter().fold(f32::INFINITY, |a, &b| a.min(b));
let max = scores.iter().fold(f32::NEG_INFINITY, |a, &b| a.max(b));
let range = max - min;
if range > 0.0 {
for score in scores.iter_mut() {
*score = (*score - min) / range;
}
}
}
NormalizationStrategy::ZScore => {
let mean = scores.iter().sum::<f32>() / scores.len() as f32;
let variance =
scores.iter().map(|&s| (s - mean).powi(2)).sum::<f32>() / scores.len() as f32;
let std = variance.sqrt();
if std > 0.0 {
for score in scores.iter_mut() {
*score = (*score - mean) / std;
}
}
}
NormalizationStrategy::None => {}
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_tokenize() {
let text = "The quick brown fox jumps over the lazy dog!";
let tokens = tokenize(text);
assert!(tokens.contains(&"quick".to_string()));
assert!(tokens.contains(&"brown".to_string()));
assert!(tokens.contains(&"the".to_string())); // "the" is 3 chars, passes > 2 filter
assert!(!tokens.contains(&"a".to_string())); // 1 char, too short
}
#[test]
fn test_bm25_indexing() {
let mut bm25 = BM25::new(1.5, 0.75);
bm25.index_document("doc1".to_string(), "rust programming language");
bm25.index_document("doc2".to_string(), "python programming tutorial");
bm25.build_idf();
assert_eq!(bm25.doc_lengths.len(), 2);
assert!(bm25.idf.contains_key("rust"));
assert!(bm25.idf.contains_key("programming"));
}
#[test]
fn test_bm25_scoring() {
let mut bm25 = BM25::new(1.5, 0.75);
bm25.index_document("doc1".to_string(), "rust programming language");
bm25.index_document("doc2".to_string(), "python programming tutorial");
bm25.index_document("doc3".to_string(), "rust systems programming");
bm25.build_idf();
let score1 = bm25.score(
"rust programming",
&"doc1".to_string(),
"rust programming language",
);
let score2 = bm25.score(
"rust programming",
&"doc2".to_string(),
"python programming tutorial",
);
// doc1 should score higher (contains both terms)
assert!(score1 > score2);
}
#[test]
fn test_hybrid_search_initialization() {
let config = HybridConfig::default();
let mut hybrid = HybridSearch::new(config);
hybrid.index_document("doc1".to_string(), "rust vector database".to_string());
hybrid.index_document("doc2".to_string(), "python machine learning".to_string());
hybrid.finalize_indexing();
assert_eq!(hybrid.doc_texts.len(), 2);
assert_eq!(hybrid.bm25.doc_lengths.len(), 2);
}
#[test]
fn test_normalize_minmax() {
let mut scores = vec![1.0, 2.0, 3.0, 4.0, 5.0];
normalize_scores(&mut scores, NormalizationStrategy::MinMax);
assert!((scores[0] - 0.0).abs() < 0.01);
assert!((scores[4] - 1.0).abs() < 0.01);
assert!((scores[2] - 0.5).abs() < 0.01);
}
#[test]
fn test_bm25_candidate_retrieval() {
let mut bm25 = BM25::new(1.5, 0.75);
bm25.index_document("doc1".to_string(), "rust programming");
bm25.index_document("doc2".to_string(), "python programming");
bm25.index_document("doc3".to_string(), "java development");
bm25.build_idf();
let candidates = bm25.get_candidate_docs("rust programming");
assert!(candidates.contains(&"doc1".to_string()));
assert!(candidates.contains(&"doc2".to_string())); // Contains "programming"
assert!(!candidates.contains(&"doc3".to_string()));
}
}

View File

@@ -0,0 +1,336 @@
//! Maximal Marginal Relevance (MMR) for Diversity-Aware Search
//!
//! Implements MMR algorithm to balance relevance and diversity in search results:
//! MMR = λ × Similarity(query, doc) - (1-λ) × max Similarity(doc, selected_docs)
use crate::error::{Result, RuvectorError};
use crate::types::{DistanceMetric, SearchResult};
use serde::{Deserialize, Serialize};
/// Configuration for MMR search
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct MMRConfig {
/// Lambda parameter: balance between relevance (1.0) and diversity (0.0)
/// - λ = 1.0: Pure relevance (standard similarity search)
/// - λ = 0.5: Equal balance
/// - λ = 0.0: Pure diversity
pub lambda: f32,
/// Distance metric for similarity computation
pub metric: DistanceMetric,
/// Fetch multiplier for initial candidates (fetch k * multiplier results)
pub fetch_multiplier: f32,
}
impl Default for MMRConfig {
fn default() -> Self {
Self {
lambda: 0.5,
metric: DistanceMetric::Cosine,
fetch_multiplier: 2.0,
}
}
}
/// MMR search implementation
#[derive(Debug, Clone)]
pub struct MMRSearch {
/// Configuration
pub config: MMRConfig,
}
impl MMRSearch {
/// Create a new MMR search instance
pub fn new(config: MMRConfig) -> Result<Self> {
if !(0.0..=1.0).contains(&config.lambda) {
return Err(RuvectorError::InvalidParameter(format!(
"Lambda must be in [0, 1], got {}",
config.lambda
)));
}
Ok(Self { config })
}
/// Perform MMR-based reranking of search results
///
/// # Arguments
/// * `query` - Query vector
/// * `candidates` - Initial search results (sorted by relevance)
/// * `k` - Number of diverse results to return
///
/// # Returns
/// Reranked results optimizing for both relevance and diversity
pub fn rerank(
&self,
query: &[f32],
candidates: Vec<SearchResult>,
k: usize,
) -> Result<Vec<SearchResult>> {
if candidates.is_empty() {
return Ok(Vec::new());
}
if k == 0 {
return Ok(Vec::new());
}
if k >= candidates.len() {
return Ok(candidates);
}
let mut selected: Vec<SearchResult> = Vec::with_capacity(k);
let mut remaining = candidates;
// Iteratively select documents maximizing MMR
for _ in 0..k {
if remaining.is_empty() {
break;
}
// Compute MMR score for each remaining candidate
let mut best_idx = 0;
let mut best_mmr = f32::NEG_INFINITY;
for (idx, candidate) in remaining.iter().enumerate() {
let mmr_score = self.compute_mmr_score(query, candidate, &selected)?;
if mmr_score > best_mmr {
best_mmr = mmr_score;
best_idx = idx;
}
}
// Move best candidate to selected set
let best = remaining.remove(best_idx);
selected.push(best);
}
Ok(selected)
}
/// Compute MMR score for a candidate
fn compute_mmr_score(
&self,
_query: &[f32],
candidate: &SearchResult,
selected: &[SearchResult],
) -> Result<f32> {
let candidate_vec = candidate.vector.as_ref().ok_or_else(|| {
RuvectorError::InvalidParameter("Candidate vector not available".to_string())
})?;
// Relevance: similarity to query (convert distance to similarity)
let relevance = self.distance_to_similarity(candidate.score);
// Diversity: max similarity to already selected documents
let max_similarity = if selected.is_empty() {
0.0
} else {
selected
.iter()
.filter_map(|s| s.vector.as_ref())
.map(|selected_vec| {
let dist = compute_distance(candidate_vec, selected_vec, self.config.metric);
self.distance_to_similarity(dist)
})
.max_by(|a, b| a.partial_cmp(b).unwrap())
.unwrap_or(0.0)
};
// MMR = λ × relevance - (1-λ) × max_similarity
let mmr = self.config.lambda * relevance - (1.0 - self.config.lambda) * max_similarity;
Ok(mmr)
}
/// Convert distance to similarity (higher is better)
fn distance_to_similarity(&self, distance: f32) -> f32 {
match self.config.metric {
DistanceMetric::Cosine => 1.0 - distance,
DistanceMetric::Euclidean => 1.0 / (1.0 + distance),
DistanceMetric::Manhattan => 1.0 / (1.0 + distance),
DistanceMetric::DotProduct => -distance, // Dot product is already similarity-like
}
}
/// Perform end-to-end MMR search
///
/// # Arguments
/// * `query` - Query vector
/// * `k` - Number of diverse results to return
/// * `search_fn` - Function to perform initial similarity search
///
/// # Returns
/// Diverse search results
pub fn search<F>(&self, query: &[f32], k: usize, search_fn: F) -> Result<Vec<SearchResult>>
where
F: Fn(&[f32], usize) -> Result<Vec<SearchResult>>,
{
// Fetch more candidates than needed
let fetch_k = (k as f32 * self.config.fetch_multiplier).ceil() as usize;
let candidates = search_fn(query, fetch_k)?;
// Rerank using MMR
self.rerank(query, candidates, k)
}
}
// Helper function
fn compute_distance(a: &[f32], b: &[f32], metric: DistanceMetric) -> f32 {
match metric {
DistanceMetric::Euclidean => euclidean_distance(a, b),
DistanceMetric::Cosine => cosine_distance(a, b),
DistanceMetric::Manhattan => manhattan_distance(a, b),
DistanceMetric::DotProduct => dot_product_distance(a, b),
}
}
fn euclidean_distance(a: &[f32], b: &[f32]) -> f32 {
a.iter()
.zip(b)
.map(|(x, y)| {
let diff = x - y;
diff * diff
})
.sum::<f32>()
.sqrt()
}
fn cosine_distance(a: &[f32], b: &[f32]) -> f32 {
let dot: f32 = a.iter().zip(b).map(|(x, y)| x * y).sum();
let norm_a: f32 = a.iter().map(|x| x * x).sum::<f32>().sqrt();
let norm_b: f32 = b.iter().map(|x| x * x).sum::<f32>().sqrt();
if norm_a == 0.0 || norm_b == 0.0 {
1.0
} else {
1.0 - (dot / (norm_a * norm_b))
}
}
fn manhattan_distance(a: &[f32], b: &[f32]) -> f32 {
a.iter().zip(b).map(|(x, y)| (x - y).abs()).sum()
}
fn dot_product_distance(a: &[f32], b: &[f32]) -> f32 {
let dot: f32 = a.iter().zip(b).map(|(x, y)| x * y).sum();
-dot
}
#[cfg(test)]
mod tests {
use super::*;
fn create_search_result(id: &str, score: f32, vector: Vec<f32>) -> SearchResult {
SearchResult {
id: id.to_string(),
score,
vector: Some(vector),
metadata: None,
}
}
#[test]
fn test_mmr_config_validation() {
let config = MMRConfig {
lambda: 0.5,
..Default::default()
};
assert!(MMRSearch::new(config).is_ok());
let invalid_config = MMRConfig {
lambda: 1.5,
..Default::default()
};
assert!(MMRSearch::new(invalid_config).is_err());
}
#[test]
fn test_mmr_reranking() {
let config = MMRConfig {
lambda: 0.5,
metric: DistanceMetric::Euclidean,
fetch_multiplier: 2.0,
};
let mmr = MMRSearch::new(config).unwrap();
let query = vec![1.0, 0.0, 0.0];
// Create candidates with varying similarity
let candidates = vec![
create_search_result("doc1", 0.1, vec![0.9, 0.1, 0.0]), // Very similar to query
create_search_result("doc2", 0.15, vec![0.9, 0.0, 0.1]), // Similar to doc1 and query
create_search_result("doc3", 0.5, vec![0.5, 0.5, 0.5]), // Different from doc1
create_search_result("doc4", 0.6, vec![0.0, 1.0, 0.0]), // Very different
];
let results = mmr.rerank(&query, candidates, 3).unwrap();
assert_eq!(results.len(), 3);
// First result should be most relevant
assert_eq!(results[0].id, "doc1");
// MMR should promote diversity, so doc3 or doc4 should appear
assert!(results.iter().any(|r| r.id == "doc3" || r.id == "doc4"));
}
#[test]
fn test_mmr_pure_relevance() {
let config = MMRConfig {
lambda: 1.0, // Pure relevance
metric: DistanceMetric::Euclidean,
fetch_multiplier: 2.0,
};
let mmr = MMRSearch::new(config).unwrap();
let query = vec![1.0, 0.0, 0.0];
let candidates = vec![
create_search_result("doc1", 0.1, vec![0.9, 0.1, 0.0]),
create_search_result("doc2", 0.15, vec![0.85, 0.1, 0.05]),
create_search_result("doc3", 0.5, vec![0.5, 0.5, 0.0]),
];
let results = mmr.rerank(&query, candidates, 2).unwrap();
// With lambda=1.0, should just preserve relevance order
assert_eq!(results[0].id, "doc1");
assert_eq!(results[1].id, "doc2");
}
#[test]
fn test_mmr_pure_diversity() {
let config = MMRConfig {
lambda: 0.0, // Pure diversity
metric: DistanceMetric::Euclidean,
fetch_multiplier: 2.0,
};
let mmr = MMRSearch::new(config).unwrap();
let query = vec![1.0, 0.0, 0.0];
let candidates = vec![
create_search_result("doc1", 0.1, vec![0.9, 0.1, 0.0]),
create_search_result("doc2", 0.15, vec![0.9, 0.0, 0.1]), // Very similar to doc1
create_search_result("doc3", 0.5, vec![0.0, 1.0, 0.0]), // Very different
];
let results = mmr.rerank(&query, candidates, 2).unwrap();
// With lambda=0.0, should maximize diversity
assert_eq!(results.len(), 2);
// Should not select both doc1 and doc2 (they're too similar)
let has_both_similar =
results.iter().any(|r| r.id == "doc1") && results.iter().any(|r| r.id == "doc2");
assert!(!has_both_similar);
}
#[test]
fn test_mmr_empty_candidates() {
let config = MMRConfig::default();
let mmr = MMRSearch::new(config).unwrap();
let query = vec![1.0, 0.0, 0.0];
let results = mmr.rerank(&query, Vec::new(), 5).unwrap();
assert!(results.is_empty());
}
}

View File

@@ -0,0 +1,549 @@
//! Enhanced Product Quantization with Precomputed Lookup Tables
//!
//! Provides 8-16x compression with 90-95% recall through:
//! - K-means clustering for codebook training
//! - Precomputed lookup tables for fast distance calculation
//! - Asymmetric distance computation (ADC)
use crate::error::{Result, RuvectorError};
use crate::types::DistanceMetric;
use serde::{Deserialize, Serialize};
use std::collections::HashMap;
/// Configuration for Enhanced Product Quantization
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct PQConfig {
/// Number of subspaces to split vector into
pub num_subspaces: usize,
/// Codebook size per subspace (typically 256)
pub codebook_size: usize,
/// Number of k-means iterations for training
pub num_iterations: usize,
/// Distance metric for codebook training
pub metric: DistanceMetric,
}
impl Default for PQConfig {
fn default() -> Self {
Self {
num_subspaces: 8,
codebook_size: 256,
num_iterations: 20,
metric: DistanceMetric::Euclidean,
}
}
}
impl PQConfig {
/// Validate the configuration
pub fn validate(&self) -> Result<()> {
if self.codebook_size > 256 {
return Err(RuvectorError::InvalidParameter(format!(
"Codebook size {} exceeds u8 maximum of 256",
self.codebook_size
)));
}
if self.num_subspaces == 0 {
return Err(RuvectorError::InvalidParameter(
"Number of subspaces must be greater than 0".to_string(),
));
}
Ok(())
}
}
/// Precomputed lookup table for fast distance computation
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct LookupTable {
/// Table: [subspace][centroid] -> distance to query subvector
pub tables: Vec<Vec<f32>>,
}
impl LookupTable {
/// Create a new lookup table for a query vector
pub fn new(query: &[f32], codebooks: &[Vec<Vec<f32>>], metric: DistanceMetric) -> Self {
let num_subspaces = codebooks.len();
let mut tables = Vec::with_capacity(num_subspaces);
for (subspace_idx, codebook) in codebooks.iter().enumerate() {
let subspace_dim = query.len() / num_subspaces;
let start = subspace_idx * subspace_dim;
let end = start + subspace_dim;
let query_subvector = &query[start..end];
// Compute distance from query subvector to each centroid
let distances: Vec<f32> = codebook
.iter()
.map(|centroid| compute_distance(query_subvector, centroid, metric))
.collect();
tables.push(distances);
}
Self { tables }
}
/// Compute distance to a quantized vector using the lookup table
#[inline]
pub fn distance(&self, codes: &[u8]) -> f32 {
codes
.iter()
.enumerate()
.map(|(subspace_idx, &code)| self.tables[subspace_idx][code as usize])
.sum()
}
}
/// Enhanced Product Quantization with lookup tables
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct EnhancedPQ {
/// Configuration
pub config: PQConfig,
/// Trained codebooks: [subspace][centroid_id][dimensions]
pub codebooks: Vec<Vec<Vec<f32>>>,
/// Dimensions of original vectors
pub dimensions: usize,
/// Quantized vectors storage: id -> codes
pub quantized_vectors: HashMap<String, Vec<u8>>,
}
impl EnhancedPQ {
/// Create a new Enhanced PQ instance
pub fn new(dimensions: usize, config: PQConfig) -> Result<Self> {
config.validate()?;
if dimensions == 0 {
return Err(RuvectorError::InvalidParameter(
"Dimensions must be greater than 0".to_string(),
));
}
if dimensions % config.num_subspaces != 0 {
return Err(RuvectorError::InvalidParameter(format!(
"Dimensions {} must be divisible by num_subspaces {}",
dimensions, config.num_subspaces
)));
}
Ok(Self {
config,
codebooks: Vec::new(),
dimensions,
quantized_vectors: HashMap::new(),
})
}
/// Train codebooks on a set of vectors using k-means clustering
pub fn train(&mut self, training_vectors: &[Vec<f32>]) -> Result<()> {
if training_vectors.is_empty() {
return Err(RuvectorError::InvalidParameter(
"Training set cannot be empty".to_string(),
));
}
if training_vectors[0].is_empty() {
return Err(RuvectorError::InvalidParameter(
"Training vectors cannot have zero dimensions".to_string(),
));
}
// Validate dimensions
for vec in training_vectors {
if vec.len() != self.dimensions {
return Err(RuvectorError::DimensionMismatch {
expected: self.dimensions,
actual: vec.len(),
});
}
}
let subspace_dim = self.dimensions / self.config.num_subspaces;
let mut codebooks = Vec::with_capacity(self.config.num_subspaces);
// Train a codebook for each subspace
for subspace_idx in 0..self.config.num_subspaces {
let start = subspace_idx * subspace_dim;
let end = start + subspace_dim;
// Extract subspace vectors
let subspace_vectors: Vec<Vec<f32>> = training_vectors
.iter()
.map(|v| v[start..end].to_vec())
.collect();
// Run k-means clustering
let codebook = kmeans_clustering(
&subspace_vectors,
self.config.codebook_size,
self.config.num_iterations,
self.config.metric,
)?;
codebooks.push(codebook);
}
self.codebooks = codebooks;
Ok(())
}
/// Encode a vector into PQ codes
pub fn encode(&self, vector: &[f32]) -> Result<Vec<u8>> {
if vector.len() != self.dimensions {
return Err(RuvectorError::DimensionMismatch {
expected: self.dimensions,
actual: vector.len(),
});
}
if self.codebooks.is_empty() {
return Err(RuvectorError::InvalidParameter(
"Codebooks not trained yet".to_string(),
));
}
let subspace_dim = self.dimensions / self.config.num_subspaces;
let mut codes = Vec::with_capacity(self.config.num_subspaces);
for (subspace_idx, codebook) in self.codebooks.iter().enumerate() {
let start = subspace_idx * subspace_dim;
let end = start + subspace_dim;
let subvector = &vector[start..end];
// Find nearest centroid (quantization)
let code = find_nearest_centroid(subvector, codebook, self.config.metric)?;
codes.push(code);
}
Ok(codes)
}
/// Add a quantized vector
pub fn add_quantized(&mut self, id: String, vector: &[f32]) -> Result<()> {
let codes = self.encode(vector)?;
self.quantized_vectors.insert(id, codes);
Ok(())
}
/// Create a lookup table for fast distance computation
pub fn create_lookup_table(&self, query: &[f32]) -> Result<LookupTable> {
if query.len() != self.dimensions {
return Err(RuvectorError::DimensionMismatch {
expected: self.dimensions,
actual: query.len(),
});
}
if self.codebooks.is_empty() {
return Err(RuvectorError::InvalidParameter(
"Codebooks not trained yet".to_string(),
));
}
Ok(LookupTable::new(query, &self.codebooks, self.config.metric))
}
/// Search for nearest neighbors using ADC (Asymmetric Distance Computation)
pub fn search(&self, query: &[f32], k: usize) -> Result<Vec<(String, f32)>> {
let lookup_table = self.create_lookup_table(query)?;
// Compute distances using lookup table
let mut distances: Vec<(String, f32)> = self
.quantized_vectors
.iter()
.map(|(id, codes)| (id.clone(), lookup_table.distance(codes)))
.collect();
// Sort by distance (ascending)
distances.sort_by(|a, b| a.1.partial_cmp(&b.1).unwrap());
// Return top-k
Ok(distances.into_iter().take(k).collect())
}
/// Reconstruct approximate vector from codes
pub fn reconstruct(&self, codes: &[u8]) -> Result<Vec<f32>> {
if codes.len() != self.config.num_subspaces {
return Err(RuvectorError::InvalidParameter(format!(
"Expected {} codes, got {}",
self.config.num_subspaces,
codes.len()
)));
}
let mut result = Vec::with_capacity(self.dimensions);
for (subspace_idx, &code) in codes.iter().enumerate() {
let centroid = &self.codebooks[subspace_idx][code as usize];
result.extend_from_slice(centroid);
}
Ok(result)
}
/// Get compression ratio
pub fn compression_ratio(&self) -> f32 {
let original_bytes = self.dimensions * 4; // f32 = 4 bytes
let compressed_bytes = self.config.num_subspaces; // 1 byte per subspace
original_bytes as f32 / compressed_bytes as f32
}
}
// Helper functions
fn compute_distance(a: &[f32], b: &[f32], metric: DistanceMetric) -> f32 {
match metric {
DistanceMetric::Euclidean => euclidean_squared(a, b).sqrt(),
DistanceMetric::Cosine => {
let dot: f32 = a.iter().zip(b).map(|(x, y)| x * y).sum();
let norm_a: f32 = a.iter().map(|x| x * x).sum::<f32>().sqrt();
let norm_b: f32 = b.iter().map(|x| x * x).sum::<f32>().sqrt();
if norm_a == 0.0 || norm_b == 0.0 {
1.0
} else {
1.0 - (dot / (norm_a * norm_b))
}
}
DistanceMetric::DotProduct => {
let dot: f32 = a.iter().zip(b).map(|(x, y)| x * y).sum();
-dot // Negative for minimization
}
DistanceMetric::Manhattan => a.iter().zip(b).map(|(x, y)| (x - y).abs()).sum(),
}
}
fn euclidean_squared(a: &[f32], b: &[f32]) -> f32 {
a.iter()
.zip(b)
.map(|(x, y)| {
let diff = x - y;
diff * diff
})
.sum()
}
fn find_nearest_centroid(
vector: &[f32],
codebook: &[Vec<f32>],
metric: DistanceMetric,
) -> Result<u8> {
codebook
.iter()
.enumerate()
.min_by(|(_, a), (_, b)| {
let dist_a = compute_distance(vector, a, metric);
let dist_b = compute_distance(vector, b, metric);
dist_a.partial_cmp(&dist_b).unwrap()
})
.map(|(idx, _)| idx as u8)
.ok_or_else(|| RuvectorError::Internal("Empty codebook".to_string()))
}
fn kmeans_clustering(
vectors: &[Vec<f32>],
k: usize,
iterations: usize,
metric: DistanceMetric,
) -> Result<Vec<Vec<f32>>> {
use rand::seq::SliceRandom;
use rand::thread_rng;
if vectors.is_empty() {
return Err(RuvectorError::InvalidParameter(
"Cannot cluster empty vector set".to_string(),
));
}
if vectors[0].is_empty() {
return Err(RuvectorError::InvalidParameter(
"Cannot cluster vectors with zero dimensions".to_string(),
));
}
if k > vectors.len() {
return Err(RuvectorError::InvalidParameter(format!(
"k ({}) cannot be larger than number of vectors ({})",
k,
vectors.len()
)));
}
if k > 256 {
return Err(RuvectorError::InvalidParameter(format!(
"k ({}) exceeds u8 maximum of 256 for codebook size",
k
)));
}
let mut rng = thread_rng();
let dim = vectors[0].len();
// Initialize centroids using k-means++
let mut centroids: Vec<Vec<f32>> = Vec::with_capacity(k);
centroids.push(vectors.choose(&mut rng).unwrap().clone());
while centroids.len() < k {
let distances: Vec<f32> = vectors
.iter()
.map(|v| {
centroids
.iter()
.map(|c| compute_distance(v, c, metric))
.min_by(|a, b| a.partial_cmp(b).unwrap())
.unwrap_or(f32::MAX)
})
.collect();
let total: f32 = distances.iter().sum();
let mut rand_val = rand::random::<f32>() * total;
for (i, &dist) in distances.iter().enumerate() {
rand_val -= dist;
if rand_val <= 0.0 {
centroids.push(vectors[i].clone());
break;
}
}
// Fallback if we didn't select anything
if centroids.len() < k && centroids.len() == centroids.len() {
centroids.push(vectors.choose(&mut rng).unwrap().clone());
}
}
// Lloyd's algorithm
for _ in 0..iterations {
let mut assignments: Vec<Vec<Vec<f32>>> = vec![Vec::new(); k];
// Assignment step
for vector in vectors {
let nearest = centroids
.iter()
.enumerate()
.min_by(|(_, a), (_, b)| {
let dist_a = compute_distance(vector, a, metric);
let dist_b = compute_distance(vector, b, metric);
dist_a.partial_cmp(&dist_b).unwrap()
})
.map(|(idx, _)| idx)
.unwrap_or(0);
assignments[nearest].push(vector.clone());
}
// Update step
for (centroid, assigned) in centroids.iter_mut().zip(&assignments) {
if !assigned.is_empty() {
*centroid = vec![0.0; dim];
for vector in assigned {
for (i, &v) in vector.iter().enumerate() {
centroid[i] += v;
}
}
let count = assigned.len() as f32;
for v in centroid.iter_mut() {
*v /= count;
}
}
}
}
Ok(centroids)
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_pq_config_default() {
let config = PQConfig::default();
assert_eq!(config.num_subspaces, 8);
assert_eq!(config.codebook_size, 256);
}
#[test]
fn test_enhanced_pq_creation() {
let config = PQConfig {
num_subspaces: 4,
codebook_size: 16,
num_iterations: 10,
metric: DistanceMetric::Euclidean,
};
let pq = EnhancedPQ::new(128, config).unwrap();
assert_eq!(pq.dimensions, 128);
assert_eq!(pq.config.num_subspaces, 4);
}
#[test]
fn test_pq_training_and_encoding() {
let config = PQConfig {
num_subspaces: 2,
codebook_size: 4,
num_iterations: 5,
metric: DistanceMetric::Euclidean,
};
let mut pq = EnhancedPQ::new(4, config).unwrap();
// Generate training data
let training_data = vec![
vec![1.0, 2.0, 3.0, 4.0],
vec![2.0, 3.0, 4.0, 5.0],
vec![3.0, 4.0, 5.0, 6.0],
vec![4.0, 5.0, 6.0, 7.0],
vec![5.0, 6.0, 7.0, 8.0],
];
pq.train(&training_data).unwrap();
assert_eq!(pq.codebooks.len(), 2);
// Test encoding
let vector = vec![2.5, 3.5, 4.5, 5.5];
let codes = pq.encode(&vector).unwrap();
assert_eq!(codes.len(), 2);
}
#[test]
fn test_lookup_table_creation() {
let config = PQConfig {
num_subspaces: 2,
codebook_size: 4,
num_iterations: 5,
metric: DistanceMetric::Euclidean,
};
let mut pq = EnhancedPQ::new(4, config).unwrap();
let training_data = vec![
vec![1.0, 2.0, 3.0, 4.0],
vec![2.0, 3.0, 4.0, 5.0],
vec![3.0, 4.0, 5.0, 6.0],
vec![4.0, 5.0, 6.0, 7.0],
];
pq.train(&training_data).unwrap();
let query = vec![2.5, 3.5, 4.5, 5.5];
let lookup_table = pq.create_lookup_table(&query).unwrap();
assert_eq!(lookup_table.tables.len(), 2);
assert_eq!(lookup_table.tables[0].len(), 4);
}
#[test]
fn test_compression_ratio() {
let config = PQConfig {
num_subspaces: 8,
codebook_size: 256,
num_iterations: 10,
metric: DistanceMetric::Euclidean,
};
let pq = EnhancedPQ::new(128, config).unwrap();
let ratio = pq.compression_ratio();
assert_eq!(ratio, 64.0); // 128 * 4 / 8 = 64
}
}

File diff suppressed because it is too large Load Diff

View File

@@ -0,0 +1,704 @@
//! Arena allocator for batch operations
//!
//! This module provides arena-based memory allocation to reduce allocation
//! overhead in hot paths and improve memory locality.
//!
//! ## Features (ADR-001)
//!
//! - **Cache-aligned allocations**: All allocations are aligned to cache line boundaries (64 bytes)
//! - **Bump allocation**: O(1) allocation with minimal overhead
//! - **Batch deallocation**: Free all allocations at once via `reset()`
//! - **Thread-local arenas**: Per-thread allocation without synchronization
use std::alloc::{alloc, dealloc, Layout};
use std::cell::RefCell;
use std::ptr;
/// Cache line size (typically 64 bytes on modern CPUs)
pub const CACHE_LINE_SIZE: usize = 64;
/// Arena allocator for temporary allocations
///
/// Use this for batch operations where many temporary allocations
/// are needed and can be freed all at once.
pub struct Arena {
chunks: RefCell<Vec<Chunk>>,
chunk_size: usize,
}
struct Chunk {
data: *mut u8,
capacity: usize,
used: usize,
}
impl Arena {
/// Create a new arena with the specified chunk size
pub fn new(chunk_size: usize) -> Self {
Self {
chunks: RefCell::new(Vec::new()),
chunk_size,
}
}
/// Create an arena with a default 1MB chunk size
pub fn with_default_chunk_size() -> Self {
Self::new(1024 * 1024) // 1MB
}
/// Allocate a buffer of the specified size
pub fn alloc_vec<T>(&self, count: usize) -> ArenaVec<T> {
let size = count * std::mem::size_of::<T>();
let align = std::mem::align_of::<T>();
let ptr = self.alloc_raw(size, align);
ArenaVec {
ptr: ptr as *mut T,
len: 0,
capacity: count,
_phantom: std::marker::PhantomData,
}
}
/// Allocate raw bytes with specified alignment
fn alloc_raw(&self, size: usize, align: usize) -> *mut u8 {
// SECURITY: Validate alignment is a power of 2 and size is reasonable
assert!(
align > 0 && align.is_power_of_two(),
"Alignment must be a power of 2"
);
assert!(size > 0, "Cannot allocate zero bytes");
assert!(size <= isize::MAX as usize, "Allocation size too large");
let mut chunks = self.chunks.borrow_mut();
// Try to allocate from the last chunk
if let Some(chunk) = chunks.last_mut() {
// Align the current position
let current = chunk.used;
let aligned = (current + align - 1) & !(align - 1);
// SECURITY: Check for overflow in alignment calculation
if aligned < current {
panic!("Alignment calculation overflow");
}
let needed = aligned
.checked_add(size)
.expect("Arena allocation size overflow");
if needed <= chunk.capacity {
chunk.used = needed;
return unsafe {
// SECURITY: Verify pointer arithmetic doesn't overflow
let ptr = chunk.data.add(aligned);
debug_assert!(ptr as usize >= chunk.data as usize, "Pointer underflow");
ptr
};
}
}
// Need a new chunk
let chunk_size = self.chunk_size.max(size + align);
let layout = Layout::from_size_align(chunk_size, 64).unwrap();
let data = unsafe { alloc(layout) };
let aligned = align;
let chunk = Chunk {
data,
capacity: chunk_size,
used: aligned + size,
};
let ptr = unsafe { data.add(aligned) };
chunks.push(chunk);
ptr
}
/// Reset the arena, allowing reuse of allocated memory
pub fn reset(&self) {
let mut chunks = self.chunks.borrow_mut();
for chunk in chunks.iter_mut() {
chunk.used = 0;
}
}
/// Get total allocated bytes
pub fn allocated_bytes(&self) -> usize {
let chunks = self.chunks.borrow();
chunks.iter().map(|c| c.capacity).sum()
}
/// Get used bytes
pub fn used_bytes(&self) -> usize {
let chunks = self.chunks.borrow();
chunks.iter().map(|c| c.used).sum()
}
}
impl Drop for Arena {
fn drop(&mut self) {
let chunks = self.chunks.borrow();
for chunk in chunks.iter() {
let layout = Layout::from_size_align(chunk.capacity, 64).unwrap();
unsafe {
dealloc(chunk.data, layout);
}
}
}
}
/// Vector allocated from an arena
pub struct ArenaVec<T> {
ptr: *mut T,
len: usize,
capacity: usize,
_phantom: std::marker::PhantomData<T>,
}
impl<T> ArenaVec<T> {
/// Push an element (panics if capacity exceeded)
pub fn push(&mut self, value: T) {
// SECURITY: Bounds check before pointer arithmetic
assert!(self.len < self.capacity, "ArenaVec capacity exceeded");
assert!(!self.ptr.is_null(), "ArenaVec pointer is null");
unsafe {
// Additional safety: verify the pointer offset is within bounds
let offset_ptr = self.ptr.add(self.len);
debug_assert!(
offset_ptr as usize >= self.ptr as usize,
"Pointer arithmetic overflow"
);
ptr::write(offset_ptr, value);
}
self.len += 1;
}
/// Get length
pub fn len(&self) -> usize {
self.len
}
/// Check if empty
pub fn is_empty(&self) -> bool {
self.len == 0
}
/// Get capacity
pub fn capacity(&self) -> usize {
self.capacity
}
/// Get as slice
pub fn as_slice(&self) -> &[T] {
// SECURITY: Bounds check before creating slice
assert!(self.len <= self.capacity, "Length exceeds capacity");
assert!(!self.ptr.is_null(), "Cannot create slice from null pointer");
unsafe { std::slice::from_raw_parts(self.ptr, self.len) }
}
/// Get as mutable slice
pub fn as_mut_slice(&mut self) -> &mut [T] {
// SECURITY: Bounds check before creating slice
assert!(self.len <= self.capacity, "Length exceeds capacity");
assert!(!self.ptr.is_null(), "Cannot create slice from null pointer");
unsafe { std::slice::from_raw_parts_mut(self.ptr, self.len) }
}
}
impl<T> std::ops::Deref for ArenaVec<T> {
type Target = [T];
fn deref(&self) -> &[T] {
self.as_slice()
}
}
impl<T> std::ops::DerefMut for ArenaVec<T> {
fn deref_mut(&mut self) -> &mut [T] {
self.as_mut_slice()
}
}
// Thread-local arena for per-thread allocations
thread_local! {
static THREAD_ARENA: RefCell<Arena> = RefCell::new(Arena::with_default_chunk_size());
}
// Get the thread-local arena
// Note: Commented out due to lifetime issues with RefCell::borrow() escaping closure
// Use THREAD_ARENA.with(|arena| { ... }) directly instead
/*
pub fn thread_arena() -> impl std::ops::Deref<Target = Arena> {
THREAD_ARENA.with(|arena| {
arena.borrow()
})
}
*/
/// Cache-aligned vector storage for SIMD operations (ADR-001)
///
/// Ensures vectors are aligned to cache line boundaries (64 bytes) for
/// optimal SIMD operations and minimal cache misses.
#[repr(C, align(64))]
pub struct CacheAlignedVec {
data: *mut f32,
len: usize,
capacity: usize,
}
impl CacheAlignedVec {
/// Create a new cache-aligned vector with the given capacity
///
/// # Panics
///
/// Panics if memory allocation fails. For fallible allocation,
/// use `try_with_capacity`.
pub fn with_capacity(capacity: usize) -> Self {
Self::try_with_capacity(capacity).expect("Failed to allocate cache-aligned memory")
}
/// Try to create a new cache-aligned vector with the given capacity
///
/// Returns `None` if memory allocation fails.
pub fn try_with_capacity(capacity: usize) -> Option<Self> {
// Handle zero capacity case
if capacity == 0 {
return Some(Self {
data: std::ptr::null_mut(),
len: 0,
capacity: 0,
});
}
// Allocate cache-line aligned memory
let layout =
Layout::from_size_align(capacity * std::mem::size_of::<f32>(), CACHE_LINE_SIZE).ok()?;
let data = unsafe { alloc(layout) as *mut f32 };
// SECURITY: Check for allocation failure
if data.is_null() {
return None;
}
Some(Self {
data,
len: 0,
capacity,
})
}
/// Create from an existing slice, copying data to cache-aligned storage
///
/// # Panics
///
/// Panics if memory allocation fails. For fallible allocation,
/// use `try_from_slice`.
pub fn from_slice(slice: &[f32]) -> Self {
Self::try_from_slice(slice).expect("Failed to allocate cache-aligned memory for slice")
}
/// Try to create from an existing slice, copying data to cache-aligned storage
///
/// Returns `None` if memory allocation fails.
pub fn try_from_slice(slice: &[f32]) -> Option<Self> {
let mut vec = Self::try_with_capacity(slice.len())?;
if !slice.is_empty() {
unsafe {
ptr::copy_nonoverlapping(slice.as_ptr(), vec.data, slice.len());
}
}
vec.len = slice.len();
Some(vec)
}
/// Push an element
///
/// # Panics
///
/// Panics if capacity is exceeded or if the vector has zero capacity.
pub fn push(&mut self, value: f32) {
assert!(
self.len < self.capacity,
"CacheAlignedVec capacity exceeded"
);
assert!(
!self.data.is_null(),
"Cannot push to zero-capacity CacheAlignedVec"
);
unsafe {
*self.data.add(self.len) = value;
}
self.len += 1;
}
/// Get length
#[inline]
pub fn len(&self) -> usize {
self.len
}
/// Check if empty
#[inline]
pub fn is_empty(&self) -> bool {
self.len == 0
}
/// Get capacity
#[inline]
pub fn capacity(&self) -> usize {
self.capacity
}
/// Get as slice
#[inline]
pub fn as_slice(&self) -> &[f32] {
if self.len == 0 {
// SAFETY: Empty slice doesn't require valid pointer
return &[];
}
// SAFETY: data is valid for len elements when len > 0
unsafe { std::slice::from_raw_parts(self.data, self.len) }
}
/// Get as mutable slice
#[inline]
pub fn as_mut_slice(&mut self) -> &mut [f32] {
if self.len == 0 {
// SAFETY: Empty slice doesn't require valid pointer
return &mut [];
}
// SAFETY: data is valid for len elements when len > 0
unsafe { std::slice::from_raw_parts_mut(self.data, self.len) }
}
/// Get raw pointer (for SIMD operations)
#[inline]
pub fn as_ptr(&self) -> *const f32 {
self.data
}
/// Get mutable raw pointer (for SIMD operations)
#[inline]
pub fn as_mut_ptr(&mut self) -> *mut f32 {
self.data
}
/// Check if properly aligned for SIMD
///
/// Returns `true` for zero-capacity vectors (considered trivially aligned).
#[inline]
pub fn is_aligned(&self) -> bool {
if self.data.is_null() {
// Zero-capacity vectors are considered aligned
return self.capacity == 0;
}
(self.data as usize) % CACHE_LINE_SIZE == 0
}
/// Clear the vector (sets len to 0, doesn't deallocate)
pub fn clear(&mut self) {
self.len = 0;
}
}
impl Drop for CacheAlignedVec {
fn drop(&mut self) {
if !self.data.is_null() && self.capacity > 0 {
let layout = Layout::from_size_align(
self.capacity * std::mem::size_of::<f32>(),
CACHE_LINE_SIZE,
)
.expect("Invalid layout");
unsafe {
dealloc(self.data as *mut u8, layout);
}
}
}
}
impl std::ops::Deref for CacheAlignedVec {
type Target = [f32];
fn deref(&self) -> &[f32] {
self.as_slice()
}
}
impl std::ops::DerefMut for CacheAlignedVec {
fn deref_mut(&mut self) -> &mut [f32] {
self.as_mut_slice()
}
}
// Safety: The raw pointer is owned and not shared
unsafe impl Send for CacheAlignedVec {}
unsafe impl Sync for CacheAlignedVec {}
/// Batch vector allocator for processing multiple vectors (ADR-001)
///
/// Allocates contiguous, cache-aligned storage for a batch of vectors,
/// enabling efficient SIMD processing and minimal cache misses.
pub struct BatchVectorAllocator {
data: *mut f32,
dimensions: usize,
capacity: usize,
count: usize,
}
impl BatchVectorAllocator {
/// Create allocator for vectors of given dimensions
///
/// # Panics
///
/// Panics if memory allocation fails. For fallible allocation,
/// use `try_new`.
pub fn new(dimensions: usize, initial_capacity: usize) -> Self {
Self::try_new(dimensions, initial_capacity)
.expect("Failed to allocate batch vector storage")
}
/// Try to create allocator for vectors of given dimensions
///
/// Returns `None` if memory allocation fails.
pub fn try_new(dimensions: usize, initial_capacity: usize) -> Option<Self> {
// Handle zero capacity case
if dimensions == 0 || initial_capacity == 0 {
return Some(Self {
data: std::ptr::null_mut(),
dimensions,
capacity: initial_capacity,
count: 0,
});
}
let total_floats = dimensions * initial_capacity;
let layout =
Layout::from_size_align(total_floats * std::mem::size_of::<f32>(), CACHE_LINE_SIZE)
.ok()?;
let data = unsafe { alloc(layout) as *mut f32 };
// SECURITY: Check for allocation failure
if data.is_null() {
return None;
}
Some(Self {
data,
dimensions,
capacity: initial_capacity,
count: 0,
})
}
/// Add a vector, returns its index
///
/// # Panics
///
/// Panics if the allocator is full, dimensions mismatch, or allocator has zero capacity.
pub fn add(&mut self, vector: &[f32]) -> usize {
assert_eq!(vector.len(), self.dimensions, "Vector dimension mismatch");
assert!(self.count < self.capacity, "Batch allocator full");
assert!(
!self.data.is_null(),
"Cannot add to zero-capacity BatchVectorAllocator"
);
let offset = self.count * self.dimensions;
unsafe {
ptr::copy_nonoverlapping(vector.as_ptr(), self.data.add(offset), self.dimensions);
}
let index = self.count;
self.count += 1;
index
}
/// Get a vector by index
pub fn get(&self, index: usize) -> &[f32] {
assert!(index < self.count, "Index out of bounds");
let offset = index * self.dimensions;
unsafe { std::slice::from_raw_parts(self.data.add(offset), self.dimensions) }
}
/// Get mutable vector by index
pub fn get_mut(&mut self, index: usize) -> &mut [f32] {
assert!(index < self.count, "Index out of bounds");
let offset = index * self.dimensions;
unsafe { std::slice::from_raw_parts_mut(self.data.add(offset), self.dimensions) }
}
/// Get raw pointer to vector at index (for SIMD)
#[inline]
pub fn ptr_at(&self, index: usize) -> *const f32 {
assert!(index < self.count, "Index out of bounds");
let offset = index * self.dimensions;
unsafe { self.data.add(offset) }
}
/// Number of vectors stored
#[inline]
pub fn len(&self) -> usize {
self.count
}
/// Check if empty
#[inline]
pub fn is_empty(&self) -> bool {
self.count == 0
}
/// Dimensions per vector
#[inline]
pub fn dimensions(&self) -> usize {
self.dimensions
}
/// Reset allocator (keeps memory)
pub fn clear(&mut self) {
self.count = 0;
}
}
impl Drop for BatchVectorAllocator {
fn drop(&mut self) {
if !self.data.is_null() {
let layout = Layout::from_size_align(
self.dimensions * self.capacity * std::mem::size_of::<f32>(),
CACHE_LINE_SIZE,
)
.expect("Invalid layout");
unsafe {
dealloc(self.data as *mut u8, layout);
}
}
}
}
// Safety: The raw pointer is owned and not shared
unsafe impl Send for BatchVectorAllocator {}
unsafe impl Sync for BatchVectorAllocator {}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_arena_alloc() {
let arena = Arena::new(1024);
let mut vec1 = arena.alloc_vec::<f32>(10);
vec1.push(1.0);
vec1.push(2.0);
vec1.push(3.0);
assert_eq!(vec1.len(), 3);
assert_eq!(vec1[0], 1.0);
assert_eq!(vec1[1], 2.0);
assert_eq!(vec1[2], 3.0);
}
#[test]
fn test_arena_multiple_allocs() {
let arena = Arena::new(1024);
let vec1 = arena.alloc_vec::<u32>(100);
let vec2 = arena.alloc_vec::<u64>(50);
let vec3 = arena.alloc_vec::<f32>(200);
assert_eq!(vec1.capacity(), 100);
assert_eq!(vec2.capacity(), 50);
assert_eq!(vec3.capacity(), 200);
}
#[test]
fn test_arena_reset() {
let arena = Arena::new(1024);
{
let _vec1 = arena.alloc_vec::<f32>(100);
let _vec2 = arena.alloc_vec::<f32>(100);
}
let used_before = arena.used_bytes();
arena.reset();
let used_after = arena.used_bytes();
assert!(used_after < used_before);
}
#[test]
fn test_cache_aligned_vec() {
let mut vec = CacheAlignedVec::with_capacity(100);
// Check alignment
assert!(vec.is_aligned(), "Vector should be cache-aligned");
// Test push
for i in 0..50 {
vec.push(i as f32);
}
assert_eq!(vec.len(), 50);
// Test slice access
let slice = vec.as_slice();
assert_eq!(slice[0], 0.0);
assert_eq!(slice[49], 49.0);
}
#[test]
fn test_cache_aligned_vec_from_slice() {
let data = vec![1.0, 2.0, 3.0, 4.0, 5.0];
let aligned = CacheAlignedVec::from_slice(&data);
assert!(aligned.is_aligned());
assert_eq!(aligned.len(), 5);
assert_eq!(aligned.as_slice(), &data[..]);
}
#[test]
fn test_batch_vector_allocator() {
let mut allocator = BatchVectorAllocator::new(4, 10);
let v1 = vec![1.0, 2.0, 3.0, 4.0];
let v2 = vec![5.0, 6.0, 7.0, 8.0];
let idx1 = allocator.add(&v1);
let idx2 = allocator.add(&v2);
assert_eq!(idx1, 0);
assert_eq!(idx2, 1);
assert_eq!(allocator.len(), 2);
// Test retrieval
assert_eq!(allocator.get(0), &v1[..]);
assert_eq!(allocator.get(1), &v2[..]);
}
#[test]
fn test_batch_allocator_clear() {
let mut allocator = BatchVectorAllocator::new(3, 5);
allocator.add(&[1.0, 2.0, 3.0]);
allocator.add(&[4.0, 5.0, 6.0]);
assert_eq!(allocator.len(), 2);
allocator.clear();
assert_eq!(allocator.len(), 0);
// Should be able to add again
allocator.add(&[7.0, 8.0, 9.0]);
assert_eq!(allocator.len(), 1);
}
}

View File

@@ -0,0 +1,436 @@
//! Cache-optimized data structures using Structure-of-Arrays (SoA) layout
//!
//! This module provides cache-friendly layouts for vector storage to minimize
//! cache misses and improve memory access patterns.
use std::alloc::{alloc, dealloc, Layout};
use std::ptr;
/// Cache line size (typically 64 bytes on modern CPUs)
const CACHE_LINE_SIZE: usize = 64;
/// Structure-of-Arrays layout for vectors
///
/// Instead of storing vectors as Vec<Vec<f32>>, we store all components
/// separately to improve cache locality during SIMD operations.
#[repr(align(64))] // Align to cache line boundary
pub struct SoAVectorStorage {
/// Number of vectors
count: usize,
/// Dimensions per vector
dimensions: usize,
/// Capacity (allocated vectors)
capacity: usize,
/// Storage for each dimension separately
/// Layout: [dim0_vec0, dim0_vec1, ..., dim0_vecN, dim1_vec0, ...]
data: *mut f32,
}
impl SoAVectorStorage {
/// Maximum allowed dimensions to prevent overflow
const MAX_DIMENSIONS: usize = 65536;
/// Maximum allowed capacity to prevent overflow
const MAX_CAPACITY: usize = 1 << 24; // ~16M vectors
/// Create a new SoA vector storage
///
/// # Panics
/// Panics if dimensions or capacity exceed safe limits or would cause overflow.
pub fn new(dimensions: usize, initial_capacity: usize) -> Self {
// Security: Validate inputs to prevent integer overflow
assert!(
dimensions > 0 && dimensions <= Self::MAX_DIMENSIONS,
"dimensions must be between 1 and {}",
Self::MAX_DIMENSIONS
);
assert!(
initial_capacity <= Self::MAX_CAPACITY,
"initial_capacity exceeds maximum of {}",
Self::MAX_CAPACITY
);
let capacity = initial_capacity.next_power_of_two();
// Security: Use checked arithmetic to prevent overflow
let total_elements = dimensions
.checked_mul(capacity)
.expect("dimensions * capacity overflow");
let total_bytes = total_elements
.checked_mul(std::mem::size_of::<f32>())
.expect("total size overflow");
let layout =
Layout::from_size_align(total_bytes, CACHE_LINE_SIZE).expect("invalid memory layout");
let data = unsafe { alloc(layout) as *mut f32 };
// Zero initialize
unsafe {
ptr::write_bytes(data, 0, total_elements);
}
Self {
count: 0,
dimensions,
capacity,
data,
}
}
/// Add a vector to the storage
pub fn push(&mut self, vector: &[f32]) {
assert_eq!(vector.len(), self.dimensions);
if self.count >= self.capacity {
self.grow();
}
// Store each dimension separately
for (dim_idx, &value) in vector.iter().enumerate() {
let offset = dim_idx * self.capacity + self.count;
unsafe {
*self.data.add(offset) = value;
}
}
self.count += 1;
}
/// Get a vector by index (copies to output buffer)
pub fn get(&self, index: usize, output: &mut [f32]) {
assert!(index < self.count);
assert_eq!(output.len(), self.dimensions);
for (dim_idx, out) in output.iter_mut().enumerate().take(self.dimensions) {
let offset = dim_idx * self.capacity + index;
*out = unsafe { *self.data.add(offset) };
}
}
/// Get a slice of a specific dimension across all vectors
/// This allows efficient SIMD operations on a single dimension
pub fn dimension_slice(&self, dim_idx: usize) -> &[f32] {
assert!(dim_idx < self.dimensions);
let offset = dim_idx * self.capacity;
unsafe { std::slice::from_raw_parts(self.data.add(offset), self.count) }
}
/// Get a mutable slice of a specific dimension
pub fn dimension_slice_mut(&mut self, dim_idx: usize) -> &mut [f32] {
assert!(dim_idx < self.dimensions);
let offset = dim_idx * self.capacity;
unsafe { std::slice::from_raw_parts_mut(self.data.add(offset), self.count) }
}
/// Number of vectors stored
pub fn len(&self) -> usize {
self.count
}
/// Check if empty
pub fn is_empty(&self) -> bool {
self.count == 0
}
/// Dimensions per vector
pub fn dimensions(&self) -> usize {
self.dimensions
}
/// Grow the storage capacity
fn grow(&mut self) {
let new_capacity = self.capacity * 2;
// Security: Use checked arithmetic to prevent overflow
let new_total_elements = self
.dimensions
.checked_mul(new_capacity)
.expect("dimensions * new_capacity overflow");
let new_total_bytes = new_total_elements
.checked_mul(std::mem::size_of::<f32>())
.expect("total size overflow in grow");
let new_layout = Layout::from_size_align(new_total_bytes, CACHE_LINE_SIZE)
.expect("invalid memory layout in grow");
let new_data = unsafe { alloc(new_layout) as *mut f32 };
// Copy old data dimension by dimension
for dim_idx in 0..self.dimensions {
let old_offset = dim_idx * self.capacity;
let new_offset = dim_idx * new_capacity;
unsafe {
ptr::copy_nonoverlapping(
self.data.add(old_offset),
new_data.add(new_offset),
self.count,
);
}
}
// Deallocate old data
let old_layout = Layout::from_size_align(
self.dimensions * self.capacity * std::mem::size_of::<f32>(),
CACHE_LINE_SIZE,
)
.unwrap();
unsafe {
dealloc(self.data as *mut u8, old_layout);
}
self.data = new_data;
self.capacity = new_capacity;
}
/// Compute distance from query to all stored vectors using dimension-wise operations
/// This takes advantage of the SoA layout for better cache utilization
#[inline(always)]
pub fn batch_euclidean_distances(&self, query: &[f32], output: &mut [f32]) {
assert_eq!(query.len(), self.dimensions);
assert_eq!(output.len(), self.count);
// Use SIMD-optimized version for larger batches
#[cfg(target_arch = "aarch64")]
{
if self.count >= 16 {
unsafe { self.batch_euclidean_distances_neon(query, output) };
return;
}
}
#[cfg(target_arch = "x86_64")]
{
if self.count >= 32 && is_x86_feature_detected!("avx2") {
unsafe { self.batch_euclidean_distances_avx2(query, output) };
return;
}
}
// Scalar fallback
self.batch_euclidean_distances_scalar(query, output);
}
/// Scalar implementation of batch euclidean distances
#[inline(always)]
fn batch_euclidean_distances_scalar(&self, query: &[f32], output: &mut [f32]) {
// Initialize output with zeros
output.fill(0.0);
// Process dimension by dimension for cache-friendly access
for dim_idx in 0..self.dimensions {
let dim_slice = self.dimension_slice(dim_idx);
// Safety: dim_idx is bounded by self.dimensions which is validated in constructor
let query_val = unsafe { *query.get_unchecked(dim_idx) };
// Compute squared differences for this dimension
// Use unchecked access since vec_idx is bounded by self.count
for vec_idx in 0..self.count {
let diff = unsafe { *dim_slice.get_unchecked(vec_idx) } - query_val;
unsafe { *output.get_unchecked_mut(vec_idx) += diff * diff };
}
}
// Take square root
for distance in output.iter_mut() {
*distance = distance.sqrt();
}
}
/// NEON-optimized batch euclidean distances
///
/// # Safety
/// Caller must ensure query.len() == self.dimensions and output.len() == self.count
#[cfg(target_arch = "aarch64")]
#[inline(always)]
unsafe fn batch_euclidean_distances_neon(&self, query: &[f32], output: &mut [f32]) {
use std::arch::aarch64::*;
let out_ptr = output.as_mut_ptr();
let query_ptr = query.as_ptr();
// Initialize output with zeros
let chunks = self.count / 4;
// Zero initialize using SIMD
let zero = vdupq_n_f32(0.0);
for i in 0..chunks {
let idx = i * 4;
vst1q_f32(out_ptr.add(idx), zero);
}
for i in (chunks * 4)..self.count {
*output.get_unchecked_mut(i) = 0.0;
}
// Process dimension by dimension for cache-friendly access
for dim_idx in 0..self.dimensions {
let dim_slice = self.dimension_slice(dim_idx);
let dim_ptr = dim_slice.as_ptr();
let query_val = vdupq_n_f32(*query_ptr.add(dim_idx));
// SIMD processing of 4 vectors at a time
for i in 0..chunks {
let idx = i * 4;
let dim_vals = vld1q_f32(dim_ptr.add(idx));
let out_vals = vld1q_f32(out_ptr.add(idx));
let diff = vsubq_f32(dim_vals, query_val);
let result = vfmaq_f32(out_vals, diff, diff);
vst1q_f32(out_ptr.add(idx), result);
}
// Handle remainder with bounds-check elimination
let query_val_scalar = *query_ptr.add(dim_idx);
for i in (chunks * 4)..self.count {
let diff = *dim_slice.get_unchecked(i) - query_val_scalar;
*output.get_unchecked_mut(i) += diff * diff;
}
}
// Take square root using SIMD vsqrtq_f32
for i in 0..chunks {
let idx = i * 4;
let vals = vld1q_f32(out_ptr.add(idx));
let sqrt_vals = vsqrtq_f32(vals);
vst1q_f32(out_ptr.add(idx), sqrt_vals);
}
for i in (chunks * 4)..self.count {
*output.get_unchecked_mut(i) = output.get_unchecked(i).sqrt();
}
}
/// AVX2-optimized batch euclidean distances
#[cfg(target_arch = "x86_64")]
#[target_feature(enable = "avx2")]
unsafe fn batch_euclidean_distances_avx2(&self, query: &[f32], output: &mut [f32]) {
use std::arch::x86_64::*;
let chunks = self.count / 8;
// Zero initialize using SIMD
let zero = _mm256_setzero_ps();
for i in 0..chunks {
let idx = i * 8;
_mm256_storeu_ps(output.as_mut_ptr().add(idx), zero);
}
for out in output.iter_mut().take(self.count).skip(chunks * 8) {
*out = 0.0;
}
// Process dimension by dimension
for (dim_idx, &q_val) in query.iter().enumerate().take(self.dimensions) {
let dim_slice = self.dimension_slice(dim_idx);
let query_val = _mm256_set1_ps(q_val);
// SIMD processing of 8 vectors at a time
for i in 0..chunks {
let idx = i * 8;
let dim_vals = _mm256_loadu_ps(dim_slice.as_ptr().add(idx));
let out_vals = _mm256_loadu_ps(output.as_ptr().add(idx));
let diff = _mm256_sub_ps(dim_vals, query_val);
let sq = _mm256_mul_ps(diff, diff);
let result = _mm256_add_ps(out_vals, sq);
_mm256_storeu_ps(output.as_mut_ptr().add(idx), result);
}
// Handle remainder
for i in (chunks * 8)..self.count {
let diff = dim_slice[i] - query[dim_idx];
output[i] += diff * diff;
}
}
// Take square root (no SIMD sqrt in basic AVX2, use scalar)
for distance in output.iter_mut() {
*distance = distance.sqrt();
}
}
}
// Feature detection helper for x86_64
#[cfg(target_arch = "x86_64")]
#[allow(dead_code)]
fn is_x86_feature_detected_helper(feature: &str) -> bool {
match feature {
"avx2" => is_x86_feature_detected!("avx2"),
_ => false,
}
}
impl Drop for SoAVectorStorage {
fn drop(&mut self) {
let layout = Layout::from_size_align(
self.dimensions * self.capacity * std::mem::size_of::<f32>(),
CACHE_LINE_SIZE,
)
.unwrap();
unsafe {
dealloc(self.data as *mut u8, layout);
}
}
}
unsafe impl Send for SoAVectorStorage {}
unsafe impl Sync for SoAVectorStorage {}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_soa_storage() {
let mut storage = SoAVectorStorage::new(3, 4);
storage.push(&[1.0, 2.0, 3.0]);
storage.push(&[4.0, 5.0, 6.0]);
assert_eq!(storage.len(), 2);
let mut output = vec![0.0; 3];
storage.get(0, &mut output);
assert_eq!(output, vec![1.0, 2.0, 3.0]);
storage.get(1, &mut output);
assert_eq!(output, vec![4.0, 5.0, 6.0]);
}
#[test]
fn test_dimension_slice() {
let mut storage = SoAVectorStorage::new(3, 4);
storage.push(&[1.0, 2.0, 3.0]);
storage.push(&[4.0, 5.0, 6.0]);
storage.push(&[7.0, 8.0, 9.0]);
// Get all values for dimension 0
let dim0 = storage.dimension_slice(0);
assert_eq!(dim0, &[1.0, 4.0, 7.0]);
// Get all values for dimension 1
let dim1 = storage.dimension_slice(1);
assert_eq!(dim1, &[2.0, 5.0, 8.0]);
}
#[test]
fn test_batch_distances() {
let mut storage = SoAVectorStorage::new(3, 4);
storage.push(&[1.0, 0.0, 0.0]);
storage.push(&[0.0, 1.0, 0.0]);
storage.push(&[0.0, 0.0, 1.0]);
let query = vec![1.0, 0.0, 0.0];
let mut distances = vec![0.0; 3];
storage.batch_euclidean_distances(&query, &mut distances);
assert!((distances[0] - 0.0).abs() < 0.001);
assert!((distances[1] - 1.414).abs() < 0.01);
assert!((distances[2] - 1.414).abs() < 0.01);
}
}

View File

@@ -0,0 +1,167 @@
//! SIMD-optimized distance metrics
//! Uses SimSIMD when available (native), falls back to pure Rust for WASM
use crate::error::{Result, RuvectorError};
use crate::types::DistanceMetric;
/// Calculate distance between two vectors using the specified metric
#[inline]
pub fn distance(a: &[f32], b: &[f32], metric: DistanceMetric) -> Result<f32> {
if a.len() != b.len() {
return Err(RuvectorError::DimensionMismatch {
expected: a.len(),
actual: b.len(),
});
}
match metric {
DistanceMetric::Euclidean => Ok(euclidean_distance(a, b)),
DistanceMetric::Cosine => Ok(cosine_distance(a, b)),
DistanceMetric::DotProduct => Ok(dot_product_distance(a, b)),
DistanceMetric::Manhattan => Ok(manhattan_distance(a, b)),
}
}
/// Euclidean (L2) distance
#[inline]
pub fn euclidean_distance(a: &[f32], b: &[f32]) -> f32 {
#[cfg(all(feature = "simd", not(target_arch = "wasm32")))]
{
(simsimd::SpatialSimilarity::sqeuclidean(a, b)
.expect("SimSIMD euclidean failed")
.sqrt()) as f32
}
#[cfg(any(not(feature = "simd"), target_arch = "wasm32"))]
{
// Pure Rust fallback for WASM
a.iter()
.zip(b.iter())
.map(|(x, y)| (x - y) * (x - y))
.sum::<f32>()
.sqrt()
}
}
/// Cosine distance (1 - cosine_similarity)
#[inline]
pub fn cosine_distance(a: &[f32], b: &[f32]) -> f32 {
#[cfg(all(feature = "simd", not(target_arch = "wasm32")))]
{
simsimd::SpatialSimilarity::cosine(a, b).expect("SimSIMD cosine failed") as f32
}
#[cfg(any(not(feature = "simd"), target_arch = "wasm32"))]
{
// Pure Rust fallback for WASM
let dot: f32 = a.iter().zip(b.iter()).map(|(x, y)| x * y).sum();
let norm_a: f32 = a.iter().map(|x| x * x).sum::<f32>().sqrt();
let norm_b: f32 = b.iter().map(|x| x * x).sum::<f32>().sqrt();
if norm_a > 1e-8 && norm_b > 1e-8 {
1.0 - (dot / (norm_a * norm_b))
} else {
1.0
}
}
}
/// Dot product distance (negative for maximization)
#[inline]
pub fn dot_product_distance(a: &[f32], b: &[f32]) -> f32 {
#[cfg(all(feature = "simd", not(target_arch = "wasm32")))]
{
let dot = simsimd::SpatialSimilarity::dot(a, b).expect("SimSIMD dot product failed");
(-dot) as f32
}
#[cfg(any(not(feature = "simd"), target_arch = "wasm32"))]
{
// Pure Rust fallback for WASM
let dot: f32 = a.iter().zip(b.iter()).map(|(x, y)| x * y).sum();
-dot
}
}
/// Manhattan (L1) distance
#[inline]
pub fn manhattan_distance(a: &[f32], b: &[f32]) -> f32 {
a.iter().zip(b.iter()).map(|(x, y)| (x - y).abs()).sum()
}
/// Batch distance calculation optimized with Rayon (native) or sequential (WASM)
pub fn batch_distances(
query: &[f32],
vectors: &[Vec<f32>],
metric: DistanceMetric,
) -> Result<Vec<f32>> {
#[cfg(all(feature = "parallel", not(target_arch = "wasm32")))]
{
use rayon::prelude::*;
vectors
.par_iter()
.map(|v| distance(query, v, metric))
.collect()
}
#[cfg(any(not(feature = "parallel"), target_arch = "wasm32"))]
{
// Sequential fallback for WASM
vectors.iter().map(|v| distance(query, v, metric)).collect()
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_euclidean_distance() {
let a = vec![1.0, 2.0, 3.0];
let b = vec![4.0, 5.0, 6.0];
let dist = euclidean_distance(&a, &b);
assert!((dist - 5.196).abs() < 0.01);
}
#[test]
fn test_cosine_distance() {
// Test with identical vectors (should have distance ~0)
let a = vec![1.0, 2.0, 3.0];
let b = vec![1.0, 2.0, 3.0];
let dist = cosine_distance(&a, &b);
assert!(
dist < 0.01,
"Identical vectors should have ~0 distance, got {}",
dist
);
// Test with opposite vectors (should have high distance)
let a = vec![1.0, 0.0, 0.0];
let b = vec![-1.0, 0.0, 0.0];
let dist = cosine_distance(&a, &b);
assert!(
dist > 1.5,
"Opposite vectors should have high distance, got {}",
dist
);
}
#[test]
fn test_dot_product_distance() {
let a = vec![1.0, 2.0, 3.0];
let b = vec![4.0, 5.0, 6.0];
let dist = dot_product_distance(&a, &b);
assert!((dist + 32.0).abs() < 0.01); // -(4 + 10 + 18) = -32
}
#[test]
fn test_manhattan_distance() {
let a = vec![1.0, 2.0, 3.0];
let b = vec![4.0, 5.0, 6.0];
let dist = manhattan_distance(&a, &b);
assert!((dist - 9.0).abs() < 0.01); // |1-4| + |2-5| + |3-6| = 9
}
#[test]
fn test_dimension_mismatch() {
let a = vec![1.0, 2.0];
let b = vec![1.0, 2.0, 3.0];
let result = distance(&a, &b, DistanceMetric::Euclidean);
assert!(result.is_err());
}
}

View File

@@ -0,0 +1,415 @@
//! Text Embedding Providers
//!
//! This module provides a pluggable embedding system for AgenticDB.
//!
//! ## Available Providers
//!
//! - **HashEmbedding**: Fast hash-based placeholder (default, not semantic)
//! - **CandleEmbedding**: Real embeddings using candle-transformers (feature: `real-embeddings`)
//! - **ApiEmbedding**: External API calls (OpenAI, Anthropic, Cohere, etc.)
//!
//! ## Usage
//!
//! ```rust,no_run
//! use ruvector_core::embeddings::{EmbeddingProvider, HashEmbedding, ApiEmbedding};
//! use ruvector_core::AgenticDB;
//!
//! // Default: Hash-based (fast, but not semantic)
//! let hash_provider = HashEmbedding::new(384);
//! let embedding = hash_provider.embed("hello world")?;
//!
//! // API-based (requires API key)
//! let api_provider = ApiEmbedding::openai("sk-...", "text-embedding-3-small");
//! let embedding = api_provider.embed("hello world")?;
//! # Ok::<(), Box<dyn std::error::Error>>(())
//! ```
use crate::error::Result;
#[cfg(any(feature = "real-embeddings", feature = "api-embeddings"))]
use crate::error::RuvectorError;
use std::sync::Arc;
/// Trait for text embedding providers
pub trait EmbeddingProvider: Send + Sync {
/// Generate embedding vector for the given text
fn embed(&self, text: &str) -> Result<Vec<f32>>;
/// Get the dimensionality of embeddings produced by this provider
fn dimensions(&self) -> usize;
/// Get a description of this provider (for logging/debugging)
fn name(&self) -> &str;
}
/// Hash-based embedding provider (placeholder, not semantic)
///
/// ⚠️ **WARNING**: This does NOT produce semantic embeddings!
/// - "dog" and "cat" will NOT be similar
/// - "dog" and "god" WILL be similar (same characters)
///
/// Use this only for:
/// - Testing
/// - Prototyping
/// - When semantic similarity is not required
#[derive(Debug, Clone)]
pub struct HashEmbedding {
dimensions: usize,
}
impl HashEmbedding {
/// Create a new hash-based embedding provider
pub fn new(dimensions: usize) -> Self {
Self { dimensions }
}
}
impl EmbeddingProvider for HashEmbedding {
fn embed(&self, text: &str) -> Result<Vec<f32>> {
let mut embedding = vec![0.0; self.dimensions];
let bytes = text.as_bytes();
for (i, byte) in bytes.iter().enumerate() {
embedding[i % self.dimensions] += (*byte as f32) / 255.0;
}
// Normalize
let norm: f32 = embedding.iter().map(|x| x * x).sum::<f32>().sqrt();
if norm > 0.0 {
for val in &mut embedding {
*val /= norm;
}
}
Ok(embedding)
}
fn dimensions(&self) -> usize {
self.dimensions
}
fn name(&self) -> &str {
"HashEmbedding (placeholder)"
}
}
/// Real embeddings using candle-transformers
///
/// Requires feature flag: `real-embeddings`
///
/// ⚠️ **Note**: Full candle integration is complex and model-specific.
/// For production use, we recommend:
/// 1. Using the API-based providers (simpler, always up-to-date)
/// 2. Using ONNX Runtime with pre-exported models
/// 3. Implementing your own candle wrapper for your specific model
///
/// This is a stub implementation showing the structure.
/// Users should implement `EmbeddingProvider` trait for their specific models.
#[cfg(feature = "real-embeddings")]
pub mod candle {
use super::*;
/// Candle-based embedding provider stub
///
/// This is a placeholder. For real implementation:
/// 1. Add candle dependencies for your specific model type
/// 2. Implement model loading and inference
/// 3. Handle tokenization appropriately
///
/// Example structure:
/// ```rust,ignore
/// pub struct CandleEmbedding {
/// model: YourModelType,
/// tokenizer: Tokenizer,
/// device: Device,
/// dimensions: usize,
/// }
/// ```
pub struct CandleEmbedding {
dimensions: usize,
model_id: String,
}
impl CandleEmbedding {
/// Create a stub candle embedding provider
///
/// **This is not a real implementation!**
/// For production, implement with actual model loading.
///
/// # Example
/// ```rust,no_run
/// # #[cfg(feature = "real-embeddings")]
/// # {
/// use ruvector_core::embeddings::candle::CandleEmbedding;
///
/// // This returns an error - real implementation required
/// let result = CandleEmbedding::from_pretrained(
/// "sentence-transformers/all-MiniLM-L6-v2",
/// false
/// );
/// assert!(result.is_err());
/// # }
/// ```
pub fn from_pretrained(model_id: &str, _use_gpu: bool) -> Result<Self> {
Err(RuvectorError::ModelLoadError(format!(
"Candle embedding support is a stub. Please:\n\
1. Use ApiEmbedding for production (recommended)\n\
2. Or implement CandleEmbedding for model: {}\n\
3. See docs for ONNX Runtime integration examples",
model_id
)))
}
}
impl EmbeddingProvider for CandleEmbedding {
fn embed(&self, _text: &str) -> Result<Vec<f32>> {
Err(RuvectorError::ModelInferenceError(
"Candle embedding not implemented - use ApiEmbedding instead".to_string(),
))
}
fn dimensions(&self) -> usize {
self.dimensions
}
fn name(&self) -> &str {
"CandleEmbedding (stub - not implemented)"
}
}
}
#[cfg(feature = "real-embeddings")]
pub use candle::CandleEmbedding;
/// API-based embedding provider (OpenAI, Anthropic, Cohere, etc.)
///
/// Supports any API that accepts JSON and returns embeddings in a standard format.
///
/// # Example (OpenAI)
/// ```rust,no_run
/// use ruvector_core::embeddings::{EmbeddingProvider, ApiEmbedding};
///
/// let provider = ApiEmbedding::openai("sk-...", "text-embedding-3-small");
/// let embedding = provider.embed("hello world")?;
/// # Ok::<(), Box<dyn std::error::Error>>(())
/// ```
#[cfg(feature = "api-embeddings")]
#[derive(Clone)]
pub struct ApiEmbedding {
api_key: String,
endpoint: String,
model: String,
dimensions: usize,
client: reqwest::blocking::Client,
}
#[cfg(feature = "api-embeddings")]
impl ApiEmbedding {
/// Create a new API embedding provider
///
/// # Arguments
/// * `api_key` - API key for authentication
/// * `endpoint` - API endpoint URL
/// * `model` - Model identifier
/// * `dimensions` - Expected embedding dimensions
pub fn new(api_key: String, endpoint: String, model: String, dimensions: usize) -> Self {
Self {
api_key,
endpoint,
model,
dimensions,
client: reqwest::blocking::Client::new(),
}
}
/// Create OpenAI embedding provider
///
/// # Models
/// - `text-embedding-3-small` - 1536 dimensions, $0.02/1M tokens
/// - `text-embedding-3-large` - 3072 dimensions, $0.13/1M tokens
/// - `text-embedding-ada-002` - 1536 dimensions (legacy)
pub fn openai(api_key: &str, model: &str) -> Self {
let dimensions = match model {
"text-embedding-3-large" => 3072,
_ => 1536, // text-embedding-3-small and ada-002
};
Self::new(
api_key.to_string(),
"https://api.openai.com/v1/embeddings".to_string(),
model.to_string(),
dimensions,
)
}
/// Create Cohere embedding provider
///
/// # Models
/// - `embed-english-v3.0` - 1024 dimensions
/// - `embed-multilingual-v3.0` - 1024 dimensions
pub fn cohere(api_key: &str, model: &str) -> Self {
Self::new(
api_key.to_string(),
"https://api.cohere.ai/v1/embed".to_string(),
model.to_string(),
1024,
)
}
/// Create Voyage AI embedding provider
///
/// # Models
/// - `voyage-2` - 1024 dimensions
/// - `voyage-large-2` - 1536 dimensions
pub fn voyage(api_key: &str, model: &str) -> Self {
let dimensions = if model.contains("large") { 1536 } else { 1024 };
Self::new(
api_key.to_string(),
"https://api.voyageai.com/v1/embeddings".to_string(),
model.to_string(),
dimensions,
)
}
}
#[cfg(feature = "api-embeddings")]
impl EmbeddingProvider for ApiEmbedding {
fn embed(&self, text: &str) -> Result<Vec<f32>> {
let request_body = serde_json::json!({
"input": text,
"model": self.model,
});
let response = self
.client
.post(&self.endpoint)
.header("Authorization", format!("Bearer {}", self.api_key))
.header("Content-Type", "application/json")
.json(&request_body)
.send()
.map_err(|e| {
RuvectorError::ModelInferenceError(format!("API request failed: {}", e))
})?;
if !response.status().is_success() {
let status = response.status();
let error_text = response
.text()
.unwrap_or_else(|_| "Unknown error".to_string());
return Err(RuvectorError::ModelInferenceError(format!(
"API returned error {}: {}",
status, error_text
)));
}
let response_json: serde_json::Value = response.json().map_err(|e| {
RuvectorError::ModelInferenceError(format!("Failed to parse response: {}", e))
})?;
// Handle different API response formats
let embedding = if let Some(data) = response_json.get("data") {
// OpenAI format: {"data": [{"embedding": [...]}]}
data.as_array()
.and_then(|arr| arr.first())
.and_then(|obj| obj.get("embedding"))
.and_then(|emb| emb.as_array())
.ok_or_else(|| {
RuvectorError::ModelInferenceError("Invalid OpenAI response format".to_string())
})?
} else if let Some(embeddings) = response_json.get("embeddings") {
// Cohere format: {"embeddings": [[...]]}
embeddings
.as_array()
.and_then(|arr| arr.first())
.and_then(|emb| emb.as_array())
.ok_or_else(|| {
RuvectorError::ModelInferenceError("Invalid Cohere response format".to_string())
})?
} else {
return Err(RuvectorError::ModelInferenceError(
"Unknown API response format".to_string(),
));
};
let embedding_vec: Result<Vec<f32>> = embedding
.iter()
.map(|v| {
v.as_f64().map(|f| f as f32).ok_or_else(|| {
RuvectorError::ModelInferenceError("Invalid embedding value".to_string())
})
})
.collect();
embedding_vec
}
fn dimensions(&self) -> usize {
self.dimensions
}
fn name(&self) -> &str {
"ApiEmbedding"
}
}
/// Type-erased embedding provider for dynamic dispatch
pub type BoxedEmbeddingProvider = Arc<dyn EmbeddingProvider>;
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_hash_embedding() {
let provider = HashEmbedding::new(128);
let emb1 = provider.embed("hello world").unwrap();
let emb2 = provider.embed("hello world").unwrap();
assert_eq!(emb1.len(), 128);
assert_eq!(emb1, emb2, "Same text should produce same embedding");
// Check normalization
let norm: f32 = emb1.iter().map(|x| x * x).sum::<f32>().sqrt();
assert!((norm - 1.0).abs() < 1e-5, "Embedding should be normalized");
}
#[test]
fn test_hash_embedding_different_text() {
let provider = HashEmbedding::new(128);
let emb1 = provider.embed("hello").unwrap();
let emb2 = provider.embed("world").unwrap();
assert_ne!(
emb1, emb2,
"Different text should produce different embeddings"
);
}
#[cfg(feature = "real-embeddings")]
#[test]
#[ignore] // Requires model download
fn test_candle_embedding() {
let provider =
CandleEmbedding::from_pretrained("sentence-transformers/all-MiniLM-L6-v2", false)
.unwrap();
let embedding = provider.embed("hello world").unwrap();
assert_eq!(embedding.len(), 384);
// Check normalization
let norm: f32 = embedding.iter().map(|x| x * x).sum::<f32>().sqrt();
assert!((norm - 1.0).abs() < 1e-5, "Embedding should be normalized");
}
#[test]
#[ignore] // Requires API key
fn test_api_embedding_openai() {
let api_key = std::env::var("OPENAI_API_KEY").unwrap();
let provider = ApiEmbedding::openai(&api_key, "text-embedding-3-small");
let embedding = provider.embed("hello world").unwrap();
assert_eq!(embedding.len(), 1536);
}
}

View File

@@ -0,0 +1,113 @@
//! Error types for Ruvector
use thiserror::Error;
/// Result type alias for Ruvector operations
pub type Result<T> = std::result::Result<T, RuvectorError>;
/// Main error type for Ruvector
#[derive(Error, Debug)]
pub enum RuvectorError {
/// Vector dimension mismatch
#[error("Dimension mismatch: expected {expected}, got {actual}")]
DimensionMismatch {
/// Expected dimension
expected: usize,
/// Actual dimension
actual: usize,
},
/// Vector not found
#[error("Vector not found: {0}")]
VectorNotFound(String),
/// Invalid parameter
#[error("Invalid parameter: {0}")]
InvalidParameter(String),
/// Invalid input
#[error("Invalid input: {0}")]
InvalidInput(String),
/// Invalid dimension
#[error("Invalid dimension: {0}")]
InvalidDimension(String),
/// Storage error
#[error("Storage error: {0}")]
StorageError(String),
/// Model loading error
#[error("Model loading error: {0}")]
ModelLoadError(String),
/// Model inference error
#[error("Model inference error: {0}")]
ModelInferenceError(String),
/// Index error
#[error("Index error: {0}")]
IndexError(String),
/// Serialization error
#[error("Serialization error: {0}")]
SerializationError(String),
/// IO error
#[error("IO error: {0}")]
IoError(#[from] std::io::Error),
/// Database error
#[error("Database error: {0}")]
DatabaseError(String),
/// Invalid path error
#[error("Invalid path: {0}")]
InvalidPath(String),
/// Other errors
#[error("Internal error: {0}")]
Internal(String),
}
#[cfg(feature = "storage")]
impl From<redb::Error> for RuvectorError {
fn from(err: redb::Error) -> Self {
RuvectorError::DatabaseError(err.to_string())
}
}
#[cfg(feature = "storage")]
impl From<redb::DatabaseError> for RuvectorError {
fn from(err: redb::DatabaseError) -> Self {
RuvectorError::DatabaseError(err.to_string())
}
}
#[cfg(feature = "storage")]
impl From<redb::StorageError> for RuvectorError {
fn from(err: redb::StorageError) -> Self {
RuvectorError::DatabaseError(err.to_string())
}
}
#[cfg(feature = "storage")]
impl From<redb::TableError> for RuvectorError {
fn from(err: redb::TableError) -> Self {
RuvectorError::DatabaseError(err.to_string())
}
}
#[cfg(feature = "storage")]
impl From<redb::TransactionError> for RuvectorError {
fn from(err: redb::TransactionError) -> Self {
RuvectorError::DatabaseError(err.to_string())
}
}
#[cfg(feature = "storage")]
impl From<redb::CommitError> for RuvectorError {
fn from(err: redb::CommitError) -> Self {
RuvectorError::DatabaseError(err.to_string())
}
}

View File

@@ -0,0 +1,36 @@
//! Index structures for efficient vector search
pub mod flat;
#[cfg(feature = "hnsw")]
pub mod hnsw;
use crate::error::Result;
use crate::types::{SearchResult, VectorId};
/// Trait for vector index implementations
pub trait VectorIndex: Send + Sync {
/// Add a vector to the index
fn add(&mut self, id: VectorId, vector: Vec<f32>) -> Result<()>;
/// Add multiple vectors in batch
fn add_batch(&mut self, entries: Vec<(VectorId, Vec<f32>)>) -> Result<()> {
for (id, vector) in entries {
self.add(id, vector)?;
}
Ok(())
}
/// Search for k nearest neighbors
fn search(&self, query: &[f32], k: usize) -> Result<Vec<SearchResult>>;
/// Remove a vector from the index
fn remove(&mut self, id: &VectorId) -> Result<bool>;
/// Get the number of vectors in the index
fn len(&self) -> usize;
/// Check if the index is empty
fn is_empty(&self) -> bool {
self.len() == 0
}
}

View File

@@ -0,0 +1,108 @@
//! Flat (brute-force) index for baseline and small datasets
use crate::distance::distance;
use crate::error::Result;
use crate::index::VectorIndex;
use crate::types::{DistanceMetric, SearchResult, VectorId};
use dashmap::DashMap;
#[cfg(all(feature = "parallel", not(target_arch = "wasm32")))]
use rayon::prelude::*;
/// Flat index using brute-force search
pub struct FlatIndex {
vectors: DashMap<VectorId, Vec<f32>>,
metric: DistanceMetric,
_dimensions: usize,
}
impl FlatIndex {
/// Create a new flat index
pub fn new(dimensions: usize, metric: DistanceMetric) -> Self {
Self {
vectors: DashMap::new(),
metric,
_dimensions: dimensions,
}
}
}
impl VectorIndex for FlatIndex {
fn add(&mut self, id: VectorId, vector: Vec<f32>) -> Result<()> {
self.vectors.insert(id, vector);
Ok(())
}
fn search(&self, query: &[f32], k: usize) -> Result<Vec<SearchResult>> {
// Distance calculation - parallel on native, sequential on WASM
#[cfg(all(feature = "parallel", not(target_arch = "wasm32")))]
let mut results: Vec<_> = self
.vectors
.iter()
.par_bridge()
.map(|entry| {
let id = entry.key().clone();
let vector = entry.value();
let dist = distance(query, vector, self.metric)?;
Ok((id, dist))
})
.collect::<Result<Vec<_>>>()?;
#[cfg(any(not(feature = "parallel"), target_arch = "wasm32"))]
let mut results: Vec<_> = self
.vectors
.iter()
.map(|entry| {
let id = entry.key().clone();
let vector = entry.value();
let dist = distance(query, vector, self.metric)?;
Ok((id, dist))
})
.collect::<Result<Vec<_>>>()?;
// Sort by distance and take top k
results.sort_by(|a, b| a.1.partial_cmp(&b.1).unwrap());
results.truncate(k);
Ok(results
.into_iter()
.map(|(id, score)| SearchResult {
id,
score,
vector: None,
metadata: None,
})
.collect())
}
fn remove(&mut self, id: &VectorId) -> Result<bool> {
Ok(self.vectors.remove(id).is_some())
}
fn len(&self) -> usize {
self.vectors.len()
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_flat_index() -> Result<()> {
let mut index = FlatIndex::new(3, DistanceMetric::Euclidean);
index.add("v1".to_string(), vec![1.0, 0.0, 0.0])?;
index.add("v2".to_string(), vec![0.0, 1.0, 0.0])?;
index.add("v3".to_string(), vec![0.0, 0.0, 1.0])?;
let query = vec![1.0, 0.0, 0.0];
let results = index.search(&query, 2)?;
assert_eq!(results.len(), 2);
assert_eq!(results[0].id, "v1");
assert!(results[0].score < 0.01);
Ok(())
}
}

View File

@@ -0,0 +1,481 @@
//! HNSW (Hierarchical Navigable Small World) index implementation
use crate::distance::distance;
use crate::error::{Result, RuvectorError};
use crate::index::VectorIndex;
use crate::types::{DistanceMetric, HnswConfig, SearchResult, VectorId};
use bincode::{Decode, Encode};
use dashmap::DashMap;
use hnsw_rs::prelude::*;
use parking_lot::RwLock;
use std::sync::Arc;
/// Distance function wrapper for hnsw_rs
struct DistanceFn {
metric: DistanceMetric,
}
impl DistanceFn {
fn new(metric: DistanceMetric) -> Self {
Self { metric }
}
}
impl Distance<f32> for DistanceFn {
fn eval(&self, a: &[f32], b: &[f32]) -> f32 {
distance(a, b, self.metric).unwrap_or(f32::MAX)
}
}
/// HNSW index wrapper
pub struct HnswIndex {
inner: Arc<RwLock<HnswInner>>,
config: HnswConfig,
metric: DistanceMetric,
dimensions: usize,
}
struct HnswInner {
hnsw: Hnsw<'static, f32, DistanceFn>,
vectors: DashMap<VectorId, Vec<f32>>,
id_to_idx: DashMap<VectorId, usize>,
idx_to_id: DashMap<usize, VectorId>,
next_idx: usize,
}
/// Serializable HNSW index state
#[derive(Encode, Decode, Clone)]
pub struct HnswState {
vectors: Vec<(String, Vec<f32>)>,
id_to_idx: Vec<(String, usize)>,
idx_to_id: Vec<(usize, String)>,
next_idx: usize,
config: SerializableHnswConfig,
dimensions: usize,
metric: SerializableDistanceMetric,
}
#[derive(Encode, Decode, Clone)]
struct SerializableHnswConfig {
m: usize,
ef_construction: usize,
ef_search: usize,
max_elements: usize,
}
#[derive(Encode, Decode, Clone, Copy)]
enum SerializableDistanceMetric {
Euclidean,
Cosine,
DotProduct,
Manhattan,
}
impl From<DistanceMetric> for SerializableDistanceMetric {
fn from(metric: DistanceMetric) -> Self {
match metric {
DistanceMetric::Euclidean => SerializableDistanceMetric::Euclidean,
DistanceMetric::Cosine => SerializableDistanceMetric::Cosine,
DistanceMetric::DotProduct => SerializableDistanceMetric::DotProduct,
DistanceMetric::Manhattan => SerializableDistanceMetric::Manhattan,
}
}
}
impl From<SerializableDistanceMetric> for DistanceMetric {
fn from(metric: SerializableDistanceMetric) -> Self {
match metric {
SerializableDistanceMetric::Euclidean => DistanceMetric::Euclidean,
SerializableDistanceMetric::Cosine => DistanceMetric::Cosine,
SerializableDistanceMetric::DotProduct => DistanceMetric::DotProduct,
SerializableDistanceMetric::Manhattan => DistanceMetric::Manhattan,
}
}
}
impl HnswIndex {
/// Create a new HNSW index
pub fn new(dimensions: usize, metric: DistanceMetric, config: HnswConfig) -> Result<Self> {
let distance_fn = DistanceFn::new(metric);
// Create HNSW with configured parameters
let hnsw = Hnsw::<f32, DistanceFn>::new(
config.m,
config.max_elements,
dimensions,
config.ef_construction,
distance_fn,
);
Ok(Self {
inner: Arc::new(RwLock::new(HnswInner {
hnsw,
vectors: DashMap::new(),
id_to_idx: DashMap::new(),
idx_to_id: DashMap::new(),
next_idx: 0,
})),
config,
metric,
dimensions,
})
}
/// Get configuration
pub fn config(&self) -> &HnswConfig {
&self.config
}
/// Set efSearch parameter for query-time accuracy tuning
pub fn set_ef_search(&mut self, _ef_search: usize) {
// Note: hnsw_rs controls ef_search via the search method's knbn parameter
// We store it in config and use it in search_with_ef
}
/// Serialize the index to bytes using bincode
pub fn serialize(&self) -> Result<Vec<u8>> {
let inner = self.inner.read();
let state = HnswState {
vectors: inner
.vectors
.iter()
.map(|entry| (entry.key().clone(), entry.value().clone()))
.collect(),
id_to_idx: inner
.id_to_idx
.iter()
.map(|entry| (entry.key().clone(), *entry.value()))
.collect(),
idx_to_id: inner
.idx_to_id
.iter()
.map(|entry| (*entry.key(), entry.value().clone()))
.collect(),
next_idx: inner.next_idx,
config: SerializableHnswConfig {
m: self.config.m,
ef_construction: self.config.ef_construction,
ef_search: self.config.ef_search,
max_elements: self.config.max_elements,
},
dimensions: self.dimensions,
metric: self.metric.into(),
};
bincode::encode_to_vec(&state, bincode::config::standard()).map_err(|e| {
RuvectorError::SerializationError(format!("Failed to serialize HNSW index: {}", e))
})
}
/// Deserialize the index from bytes using bincode
pub fn deserialize(bytes: &[u8]) -> Result<Self> {
let (state, _): (HnswState, usize) =
bincode::decode_from_slice(bytes, bincode::config::standard()).map_err(|e| {
RuvectorError::SerializationError(format!(
"Failed to deserialize HNSW index: {}",
e
))
})?;
let config = HnswConfig {
m: state.config.m,
ef_construction: state.config.ef_construction,
ef_search: state.config.ef_search,
max_elements: state.config.max_elements,
};
let dimensions = state.dimensions;
let metric: DistanceMetric = state.metric.into();
let distance_fn = DistanceFn::new(metric);
let mut hnsw = Hnsw::<'static, f32, DistanceFn>::new(
config.m,
config.max_elements,
dimensions,
config.ef_construction,
distance_fn,
);
// Rebuild the index by inserting all vectors
let id_to_idx: DashMap<VectorId, usize> = state.id_to_idx.into_iter().collect();
let idx_to_id: DashMap<usize, VectorId> = state.idx_to_id.into_iter().collect();
// Insert vectors into HNSW in order
for entry in idx_to_id.iter() {
let idx = *entry.key();
let id = entry.value();
if let Some(vector) = state.vectors.iter().find(|(vid, _)| vid == id) {
// Use insert_data method with slice and idx
hnsw.insert_data(&vector.1, idx);
}
}
let vectors_map: DashMap<VectorId, Vec<f32>> = state.vectors.into_iter().collect();
Ok(Self {
inner: Arc::new(RwLock::new(HnswInner {
hnsw,
vectors: vectors_map,
id_to_idx,
idx_to_id,
next_idx: state.next_idx,
})),
config,
metric,
dimensions,
})
}
/// Search with custom efSearch parameter
pub fn search_with_ef(
&self,
query: &[f32],
k: usize,
ef_search: usize,
) -> Result<Vec<SearchResult>> {
if query.len() != self.dimensions {
return Err(RuvectorError::DimensionMismatch {
expected: self.dimensions,
actual: query.len(),
});
}
let inner = self.inner.read();
// Use HNSW search with custom ef parameter (knbn)
let neighbors = inner.hnsw.search(query, k, ef_search);
Ok(neighbors
.into_iter()
.filter_map(|neighbor| {
inner.idx_to_id.get(&neighbor.d_id).map(|id| SearchResult {
id: id.clone(),
score: neighbor.distance,
vector: None,
metadata: None,
})
})
.collect())
}
}
impl VectorIndex for HnswIndex {
fn add(&mut self, id: VectorId, vector: Vec<f32>) -> Result<()> {
if vector.len() != self.dimensions {
return Err(RuvectorError::DimensionMismatch {
expected: self.dimensions,
actual: vector.len(),
});
}
let mut inner = self.inner.write();
let idx = inner.next_idx;
inner.next_idx += 1;
// Insert into HNSW graph using insert_data
inner.hnsw.insert_data(&vector, idx);
// Store mappings
inner.vectors.insert(id.clone(), vector);
inner.id_to_idx.insert(id.clone(), idx);
inner.idx_to_id.insert(idx, id);
Ok(())
}
fn add_batch(&mut self, entries: Vec<(VectorId, Vec<f32>)>) -> Result<()> {
// Validate all dimensions first
for (_, vector) in &entries {
if vector.len() != self.dimensions {
return Err(RuvectorError::DimensionMismatch {
expected: self.dimensions,
actual: vector.len(),
});
}
}
let mut inner = self.inner.write();
// Prepare batch data for insertion
// First, assign indices and collect vector data
let data_with_ids: Vec<_> = entries
.iter()
.enumerate()
.map(|(i, (id, vector))| {
let idx = inner.next_idx + i;
(id.clone(), idx, vector.clone())
})
.collect();
// Update next_idx
inner.next_idx += entries.len();
// Insert into HNSW sequentially
// Note: Using sequential insertion to avoid Send requirements with RwLock guard
// For large batches, consider restructuring to use hnsw_rs parallel_insert
for (_id, idx, vector) in &data_with_ids {
inner.hnsw.insert_data(vector, *idx);
}
// Store mappings
for (id, idx, vector) in data_with_ids {
inner.vectors.insert(id.clone(), vector);
inner.id_to_idx.insert(id.clone(), idx);
inner.idx_to_id.insert(idx, id);
}
Ok(())
}
fn search(&self, query: &[f32], k: usize) -> Result<Vec<SearchResult>> {
// Use configured ef_search
self.search_with_ef(query, k, self.config.ef_search)
}
fn remove(&mut self, id: &VectorId) -> Result<bool> {
let inner = self.inner.write();
// Note: hnsw_rs doesn't support direct deletion
// We remove from our mappings but the graph structure remains
// This is a known limitation of HNSW
let removed = inner.vectors.remove(id).is_some();
if removed {
if let Some((_, idx)) = inner.id_to_idx.remove(id) {
inner.idx_to_id.remove(&idx);
}
}
Ok(removed)
}
fn len(&self) -> usize {
self.inner.read().vectors.len()
}
}
#[cfg(test)]
mod tests {
use super::*;
fn generate_random_vectors(count: usize, dimensions: usize) -> Vec<Vec<f32>> {
use rand::Rng;
let mut rng = rand::thread_rng();
(0..count)
.map(|_| (0..dimensions).map(|_| rng.gen::<f32>()).collect())
.collect()
}
fn normalize_vector(v: &[f32]) -> Vec<f32> {
let norm = v.iter().map(|x| x * x).sum::<f32>().sqrt();
if norm > 0.0 {
v.iter().map(|x| x / norm).collect()
} else {
v.to_vec()
}
}
#[test]
fn test_hnsw_index_creation() -> Result<()> {
let config = HnswConfig::default();
let index = HnswIndex::new(128, DistanceMetric::Cosine, config)?;
assert_eq!(index.len(), 0);
Ok(())
}
#[test]
fn test_hnsw_insert_and_search() -> Result<()> {
let config = HnswConfig {
m: 16,
ef_construction: 100,
ef_search: 50,
max_elements: 1000,
};
let mut index = HnswIndex::new(128, DistanceMetric::Cosine, config)?;
// Insert a few vectors
let vectors = generate_random_vectors(100, 128);
for (i, vector) in vectors.iter().enumerate() {
let normalized = normalize_vector(vector);
index.add(format!("vec_{}", i), normalized)?;
}
assert_eq!(index.len(), 100);
// Search for the first vector
let query = normalize_vector(&vectors[0]);
let results = index.search(&query, 10)?;
assert!(!results.is_empty());
assert_eq!(results[0].id, "vec_0");
Ok(())
}
#[test]
fn test_hnsw_batch_insert() -> Result<()> {
let config = HnswConfig::default();
let mut index = HnswIndex::new(128, DistanceMetric::Cosine, config)?;
let vectors = generate_random_vectors(100, 128);
let entries: Vec<_> = vectors
.iter()
.enumerate()
.map(|(i, v)| (format!("vec_{}", i), normalize_vector(v)))
.collect();
index.add_batch(entries)?;
assert_eq!(index.len(), 100);
Ok(())
}
#[test]
fn test_hnsw_serialization() -> Result<()> {
let config = HnswConfig {
m: 16,
ef_construction: 100,
ef_search: 50,
max_elements: 1000,
};
let mut index = HnswIndex::new(128, DistanceMetric::Cosine, config)?;
// Insert vectors
let vectors = generate_random_vectors(50, 128);
for (i, vector) in vectors.iter().enumerate() {
let normalized = normalize_vector(vector);
index.add(format!("vec_{}", i), normalized)?;
}
// Serialize
let bytes = index.serialize()?;
// Deserialize
let restored_index = HnswIndex::deserialize(&bytes)?;
assert_eq!(restored_index.len(), 50);
// Test search on restored index
let query = normalize_vector(&vectors[0]);
let results = restored_index.search(&query, 5)?;
assert!(!results.is_empty());
Ok(())
}
#[test]
fn test_dimension_mismatch() -> Result<()> {
let config = HnswConfig::default();
let mut index = HnswIndex::new(128, DistanceMetric::Cosine, config)?;
let result = index.add("test".to_string(), vec![1.0; 64]);
assert!(result.is_err());
Ok(())
}
}

View File

@@ -0,0 +1,142 @@
//! # Ruvector Core
//!
//! High-performance Rust-native vector database with HNSW indexing and SIMD-optimized operations.
//!
//! ## Working Features (Tested & Benchmarked)
//!
//! - **HNSW Indexing**: Approximate nearest neighbor search with O(log n) complexity
//! - **SIMD Distance**: SimSIMD-powered distance calculations (~16M ops/sec for 512-dim)
//! - **Quantization**: Scalar (4x), Int4 (8x), Product (8-16x), and binary (32x) compression with distance support
//! - **Persistence**: REDB-based storage with config persistence
//! - **Search**: ~2.5K queries/sec on 10K vectors (benchmarked)
//!
//! ## ⚠️ Experimental/Incomplete Features - READ BEFORE USE
//!
//! - **AgenticDB**: ⚠️⚠️⚠️ **CRITICAL WARNING** ⚠️⚠️⚠️
//! - Uses PLACEHOLDER hash-based embeddings, NOT real semantic embeddings
//! - "dog" and "cat" will NOT be similar (different characters)
//! - "dog" and "god" WILL be similar (same characters) - **This is wrong!**
//! - **MUST integrate real embedding model for production** (ONNX, Candle, or API)
//! - See [`agenticdb`] module docs and `/examples/onnx-embeddings` for integration
//! - **Advanced Features**: Conformal prediction, hybrid search - functional but less tested
//!
//! ## What This Is NOT
//!
//! - This is NOT a complete RAG solution - you need external embedding models
//! - Examples use mock embeddings for demonstration only
#![allow(missing_docs)]
#![warn(clippy::all)]
#![allow(clippy::incompatible_msrv)]
pub mod advanced_features;
// AgenticDB requires storage feature
#[cfg(feature = "storage")]
pub mod agenticdb;
pub mod distance;
pub mod embeddings;
pub mod error;
pub mod index;
pub mod quantization;
// Storage backends - conditional compilation based on features
#[cfg(feature = "storage")]
pub mod storage;
#[cfg(not(feature = "storage"))]
pub mod storage_memory;
#[cfg(not(feature = "storage"))]
pub use storage_memory as storage;
pub mod types;
pub mod vector_db;
// Performance optimization modules
pub mod arena;
pub mod cache_optimized;
#[cfg(all(feature = "parallel", not(target_arch = "wasm32")))]
pub mod lockfree;
pub mod simd_intrinsics;
/// Unified Memory Pool and Paging System (ADR-006)
///
/// High-performance paged memory management for LLM inference:
/// - 2MB page-granular allocation with best-fit strategy
/// - Reference-counted pinning with RAII guards
/// - LRU eviction with hysteresis for thrash prevention
/// - Multi-tenant isolation with Hot/Warm/Cold residency tiers
pub mod memory;
/// Advanced techniques: hypergraphs, learned indexes, neural hashing, TDA (Phase 6)
pub mod advanced;
// Re-exports
pub use advanced_features::{
ConformalConfig, ConformalPredictor, EnhancedPQ, FilterExpression, FilterStrategy,
FilteredSearch, HybridConfig, HybridSearch, MMRConfig, MMRSearch, PQConfig, PredictionSet,
BM25,
};
#[cfg(feature = "storage")]
pub use agenticdb::{
AgenticDB, PolicyAction, PolicyEntry, PolicyMemoryStore, SessionStateIndex, SessionTurn,
WitnessEntry, WitnessLog,
};
#[cfg(feature = "api-embeddings")]
pub use embeddings::ApiEmbedding;
pub use embeddings::{BoxedEmbeddingProvider, EmbeddingProvider, HashEmbedding};
#[cfg(feature = "real-embeddings")]
pub use embeddings::CandleEmbedding;
// Compile-time warning about AgenticDB limitations
#[cfg(feature = "storage")]
#[allow(deprecated, clippy::let_unit_value)]
const _: () = {
#[deprecated(
since = "0.1.0",
note = "AgenticDB uses placeholder hash-based embeddings. For semantic search, integrate a real embedding model (ONNX, Candle, or API). See /examples/onnx-embeddings for production setup."
)]
const AGENTICDB_EMBEDDING_WARNING: () = ();
let _ = AGENTICDB_EMBEDDING_WARNING;
};
pub use error::{Result, RuvectorError};
pub use types::{DistanceMetric, SearchQuery, SearchResult, VectorEntry, VectorId};
pub use vector_db::VectorDB;
// Quantization types (ADR-001)
pub use quantization::{
BinaryQuantized, Int4Quantized, ProductQuantized, QuantizedVector, ScalarQuantized,
};
// Memory management types (ADR-001)
pub use arena::{Arena, ArenaVec, BatchVectorAllocator, CacheAlignedVec, CACHE_LINE_SIZE};
// Lock-free structures (requires parallel feature)
#[cfg(all(feature = "parallel", not(target_arch = "wasm32")))]
pub use lockfree::{
AtomicVectorPool, BatchItem, BatchResult, LockFreeBatchProcessor, LockFreeCounter,
LockFreeStats, LockFreeWorkQueue, ObjectPool, PooledObject, PooledVector, StatsSnapshot,
VectorPoolStats,
};
// Cache-optimized storage
pub use cache_optimized::SoAVectorStorage;
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_version() {
// Verify version matches workspace - use dynamic check instead of hardcoded value
let version = env!("CARGO_PKG_VERSION");
assert!(!version.is_empty(), "Version should not be empty");
assert!(version.starts_with("0.1."), "Version should be 0.1.x");
}
}

View File

@@ -0,0 +1,590 @@
//! Lock-free data structures for high-concurrency operations
//!
//! This module provides lock-free implementations of common data structures
//! to minimize contention and improve scalability.
//!
//! Note: This module requires the `parallel` feature and is not available on WASM.
#![cfg(all(feature = "parallel", not(target_arch = "wasm32")))]
use crossbeam::queue::{ArrayQueue, SegQueue};
use crossbeam::utils::CachePadded;
use std::sync::atomic::{AtomicU64, AtomicUsize, Ordering};
use std::sync::Arc;
/// Lock-free counter with cache padding to prevent false sharing
#[repr(align(64))]
pub struct LockFreeCounter {
value: CachePadded<AtomicU64>,
}
impl LockFreeCounter {
pub fn new(initial: u64) -> Self {
Self {
value: CachePadded::new(AtomicU64::new(initial)),
}
}
#[inline]
pub fn increment(&self) -> u64 {
self.value.fetch_add(1, Ordering::Relaxed)
}
#[inline]
pub fn get(&self) -> u64 {
self.value.load(Ordering::Relaxed)
}
#[inline]
pub fn add(&self, delta: u64) -> u64 {
self.value.fetch_add(delta, Ordering::Relaxed)
}
}
/// Lock-free statistics collector
pub struct LockFreeStats {
queries: CachePadded<AtomicU64>,
inserts: CachePadded<AtomicU64>,
deletes: CachePadded<AtomicU64>,
total_latency_ns: CachePadded<AtomicU64>,
}
impl LockFreeStats {
pub fn new() -> Self {
Self {
queries: CachePadded::new(AtomicU64::new(0)),
inserts: CachePadded::new(AtomicU64::new(0)),
deletes: CachePadded::new(AtomicU64::new(0)),
total_latency_ns: CachePadded::new(AtomicU64::new(0)),
}
}
#[inline]
pub fn record_query(&self, latency_ns: u64) {
self.queries.fetch_add(1, Ordering::Relaxed);
self.total_latency_ns
.fetch_add(latency_ns, Ordering::Relaxed);
}
#[inline]
pub fn record_insert(&self) {
self.inserts.fetch_add(1, Ordering::Relaxed);
}
#[inline]
pub fn record_delete(&self) {
self.deletes.fetch_add(1, Ordering::Relaxed);
}
pub fn snapshot(&self) -> StatsSnapshot {
let queries = self.queries.load(Ordering::Relaxed);
let total_latency = self.total_latency_ns.load(Ordering::Relaxed);
StatsSnapshot {
queries,
inserts: self.inserts.load(Ordering::Relaxed),
deletes: self.deletes.load(Ordering::Relaxed),
avg_latency_ns: if queries > 0 {
total_latency / queries
} else {
0
},
}
}
}
impl Default for LockFreeStats {
fn default() -> Self {
Self::new()
}
}
#[derive(Debug, Clone)]
pub struct StatsSnapshot {
pub queries: u64,
pub inserts: u64,
pub deletes: u64,
pub avg_latency_ns: u64,
}
/// Lock-free object pool for reducing allocations
pub struct ObjectPool<T> {
queue: Arc<SegQueue<T>>,
factory: Arc<dyn Fn() -> T + Send + Sync>,
capacity: usize,
allocated: AtomicUsize,
}
impl<T> ObjectPool<T> {
pub fn new<F>(capacity: usize, factory: F) -> Self
where
F: Fn() -> T + Send + Sync + 'static,
{
Self {
queue: Arc::new(SegQueue::new()),
factory: Arc::new(factory),
capacity,
allocated: AtomicUsize::new(0),
}
}
/// Get an object from the pool or create a new one
pub fn acquire(&self) -> PooledObject<T> {
let object = self.queue.pop().unwrap_or_else(|| {
let current = self.allocated.fetch_add(1, Ordering::Relaxed);
if current < self.capacity {
(self.factory)()
} else {
self.allocated.fetch_sub(1, Ordering::Relaxed);
// Wait for an object to be returned
loop {
if let Some(obj) = self.queue.pop() {
break obj;
}
std::hint::spin_loop();
}
}
});
PooledObject {
object: Some(object),
pool: Arc::clone(&self.queue),
}
}
}
/// RAII wrapper for pooled objects
pub struct PooledObject<T> {
object: Option<T>,
pool: Arc<SegQueue<T>>,
}
impl<T> PooledObject<T> {
pub fn get(&self) -> &T {
self.object.as_ref().unwrap()
}
pub fn get_mut(&mut self) -> &mut T {
self.object.as_mut().unwrap()
}
}
impl<T> Drop for PooledObject<T> {
fn drop(&mut self) {
if let Some(object) = self.object.take() {
self.pool.push(object);
}
}
}
impl<T> std::ops::Deref for PooledObject<T> {
type Target = T;
fn deref(&self) -> &Self::Target {
self.object.as_ref().unwrap()
}
}
impl<T> std::ops::DerefMut for PooledObject<T> {
fn deref_mut(&mut self) -> &mut Self::Target {
self.object.as_mut().unwrap()
}
}
/// Lock-free ring buffer for work distribution
pub struct LockFreeWorkQueue<T> {
queue: ArrayQueue<T>,
}
impl<T> LockFreeWorkQueue<T> {
pub fn new(capacity: usize) -> Self {
Self {
queue: ArrayQueue::new(capacity),
}
}
#[inline]
pub fn try_push(&self, item: T) -> Result<(), T> {
self.queue.push(item)
}
#[inline]
pub fn try_pop(&self) -> Option<T> {
self.queue.pop()
}
#[inline]
pub fn len(&self) -> usize {
self.queue.len()
}
#[inline]
pub fn is_empty(&self) -> bool {
self.queue.is_empty()
}
}
/// Atomic vector pool for lock-free vector operations (ADR-001)
///
/// Provides a pool of pre-allocated vectors that can be acquired and released
/// without locking, ideal for high-throughput batch operations.
pub struct AtomicVectorPool {
/// Pool of available vectors
pool: SegQueue<Vec<f32>>,
/// Dimensions per vector
dimensions: usize,
/// Maximum pool size
max_size: usize,
/// Current pool size
size: AtomicUsize,
/// Total allocations
total_allocations: AtomicU64,
/// Pool hits (reused vectors)
pool_hits: AtomicU64,
}
impl AtomicVectorPool {
/// Create a new atomic vector pool
pub fn new(dimensions: usize, initial_size: usize, max_size: usize) -> Self {
let pool = SegQueue::new();
// Pre-allocate vectors
for _ in 0..initial_size {
pool.push(vec![0.0; dimensions]);
}
Self {
pool,
dimensions,
max_size,
size: AtomicUsize::new(initial_size),
total_allocations: AtomicU64::new(0),
pool_hits: AtomicU64::new(0),
}
}
/// Acquire a vector from the pool (or allocate new one)
pub fn acquire(&self) -> PooledVector<'_> {
self.total_allocations.fetch_add(1, Ordering::Relaxed);
let vec = if let Some(mut v) = self.pool.pop() {
self.pool_hits.fetch_add(1, Ordering::Relaxed);
// Clear the vector for reuse
v.fill(0.0);
v
} else {
// Allocate new vector
vec![0.0; self.dimensions]
};
PooledVector {
vec: Some(vec),
pool: self,
}
}
/// Return a vector to the pool
fn return_to_pool(&self, vec: Vec<f32>) {
let current_size = self.size.load(Ordering::Relaxed);
if current_size < self.max_size {
self.pool.push(vec);
self.size.fetch_add(1, Ordering::Relaxed);
}
// If pool is full, vector is dropped
}
/// Get pool statistics
pub fn stats(&self) -> VectorPoolStats {
let total = self.total_allocations.load(Ordering::Relaxed);
let hits = self.pool_hits.load(Ordering::Relaxed);
let hit_rate = if total > 0 {
hits as f64 / total as f64
} else {
0.0
};
VectorPoolStats {
total_allocations: total,
pool_hits: hits,
hit_rate,
current_size: self.size.load(Ordering::Relaxed),
max_size: self.max_size,
}
}
/// Get dimensions
pub fn dimensions(&self) -> usize {
self.dimensions
}
}
/// Statistics for the vector pool
#[derive(Debug, Clone)]
pub struct VectorPoolStats {
pub total_allocations: u64,
pub pool_hits: u64,
pub hit_rate: f64,
pub current_size: usize,
pub max_size: usize,
}
/// RAII wrapper for pooled vectors
pub struct PooledVector<'a> {
vec: Option<Vec<f32>>,
pool: &'a AtomicVectorPool,
}
impl<'a> PooledVector<'a> {
/// Get as slice
pub fn as_slice(&self) -> &[f32] {
self.vec.as_ref().unwrap()
}
/// Get as mutable slice
pub fn as_mut_slice(&mut self) -> &mut [f32] {
self.vec.as_mut().unwrap()
}
/// Copy from source slice
pub fn copy_from(&mut self, src: &[f32]) {
let vec = self.vec.as_mut().unwrap();
assert_eq!(vec.len(), src.len(), "Dimension mismatch");
vec.copy_from_slice(src);
}
/// Detach the vector from the pool (it won't be returned)
pub fn detach(mut self) -> Vec<f32> {
self.vec.take().unwrap()
}
}
impl<'a> Drop for PooledVector<'a> {
fn drop(&mut self) {
if let Some(vec) = self.vec.take() {
self.pool.return_to_pool(vec);
}
}
}
impl<'a> std::ops::Deref for PooledVector<'a> {
type Target = [f32];
fn deref(&self) -> &[f32] {
self.as_slice()
}
}
impl<'a> std::ops::DerefMut for PooledVector<'a> {
fn deref_mut(&mut self) -> &mut [f32] {
self.as_mut_slice()
}
}
/// Lock-free batch processor for parallel vector operations (ADR-001)
///
/// Distributes work across multiple workers without contention.
pub struct LockFreeBatchProcessor {
/// Work queue for pending items
work_queue: ArrayQueue<BatchItem>,
/// Results queue
results_queue: SegQueue<BatchResult>,
/// Pending count
pending: AtomicUsize,
/// Completed count
completed: AtomicUsize,
}
/// Item in the batch work queue
#[derive(Debug)]
pub struct BatchItem {
pub id: u64,
pub data: Vec<f32>,
}
/// Result from batch processing
pub struct BatchResult {
pub id: u64,
pub result: Vec<f32>,
}
impl LockFreeBatchProcessor {
/// Create a new batch processor with given capacity
pub fn new(capacity: usize) -> Self {
Self {
work_queue: ArrayQueue::new(capacity),
results_queue: SegQueue::new(),
pending: AtomicUsize::new(0),
completed: AtomicUsize::new(0),
}
}
/// Submit a batch item for processing
pub fn submit(&self, item: BatchItem) -> Result<(), BatchItem> {
self.pending.fetch_add(1, Ordering::Relaxed);
self.work_queue.push(item)
}
/// Try to get a work item (for workers)
pub fn try_get_work(&self) -> Option<BatchItem> {
self.work_queue.pop()
}
/// Submit a result (from workers)
pub fn submit_result(&self, result: BatchResult) {
self.completed.fetch_add(1, Ordering::Relaxed);
self.results_queue.push(result);
}
/// Collect all available results
pub fn collect_results(&self) -> Vec<BatchResult> {
let mut results = Vec::new();
while let Some(result) = self.results_queue.pop() {
results.push(result);
}
results
}
/// Get pending count
pub fn pending(&self) -> usize {
self.pending.load(Ordering::Relaxed)
}
/// Get completed count
pub fn completed(&self) -> usize {
self.completed.load(Ordering::Relaxed)
}
/// Check if all work is done
pub fn is_done(&self) -> bool {
self.pending() == self.completed()
}
}
#[cfg(test)]
mod tests {
use super::*;
use std::thread;
#[test]
fn test_lockfree_counter() {
let counter = Arc::new(LockFreeCounter::new(0));
let mut handles = vec![];
for _ in 0..10 {
let counter_clone = Arc::clone(&counter);
handles.push(thread::spawn(move || {
for _ in 0..1000 {
counter_clone.increment();
}
}));
}
for handle in handles {
handle.join().unwrap();
}
assert_eq!(counter.get(), 10000);
}
#[test]
fn test_object_pool() {
let pool = ObjectPool::new(4, || Vec::<u8>::with_capacity(1024));
let mut obj1 = pool.acquire();
obj1.push(1);
assert_eq!(obj1.len(), 1);
drop(obj1);
let obj2 = pool.acquire();
// Object should be reused (but cleared state is not guaranteed)
assert!(obj2.capacity() >= 1024);
}
#[test]
fn test_stats_collector() {
let stats = LockFreeStats::new();
stats.record_query(1000);
stats.record_query(2000);
stats.record_insert();
let snapshot = stats.snapshot();
assert_eq!(snapshot.queries, 2);
assert_eq!(snapshot.inserts, 1);
assert_eq!(snapshot.avg_latency_ns, 1500);
}
#[test]
fn test_atomic_vector_pool() {
let pool = AtomicVectorPool::new(4, 2, 10);
// Acquire first vector
let mut v1 = pool.acquire();
v1.copy_from(&[1.0, 2.0, 3.0, 4.0]);
assert_eq!(v1.as_slice(), &[1.0, 2.0, 3.0, 4.0]);
// Acquire second vector
let mut v2 = pool.acquire();
v2.copy_from(&[5.0, 6.0, 7.0, 8.0]);
// Stats should show allocations
let stats = pool.stats();
assert_eq!(stats.total_allocations, 2);
}
#[test]
fn test_vector_pool_reuse() {
let pool = AtomicVectorPool::new(3, 1, 5);
// Acquire and release
{
let mut v = pool.acquire();
v.copy_from(&[1.0, 2.0, 3.0]);
} // v is returned to pool here
// Acquire again - should be a pool hit
let _v2 = pool.acquire();
let stats = pool.stats();
assert_eq!(stats.total_allocations, 2);
assert!(stats.pool_hits >= 1, "Should have at least one pool hit");
}
#[test]
fn test_batch_processor() {
let processor = LockFreeBatchProcessor::new(10);
// Submit work items
processor
.submit(BatchItem {
id: 1,
data: vec![1.0, 2.0],
})
.unwrap();
processor
.submit(BatchItem {
id: 2,
data: vec![3.0, 4.0],
})
.unwrap();
assert_eq!(processor.pending(), 2);
// Process work
while let Some(item) = processor.try_get_work() {
let result = BatchResult {
id: item.id,
result: item.data.iter().map(|x| x * 2.0).collect(),
};
processor.submit_result(result);
}
assert!(processor.is_done());
assert_eq!(processor.completed(), 2);
// Collect results
let results = processor.collect_results();
assert_eq!(results.len(), 2);
}
}

View File

@@ -0,0 +1,38 @@
//! Memory management utilities for ruvector-core
//!
//! This module provides memory-efficient data structures and utilities
//! for vector storage operations.
/// Memory pool for vector allocations.
#[derive(Debug, Default)]
pub struct MemoryPool {
/// Total allocated bytes.
allocated: usize,
/// Maximum allocation limit.
limit: Option<usize>,
}
impl MemoryPool {
/// Create a new memory pool.
pub fn new() -> Self {
Self::default()
}
/// Create a memory pool with a limit.
pub fn with_limit(limit: usize) -> Self {
Self {
allocated: 0,
limit: Some(limit),
}
}
/// Get currently allocated bytes.
pub fn allocated(&self) -> usize {
self.allocated
}
/// Get the allocation limit, if any.
pub fn limit(&self) -> Option<usize> {
self.limit
}
}

View File

@@ -0,0 +1,934 @@
//! Quantization techniques for memory compression
//!
//! This module provides tiered quantization strategies as specified in ADR-001:
//!
//! | Quantization | Compression | Use Case |
//! |--------------|-------------|----------|
//! | Scalar (u8) | 4x | Warm data (40-80% access) |
//! | Int4 | 8x | Cool data (10-40% access) |
//! | Product | 8-16x | Cold data (1-10% access) |
//! | Binary | 32x | Archive (<1% access) |
//!
//! ## Performance Optimizations v2
//!
//! - SIMD-accelerated distance calculations for scalar (int8) quantization
//! - SIMD popcnt for binary hamming distance
//! - 4x loop unrolling for better instruction-level parallelism
//! - Separate accumulator strategy to reduce data dependencies
use crate::error::Result;
use serde::{Deserialize, Serialize};
/// Trait for quantized vector representations
pub trait QuantizedVector: Send + Sync {
/// Quantize a full-precision vector
fn quantize(vector: &[f32]) -> Self;
/// Calculate distance to another quantized vector
fn distance(&self, other: &Self) -> f32;
/// Reconstruct approximate full-precision vector
fn reconstruct(&self) -> Vec<f32>;
}
/// Scalar quantization to int8 (4x compression)
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct ScalarQuantized {
/// Quantized values (int8)
pub data: Vec<u8>,
/// Minimum value for dequantization
pub min: f32,
/// Scale factor for dequantization
pub scale: f32,
}
impl QuantizedVector for ScalarQuantized {
fn quantize(vector: &[f32]) -> Self {
let min = vector.iter().copied().fold(f32::INFINITY, f32::min);
let max = vector.iter().copied().fold(f32::NEG_INFINITY, f32::max);
// Handle edge case where all values are the same (scale = 0)
let scale = if (max - min).abs() < f32::EPSILON {
1.0 // Arbitrary non-zero scale when all values are identical
} else {
(max - min) / 255.0
};
let data = vector
.iter()
.map(|&v| ((v - min) / scale).round().clamp(0.0, 255.0) as u8)
.collect();
Self { data, min, scale }
}
fn distance(&self, other: &Self) -> f32 {
// Fast int8 distance calculation with SIMD optimization
// Use i32 to avoid overflow: max diff is 255, and 255*255=65025 fits in i32
// Scale handling: We use the average of both scales for balanced comparison.
// Using max(scale) would bias toward the vector with larger range,
// while average provides a more symmetric distance metric.
// This ensures distance(a, b) ≈ distance(b, a) in the reconstructed space.
let avg_scale = (self.scale + other.scale) / 2.0;
// Use SIMD-optimized version for larger vectors
#[cfg(target_arch = "aarch64")]
{
if self.data.len() >= 16 {
return unsafe { scalar_distance_neon(&self.data, &other.data) }.sqrt() * avg_scale;
}
}
#[cfg(target_arch = "x86_64")]
{
if self.data.len() >= 32 && is_x86_feature_detected!("avx2") {
return unsafe { scalar_distance_avx2(&self.data, &other.data) }.sqrt() * avg_scale;
}
}
// Scalar fallback with 4x loop unrolling for better ILP
scalar_distance_scalar(&self.data, &other.data).sqrt() * avg_scale
}
fn reconstruct(&self) -> Vec<f32> {
self.data
.iter()
.map(|&v| self.min + (v as f32) * self.scale)
.collect()
}
}
/// Product quantization (8-16x compression)
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct ProductQuantized {
/// Quantized codes (one per subspace)
pub codes: Vec<u8>,
/// Codebooks for each subspace
pub codebooks: Vec<Vec<Vec<f32>>>,
}
impl ProductQuantized {
/// Train product quantization on a set of vectors
pub fn train(
vectors: &[Vec<f32>],
num_subspaces: usize,
codebook_size: usize,
iterations: usize,
) -> Result<Self> {
if vectors.is_empty() {
return Err(crate::error::RuvectorError::InvalidInput(
"Cannot train on empty vector set".into(),
));
}
if vectors[0].is_empty() {
return Err(crate::error::RuvectorError::InvalidInput(
"Cannot train on vectors with zero dimensions".into(),
));
}
if codebook_size > 256 {
return Err(crate::error::RuvectorError::InvalidParameter(format!(
"Codebook size {} exceeds u8 maximum of 256",
codebook_size
)));
}
let dimensions = vectors[0].len();
let subspace_dim = dimensions / num_subspaces;
let mut codebooks = Vec::with_capacity(num_subspaces);
// Train codebook for each subspace using k-means
for subspace_idx in 0..num_subspaces {
let start = subspace_idx * subspace_dim;
let end = start + subspace_dim;
// Extract subspace vectors
let subspace_vectors: Vec<Vec<f32>> =
vectors.iter().map(|v| v[start..end].to_vec()).collect();
// Run k-means
let codebook = kmeans_clustering(&subspace_vectors, codebook_size, iterations);
codebooks.push(codebook);
}
Ok(Self {
codes: vec![],
codebooks,
})
}
/// Quantize a vector using trained codebooks
pub fn encode(&self, vector: &[f32]) -> Vec<u8> {
let num_subspaces = self.codebooks.len();
let subspace_dim = vector.len() / num_subspaces;
let mut codes = Vec::with_capacity(num_subspaces);
for (subspace_idx, codebook) in self.codebooks.iter().enumerate() {
let start = subspace_idx * subspace_dim;
let end = start + subspace_dim;
let subvector = &vector[start..end];
// Find nearest centroid
let code = codebook
.iter()
.enumerate()
.min_by(|(_, a), (_, b)| {
let dist_a = euclidean_squared(subvector, a);
let dist_b = euclidean_squared(subvector, b);
dist_a.partial_cmp(&dist_b).unwrap()
})
.map(|(idx, _)| idx as u8)
.unwrap_or(0);
codes.push(code);
}
codes
}
}
/// Int4 quantization (8x compression)
///
/// Quantizes f32 to 4-bit integers (0-15), packing 2 values per byte.
/// Provides 8x compression with better precision than binary.
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct Int4Quantized {
/// Packed 4-bit values (2 per byte)
pub data: Vec<u8>,
/// Minimum value for dequantization
pub min: f32,
/// Scale factor for dequantization
pub scale: f32,
/// Number of dimensions
pub dimensions: usize,
}
impl Int4Quantized {
/// Quantize a vector to 4-bit representation
pub fn quantize(vector: &[f32]) -> Self {
let min = vector.iter().copied().fold(f32::INFINITY, f32::min);
let max = vector.iter().copied().fold(f32::NEG_INFINITY, f32::max);
// Handle edge case where all values are the same
let scale = if (max - min).abs() < f32::EPSILON {
1.0
} else {
(max - min) / 15.0 // 4-bit gives 0-15 range
};
let dimensions = vector.len();
let num_bytes = dimensions.div_ceil(2);
let mut data = vec![0u8; num_bytes];
for (i, &v) in vector.iter().enumerate() {
let quantized = ((v - min) / scale).round().clamp(0.0, 15.0) as u8;
let byte_idx = i / 2;
if i % 2 == 0 {
// Low nibble
data[byte_idx] |= quantized;
} else {
// High nibble
data[byte_idx] |= quantized << 4;
}
}
Self {
data,
min,
scale,
dimensions,
}
}
/// Calculate distance to another Int4 quantized vector
pub fn distance(&self, other: &Self) -> f32 {
assert_eq!(self.dimensions, other.dimensions);
// Use average scale for balanced comparison
let avg_scale = (self.scale + other.scale) / 2.0;
let _avg_min = (self.min + other.min) / 2.0;
let mut sum_sq = 0i32;
for i in 0..self.dimensions {
let byte_idx = i / 2;
let shift = if i % 2 == 0 { 0 } else { 4 };
let a = ((self.data[byte_idx] >> shift) & 0x0F) as i32;
let b = ((other.data[byte_idx] >> shift) & 0x0F) as i32;
let diff = a - b;
sum_sq += diff * diff;
}
(sum_sq as f32).sqrt() * avg_scale
}
/// Reconstruct approximate full-precision vector
pub fn reconstruct(&self) -> Vec<f32> {
let mut result = Vec::with_capacity(self.dimensions);
for i in 0..self.dimensions {
let byte_idx = i / 2;
let shift = if i % 2 == 0 { 0 } else { 4 };
let quantized = (self.data[byte_idx] >> shift) & 0x0F;
result.push(self.min + (quantized as f32) * self.scale);
}
result
}
/// Get compression ratio (8x for Int4)
pub fn compression_ratio() -> f32 {
8.0 // f32 (4 bytes) -> 4 bits (0.5 bytes)
}
}
/// Binary quantization (32x compression)
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct BinaryQuantized {
/// Binary representation (1 bit per dimension, packed into bytes)
pub bits: Vec<u8>,
/// Number of dimensions
pub dimensions: usize,
}
impl QuantizedVector for BinaryQuantized {
fn quantize(vector: &[f32]) -> Self {
let dimensions = vector.len();
let num_bytes = dimensions.div_ceil(8);
let mut bits = vec![0u8; num_bytes];
for (i, &v) in vector.iter().enumerate() {
if v > 0.0 {
let byte_idx = i / 8;
let bit_idx = i % 8;
bits[byte_idx] |= 1 << bit_idx;
}
}
Self { bits, dimensions }
}
fn distance(&self, other: &Self) -> f32 {
// Hamming distance using SIMD-friendly operations
Self::hamming_distance_fast(&self.bits, &other.bits) as f32
}
fn reconstruct(&self) -> Vec<f32> {
let mut result = Vec::with_capacity(self.dimensions);
for i in 0..self.dimensions {
let byte_idx = i / 8;
let bit_idx = i % 8;
let bit = (self.bits[byte_idx] >> bit_idx) & 1;
result.push(if bit == 1 { 1.0 } else { -1.0 });
}
result
}
}
impl BinaryQuantized {
/// Fast hamming distance using SIMD-optimized operations
///
/// Uses hardware POPCNT on x86_64 or NEON vcnt on ARM64 for optimal performance.
/// Processes 16 bytes at a time on ARM64, 8 bytes at a time on x86_64.
/// Falls back to 64-bit operations for remainders.
pub fn hamming_distance_fast(a: &[u8], b: &[u8]) -> u32 {
// Use SIMD-optimized version based on architecture
#[cfg(target_arch = "aarch64")]
{
if a.len() >= 16 {
return unsafe { hamming_distance_neon(a, b) };
}
}
#[cfg(target_arch = "x86_64")]
{
if a.len() >= 8 && is_x86_feature_detected!("popcnt") {
return unsafe { hamming_distance_simd_x86(a, b) };
}
}
// Scalar fallback using 64-bit operations
let mut distance = 0u32;
// Process 8 bytes at a time using u64
let chunks_a = a.chunks_exact(8);
let chunks_b = b.chunks_exact(8);
let remainder_a = chunks_a.remainder();
let remainder_b = chunks_b.remainder();
for (chunk_a, chunk_b) in chunks_a.zip(chunks_b) {
let a_u64 = u64::from_le_bytes(chunk_a.try_into().unwrap());
let b_u64 = u64::from_le_bytes(chunk_b.try_into().unwrap());
distance += (a_u64 ^ b_u64).count_ones();
}
// Handle remainder bytes
for (&a_byte, &b_byte) in remainder_a.iter().zip(remainder_b) {
distance += (a_byte ^ b_byte).count_ones();
}
distance
}
/// Compute normalized hamming similarity (0.0 to 1.0)
pub fn similarity(&self, other: &Self) -> f32 {
let distance = self.distance(other);
1.0 - (distance / self.dimensions as f32)
}
/// Get compression ratio (32x for binary)
pub fn compression_ratio() -> f32 {
32.0 // f32 (4 bytes = 32 bits) -> 1 bit
}
/// Convert to bytes for storage
pub fn to_bytes(&self) -> &[u8] {
&self.bits
}
/// Create from bytes
pub fn from_bytes(bits: Vec<u8>, dimensions: usize) -> Self {
Self { bits, dimensions }
}
}
// ============================================================================
// Helper functions for scalar quantization distance
// ============================================================================
/// Scalar fallback for scalar quantization distance (sum of squared differences)
fn scalar_distance_scalar(a: &[u8], b: &[u8]) -> f32 {
let mut sum_sq = 0i32;
// 4x loop unrolling for better ILP
let chunks = a.len() / 4;
for i in 0..chunks {
let idx = i * 4;
let d0 = (a[idx] as i32) - (b[idx] as i32);
let d1 = (a[idx + 1] as i32) - (b[idx + 1] as i32);
let d2 = (a[idx + 2] as i32) - (b[idx + 2] as i32);
let d3 = (a[idx + 3] as i32) - (b[idx + 3] as i32);
sum_sq += d0 * d0 + d1 * d1 + d2 * d2 + d3 * d3;
}
// Handle remainder
for i in (chunks * 4)..a.len() {
let diff = (a[i] as i32) - (b[i] as i32);
sum_sq += diff * diff;
}
sum_sq as f32
}
/// NEON SIMD distance for scalar quantization
///
/// # Safety
/// Caller must ensure a.len() == b.len()
#[cfg(target_arch = "aarch64")]
#[inline(always)]
unsafe fn scalar_distance_neon(a: &[u8], b: &[u8]) -> f32 {
use std::arch::aarch64::*;
let len = a.len();
let a_ptr = a.as_ptr();
let b_ptr = b.as_ptr();
let mut sum = vdupq_n_s32(0);
// Process 8 bytes at a time
let chunks = len / 8;
let mut idx = 0usize;
for _ in 0..chunks {
// Load 8 u8 values
let va = vld1_u8(a_ptr.add(idx));
let vb = vld1_u8(b_ptr.add(idx));
// Zero-extend u8 to u16
let va_u16 = vmovl_u8(va);
let vb_u16 = vmovl_u8(vb);
// Convert to signed for subtraction
let va_s16 = vreinterpretq_s16_u16(va_u16);
let vb_s16 = vreinterpretq_s16_u16(vb_u16);
// Compute difference
let diff = vsubq_s16(va_s16, vb_s16);
// Square and accumulate
let prod_lo = vmull_s16(vget_low_s16(diff), vget_low_s16(diff));
let prod_hi = vmull_s16(vget_high_s16(diff), vget_high_s16(diff));
sum = vaddq_s32(sum, prod_lo);
sum = vaddq_s32(sum, prod_hi);
idx += 8;
}
let mut total = vaddvq_s32(sum);
// Handle remainder with bounds-check elimination
for i in (chunks * 8)..len {
let diff = (*a.get_unchecked(i) as i32) - (*b.get_unchecked(i) as i32);
total += diff * diff;
}
total as f32
}
/// AVX2 SIMD distance for scalar quantization
#[cfg(target_arch = "x86_64")]
#[target_feature(enable = "avx2")]
#[inline]
unsafe fn scalar_distance_avx2(a: &[u8], b: &[u8]) -> f32 {
use std::arch::x86_64::*;
let len = a.len();
let mut sum = _mm256_setzero_si256();
// Process 16 bytes at a time
let chunks = len / 16;
for i in 0..chunks {
let idx = i * 16;
// Load 16 u8 values
let va = _mm_loadu_si128(a.as_ptr().add(idx) as *const __m128i);
let vb = _mm_loadu_si128(b.as_ptr().add(idx) as *const __m128i);
// Zero-extend u8 to i16 (low and high halves)
let va_lo = _mm256_cvtepu8_epi16(va);
let vb_lo = _mm256_cvtepu8_epi16(vb);
// Compute difference
let diff = _mm256_sub_epi16(va_lo, vb_lo);
// Square (multiply i16 * i16 -> i32)
let prod = _mm256_madd_epi16(diff, diff);
// Accumulate
sum = _mm256_add_epi32(sum, prod);
}
// Horizontal sum
let sum_lo = _mm256_castsi256_si128(sum);
let sum_hi = _mm256_extracti128_si256(sum, 1);
let sum_128 = _mm_add_epi32(sum_lo, sum_hi);
let shuffle = _mm_shuffle_epi32(sum_128, 0b10_11_00_01);
let sum_64 = _mm_add_epi32(sum_128, shuffle);
let shuffle2 = _mm_shuffle_epi32(sum_64, 0b00_00_10_10);
let final_sum = _mm_add_epi32(sum_64, shuffle2);
let mut total = _mm_cvtsi128_si32(final_sum);
// Handle remainder
for i in (chunks * 16)..len {
let diff = (a[i] as i32) - (b[i] as i32);
total += diff * diff;
}
total as f32
}
// Helper functions
fn euclidean_squared(a: &[f32], b: &[f32]) -> f32 {
a.iter()
.zip(b)
.map(|(&x, &y)| {
let diff = x - y;
diff * diff
})
.sum()
}
fn kmeans_clustering(vectors: &[Vec<f32>], k: usize, iterations: usize) -> Vec<Vec<f32>> {
use rand::seq::SliceRandom;
use rand::thread_rng;
let mut rng = thread_rng();
// Initialize centroids randomly
let mut centroids: Vec<Vec<f32>> = vectors.choose_multiple(&mut rng, k).cloned().collect();
for _ in 0..iterations {
// Assign vectors to nearest centroid
let mut assignments = vec![Vec::new(); k];
for vector in vectors {
let nearest = centroids
.iter()
.enumerate()
.min_by(|(_, a), (_, b)| {
let dist_a = euclidean_squared(vector, a);
let dist_b = euclidean_squared(vector, b);
dist_a.partial_cmp(&dist_b).unwrap()
})
.map(|(idx, _)| idx)
.unwrap_or(0);
assignments[nearest].push(vector.clone());
}
// Update centroids
for (centroid, assigned) in centroids.iter_mut().zip(&assignments) {
if !assigned.is_empty() {
let dim = centroid.len();
*centroid = vec![0.0; dim];
for vector in assigned {
for (i, &v) in vector.iter().enumerate() {
centroid[i] += v;
}
}
let count = assigned.len() as f32;
for v in centroid.iter_mut() {
*v /= count;
}
}
}
}
centroids
}
// =============================================================================
// SIMD-Optimized Distance Calculations for Quantized Vectors
// =============================================================================
// NOTE: scalar_distance_scalar is already defined above (lines 404-425)
// NOTE: scalar_distance_neon is already defined above (lines 430-473)
// NOTE: scalar_distance_avx2 is already defined above (lines 479-540)
// This section uses the existing implementations for consistency
/// SIMD-optimized hamming distance using popcnt
#[cfg(target_arch = "x86_64")]
#[target_feature(enable = "popcnt")]
#[inline]
unsafe fn hamming_distance_simd_x86(a: &[u8], b: &[u8]) -> u32 {
use std::arch::x86_64::*;
let mut distance = 0u64;
// Process 8 bytes at a time using u64 with hardware popcnt
let chunks_a = a.chunks_exact(8);
let chunks_b = b.chunks_exact(8);
let remainder_a = chunks_a.remainder();
let remainder_b = chunks_b.remainder();
for (chunk_a, chunk_b) in chunks_a.zip(chunks_b) {
let a_u64 = u64::from_le_bytes(chunk_a.try_into().unwrap());
let b_u64 = u64::from_le_bytes(chunk_b.try_into().unwrap());
distance += _popcnt64((a_u64 ^ b_u64) as i64) as u64;
}
// Handle remainder
for (&a_byte, &b_byte) in remainder_a.iter().zip(remainder_b) {
distance += (a_byte ^ b_byte).count_ones() as u64;
}
distance as u32
}
/// NEON-optimized hamming distance for ARM64
///
/// # Safety
/// Caller must ensure a.len() == b.len()
#[cfg(target_arch = "aarch64")]
#[inline(always)]
unsafe fn hamming_distance_neon(a: &[u8], b: &[u8]) -> u32 {
use std::arch::aarch64::*;
let len = a.len();
let a_ptr = a.as_ptr();
let b_ptr = b.as_ptr();
let chunks = len / 16;
let mut idx = 0usize;
let mut sum = vdupq_n_u8(0);
for _ in 0..chunks {
// Load 16 bytes
let a_vec = vld1q_u8(a_ptr.add(idx));
let b_vec = vld1q_u8(b_ptr.add(idx));
// XOR and count bits using vcntq_u8 (population count)
let xor_result = veorq_u8(a_vec, b_vec);
let bits = vcntq_u8(xor_result);
// Accumulate
sum = vaddq_u8(sum, bits);
idx += 16;
}
// Horizontal sum
let sum_val = vaddvq_u8(sum) as u32;
// Handle remainder with bounds-check elimination
let mut remainder_sum = 0u32;
let start = chunks * 16;
for i in start..len {
remainder_sum += (*a.get_unchecked(i) ^ *b.get_unchecked(i)).count_ones();
}
sum_val + remainder_sum
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_scalar_quantization() {
let vector = vec![1.0, 2.0, 3.0, 4.0, 5.0];
let quantized = ScalarQuantized::quantize(&vector);
let reconstructed = quantized.reconstruct();
// Check approximate reconstruction
for (orig, recon) in vector.iter().zip(&reconstructed) {
assert!((orig - recon).abs() < 0.1);
}
}
#[test]
fn test_binary_quantization() {
let vector = vec![1.0, -1.0, 2.0, -2.0, 0.5];
let quantized = BinaryQuantized::quantize(&vector);
assert_eq!(quantized.dimensions, 5);
assert_eq!(quantized.bits.len(), 1); // 5 bits fit in 1 byte
}
#[test]
fn test_binary_distance() {
let v1 = vec![1.0, 1.0, 1.0, 1.0];
let v2 = vec![1.0, 1.0, -1.0, -1.0];
let q1 = BinaryQuantized::quantize(&v1);
let q2 = BinaryQuantized::quantize(&v2);
let dist = q1.distance(&q2);
assert_eq!(dist, 2.0); // 2 bits differ
}
#[test]
fn test_scalar_quantization_roundtrip() {
// Test that quantize -> reconstruct produces values close to original
let test_vectors = vec![
vec![1.0, 2.0, 3.0, 4.0, 5.0],
vec![-10.0, -5.0, 0.0, 5.0, 10.0],
vec![0.1, 0.2, 0.3, 0.4, 0.5],
vec![100.0, 200.0, 300.0, 400.0, 500.0],
];
for vector in test_vectors {
let quantized = ScalarQuantized::quantize(&vector);
let reconstructed = quantized.reconstruct();
assert_eq!(vector.len(), reconstructed.len());
for (orig, recon) in vector.iter().zip(reconstructed.iter()) {
// With 8-bit quantization, max error is roughly (max-min)/255
let max = vector.iter().copied().fold(f32::NEG_INFINITY, f32::max);
let min = vector.iter().copied().fold(f32::INFINITY, f32::min);
let max_error = (max - min) / 255.0 * 2.0; // Allow 2x for rounding
assert!(
(orig - recon).abs() < max_error,
"Roundtrip error too large: orig={}, recon={}, error={}",
orig,
recon,
(orig - recon).abs()
);
}
}
}
#[test]
fn test_scalar_distance_symmetry() {
// Test that distance(a, b) == distance(b, a)
let v1 = vec![1.0, 2.0, 3.0, 4.0, 5.0];
let v2 = vec![2.0, 3.0, 4.0, 5.0, 6.0];
let q1 = ScalarQuantized::quantize(&v1);
let q2 = ScalarQuantized::quantize(&v2);
let dist_ab = q1.distance(&q2);
let dist_ba = q2.distance(&q1);
// Distance should be symmetric (within floating point precision)
assert!(
(dist_ab - dist_ba).abs() < 0.01,
"Distance is not symmetric: d(a,b)={}, d(b,a)={}",
dist_ab,
dist_ba
);
}
#[test]
fn test_scalar_distance_different_scales() {
// Test distance calculation with vectors that have different scales
let v1 = vec![1.0, 2.0, 3.0, 4.0, 5.0]; // range: 4.0
let v2 = vec![10.0, 20.0, 30.0, 40.0, 50.0]; // range: 40.0
let q1 = ScalarQuantized::quantize(&v1);
let q2 = ScalarQuantized::quantize(&v2);
let dist_ab = q1.distance(&q2);
let dist_ba = q2.distance(&q1);
// With average scaling, symmetry should be maintained
assert!(
(dist_ab - dist_ba).abs() < 0.01,
"Distance with different scales not symmetric: d(a,b)={}, d(b,a)={}",
dist_ab,
dist_ba
);
}
#[test]
fn test_scalar_quantization_edge_cases() {
// Test with all same values
let same_values = vec![5.0, 5.0, 5.0, 5.0];
let quantized = ScalarQuantized::quantize(&same_values);
let reconstructed = quantized.reconstruct();
for (orig, recon) in same_values.iter().zip(reconstructed.iter()) {
assert!((orig - recon).abs() < 0.01);
}
// Test with extreme ranges
let extreme = vec![f32::MIN / 1e10, 0.0, f32::MAX / 1e10];
let quantized = ScalarQuantized::quantize(&extreme);
let reconstructed = quantized.reconstruct();
assert_eq!(extreme.len(), reconstructed.len());
}
#[test]
fn test_binary_distance_symmetry() {
// Test that binary distance is symmetric
let v1 = vec![1.0, -1.0, 1.0, -1.0];
let v2 = vec![1.0, 1.0, -1.0, -1.0];
let q1 = BinaryQuantized::quantize(&v1);
let q2 = BinaryQuantized::quantize(&v2);
let dist_ab = q1.distance(&q2);
let dist_ba = q2.distance(&q1);
assert_eq!(
dist_ab, dist_ba,
"Binary distance not symmetric: d(a,b)={}, d(b,a)={}",
dist_ab, dist_ba
);
}
#[test]
fn test_int4_quantization() {
let vector = vec![1.0, 2.0, 3.0, 4.0, 5.0];
let quantized = Int4Quantized::quantize(&vector);
let reconstructed = quantized.reconstruct();
assert_eq!(quantized.dimensions, 5);
// 5 dimensions = 3 bytes (2 per byte, last byte has 1)
assert_eq!(quantized.data.len(), 3);
// Check approximate reconstruction
for (orig, recon) in vector.iter().zip(&reconstructed) {
// With 4-bit quantization, max error is roughly (max-min)/15
let max_error = (5.0 - 1.0) / 15.0 * 2.0;
assert!(
(orig - recon).abs() < max_error,
"Int4 roundtrip error too large: orig={}, recon={}",
orig,
recon
);
}
}
#[test]
fn test_int4_distance() {
// Use vectors with different quantized patterns
// v1 spans [0.0, 15.0] -> quantizes to [0, 1, 2, ..., 15] (linear mapping)
// v2 spans [0.0, 15.0] but with different distribution
let v1 = vec![0.0, 5.0, 10.0, 15.0];
let v2 = vec![0.0, 3.0, 12.0, 15.0]; // Different middle values
let q1 = Int4Quantized::quantize(&v1);
let q2 = Int4Quantized::quantize(&v2);
let dist = q1.distance(&q2);
// The quantized values differ in the middle, so distance should be positive
assert!(
dist > 0.0,
"Distance should be positive, got {}. q1.data={:?}, q2.data={:?}",
dist,
q1.data,
q2.data
);
}
#[test]
fn test_int4_distance_symmetry() {
let v1 = vec![1.0, 2.0, 3.0, 4.0, 5.0];
let v2 = vec![2.0, 3.0, 4.0, 5.0, 6.0];
let q1 = Int4Quantized::quantize(&v1);
let q2 = Int4Quantized::quantize(&v2);
let dist_ab = q1.distance(&q2);
let dist_ba = q2.distance(&q1);
assert!(
(dist_ab - dist_ba).abs() < 0.01,
"Int4 distance not symmetric: d(a,b)={}, d(b,a)={}",
dist_ab,
dist_ba
);
}
#[test]
fn test_int4_compression_ratio() {
assert_eq!(Int4Quantized::compression_ratio(), 8.0);
}
#[test]
fn test_binary_fast_hamming() {
// Test fast hamming distance with various sizes
let a = vec![0xFF, 0x00, 0xFF, 0x00, 0xFF, 0x00, 0xFF, 0x00, 0xAA];
let b = vec![0x00, 0xFF, 0x00, 0xFF, 0x00, 0xFF, 0x00, 0xFF, 0x55];
let distance = BinaryQuantized::hamming_distance_fast(&a, &b);
// All bits differ: 9 bytes * 8 bits = 72 bits
assert_eq!(distance, 72);
}
#[test]
fn test_binary_similarity() {
let v1 = vec![1.0; 8]; // All positive
let v2 = vec![1.0; 8]; // Same
let q1 = BinaryQuantized::quantize(&v1);
let q2 = BinaryQuantized::quantize(&v2);
let sim = q1.similarity(&q2);
assert!(
(sim - 1.0).abs() < 0.001,
"Same vectors should have similarity 1.0"
);
}
#[test]
fn test_binary_compression_ratio() {
assert_eq!(BinaryQuantized::compression_ratio(), 32.0);
}
}

File diff suppressed because it is too large Load Diff

View File

@@ -0,0 +1,446 @@
//! Storage layer with redb for metadata and memory-mapped vectors
//!
//! This module is only available when the "storage" feature is enabled.
//! For WASM builds, use the in-memory storage backend instead.
#[cfg(feature = "storage")]
use crate::error::{Result, RuvectorError};
#[cfg(feature = "storage")]
use crate::types::{DbOptions, VectorEntry, VectorId};
#[cfg(feature = "storage")]
use bincode::config;
#[cfg(feature = "storage")]
use once_cell::sync::Lazy;
#[cfg(feature = "storage")]
use parking_lot::Mutex;
#[cfg(feature = "storage")]
use redb::{Database, ReadableTable, ReadableTableMetadata, TableDefinition};
#[cfg(feature = "storage")]
use serde_json;
#[cfg(feature = "storage")]
use std::collections::HashMap;
#[cfg(feature = "storage")]
use std::path::{Path, PathBuf};
#[cfg(feature = "storage")]
use std::sync::Arc;
#[cfg(feature = "storage")]
const VECTORS_TABLE: TableDefinition<&str, &[u8]> = TableDefinition::new("vectors");
const METADATA_TABLE: TableDefinition<&str, &str> = TableDefinition::new("metadata");
const CONFIG_TABLE: TableDefinition<&str, &str> = TableDefinition::new("config");
/// Key used to store database configuration in CONFIG_TABLE
const DB_CONFIG_KEY: &str = "__ruvector_db_config__";
// Global database connection pool to allow multiple VectorDB instances
// to share the same underlying database file
static DB_POOL: Lazy<Mutex<HashMap<PathBuf, Arc<Database>>>> =
Lazy::new(|| Mutex::new(HashMap::new()));
/// Storage backend for vector database
pub struct VectorStorage {
db: Arc<Database>,
dimensions: usize,
}
impl VectorStorage {
/// Create or open a vector storage at the given path
///
/// This method uses a global connection pool to allow multiple VectorDB
/// instances to share the same underlying database file, fixing the
/// "Database already open. Cannot acquire lock" error.
pub fn new<P: AsRef<Path>>(path: P, dimensions: usize) -> Result<Self> {
// SECURITY: Validate path to prevent directory traversal attacks
let path_ref = path.as_ref();
// Create parent directories if they don't exist (needed for canonicalize)
if let Some(parent) = path_ref.parent() {
if !parent.as_os_str().is_empty() && !parent.exists() {
std::fs::create_dir_all(parent).map_err(|e| {
RuvectorError::InvalidPath(format!("Failed to create directory: {}", e))
})?;
}
}
// Convert to absolute path first, then validate
let path_buf = if path_ref.is_absolute() {
path_ref.to_path_buf()
} else {
std::env::current_dir()
.map_err(|e| RuvectorError::InvalidPath(format!("Failed to get cwd: {}", e)))?
.join(path_ref)
};
// SECURITY: Check for path traversal attempts (e.g., "../../../etc/passwd")
// Only reject paths that contain ".." components trying to escape
let path_str = path_ref.to_string_lossy();
if path_str.contains("..") {
// Verify the resolved path doesn't escape intended boundaries
// For absolute paths, we allow them as-is (user explicitly specified)
// For relative paths with "..", check they don't escape cwd
if !path_ref.is_absolute() {
if let Ok(cwd) = std::env::current_dir() {
// Normalize the path by resolving .. components
let mut normalized = cwd.clone();
for component in path_ref.components() {
match component {
std::path::Component::ParentDir => {
if !normalized.pop() || !normalized.starts_with(&cwd) {
return Err(RuvectorError::InvalidPath(
"Path traversal attempt detected".to_string(),
));
}
}
std::path::Component::Normal(c) => normalized.push(c),
_ => {}
}
}
}
}
}
// Check if we already have a Database instance for this path
let db = {
let mut pool = DB_POOL.lock();
if let Some(existing_db) = pool.get(&path_buf) {
// Reuse existing database connection
Arc::clone(existing_db)
} else {
// Create new database and add to pool
let new_db = Arc::new(Database::create(&path_buf)?);
// Initialize tables
let write_txn = new_db.begin_write()?;
{
let _ = write_txn.open_table(VECTORS_TABLE)?;
let _ = write_txn.open_table(METADATA_TABLE)?;
let _ = write_txn.open_table(CONFIG_TABLE)?;
}
write_txn.commit()?;
pool.insert(path_buf, Arc::clone(&new_db));
new_db
}
};
Ok(Self { db, dimensions })
}
/// Insert a vector entry
pub fn insert(&self, entry: &VectorEntry) -> Result<VectorId> {
if entry.vector.len() != self.dimensions {
return Err(RuvectorError::DimensionMismatch {
expected: self.dimensions,
actual: entry.vector.len(),
});
}
let id = entry
.id
.clone()
.unwrap_or_else(|| uuid::Uuid::new_v4().to_string());
let write_txn = self.db.begin_write()?;
{
let mut table = write_txn.open_table(VECTORS_TABLE)?;
// Serialize vector data
let vector_data = bincode::encode_to_vec(&entry.vector, config::standard())
.map_err(|e| RuvectorError::SerializationError(e.to_string()))?;
table.insert(id.as_str(), vector_data.as_slice())?;
// Store metadata if present
if let Some(metadata) = &entry.metadata {
let mut meta_table = write_txn.open_table(METADATA_TABLE)?;
let metadata_json = serde_json::to_string(metadata)
.map_err(|e| RuvectorError::SerializationError(e.to_string()))?;
meta_table.insert(id.as_str(), metadata_json.as_str())?;
}
}
write_txn.commit()?;
Ok(id)
}
/// Insert multiple vectors in a batch
pub fn insert_batch(&self, entries: &[VectorEntry]) -> Result<Vec<VectorId>> {
let write_txn = self.db.begin_write()?;
let mut ids = Vec::with_capacity(entries.len());
{
let mut table = write_txn.open_table(VECTORS_TABLE)?;
let mut meta_table = write_txn.open_table(METADATA_TABLE)?;
for entry in entries {
if entry.vector.len() != self.dimensions {
return Err(RuvectorError::DimensionMismatch {
expected: self.dimensions,
actual: entry.vector.len(),
});
}
let id = entry
.id
.clone()
.unwrap_or_else(|| uuid::Uuid::new_v4().to_string());
// Serialize and insert vector
let vector_data = bincode::encode_to_vec(&entry.vector, config::standard())
.map_err(|e| RuvectorError::SerializationError(e.to_string()))?;
table.insert(id.as_str(), vector_data.as_slice())?;
// Insert metadata if present
if let Some(metadata) = &entry.metadata {
let metadata_json = serde_json::to_string(metadata)
.map_err(|e| RuvectorError::SerializationError(e.to_string()))?;
meta_table.insert(id.as_str(), metadata_json.as_str())?;
}
ids.push(id);
}
}
write_txn.commit()?;
Ok(ids)
}
/// Get a vector by ID
pub fn get(&self, id: &str) -> Result<Option<VectorEntry>> {
let read_txn = self.db.begin_read()?;
let table = read_txn.open_table(VECTORS_TABLE)?;
let Some(vector_data) = table.get(id)? else {
return Ok(None);
};
let (vector, _): (Vec<f32>, usize) =
bincode::decode_from_slice(vector_data.value(), config::standard())
.map_err(|e| RuvectorError::SerializationError(e.to_string()))?;
// Try to get metadata
let meta_table = read_txn.open_table(METADATA_TABLE)?;
let metadata = if let Some(meta_data) = meta_table.get(id)? {
let meta_str = meta_data.value();
Some(
serde_json::from_str(meta_str)
.map_err(|e| RuvectorError::SerializationError(e.to_string()))?,
)
} else {
None
};
Ok(Some(VectorEntry {
id: Some(id.to_string()),
vector,
metadata,
}))
}
/// Delete a vector by ID
pub fn delete(&self, id: &str) -> Result<bool> {
let write_txn = self.db.begin_write()?;
let deleted;
{
let mut table = write_txn.open_table(VECTORS_TABLE)?;
deleted = table.remove(id)?.is_some();
let mut meta_table = write_txn.open_table(METADATA_TABLE)?;
let _ = meta_table.remove(id)?;
}
write_txn.commit()?;
Ok(deleted)
}
/// Get the number of vectors stored
pub fn len(&self) -> Result<usize> {
let read_txn = self.db.begin_read()?;
let table = read_txn.open_table(VECTORS_TABLE)?;
Ok(table.len()? as usize)
}
/// Check if storage is empty
pub fn is_empty(&self) -> Result<bool> {
Ok(self.len()? == 0)
}
/// Get all vector IDs
pub fn all_ids(&self) -> Result<Vec<VectorId>> {
let read_txn = self.db.begin_read()?;
let table = read_txn.open_table(VECTORS_TABLE)?;
let mut ids = Vec::new();
let iter = table.iter()?;
for item in iter {
let (key, _) = item?;
ids.push(key.value().to_string());
}
Ok(ids)
}
/// Save database configuration to persistent storage
pub fn save_config(&self, options: &DbOptions) -> Result<()> {
let config_json = serde_json::to_string(options)
.map_err(|e| RuvectorError::SerializationError(e.to_string()))?;
let write_txn = self.db.begin_write()?;
{
let mut table = write_txn.open_table(CONFIG_TABLE)?;
table.insert(DB_CONFIG_KEY, config_json.as_str())?;
}
write_txn.commit()?;
Ok(())
}
/// Load database configuration from persistent storage
pub fn load_config(&self) -> Result<Option<DbOptions>> {
let read_txn = self.db.begin_read()?;
// Try to open config table - may not exist in older databases
let table = match read_txn.open_table(CONFIG_TABLE) {
Ok(t) => t,
Err(_) => return Ok(None),
};
let Some(config_data) = table.get(DB_CONFIG_KEY)? else {
return Ok(None);
};
let config: DbOptions = serde_json::from_str(config_data.value())
.map_err(|e| RuvectorError::SerializationError(e.to_string()))?;
Ok(Some(config))
}
/// Get the stored dimensions
pub fn dimensions(&self) -> usize {
self.dimensions
}
}
// Add uuid dependency
use uuid;
#[cfg(test)]
mod tests {
use super::*;
use tempfile::tempdir;
#[test]
fn test_insert_and_get() -> Result<()> {
let dir = tempdir().unwrap();
let storage = VectorStorage::new(dir.path().join("test.db"), 3)?;
let entry = VectorEntry {
id: Some("test1".to_string()),
vector: vec![1.0, 2.0, 3.0],
metadata: None,
};
let id = storage.insert(&entry)?;
assert_eq!(id, "test1");
let retrieved = storage.get("test1")?;
assert!(retrieved.is_some());
let retrieved = retrieved.unwrap();
assert_eq!(retrieved.vector, vec![1.0, 2.0, 3.0]);
Ok(())
}
#[test]
fn test_batch_insert() -> Result<()> {
let dir = tempdir().unwrap();
let storage = VectorStorage::new(dir.path().join("test.db"), 3)?;
let entries = vec![
VectorEntry {
id: None,
vector: vec![1.0, 2.0, 3.0],
metadata: None,
},
VectorEntry {
id: None,
vector: vec![4.0, 5.0, 6.0],
metadata: None,
},
];
let ids = storage.insert_batch(&entries)?;
assert_eq!(ids.len(), 2);
assert_eq!(storage.len()?, 2);
Ok(())
}
#[test]
fn test_delete() -> Result<()> {
let dir = tempdir().unwrap();
let storage = VectorStorage::new(dir.path().join("test.db"), 3)?;
let entry = VectorEntry {
id: Some("test1".to_string()),
vector: vec![1.0, 2.0, 3.0],
metadata: None,
};
storage.insert(&entry)?;
assert_eq!(storage.len()?, 1);
let deleted = storage.delete("test1")?;
assert!(deleted);
assert_eq!(storage.len()?, 0);
Ok(())
}
#[test]
fn test_multiple_instances_same_path() -> Result<()> {
// This test verifies the fix for the database locking bug
// Multiple VectorStorage instances should be able to share the same database file
let dir = tempdir().unwrap();
let db_path = dir.path().join("shared.db");
// Create first instance
let storage1 = VectorStorage::new(&db_path, 3)?;
// Insert data with first instance
storage1.insert(&VectorEntry {
id: Some("test1".to_string()),
vector: vec![1.0, 2.0, 3.0],
metadata: None,
})?;
// Create second instance with SAME path - this should NOT fail
let storage2 = VectorStorage::new(&db_path, 3)?;
// Both instances should see the same data
assert_eq!(storage1.len()?, 1);
assert_eq!(storage2.len()?, 1);
// Insert with second instance
storage2.insert(&VectorEntry {
id: Some("test2".to_string()),
vector: vec![4.0, 5.0, 6.0],
metadata: None,
})?;
// Both instances should see both records
assert_eq!(storage1.len()?, 2);
assert_eq!(storage2.len()?, 2);
// Verify data integrity
let retrieved1 = storage1.get("test1")?;
assert!(retrieved1.is_some());
let retrieved2 = storage2.get("test2")?;
assert!(retrieved2.is_some());
Ok(())
}
}

View File

@@ -0,0 +1,79 @@
//! Storage compatibility layer
//!
//! This module provides a unified interface that works with both
//! file-based (redb) and in-memory storage backends.
use crate::error::Result;
use crate::types::{VectorEntry, VectorId};
#[cfg(feature = "storage")]
pub use crate::storage::VectorStorage;
#[cfg(not(feature = "storage"))]
pub use crate::storage_memory::MemoryStorage as VectorStorage;
/// Unified storage trait
pub trait StorageBackend {
fn insert(&self, entry: &VectorEntry) -> Result<VectorId>;
fn insert_batch(&self, entries: &[VectorEntry]) -> Result<Vec<VectorId>>;
fn get(&self, id: &str) -> Result<Option<VectorEntry>>;
fn delete(&self, id: &str) -> Result<bool>;
fn len(&self) -> Result<usize>;
fn is_empty(&self) -> Result<bool>;
}
// Implement trait for redb-based storage
#[cfg(feature = "storage")]
impl StorageBackend for crate::storage::VectorStorage {
fn insert(&self, entry: &VectorEntry) -> Result<VectorId> {
self.insert(entry)
}
fn insert_batch(&self, entries: &[VectorEntry]) -> Result<Vec<VectorId>> {
self.insert_batch(entries)
}
fn get(&self, id: &str) -> Result<Option<VectorEntry>> {
self.get(id)
}
fn delete(&self, id: &str) -> Result<bool> {
self.delete(id)
}
fn len(&self) -> Result<usize> {
self.len()
}
fn is_empty(&self) -> Result<bool> {
self.is_empty()
}
}
// Implement trait for memory storage
#[cfg(not(feature = "storage"))]
impl StorageBackend for crate::storage_memory::MemoryStorage {
fn insert(&self, entry: &VectorEntry) -> Result<VectorId> {
self.insert(entry)
}
fn insert_batch(&self, entries: &[VectorEntry]) -> Result<Vec<VectorId>> {
self.insert_batch(entries)
}
fn get(&self, id: &str) -> Result<Option<VectorEntry>> {
self.get(id)
}
fn delete(&self, id: &str) -> Result<bool> {
self.delete(id)
}
fn len(&self) -> Result<usize> {
self.len()
}
fn is_empty(&self) -> Result<bool> {
self.is_empty()
}
}

View File

@@ -0,0 +1,257 @@
//! In-memory storage backend for WASM and testing
//!
//! This storage implementation doesn't require file system access,
//! making it suitable for WebAssembly environments.
use crate::error::{Result, RuvectorError};
use crate::types::{VectorEntry, VectorId};
use dashmap::DashMap;
use serde_json::Value as JsonValue;
use std::sync::atomic::{AtomicU64, Ordering};
/// In-memory storage backend using DashMap for thread-safe concurrent access
pub struct MemoryStorage {
vectors: DashMap<String, Vec<f32>>,
metadata: DashMap<String, JsonValue>,
dimensions: usize,
counter: AtomicU64,
}
impl MemoryStorage {
/// Create a new in-memory storage
pub fn new(dimensions: usize) -> Result<Self> {
Ok(Self {
vectors: DashMap::new(),
metadata: DashMap::new(),
dimensions,
counter: AtomicU64::new(0),
})
}
/// Generate a new unique ID
fn generate_id(&self) -> String {
let id = self.counter.fetch_add(1, Ordering::SeqCst);
format!("vec_{}", id)
}
/// Insert a vector entry
pub fn insert(&self, entry: &VectorEntry) -> Result<VectorId> {
if entry.vector.len() != self.dimensions {
return Err(RuvectorError::DimensionMismatch {
expected: self.dimensions,
actual: entry.vector.len(),
});
}
let id = entry.id.clone().unwrap_or_else(|| self.generate_id());
// Insert vector
self.vectors.insert(id.clone(), entry.vector.clone());
// Insert metadata if present
if let Some(metadata) = &entry.metadata {
self.metadata.insert(
id.clone(),
serde_json::Value::Object(
metadata
.iter()
.map(|(k, v)| (k.clone(), v.clone()))
.collect(),
),
);
}
Ok(id)
}
/// Insert multiple vectors in a batch
pub fn insert_batch(&self, entries: &[VectorEntry]) -> Result<Vec<VectorId>> {
let mut ids = Vec::with_capacity(entries.len());
for entry in entries {
if entry.vector.len() != self.dimensions {
return Err(RuvectorError::DimensionMismatch {
expected: self.dimensions,
actual: entry.vector.len(),
});
}
let id = entry.id.clone().unwrap_or_else(|| self.generate_id());
self.vectors.insert(id.clone(), entry.vector.clone());
if let Some(metadata) = &entry.metadata {
self.metadata.insert(
id.clone(),
serde_json::Value::Object(
metadata
.iter()
.map(|(k, v)| (k.clone(), v.clone()))
.collect(),
),
);
}
ids.push(id);
}
Ok(ids)
}
/// Get a vector by ID
pub fn get(&self, id: &str) -> Result<Option<VectorEntry>> {
if let Some(vector_ref) = self.vectors.get(id) {
let vector = vector_ref.clone();
let metadata = self.metadata.get(id).and_then(|m| {
if let serde_json::Value::Object(map) = m.value() {
Some(map.iter().map(|(k, v)| (k.clone(), v.clone())).collect())
} else {
None
}
});
Ok(Some(VectorEntry {
id: Some(id.to_string()),
vector,
metadata,
}))
} else {
Ok(None)
}
}
/// Delete a vector by ID
pub fn delete(&self, id: &str) -> Result<bool> {
let vector_removed = self.vectors.remove(id).is_some();
self.metadata.remove(id);
Ok(vector_removed)
}
/// Get the number of vectors stored
pub fn len(&self) -> Result<usize> {
Ok(self.vectors.len())
}
/// Check if the storage is empty
pub fn is_empty(&self) -> Result<bool> {
Ok(self.vectors.is_empty())
}
/// Get all vector IDs (for iteration)
pub fn keys(&self) -> Vec<String> {
self.vectors
.iter()
.map(|entry| entry.key().clone())
.collect()
}
/// Get all vector IDs (alias for keys, for API compatibility with VectorStorage)
pub fn all_ids(&self) -> Result<Vec<String>> {
Ok(self.keys())
}
/// Clear all data
pub fn clear(&self) -> Result<()> {
self.vectors.clear();
self.metadata.clear();
Ok(())
}
}
#[cfg(test)]
mod tests {
use super::*;
use serde_json::json;
#[test]
fn test_insert_and_get() {
let storage = MemoryStorage::new(128).unwrap();
let entry = VectorEntry {
id: Some("test_1".to_string()),
vector: vec![0.1; 128],
metadata: Some(json!({"key": "value"})),
};
let id = storage.insert(&entry).unwrap();
assert_eq!(id, "test_1");
let retrieved = storage.get("test_1").unwrap().unwrap();
assert_eq!(retrieved.vector.len(), 128);
assert!(retrieved.metadata.is_some());
}
#[test]
fn test_batch_insert() {
let storage = MemoryStorage::new(64).unwrap();
let entries: Vec<_> = (0..10)
.map(|i| VectorEntry {
id: Some(format!("vec_{}", i)),
vector: vec![i as f32; 64],
metadata: None,
})
.collect();
let ids = storage.insert_batch(&entries).unwrap();
assert_eq!(ids.len(), 10);
assert_eq!(storage.len().unwrap(), 10);
}
#[test]
fn test_delete() {
let storage = MemoryStorage::new(32).unwrap();
let entry = VectorEntry {
id: Some("delete_me".to_string()),
vector: vec![1.0; 32],
metadata: None,
};
storage.insert(&entry).unwrap();
assert_eq!(storage.len().unwrap(), 1);
let deleted = storage.delete("delete_me").unwrap();
assert!(deleted);
assert_eq!(storage.len().unwrap(), 0);
}
#[test]
fn test_auto_id_generation() {
let storage = MemoryStorage::new(16).unwrap();
let entry = VectorEntry {
id: None,
vector: vec![0.5; 16],
metadata: None,
};
let id1 = storage.insert(&entry).unwrap();
let id2 = storage.insert(&entry).unwrap();
assert_ne!(id1, id2);
assert!(id1.starts_with("vec_"));
assert!(id2.starts_with("vec_"));
}
#[test]
fn test_dimension_mismatch() {
let storage = MemoryStorage::new(128).unwrap();
let entry = VectorEntry {
id: Some("bad".to_string()),
vector: vec![0.1; 64], // Wrong dimension
metadata: None,
};
let result = storage.insert(&entry);
assert!(result.is_err());
if let Err(RuvectorError::DimensionMismatch { expected, actual }) = result {
assert_eq!(expected, 128);
assert_eq!(actual, 64);
} else {
panic!("Expected DimensionMismatch error");
}
}
}

View File

@@ -0,0 +1,126 @@
//! Core types and data structures
use serde::{Deserialize, Serialize};
use std::collections::HashMap;
/// Unique identifier for vectors
pub type VectorId = String;
/// Distance metric for similarity calculation
#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
pub enum DistanceMetric {
/// Euclidean (L2) distance
Euclidean,
/// Cosine similarity (converted to distance)
Cosine,
/// Dot product (converted to distance for maximization)
DotProduct,
/// Manhattan (L1) distance
Manhattan,
}
/// Vector entry with metadata
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct VectorEntry {
/// Optional ID (auto-generated if not provided)
pub id: Option<VectorId>,
/// Vector data
pub vector: Vec<f32>,
/// Optional metadata
pub metadata: Option<HashMap<String, serde_json::Value>>,
}
/// Search query parameters
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct SearchQuery {
/// Query vector
pub vector: Vec<f32>,
/// Number of results to return (top-k)
pub k: usize,
/// Optional metadata filters
pub filter: Option<HashMap<String, serde_json::Value>>,
/// Optional ef_search parameter for HNSW (overrides default)
pub ef_search: Option<usize>,
}
/// Search result with similarity score
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct SearchResult {
/// Vector ID
pub id: VectorId,
/// Distance/similarity score (lower is better for distance metrics)
pub score: f32,
/// Vector data (optional)
pub vector: Option<Vec<f32>>,
/// Metadata (optional)
pub metadata: Option<HashMap<String, serde_json::Value>>,
}
/// Database configuration options
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct DbOptions {
/// Vector dimensions
pub dimensions: usize,
/// Distance metric
pub distance_metric: DistanceMetric,
/// Storage path
pub storage_path: String,
/// HNSW configuration
pub hnsw_config: Option<HnswConfig>,
/// Quantization configuration
pub quantization: Option<QuantizationConfig>,
}
/// HNSW index configuration
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct HnswConfig {
/// Number of connections per layer (M)
pub m: usize,
/// Size of dynamic candidate list during construction (efConstruction)
pub ef_construction: usize,
/// Size of dynamic candidate list during search (efSearch)
pub ef_search: usize,
/// Maximum number of elements
pub max_elements: usize,
}
impl Default for HnswConfig {
fn default() -> Self {
Self {
m: 32,
ef_construction: 200,
ef_search: 100,
max_elements: 10_000_000,
}
}
}
/// Quantization configuration
#[derive(Debug, Clone, Serialize, Deserialize)]
pub enum QuantizationConfig {
/// No quantization (full precision)
None,
/// Scalar quantization to int8 (4x compression)
Scalar,
/// Product quantization
Product {
/// Number of subspaces
subspaces: usize,
/// Codebook size (typically 256)
k: usize,
},
/// Binary quantization (32x compression)
Binary,
}
impl Default for DbOptions {
fn default() -> Self {
Self {
dimensions: 384,
distance_metric: DistanceMetric::Cosine,
storage_path: "./ruvector.db".to_string(),
hnsw_config: Some(HnswConfig::default()),
quantization: Some(QuantizationConfig::Scalar),
}
}
}

View File

@@ -0,0 +1,391 @@
//! Main VectorDB interface
use crate::error::Result;
use crate::index::flat::FlatIndex;
#[cfg(feature = "hnsw")]
use crate::index::hnsw::HnswIndex;
use crate::index::VectorIndex;
use crate::types::*;
use parking_lot::RwLock;
use std::sync::Arc;
// Import appropriate storage backend based on features
#[cfg(feature = "storage")]
use crate::storage::VectorStorage;
#[cfg(not(feature = "storage"))]
use crate::storage_memory::MemoryStorage as VectorStorage;
/// Main vector database
pub struct VectorDB {
storage: Arc<VectorStorage>,
index: Arc<RwLock<Box<dyn VectorIndex>>>,
options: DbOptions,
}
impl VectorDB {
/// Create a new vector database with the given options
///
/// If a storage path is provided and contains persisted vectors,
/// the HNSW index will be automatically rebuilt from storage.
/// If opening an existing database, the stored configuration (dimensions,
/// distance metric, etc.) will be used instead of the provided options.
#[allow(unused_mut)] // `options` is mutated only when feature = "storage"
pub fn new(mut options: DbOptions) -> Result<Self> {
#[cfg(feature = "storage")]
let storage = {
// First, try to load existing configuration from the database
// We create a temporary storage to check for config
let temp_storage = VectorStorage::new(&options.storage_path, options.dimensions)?;
let stored_config = temp_storage.load_config()?;
if let Some(config) = stored_config {
// Existing database - use stored configuration
tracing::info!(
"Loading existing database with {} dimensions",
config.dimensions
);
options = DbOptions {
// Keep the provided storage path (may have changed)
storage_path: options.storage_path.clone(),
// Use stored configuration for everything else
dimensions: config.dimensions,
distance_metric: config.distance_metric,
hnsw_config: config.hnsw_config,
quantization: config.quantization,
};
// Recreate storage with correct dimensions
Arc::new(VectorStorage::new(
&options.storage_path,
options.dimensions,
)?)
} else {
// New database - save the configuration
tracing::info!(
"Creating new database with {} dimensions",
options.dimensions
);
temp_storage.save_config(&options)?;
Arc::new(temp_storage)
}
};
#[cfg(not(feature = "storage"))]
let storage = Arc::new(VectorStorage::new(options.dimensions)?);
// Choose index based on configuration and available features
#[allow(unused_mut)] // `index` is mutated only when feature = "storage"
let mut index: Box<dyn VectorIndex> = if let Some(hnsw_config) = &options.hnsw_config {
#[cfg(feature = "hnsw")]
{
Box::new(HnswIndex::new(
options.dimensions,
options.distance_metric,
hnsw_config.clone(),
)?)
}
#[cfg(not(feature = "hnsw"))]
{
// Fall back to flat index if HNSW is not available
tracing::warn!("HNSW requested but not available (WASM build), using flat index");
Box::new(FlatIndex::new(options.dimensions, options.distance_metric))
}
} else {
Box::new(FlatIndex::new(options.dimensions, options.distance_metric))
};
// Rebuild index from persisted vectors if storage is not empty
// This fixes the bug where search() returns empty results after restart
#[cfg(feature = "storage")]
{
let stored_ids = storage.all_ids()?;
if !stored_ids.is_empty() {
tracing::info!(
"Rebuilding index from {} persisted vectors",
stored_ids.len()
);
// Batch load all vectors for efficient index rebuilding
let mut entries = Vec::with_capacity(stored_ids.len());
for id in stored_ids {
if let Some(entry) = storage.get(&id)? {
entries.push((id, entry.vector));
}
}
// Add all vectors to index in batch for better performance
index.add_batch(entries)?;
tracing::info!("Index rebuilt successfully");
}
}
Ok(Self {
storage,
index: Arc::new(RwLock::new(index)),
options,
})
}
/// Create with default options
pub fn with_dimensions(dimensions: usize) -> Result<Self> {
let options = DbOptions {
dimensions,
..DbOptions::default()
};
Self::new(options)
}
/// Insert a vector entry
pub fn insert(&self, entry: VectorEntry) -> Result<VectorId> {
let id = self.storage.insert(&entry)?;
// Add to index
let mut index = self.index.write();
index.add(id.clone(), entry.vector)?;
Ok(id)
}
/// Insert multiple vectors in a batch
pub fn insert_batch(&self, entries: Vec<VectorEntry>) -> Result<Vec<VectorId>> {
let ids = self.storage.insert_batch(&entries)?;
// Add to index
let mut index = self.index.write();
let index_entries: Vec<_> = ids
.iter()
.zip(entries.iter())
.map(|(id, entry)| (id.clone(), entry.vector.clone()))
.collect();
index.add_batch(index_entries)?;
Ok(ids)
}
/// Search for similar vectors
pub fn search(&self, query: SearchQuery) -> Result<Vec<SearchResult>> {
let index = self.index.read();
let mut results = index.search(&query.vector, query.k)?;
// Enrich results with full data if needed
for result in &mut results {
if let Ok(Some(entry)) = self.storage.get(&result.id) {
result.vector = Some(entry.vector);
result.metadata = entry.metadata;
}
}
// Apply metadata filters if specified
if let Some(filter) = &query.filter {
results.retain(|r| {
if let Some(metadata) = &r.metadata {
filter
.iter()
.all(|(key, value)| metadata.get(key).is_some_and(|v| v == value))
} else {
false
}
});
}
Ok(results)
}
/// Delete a vector by ID
pub fn delete(&self, id: &str) -> Result<bool> {
let deleted_storage = self.storage.delete(id)?;
if deleted_storage {
let mut index = self.index.write();
let _ = index.remove(&id.to_string())?;
}
Ok(deleted_storage)
}
/// Get a vector by ID
pub fn get(&self, id: &str) -> Result<Option<VectorEntry>> {
self.storage.get(id)
}
/// Get the number of vectors
pub fn len(&self) -> Result<usize> {
self.storage.len()
}
/// Check if database is empty
pub fn is_empty(&self) -> Result<bool> {
self.storage.is_empty()
}
/// Get database options
pub fn options(&self) -> &DbOptions {
&self.options
}
/// Get all vector IDs (for iteration/serialization)
pub fn keys(&self) -> Result<Vec<String>> {
self.storage.all_ids()
}
}
#[cfg(test)]
mod tests {
use super::*;
use std::path::Path;
use tempfile::tempdir;
#[test]
fn test_vector_db_creation() -> Result<()> {
let dir = tempdir().unwrap();
let mut options = DbOptions::default();
options.storage_path = dir.path().join("test.db").to_string_lossy().to_string();
options.dimensions = 3;
let db = VectorDB::new(options)?;
assert!(db.is_empty()?);
Ok(())
}
#[test]
fn test_insert_and_search() -> Result<()> {
let dir = tempdir().unwrap();
let mut options = DbOptions::default();
options.storage_path = dir.path().join("test.db").to_string_lossy().to_string();
options.dimensions = 3;
options.distance_metric = DistanceMetric::Euclidean; // Use Euclidean for clearer test
options.hnsw_config = None; // Use flat index for testing
let db = VectorDB::new(options)?;
// Insert vectors
db.insert(VectorEntry {
id: Some("v1".to_string()),
vector: vec![1.0, 0.0, 0.0],
metadata: None,
})?;
db.insert(VectorEntry {
id: Some("v2".to_string()),
vector: vec![0.0, 1.0, 0.0],
metadata: None,
})?;
db.insert(VectorEntry {
id: Some("v3".to_string()),
vector: vec![0.0, 0.0, 1.0],
metadata: None,
})?;
// Search for exact match
let results = db.search(SearchQuery {
vector: vec![1.0, 0.0, 0.0],
k: 2,
filter: None,
ef_search: None,
})?;
assert!(results.len() >= 1);
assert_eq!(results[0].id, "v1", "First result should be exact match");
assert!(
results[0].score < 0.01,
"Exact match should have ~0 distance"
);
Ok(())
}
/// Test that search works after simulated restart (new VectorDB instance)
/// This verifies the fix for issue #30: HNSW index not rebuilt from storage
#[test]
#[cfg(feature = "storage")]
fn test_search_after_restart() -> Result<()> {
let dir = tempdir().unwrap();
let db_path = dir.path().join("persist.db").to_string_lossy().to_string();
// Phase 1: Create database and insert vectors
{
let mut options = DbOptions::default();
options.storage_path = db_path.clone();
options.dimensions = 3;
options.distance_metric = DistanceMetric::Euclidean;
options.hnsw_config = None;
let db = VectorDB::new(options)?;
db.insert(VectorEntry {
id: Some("v1".to_string()),
vector: vec![1.0, 0.0, 0.0],
metadata: None,
})?;
db.insert(VectorEntry {
id: Some("v2".to_string()),
vector: vec![0.0, 1.0, 0.0],
metadata: None,
})?;
db.insert(VectorEntry {
id: Some("v3".to_string()),
vector: vec![0.7, 0.7, 0.0],
metadata: None,
})?;
// Verify search works before "restart"
let results = db.search(SearchQuery {
vector: vec![0.8, 0.6, 0.0],
k: 3,
filter: None,
ef_search: None,
})?;
assert_eq!(results.len(), 3, "Should find all 3 vectors before restart");
}
// db is dropped here, simulating application shutdown
// Phase 2: Create new database instance (simulates restart)
{
let mut options = DbOptions::default();
options.storage_path = db_path.clone();
options.dimensions = 3;
options.distance_metric = DistanceMetric::Euclidean;
options.hnsw_config = None;
let db = VectorDB::new(options)?;
// Verify vectors are still accessible
assert_eq!(db.len()?, 3, "Should have 3 vectors after restart");
// Verify get() works
let v1 = db.get("v1")?;
assert!(v1.is_some(), "get() should work after restart");
// Verify search() works - THIS WAS THE BUG
let results = db.search(SearchQuery {
vector: vec![0.8, 0.6, 0.0],
k: 3,
filter: None,
ef_search: None,
})?;
assert_eq!(
results.len(),
3,
"search() should return results after restart (was returning 0 before fix)"
);
// v3 should be closest to query [0.8, 0.6, 0.0]
assert_eq!(
results[0].id, "v3",
"v3 [0.7, 0.7, 0.0] should be closest to query [0.8, 0.6, 0.0]"
);
}
Ok(())
}
}