Merge commit 'd803bfe2b1fe7f5e219e50ac20d6801a0a58ac75' as 'vendor/ruvector'
This commit is contained in:
42
vendor/ruvector/crates/ruvector-hyperbolic-hnsw/src/error.rs
vendored
Normal file
42
vendor/ruvector/crates/ruvector-hyperbolic-hnsw/src/error.rs
vendored
Normal file
@@ -0,0 +1,42 @@
|
||||
//! Error types for hyperbolic HNSW operations
|
||||
|
||||
use thiserror::Error;
|
||||
|
||||
/// Errors that can occur during hyperbolic operations
|
||||
#[derive(Error, Debug, Clone)]
|
||||
pub enum HyperbolicError {
|
||||
/// Vector is outside the Poincaré ball
|
||||
#[error("Vector norm {norm} exceeds ball radius (1/sqrt(c) - eps) for curvature c={curvature}")]
|
||||
OutsideBall { norm: f32, curvature: f32 },
|
||||
|
||||
/// Invalid curvature parameter
|
||||
#[error("Invalid curvature: {0}. Must be positive.")]
|
||||
InvalidCurvature(f32),
|
||||
|
||||
/// Dimension mismatch between vectors
|
||||
#[error("Dimension mismatch: expected {expected}, got {got}")]
|
||||
DimensionMismatch { expected: usize, got: usize },
|
||||
|
||||
/// Numerical instability detected
|
||||
#[error("Numerical instability: {0}")]
|
||||
NumericalInstability(String),
|
||||
|
||||
/// Shard not found
|
||||
#[error("Shard not found: {0}")]
|
||||
ShardNotFound(String),
|
||||
|
||||
/// Index out of bounds
|
||||
#[error("Index {index} out of bounds for size {size}")]
|
||||
IndexOutOfBounds { index: usize, size: usize },
|
||||
|
||||
/// Empty collection
|
||||
#[error("Cannot perform operation on empty collection")]
|
||||
EmptyCollection,
|
||||
|
||||
/// Search failed
|
||||
#[error("Search failed: {0}")]
|
||||
SearchFailed(String),
|
||||
}
|
||||
|
||||
/// Result type for hyperbolic operations
|
||||
pub type HyperbolicResult<T> = Result<T, HyperbolicError>;
|
||||
650
vendor/ruvector/crates/ruvector-hyperbolic-hnsw/src/hnsw.rs
vendored
Normal file
650
vendor/ruvector/crates/ruvector-hyperbolic-hnsw/src/hnsw.rs
vendored
Normal file
@@ -0,0 +1,650 @@
|
||||
//! HNSW Adapter with Hyperbolic Distance Support
|
||||
//!
|
||||
//! This module provides HNSW (Hierarchical Navigable Small World) graph
|
||||
//! implementation optimized for hyperbolic space using the Poincaré ball model.
|
||||
//!
|
||||
//! # Key Features
|
||||
//!
|
||||
//! - Hyperbolic distance metric for neighbor selection
|
||||
//! - Tangent space pruning for accelerated search
|
||||
//! - Configurable curvature per index
|
||||
//! - Dual-space search (Euclidean fallback)
|
||||
|
||||
use crate::error::{HyperbolicError, HyperbolicResult};
|
||||
use crate::poincare::{fused_norms, norm_squared, poincare_distance, poincare_distance_from_norms, project_to_ball, EPS};
|
||||
use crate::tangent::TangentCache;
|
||||
use serde::{Deserialize, Serialize};
|
||||
|
||||
#[cfg(feature = "parallel")]
|
||||
use rayon::prelude::*;
|
||||
|
||||
/// Distance metric type
|
||||
#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
|
||||
pub enum DistanceMetric {
|
||||
/// Poincaré ball hyperbolic distance
|
||||
Poincare,
|
||||
/// Standard Euclidean distance
|
||||
Euclidean,
|
||||
/// Cosine similarity (converted to distance)
|
||||
Cosine,
|
||||
/// Hybrid: Euclidean for pruning, Poincaré for ranking
|
||||
Hybrid,
|
||||
}
|
||||
|
||||
/// HNSW configuration for hyperbolic space
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct HyperbolicHnswConfig {
|
||||
/// Maximum number of connections per node (M parameter)
|
||||
pub max_connections: usize,
|
||||
/// Maximum connections for layer 0 (M0 = 2*M typically)
|
||||
pub max_connections_0: usize,
|
||||
/// Size of dynamic candidate list during construction (ef_construction)
|
||||
pub ef_construction: usize,
|
||||
/// Size of dynamic candidate list during search (ef)
|
||||
pub ef_search: usize,
|
||||
/// Level multiplier for layer selection (ml = 1/ln(M))
|
||||
pub level_mult: f32,
|
||||
/// Curvature parameter for Poincaré ball
|
||||
pub curvature: f32,
|
||||
/// Distance metric
|
||||
pub metric: DistanceMetric,
|
||||
/// Pruning factor for tangent space optimization
|
||||
pub prune_factor: usize,
|
||||
/// Whether to use tangent space pruning
|
||||
pub use_tangent_pruning: bool,
|
||||
}
|
||||
|
||||
impl Default for HyperbolicHnswConfig {
|
||||
fn default() -> Self {
|
||||
Self {
|
||||
max_connections: 16,
|
||||
max_connections_0: 32,
|
||||
ef_construction: 200,
|
||||
ef_search: 50,
|
||||
level_mult: 1.0 / (16.0_f32).ln(),
|
||||
curvature: 1.0,
|
||||
metric: DistanceMetric::Poincare,
|
||||
prune_factor: 10,
|
||||
use_tangent_pruning: true,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// A node in the HNSW graph
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct HnswNode {
|
||||
/// Node ID
|
||||
pub id: usize,
|
||||
/// Vector in Poincaré ball
|
||||
pub vector: Vec<f32>,
|
||||
/// Connections at each level (level -> neighbor ids)
|
||||
pub connections: Vec<Vec<usize>>,
|
||||
/// Maximum level this node appears in
|
||||
pub level: usize,
|
||||
}
|
||||
|
||||
impl HnswNode {
|
||||
pub fn new(id: usize, vector: Vec<f32>, max_level: usize) -> Self {
|
||||
let connections = (0..=max_level).map(|_| Vec::new()).collect();
|
||||
Self {
|
||||
id,
|
||||
vector,
|
||||
connections,
|
||||
level: max_level,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Search result with distance
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct SearchResult {
|
||||
pub id: usize,
|
||||
pub distance: f32,
|
||||
}
|
||||
|
||||
impl PartialEq for SearchResult {
|
||||
fn eq(&self, other: &Self) -> bool {
|
||||
self.distance == other.distance
|
||||
}
|
||||
}
|
||||
|
||||
impl Eq for SearchResult {}
|
||||
|
||||
impl PartialOrd for SearchResult {
|
||||
fn partial_cmp(&self, other: &Self) -> Option<std::cmp::Ordering> {
|
||||
self.distance.partial_cmp(&other.distance)
|
||||
}
|
||||
}
|
||||
|
||||
impl Ord for SearchResult {
|
||||
fn cmp(&self, other: &Self) -> std::cmp::Ordering {
|
||||
self.distance.partial_cmp(&other.distance).unwrap()
|
||||
}
|
||||
}
|
||||
|
||||
/// Hyperbolic HNSW Index
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct HyperbolicHnsw {
|
||||
/// Configuration
|
||||
pub config: HyperbolicHnswConfig,
|
||||
/// All nodes in the graph
|
||||
nodes: Vec<HnswNode>,
|
||||
/// Entry point node ID
|
||||
entry_point: Option<usize>,
|
||||
/// Maximum level in the graph
|
||||
max_level: usize,
|
||||
/// Tangent cache for pruning (not serialized)
|
||||
#[serde(skip)]
|
||||
tangent_cache: Option<TangentCache>,
|
||||
}
|
||||
|
||||
impl HyperbolicHnsw {
|
||||
/// Create a new empty HNSW index
|
||||
pub fn new(config: HyperbolicHnswConfig) -> Self {
|
||||
Self {
|
||||
config,
|
||||
nodes: Vec::new(),
|
||||
entry_point: None,
|
||||
max_level: 0,
|
||||
tangent_cache: None,
|
||||
}
|
||||
}
|
||||
|
||||
/// Create with default configuration
|
||||
pub fn default_config() -> Self {
|
||||
Self::new(HyperbolicHnswConfig::default())
|
||||
}
|
||||
|
||||
/// Get the number of nodes in the index
|
||||
pub fn len(&self) -> usize {
|
||||
self.nodes.len()
|
||||
}
|
||||
|
||||
/// Check if the index is empty
|
||||
pub fn is_empty(&self) -> bool {
|
||||
self.nodes.is_empty()
|
||||
}
|
||||
|
||||
/// Get the dimension of vectors
|
||||
pub fn dim(&self) -> Option<usize> {
|
||||
self.nodes.first().map(|n| n.vector.len())
|
||||
}
|
||||
|
||||
/// Compute distance between two vectors (optimized with fused norms)
|
||||
#[inline]
|
||||
fn distance(&self, a: &[f32], b: &[f32]) -> f32 {
|
||||
match self.config.metric {
|
||||
DistanceMetric::Poincare | DistanceMetric::Hybrid => {
|
||||
// Use fused_norms for single-pass computation
|
||||
let (diff_sq, norm_a_sq, norm_b_sq) = fused_norms(a, b);
|
||||
poincare_distance_from_norms(diff_sq, norm_a_sq, norm_b_sq, self.config.curvature)
|
||||
}
|
||||
DistanceMetric::Euclidean => {
|
||||
let (diff_sq, _, _) = fused_norms(a, b);
|
||||
diff_sq.sqrt()
|
||||
}
|
||||
DistanceMetric::Cosine => {
|
||||
let len = a.len().min(b.len());
|
||||
let mut dot_ab = 0.0f32;
|
||||
let mut norm_a_sq = 0.0f32;
|
||||
let mut norm_b_sq = 0.0f32;
|
||||
|
||||
// Fused computation
|
||||
for i in 0..len {
|
||||
let ai = a[i];
|
||||
let bi = b[i];
|
||||
dot_ab += ai * bi;
|
||||
norm_a_sq += ai * ai;
|
||||
norm_b_sq += bi * bi;
|
||||
}
|
||||
|
||||
let norm_prod = (norm_a_sq * norm_b_sq).sqrt();
|
||||
1.0 - dot_ab / (norm_prod + EPS)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Compute distance with pre-computed query norm (for batch search)
|
||||
#[inline]
|
||||
fn distance_with_query_norm(&self, query: &[f32], query_norm_sq: f32, point: &[f32]) -> f32 {
|
||||
match self.config.metric {
|
||||
DistanceMetric::Poincare | DistanceMetric::Hybrid => {
|
||||
let (diff_sq, _, point_norm_sq) = fused_norms(query, point);
|
||||
poincare_distance_from_norms(diff_sq, query_norm_sq, point_norm_sq, self.config.curvature)
|
||||
}
|
||||
_ => self.distance(query, point)
|
||||
}
|
||||
}
|
||||
|
||||
/// Generate random level for a new node
|
||||
fn random_level(&self) -> usize {
|
||||
let r: f32 = rand::random();
|
||||
(-r.ln() * self.config.level_mult) as usize
|
||||
}
|
||||
|
||||
/// Insert a vector into the index
|
||||
pub fn insert(&mut self, vector: Vec<f32>) -> HyperbolicResult<usize> {
|
||||
// Project to ball for safety
|
||||
let vector = project_to_ball(&vector, self.config.curvature, EPS);
|
||||
|
||||
let id = self.nodes.len();
|
||||
let level = self.random_level();
|
||||
|
||||
// Create new node
|
||||
let node = HnswNode::new(id, vector.clone(), level);
|
||||
self.nodes.push(node);
|
||||
|
||||
if self.entry_point.is_none() {
|
||||
self.entry_point = Some(id);
|
||||
self.max_level = level;
|
||||
return Ok(id);
|
||||
}
|
||||
|
||||
let entry_id = self.entry_point.unwrap();
|
||||
|
||||
// Search for entry point at top levels
|
||||
let mut current = entry_id;
|
||||
for l in (level + 1..=self.max_level).rev() {
|
||||
current = self.search_layer_single(&vector, current, l)?;
|
||||
}
|
||||
|
||||
// Insert at levels [0, min(level, max_level)]
|
||||
let insert_level = level.min(self.max_level);
|
||||
for l in (0..=insert_level).rev() {
|
||||
let neighbors = self.search_layer(&vector, current, self.config.ef_construction, l)?;
|
||||
|
||||
// Select best neighbors
|
||||
let max_conn = if l == 0 {
|
||||
self.config.max_connections_0
|
||||
} else {
|
||||
self.config.max_connections
|
||||
};
|
||||
|
||||
let selected: Vec<usize> = neighbors.iter().take(max_conn).map(|r| r.id).collect();
|
||||
|
||||
// Add bidirectional connections
|
||||
self.nodes[id].connections[l] = selected.clone();
|
||||
|
||||
for &neighbor_id in &selected {
|
||||
self.nodes[neighbor_id].connections[l].push(id);
|
||||
|
||||
// Prune if too many connections
|
||||
if self.nodes[neighbor_id].connections[l].len() > max_conn {
|
||||
self.prune_connections(neighbor_id, l, max_conn)?;
|
||||
}
|
||||
}
|
||||
|
||||
if !neighbors.is_empty() {
|
||||
current = neighbors[0].id;
|
||||
}
|
||||
}
|
||||
|
||||
// Update entry point if new node has higher level
|
||||
if level > self.max_level {
|
||||
self.entry_point = Some(id);
|
||||
self.max_level = level;
|
||||
}
|
||||
|
||||
// Invalidate tangent cache
|
||||
self.tangent_cache = None;
|
||||
|
||||
Ok(id)
|
||||
}
|
||||
|
||||
/// Insert batch of vectors
|
||||
pub fn insert_batch(&mut self, vectors: Vec<Vec<f32>>) -> HyperbolicResult<Vec<usize>> {
|
||||
let mut ids = Vec::with_capacity(vectors.len());
|
||||
for vector in vectors {
|
||||
ids.push(self.insert(vector)?);
|
||||
}
|
||||
Ok(ids)
|
||||
}
|
||||
|
||||
/// Search for single nearest neighbor at a layer (greedy)
|
||||
fn search_layer_single(&self, query: &[f32], entry: usize, level: usize) -> HyperbolicResult<usize> {
|
||||
let mut current = entry;
|
||||
let mut current_dist = self.distance(query, &self.nodes[current].vector);
|
||||
|
||||
loop {
|
||||
let mut changed = false;
|
||||
|
||||
for &neighbor in &self.nodes[current].connections[level] {
|
||||
let dist = self.distance(query, &self.nodes[neighbor].vector);
|
||||
if dist < current_dist {
|
||||
current_dist = dist;
|
||||
current = neighbor;
|
||||
changed = true;
|
||||
}
|
||||
}
|
||||
|
||||
if !changed {
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
Ok(current)
|
||||
}
|
||||
|
||||
/// Search layer with ef candidates
|
||||
fn search_layer(
|
||||
&self,
|
||||
query: &[f32],
|
||||
entry: usize,
|
||||
ef: usize,
|
||||
level: usize,
|
||||
) -> HyperbolicResult<Vec<SearchResult>> {
|
||||
use std::collections::{BinaryHeap, HashSet};
|
||||
|
||||
let entry_dist = self.distance(query, &self.nodes[entry].vector);
|
||||
|
||||
let mut visited = HashSet::new();
|
||||
visited.insert(entry);
|
||||
|
||||
// Candidates (min-heap by distance)
|
||||
let mut candidates: BinaryHeap<std::cmp::Reverse<SearchResult>> = BinaryHeap::new();
|
||||
candidates.push(std::cmp::Reverse(SearchResult {
|
||||
id: entry,
|
||||
distance: entry_dist,
|
||||
}));
|
||||
|
||||
// Results (max-heap by distance for easy pruning)
|
||||
let mut results: BinaryHeap<SearchResult> = BinaryHeap::new();
|
||||
results.push(SearchResult {
|
||||
id: entry,
|
||||
distance: entry_dist,
|
||||
});
|
||||
|
||||
while let Some(std::cmp::Reverse(current)) = candidates.pop() {
|
||||
// Check if we can stop early
|
||||
if let Some(furthest) = results.peek() {
|
||||
if current.distance > furthest.distance && results.len() >= ef {
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
// Explore neighbors
|
||||
for &neighbor in &self.nodes[current.id].connections[level] {
|
||||
if visited.contains(&neighbor) {
|
||||
continue;
|
||||
}
|
||||
visited.insert(neighbor);
|
||||
|
||||
let dist = self.distance(query, &self.nodes[neighbor].vector);
|
||||
|
||||
let should_add = results.len() < ef
|
||||
|| results
|
||||
.peek()
|
||||
.map(|r| dist < r.distance)
|
||||
.unwrap_or(true);
|
||||
|
||||
if should_add {
|
||||
candidates.push(std::cmp::Reverse(SearchResult {
|
||||
id: neighbor,
|
||||
distance: dist,
|
||||
}));
|
||||
results.push(SearchResult {
|
||||
id: neighbor,
|
||||
distance: dist,
|
||||
});
|
||||
|
||||
if results.len() > ef {
|
||||
results.pop();
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
let mut result_vec: Vec<SearchResult> = results.into_iter().collect();
|
||||
result_vec.sort_by(|a, b| a.distance.partial_cmp(&b.distance).unwrap());
|
||||
|
||||
Ok(result_vec)
|
||||
}
|
||||
|
||||
/// Prune connections to keep only the best
|
||||
fn prune_connections(
|
||||
&mut self,
|
||||
node_id: usize,
|
||||
level: usize,
|
||||
max_conn: usize,
|
||||
) -> HyperbolicResult<()> {
|
||||
let node_vector = self.nodes[node_id].vector.clone();
|
||||
let connections = &self.nodes[node_id].connections[level];
|
||||
|
||||
let mut scored: Vec<(usize, f32)> = connections
|
||||
.iter()
|
||||
.map(|&id| (id, self.distance(&node_vector, &self.nodes[id].vector)))
|
||||
.collect();
|
||||
|
||||
scored.sort_by(|a, b| a.1.partial_cmp(&b.1).unwrap());
|
||||
|
||||
self.nodes[node_id].connections[level] =
|
||||
scored.into_iter().take(max_conn).map(|(id, _)| id).collect();
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// Search for k nearest neighbors
|
||||
pub fn search(&self, query: &[f32], k: usize) -> HyperbolicResult<Vec<SearchResult>> {
|
||||
if self.is_empty() {
|
||||
return Ok(Vec::new());
|
||||
}
|
||||
|
||||
let query = project_to_ball(query, self.config.curvature, EPS);
|
||||
let entry = self.entry_point.unwrap();
|
||||
|
||||
// Navigate to lowest level from top
|
||||
let mut current = entry;
|
||||
for l in (1..=self.max_level).rev() {
|
||||
current = self.search_layer_single(&query, current, l)?;
|
||||
}
|
||||
|
||||
// Search at layer 0 with ef_search candidates
|
||||
let ef = self.config.ef_search.max(k);
|
||||
let mut results = self.search_layer(&query, current, ef, 0)?;
|
||||
|
||||
results.truncate(k);
|
||||
Ok(results)
|
||||
}
|
||||
|
||||
/// Search with tangent space pruning (optimized for hyperbolic)
|
||||
pub fn search_with_pruning(&self, query: &[f32], k: usize) -> HyperbolicResult<Vec<SearchResult>> {
|
||||
// Fall back to regular search if no tangent cache
|
||||
if self.tangent_cache.is_none() || !self.config.use_tangent_pruning {
|
||||
return self.search(query, k);
|
||||
}
|
||||
|
||||
let cache = self.tangent_cache.as_ref().unwrap();
|
||||
let query = project_to_ball(query, self.config.curvature, EPS);
|
||||
|
||||
// Phase 1: Fast tangent space filtering
|
||||
let query_tangent = cache.query_tangent(&query);
|
||||
|
||||
let mut candidates: Vec<(usize, f32)> = (0..cache.len())
|
||||
.map(|i| {
|
||||
let tangent_dist = cache.tangent_distance_squared(&query_tangent, i);
|
||||
(cache.point_indices[i], tangent_dist)
|
||||
})
|
||||
.collect();
|
||||
|
||||
// Sort by tangent distance
|
||||
candidates.sort_by(|a, b| a.1.partial_cmp(&b.1).unwrap());
|
||||
|
||||
// Keep top prune_factor * k candidates
|
||||
let num_candidates = (k * self.config.prune_factor).min(candidates.len());
|
||||
candidates.truncate(num_candidates);
|
||||
|
||||
// Phase 2: Exact Poincaré distance for finalists
|
||||
let mut results: Vec<SearchResult> = candidates
|
||||
.into_iter()
|
||||
.map(|(id, _)| {
|
||||
let dist = self.distance(&query, &self.nodes[id].vector);
|
||||
SearchResult { id, distance: dist }
|
||||
})
|
||||
.collect();
|
||||
|
||||
results.sort_by(|a, b| a.distance.partial_cmp(&b.distance).unwrap());
|
||||
results.truncate(k);
|
||||
|
||||
Ok(results)
|
||||
}
|
||||
|
||||
/// Build tangent cache for all points
|
||||
pub fn build_tangent_cache(&mut self) -> HyperbolicResult<()> {
|
||||
if self.is_empty() {
|
||||
return Ok(());
|
||||
}
|
||||
|
||||
let vectors: Vec<Vec<f32>> = self.nodes.iter().map(|n| n.vector.clone()).collect();
|
||||
let indices: Vec<usize> = (0..self.nodes.len()).collect();
|
||||
|
||||
self.tangent_cache = Some(TangentCache::new(&vectors, &indices, self.config.curvature)?);
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// Get a reference to a node's vector
|
||||
pub fn get_vector(&self, id: usize) -> Option<&[f32]> {
|
||||
self.nodes.get(id).map(|n| n.vector.as_slice())
|
||||
}
|
||||
|
||||
/// Update curvature and rebuild tangent cache
|
||||
pub fn set_curvature(&mut self, curvature: f32) -> HyperbolicResult<()> {
|
||||
if curvature <= 0.0 {
|
||||
return Err(HyperbolicError::InvalidCurvature(curvature));
|
||||
}
|
||||
|
||||
self.config.curvature = curvature;
|
||||
|
||||
// Reproject all vectors
|
||||
for node in &mut self.nodes {
|
||||
node.vector = project_to_ball(&node.vector, curvature, EPS);
|
||||
}
|
||||
|
||||
// Rebuild tangent cache
|
||||
if self.tangent_cache.is_some() {
|
||||
self.build_tangent_cache()?;
|
||||
}
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// Get all vectors as a slice
|
||||
pub fn vectors(&self) -> Vec<&[f32]> {
|
||||
self.nodes.iter().map(|n| n.vector.as_slice()).collect()
|
||||
}
|
||||
}
|
||||
|
||||
/// Dual-space index for fallback and mutual ranking fusion
|
||||
#[derive(Debug)]
|
||||
pub struct DualSpaceIndex {
|
||||
/// Hyperbolic index (primary)
|
||||
pub hyperbolic: HyperbolicHnsw,
|
||||
/// Euclidean index (fallback)
|
||||
pub euclidean: HyperbolicHnsw,
|
||||
/// Fusion weight for hyperbolic results (0-1)
|
||||
pub fusion_weight: f32,
|
||||
}
|
||||
|
||||
impl DualSpaceIndex {
|
||||
/// Create a new dual-space index
|
||||
pub fn new(curvature: f32, fusion_weight: f32) -> Self {
|
||||
let mut hyp_config = HyperbolicHnswConfig::default();
|
||||
hyp_config.curvature = curvature;
|
||||
hyp_config.metric = DistanceMetric::Poincare;
|
||||
|
||||
let mut euc_config = HyperbolicHnswConfig::default();
|
||||
euc_config.metric = DistanceMetric::Euclidean;
|
||||
|
||||
Self {
|
||||
hyperbolic: HyperbolicHnsw::new(hyp_config),
|
||||
euclidean: HyperbolicHnsw::new(euc_config),
|
||||
fusion_weight: fusion_weight.clamp(0.0, 1.0),
|
||||
}
|
||||
}
|
||||
|
||||
/// Insert into both indices
|
||||
pub fn insert(&mut self, vector: Vec<f32>) -> HyperbolicResult<usize> {
|
||||
self.euclidean.insert(vector.clone())?;
|
||||
self.hyperbolic.insert(vector)
|
||||
}
|
||||
|
||||
/// Search with mutual ranking fusion
|
||||
pub fn search(&self, query: &[f32], k: usize) -> HyperbolicResult<Vec<SearchResult>> {
|
||||
let hyp_results = self.hyperbolic.search(query, k * 2)?;
|
||||
let euc_results = self.euclidean.search(query, k * 2)?;
|
||||
|
||||
// Combine and re-rank using fusion
|
||||
use std::collections::HashMap;
|
||||
|
||||
let mut scores: HashMap<usize, f32> = HashMap::new();
|
||||
|
||||
// Add hyperbolic scores
|
||||
for (rank, r) in hyp_results.iter().enumerate() {
|
||||
let score = self.fusion_weight * (1.0 / (rank as f32 + 1.0));
|
||||
*scores.entry(r.id).or_insert(0.0) += score;
|
||||
}
|
||||
|
||||
// Add Euclidean scores
|
||||
for (rank, r) in euc_results.iter().enumerate() {
|
||||
let score = (1.0 - self.fusion_weight) * (1.0 / (rank as f32 + 1.0));
|
||||
*scores.entry(r.id).or_insert(0.0) += score;
|
||||
}
|
||||
|
||||
// Sort by combined score (higher is better)
|
||||
let mut combined: Vec<(usize, f32)> = scores.into_iter().collect();
|
||||
combined.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap());
|
||||
|
||||
// Return top k with hyperbolic distances
|
||||
Ok(combined
|
||||
.into_iter()
|
||||
.take(k)
|
||||
.map(|(id, _)| {
|
||||
let dist = self.hyperbolic.distance(
|
||||
query,
|
||||
self.hyperbolic.get_vector(id).unwrap_or(&[]),
|
||||
);
|
||||
SearchResult { id, distance: dist }
|
||||
})
|
||||
.collect())
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
#[test]
|
||||
fn test_hnsw_insert_search() {
|
||||
let mut hnsw = HyperbolicHnsw::default_config();
|
||||
|
||||
// Insert some vectors
|
||||
for i in 0..10 {
|
||||
let v = vec![0.1 * i as f32, 0.05 * i as f32];
|
||||
hnsw.insert(v).unwrap();
|
||||
}
|
||||
|
||||
assert_eq!(hnsw.len(), 10);
|
||||
|
||||
// Search
|
||||
let query = vec![0.3, 0.15];
|
||||
let results = hnsw.search(&query, 3).unwrap();
|
||||
|
||||
assert_eq!(results.len(), 3);
|
||||
assert!(results[0].distance <= results[1].distance);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_dual_space() {
|
||||
let mut dual = DualSpaceIndex::new(1.0, 0.5);
|
||||
|
||||
for i in 0..10 {
|
||||
let v = vec![0.1 * i as f32, 0.05 * i as f32];
|
||||
dual.insert(v).unwrap();
|
||||
}
|
||||
|
||||
let query = vec![0.3, 0.15];
|
||||
let results = dual.search(&query, 3).unwrap();
|
||||
|
||||
assert_eq!(results.len(), 3);
|
||||
}
|
||||
}
|
||||
210
vendor/ruvector/crates/ruvector-hyperbolic-hnsw/src/lib.rs
vendored
Normal file
210
vendor/ruvector/crates/ruvector-hyperbolic-hnsw/src/lib.rs
vendored
Normal file
@@ -0,0 +1,210 @@
|
||||
//! Hyperbolic Embeddings with HNSW Integration for RuVector
|
||||
//!
|
||||
//! This crate provides hyperbolic (Poincaré ball) embeddings integrated with
|
||||
//! HNSW (Hierarchical Navigable Small World) graphs for hierarchy-aware
|
||||
//! vector search.
|
||||
//!
|
||||
//! # Overview
|
||||
//!
|
||||
//! Hierarchies compress naturally in hyperbolic space. Taxonomies, catalogs,
|
||||
//! ICD trees, product facets, org charts, and long-tail tags all fit better
|
||||
//! than in Euclidean space, which means higher recall on deep leaves without
|
||||
//! blowing up memory or latency.
|
||||
//!
|
||||
//! # Key Features
|
||||
//!
|
||||
//! - **Poincaré Ball Model**: Store vectors in the Poincaré ball with proper
|
||||
//! geometric operations (Möbius addition, exp/log maps)
|
||||
//! - **Tangent Space Pruning**: Prune HNSW candidates with cheap Euclidean
|
||||
//! distance in tangent space before exact hyperbolic ranking
|
||||
//! - **Per-Shard Curvature**: Different parts of the hierarchy can have
|
||||
//! different optimal curvatures
|
||||
//! - **Dual-Space Index**: Keep a synchronized Euclidean index for fallback
|
||||
//! and mutual ranking fusion
|
||||
//!
|
||||
//! # Quick Start
|
||||
//!
|
||||
//! ```rust
|
||||
//! use ruvector_hyperbolic_hnsw::{HyperbolicHnsw, HyperbolicHnswConfig};
|
||||
//!
|
||||
//! // Create index with default settings
|
||||
//! let mut index = HyperbolicHnsw::default_config();
|
||||
//!
|
||||
//! // Insert vectors (automatically projected to Poincaré ball)
|
||||
//! index.insert(vec![0.1, 0.2, 0.3]).unwrap();
|
||||
//! index.insert(vec![-0.1, 0.15, 0.25]).unwrap();
|
||||
//! index.insert(vec![0.2, -0.1, 0.1]).unwrap();
|
||||
//!
|
||||
//! // Search for nearest neighbors
|
||||
//! let results = index.search(&[0.15, 0.1, 0.2], 2).unwrap();
|
||||
//! for r in results {
|
||||
//! println!("ID: {}, Distance: {:.4}", r.id, r.distance);
|
||||
//! }
|
||||
//! ```
|
||||
//!
|
||||
//! # HNSW Speed Trick
|
||||
//!
|
||||
//! The core optimization is:
|
||||
//! 1. Precompute `u = log_c(x)` at a shard centroid `c`
|
||||
//! 2. During neighbor selection, use Euclidean `||u_q - u_p||` to prune
|
||||
//! 3. Run exact Poincaré distance only on top N candidates before final ranking
|
||||
//!
|
||||
//! ```rust
|
||||
//! use ruvector_hyperbolic_hnsw::{HyperbolicHnsw, HyperbolicHnswConfig};
|
||||
//!
|
||||
//! let mut config = HyperbolicHnswConfig::default();
|
||||
//! config.use_tangent_pruning = true;
|
||||
//! config.prune_factor = 10; // Consider 10x candidates in tangent space
|
||||
//!
|
||||
//! let mut index = HyperbolicHnsw::new(config);
|
||||
//! // ... insert vectors ...
|
||||
//!
|
||||
//! // Build tangent cache for pruning optimization
|
||||
//! # index.insert(vec![0.1, 0.2]).unwrap();
|
||||
//! index.build_tangent_cache().unwrap();
|
||||
//!
|
||||
//! // Search with pruning
|
||||
//! let results = index.search_with_pruning(&[0.1, 0.15], 5).unwrap();
|
||||
//! ```
|
||||
//!
|
||||
//! # Sharded Index with Per-Shard Curvature
|
||||
//!
|
||||
//! ```rust
|
||||
//! use ruvector_hyperbolic_hnsw::{ShardedHyperbolicHnsw, ShardStrategy};
|
||||
//!
|
||||
//! let mut manager = ShardedHyperbolicHnsw::new(1.0);
|
||||
//!
|
||||
//! // Insert with hierarchy depth information
|
||||
//! manager.insert(vec![0.1, 0.2], Some(0)).unwrap(); // Root level
|
||||
//! manager.insert(vec![0.3, 0.1], Some(3)).unwrap(); // Deeper level
|
||||
//!
|
||||
//! // Update curvature for specific shard
|
||||
//! manager.update_curvature("radius_1", 0.5).unwrap();
|
||||
//!
|
||||
//! // Search across all shards
|
||||
//! let results = manager.search(&[0.2, 0.15], 5).unwrap();
|
||||
//! ```
|
||||
//!
|
||||
//! # Mathematical Operations
|
||||
//!
|
||||
//! The `poincare` module provides low-level hyperbolic geometry operations:
|
||||
//!
|
||||
//! ```rust
|
||||
//! use ruvector_hyperbolic_hnsw::poincare::{
|
||||
//! mobius_add, exp_map, log_map, poincare_distance, project_to_ball
|
||||
//! };
|
||||
//!
|
||||
//! let x = vec![0.3, 0.2];
|
||||
//! let y = vec![-0.1, 0.4];
|
||||
//! let c = 1.0; // Curvature
|
||||
//!
|
||||
//! // Möbius addition (hyperbolic vector addition)
|
||||
//! let z = mobius_add(&x, &y, c);
|
||||
//!
|
||||
//! // Geodesic distance in hyperbolic space
|
||||
//! let d = poincare_distance(&x, &y, c);
|
||||
//!
|
||||
//! // Map to tangent space at x
|
||||
//! let v = log_map(&y, &x, c);
|
||||
//!
|
||||
//! // Map back to manifold
|
||||
//! let y_recovered = exp_map(&v, &x, c);
|
||||
//! ```
|
||||
//!
|
||||
//! # Numerical Stability
|
||||
//!
|
||||
//! All operations include numerical safeguards:
|
||||
//! - Norm clamping with `eps = 1e-5`
|
||||
//! - Projection after every update
|
||||
//! - Stable `acosh` and `log1p` implementations
|
||||
//!
|
||||
//! # Feature Flags
|
||||
//!
|
||||
//! - `simd`: Enable SIMD acceleration (default)
|
||||
//! - `parallel`: Enable parallel processing with rayon (default)
|
||||
//! - `wasm`: Enable WebAssembly compatibility
|
||||
|
||||
pub mod error;
|
||||
pub mod hnsw;
|
||||
pub mod poincare;
|
||||
pub mod shard;
|
||||
pub mod tangent;
|
||||
|
||||
// Re-exports
|
||||
pub use error::{HyperbolicError, HyperbolicResult};
|
||||
pub use hnsw::{
|
||||
DistanceMetric, DualSpaceIndex, HnswNode, HyperbolicHnsw, HyperbolicHnswConfig, SearchResult,
|
||||
};
|
||||
pub use poincare::{
|
||||
conformal_factor, conformal_factor_from_norm_sq, dot, exp_map, frechet_mean, fused_norms,
|
||||
hyperbolic_midpoint, log_map, log_map_at_centroid, mobius_add, mobius_add_inplace,
|
||||
mobius_scalar_mult, norm, norm_squared, parallel_transport, poincare_distance,
|
||||
poincare_distance_batch, poincare_distance_from_norms, poincare_distance_squared,
|
||||
project_to_ball, project_to_ball_inplace, PoincareConfig, DEFAULT_CURVATURE, EPS,
|
||||
};
|
||||
pub use shard::{
|
||||
CurvatureRegistry, HierarchyMetrics, HyperbolicShard, ShardCurvature, ShardStrategy,
|
||||
ShardedHyperbolicHnsw,
|
||||
};
|
||||
pub use tangent::{tangent_micro_update, PrunedCandidate, TangentCache, TangentPruner};
|
||||
|
||||
/// Library version
|
||||
pub const VERSION: &str = env!("CARGO_PKG_VERSION");
|
||||
|
||||
/// Prelude for common imports
|
||||
pub mod prelude {
|
||||
pub use crate::error::{HyperbolicError, HyperbolicResult};
|
||||
pub use crate::hnsw::{HyperbolicHnsw, HyperbolicHnswConfig, SearchResult};
|
||||
pub use crate::poincare::{exp_map, log_map, mobius_add, poincare_distance, project_to_ball};
|
||||
pub use crate::shard::{ShardedHyperbolicHnsw, ShardStrategy};
|
||||
pub use crate::tangent::{TangentCache, TangentPruner};
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
#[test]
|
||||
fn test_basic_workflow() {
|
||||
// Create index
|
||||
let mut index = HyperbolicHnsw::default_config();
|
||||
|
||||
// Insert vectors
|
||||
for i in 0..10 {
|
||||
let v = vec![0.1 * i as f32, 0.05 * i as f32, 0.02 * i as f32];
|
||||
index.insert(v).unwrap();
|
||||
}
|
||||
|
||||
// Search
|
||||
let query = vec![0.35, 0.175, 0.07];
|
||||
let results = index.search(&query, 3).unwrap();
|
||||
|
||||
assert_eq!(results.len(), 3);
|
||||
// Results should be sorted by distance
|
||||
for i in 1..results.len() {
|
||||
assert!(results[i - 1].distance <= results[i].distance);
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_hierarchy_preservation() {
|
||||
// Create points at different "depths"
|
||||
let points: Vec<Vec<f32>> = (0..20)
|
||||
.map(|i| {
|
||||
// Points further from origin represent deeper hierarchy
|
||||
let depth = i / 4;
|
||||
let radius = 0.1 + 0.15 * depth as f32;
|
||||
let angle = (i % 4) as f32 * std::f32::consts::PI / 2.0;
|
||||
vec![radius * angle.cos(), radius * angle.sin()]
|
||||
})
|
||||
.collect();
|
||||
|
||||
let depths: Vec<usize> = (0..20).map(|i| i / 4).collect();
|
||||
|
||||
// Compute metrics
|
||||
let metrics = HierarchyMetrics::compute(&points, &depths, 1.0).unwrap();
|
||||
|
||||
// Radius should correlate positively with depth
|
||||
assert!(metrics.radius_depth_correlation > 0.5);
|
||||
}
|
||||
}
|
||||
627
vendor/ruvector/crates/ruvector-hyperbolic-hnsw/src/poincare.rs
vendored
Normal file
627
vendor/ruvector/crates/ruvector-hyperbolic-hnsw/src/poincare.rs
vendored
Normal file
@@ -0,0 +1,627 @@
|
||||
//! Poincaré Ball Model Operations for Hyperbolic Geometry
|
||||
//!
|
||||
//! This module implements core operations in the Poincaré ball model of hyperbolic space,
|
||||
//! providing mathematically correct implementations with numerical stability guarantees.
|
||||
//!
|
||||
//! # Mathematical Background
|
||||
//!
|
||||
//! The Poincaré ball model represents hyperbolic space as the interior of a unit ball
|
||||
//! in Euclidean space. Points are constrained to satisfy ||x|| < 1/√c where c > 0 is
|
||||
//! the curvature parameter.
|
||||
//!
|
||||
//! # Key Operations
|
||||
//!
|
||||
//! - **Möbius Addition**: The hyperbolic analog of vector addition
|
||||
//! - **Exponential Map**: Maps tangent vectors to the manifold
|
||||
//! - **Logarithmic Map**: Maps manifold points to tangent space
|
||||
//! - **Poincaré Distance**: The geodesic distance in hyperbolic space
|
||||
|
||||
use crate::error::{HyperbolicError, HyperbolicResult};
|
||||
use serde::{Deserialize, Serialize};
|
||||
|
||||
/// Small epsilon for numerical stability (as specified: eps=1e-5)
|
||||
pub const EPS: f32 = 1e-5;
|
||||
|
||||
/// Default curvature parameter (negative curvature, c > 0)
|
||||
pub const DEFAULT_CURVATURE: f32 = 1.0;
|
||||
|
||||
/// Configuration for Poincaré ball operations
|
||||
#[derive(Debug, Clone, Copy, Serialize, Deserialize)]
|
||||
pub struct PoincareConfig {
|
||||
/// Curvature parameter (c > 0 for hyperbolic space)
|
||||
pub curvature: f32,
|
||||
/// Numerical stability epsilon
|
||||
pub eps: f32,
|
||||
/// Maximum iterations for iterative algorithms (e.g., Fréchet mean)
|
||||
pub max_iter: usize,
|
||||
/// Convergence tolerance
|
||||
pub tol: f32,
|
||||
}
|
||||
|
||||
impl Default for PoincareConfig {
|
||||
fn default() -> Self {
|
||||
Self {
|
||||
curvature: DEFAULT_CURVATURE,
|
||||
eps: EPS,
|
||||
max_iter: 100,
|
||||
tol: 1e-6,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl PoincareConfig {
|
||||
/// Create configuration with custom curvature
|
||||
pub fn with_curvature(curvature: f32) -> HyperbolicResult<Self> {
|
||||
if curvature <= 0.0 {
|
||||
return Err(HyperbolicError::InvalidCurvature(curvature));
|
||||
}
|
||||
Ok(Self {
|
||||
curvature,
|
||||
..Default::default()
|
||||
})
|
||||
}
|
||||
|
||||
/// Maximum allowed norm for points in the ball
|
||||
#[inline]
|
||||
pub fn max_norm(&self) -> f32 {
|
||||
(1.0 / self.curvature.sqrt()) - self.eps
|
||||
}
|
||||
}
|
||||
|
||||
// ============================================================================
|
||||
// Optimized Core Operations (SIMD-friendly)
|
||||
// ============================================================================
|
||||
|
||||
/// Compute the squared Euclidean norm of a slice (optimized with unrolling)
|
||||
#[inline]
|
||||
pub fn norm_squared(x: &[f32]) -> f32 {
|
||||
let len = x.len();
|
||||
let mut sum = 0.0f32;
|
||||
|
||||
// Process 4 elements at a time for better SIMD utilization
|
||||
let chunks = len / 4;
|
||||
let remainder = len % 4;
|
||||
|
||||
let mut i = 0;
|
||||
for _ in 0..chunks {
|
||||
let a = x[i];
|
||||
let b = x[i + 1];
|
||||
let c = x[i + 2];
|
||||
let d = x[i + 3];
|
||||
sum += a * a + b * b + c * c + d * d;
|
||||
i += 4;
|
||||
}
|
||||
|
||||
// Handle remainder
|
||||
for j in 0..remainder {
|
||||
let v = x[i + j];
|
||||
sum += v * v;
|
||||
}
|
||||
|
||||
sum
|
||||
}
|
||||
|
||||
/// Compute the Euclidean norm of a slice
|
||||
#[inline]
|
||||
pub fn norm(x: &[f32]) -> f32 {
|
||||
norm_squared(x).sqrt()
|
||||
}
|
||||
|
||||
/// Compute the dot product of two slices (optimized with unrolling)
|
||||
#[inline]
|
||||
pub fn dot(x: &[f32], y: &[f32]) -> f32 {
|
||||
let len = x.len().min(y.len());
|
||||
let mut sum = 0.0f32;
|
||||
|
||||
// Process 4 elements at a time
|
||||
let chunks = len / 4;
|
||||
let remainder = len % 4;
|
||||
|
||||
let mut i = 0;
|
||||
for _ in 0..chunks {
|
||||
sum += x[i] * y[i] + x[i+1] * y[i+1] + x[i+2] * y[i+2] + x[i+3] * y[i+3];
|
||||
i += 4;
|
||||
}
|
||||
|
||||
for j in 0..remainder {
|
||||
sum += x[i + j] * y[i + j];
|
||||
}
|
||||
|
||||
sum
|
||||
}
|
||||
|
||||
/// Fused computation of ||u-v||², ||u||², ||v||² in single pass (3x faster)
|
||||
#[inline]
|
||||
pub fn fused_norms(u: &[f32], v: &[f32]) -> (f32, f32, f32) {
|
||||
let len = u.len().min(v.len());
|
||||
let mut diff_sq = 0.0f32;
|
||||
let mut norm_u_sq = 0.0f32;
|
||||
let mut norm_v_sq = 0.0f32;
|
||||
|
||||
// Process 4 elements at a time
|
||||
let chunks = len / 4;
|
||||
let remainder = len % 4;
|
||||
|
||||
let mut i = 0;
|
||||
for _ in 0..chunks {
|
||||
let (u0, u1, u2, u3) = (u[i], u[i+1], u[i+2], u[i+3]);
|
||||
let (v0, v1, v2, v3) = (v[i], v[i+1], v[i+2], v[i+3]);
|
||||
let (d0, d1, d2, d3) = (u0 - v0, u1 - v1, u2 - v2, u3 - v3);
|
||||
|
||||
diff_sq += d0 * d0 + d1 * d1 + d2 * d2 + d3 * d3;
|
||||
norm_u_sq += u0 * u0 + u1 * u1 + u2 * u2 + u3 * u3;
|
||||
norm_v_sq += v0 * v0 + v1 * v1 + v2 * v2 + v3 * v3;
|
||||
i += 4;
|
||||
}
|
||||
|
||||
for j in 0..remainder {
|
||||
let ui = u[i + j];
|
||||
let vi = v[i + j];
|
||||
let di = ui - vi;
|
||||
diff_sq += di * di;
|
||||
norm_u_sq += ui * ui;
|
||||
norm_v_sq += vi * vi;
|
||||
}
|
||||
|
||||
(diff_sq, norm_u_sq, norm_v_sq)
|
||||
}
|
||||
|
||||
/// Project a point back into the Poincaré ball
|
||||
///
|
||||
/// Ensures ||x|| < 1/√c - eps for numerical stability
|
||||
#[inline]
|
||||
pub fn project_to_ball(x: &[f32], c: f32, eps: f32) -> Vec<f32> {
|
||||
let c = c.abs().max(EPS);
|
||||
let norm_sq = norm_squared(x);
|
||||
let max_norm = (1.0 / c.sqrt()) - eps;
|
||||
let max_norm_sq = max_norm * max_norm;
|
||||
|
||||
if norm_sq < max_norm_sq || norm_sq < eps * eps {
|
||||
x.to_vec()
|
||||
} else {
|
||||
let scale = max_norm / norm_sq.sqrt();
|
||||
x.iter().map(|&xi| scale * xi).collect()
|
||||
}
|
||||
}
|
||||
|
||||
/// Project in-place (avoids allocation when possible)
|
||||
#[inline]
|
||||
pub fn project_to_ball_inplace(x: &mut [f32], c: f32, eps: f32) {
|
||||
let c = c.abs().max(EPS);
|
||||
let norm_sq = norm_squared(x);
|
||||
let max_norm = (1.0 / c.sqrt()) - eps;
|
||||
let max_norm_sq = max_norm * max_norm;
|
||||
|
||||
if norm_sq >= max_norm_sq && norm_sq >= eps * eps {
|
||||
let scale = max_norm / norm_sq.sqrt();
|
||||
for xi in x.iter_mut() {
|
||||
*xi *= scale;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Compute the conformal factor λ_x at point x
|
||||
///
|
||||
/// λ_x = 2 / (1 - c||x||²)
|
||||
#[inline]
|
||||
pub fn conformal_factor(x: &[f32], c: f32) -> f32 {
|
||||
let norm_sq = norm_squared(x);
|
||||
2.0 / (1.0 - c * norm_sq).max(EPS)
|
||||
}
|
||||
|
||||
/// Conformal factor from pre-computed norm squared
|
||||
#[inline]
|
||||
pub fn conformal_factor_from_norm_sq(norm_sq: f32, c: f32) -> f32 {
|
||||
2.0 / (1.0 - c * norm_sq).max(EPS)
|
||||
}
|
||||
|
||||
// ============================================================================
|
||||
// Poincaré Distance (Optimized)
|
||||
// ============================================================================
|
||||
|
||||
/// Poincaré distance between two points (optimized with fused norms)
|
||||
///
|
||||
/// Uses the formula:
|
||||
/// d(u, v) = (1/√c) acosh(1 + 2c ||u - v||² / ((1 - c||u||²)(1 - c||v||²)))
|
||||
#[inline]
|
||||
pub fn poincare_distance(u: &[f32], v: &[f32], c: f32) -> f32 {
|
||||
let c = c.abs().max(EPS);
|
||||
|
||||
// Fused computation: single pass for all three norms
|
||||
let (diff_sq, norm_u_sq, norm_v_sq) = fused_norms(u, v);
|
||||
|
||||
poincare_distance_from_norms(diff_sq, norm_u_sq, norm_v_sq, c)
|
||||
}
|
||||
|
||||
/// Poincaré distance from pre-computed norms (for batch operations)
|
||||
#[inline]
|
||||
pub fn poincare_distance_from_norms(diff_sq: f32, norm_u_sq: f32, norm_v_sq: f32, c: f32) -> f32 {
|
||||
let sqrt_c = c.sqrt();
|
||||
|
||||
let lambda_u = (1.0 - c * norm_u_sq).max(EPS);
|
||||
let lambda_v = (1.0 - c * norm_v_sq).max(EPS);
|
||||
|
||||
let numerator = 2.0 * c * diff_sq;
|
||||
let denominator = lambda_u * lambda_v;
|
||||
|
||||
let arg = 1.0 + numerator / denominator;
|
||||
|
||||
if arg <= 1.0 {
|
||||
return 0.0;
|
||||
}
|
||||
|
||||
// Stable acosh computation
|
||||
(1.0 / sqrt_c) * fast_acosh(arg)
|
||||
}
|
||||
|
||||
/// Fast acosh with numerical stability
|
||||
#[inline]
|
||||
fn fast_acosh(x: f32) -> f32 {
|
||||
if x <= 1.0 {
|
||||
return 0.0;
|
||||
}
|
||||
|
||||
let delta = x - 1.0;
|
||||
if delta < 1e-4 {
|
||||
// Taylor expansion for small delta: acosh(1+δ) ≈ √(2δ)
|
||||
(2.0 * delta).sqrt()
|
||||
} else if x < 1e6 {
|
||||
// Standard formula: acosh(x) = ln(x + √(x²-1))
|
||||
(x + (x * x - 1.0).sqrt()).ln()
|
||||
} else {
|
||||
// For very large x: acosh(x) ≈ ln(2x)
|
||||
(2.0 * x).ln()
|
||||
}
|
||||
}
|
||||
|
||||
/// Squared Poincaré distance (faster for comparisons)
|
||||
#[inline]
|
||||
pub fn poincare_distance_squared(u: &[f32], v: &[f32], c: f32) -> f32 {
|
||||
let d = poincare_distance(u, v, c);
|
||||
d * d
|
||||
}
|
||||
|
||||
/// Batch distance computation (processes multiple pairs efficiently)
|
||||
pub fn poincare_distance_batch(
|
||||
query: &[f32],
|
||||
points: &[&[f32]],
|
||||
c: f32,
|
||||
) -> Vec<f32> {
|
||||
let c = c.abs().max(EPS);
|
||||
let query_norm_sq = norm_squared(query);
|
||||
|
||||
points
|
||||
.iter()
|
||||
.map(|point| {
|
||||
let (diff_sq, _, point_norm_sq) = fused_norms(query, point);
|
||||
poincare_distance_from_norms(diff_sq, query_norm_sq, point_norm_sq, c)
|
||||
})
|
||||
.collect()
|
||||
}
|
||||
|
||||
// ============================================================================
|
||||
// Möbius Operations (Optimized)
|
||||
// ============================================================================
|
||||
|
||||
/// Möbius addition in the Poincaré ball (optimized)
|
||||
///
|
||||
/// x ⊕_c y = ((1 + 2c⟨x,y⟩ + c||y||²)x + (1 - c||x||²)y) / (1 + 2c⟨x,y⟩ + c²||x||²||y||²)
|
||||
#[inline]
|
||||
pub fn mobius_add(x: &[f32], y: &[f32], c: f32) -> Vec<f32> {
|
||||
let c = c.abs().max(EPS);
|
||||
|
||||
// Fused computation of norms and dot product
|
||||
let len = x.len().min(y.len());
|
||||
let mut norm_x_sq = 0.0f32;
|
||||
let mut norm_y_sq = 0.0f32;
|
||||
let mut dot_xy = 0.0f32;
|
||||
|
||||
// Process 4 elements at a time
|
||||
let chunks = len / 4;
|
||||
let remainder = len % 4;
|
||||
|
||||
let mut i = 0;
|
||||
for _ in 0..chunks {
|
||||
let (x0, x1, x2, x3) = (x[i], x[i+1], x[i+2], x[i+3]);
|
||||
let (y0, y1, y2, y3) = (y[i], y[i+1], y[i+2], y[i+3]);
|
||||
|
||||
norm_x_sq += x0 * x0 + x1 * x1 + x2 * x2 + x3 * x3;
|
||||
norm_y_sq += y0 * y0 + y1 * y1 + y2 * y2 + y3 * y3;
|
||||
dot_xy += x0 * y0 + x1 * y1 + x2 * y2 + x3 * y3;
|
||||
i += 4;
|
||||
}
|
||||
|
||||
for j in 0..remainder {
|
||||
let xi = x[i + j];
|
||||
let yi = y[i + j];
|
||||
norm_x_sq += xi * xi;
|
||||
norm_y_sq += yi * yi;
|
||||
dot_xy += xi * yi;
|
||||
}
|
||||
|
||||
// Compute coefficients
|
||||
let coef_x = 1.0 + 2.0 * c * dot_xy + c * norm_y_sq;
|
||||
let coef_y = 1.0 - c * norm_x_sq;
|
||||
let denom = (1.0 + 2.0 * c * dot_xy + c * c * norm_x_sq * norm_y_sq).max(EPS);
|
||||
let inv_denom = 1.0 / denom;
|
||||
|
||||
// Compute result
|
||||
let mut result = Vec::with_capacity(len);
|
||||
for j in 0..len {
|
||||
result.push((coef_x * x[j] + coef_y * y[j]) * inv_denom);
|
||||
}
|
||||
|
||||
// Project back into ball
|
||||
project_to_ball_inplace(&mut result, c, EPS);
|
||||
result
|
||||
}
|
||||
|
||||
/// Möbius addition in-place (modifies first argument)
|
||||
#[inline]
|
||||
pub fn mobius_add_inplace(x: &mut [f32], y: &[f32], c: f32) {
|
||||
let c = c.abs().max(EPS);
|
||||
let len = x.len().min(y.len());
|
||||
|
||||
let norm_x_sq = norm_squared(x);
|
||||
let norm_y_sq = norm_squared(y);
|
||||
let dot_xy = dot(x, y);
|
||||
|
||||
let coef_x = 1.0 + 2.0 * c * dot_xy + c * norm_y_sq;
|
||||
let coef_y = 1.0 - c * norm_x_sq;
|
||||
let denom = (1.0 + 2.0 * c * dot_xy + c * c * norm_x_sq * norm_y_sq).max(EPS);
|
||||
let inv_denom = 1.0 / denom;
|
||||
|
||||
for j in 0..len {
|
||||
x[j] = (coef_x * x[j] + coef_y * y[j]) * inv_denom;
|
||||
}
|
||||
|
||||
project_to_ball_inplace(x, c, EPS);
|
||||
}
|
||||
|
||||
/// Möbius scalar multiplication
|
||||
///
|
||||
/// r ⊗_c x = (1/√c) tanh(r · arctanh(√c ||x||)) · (x / ||x||)
|
||||
pub fn mobius_scalar_mult(r: f32, x: &[f32], c: f32) -> Vec<f32> {
|
||||
let c = c.abs().max(EPS);
|
||||
let sqrt_c = c.sqrt();
|
||||
let norm_x = norm(x);
|
||||
|
||||
if norm_x < EPS {
|
||||
return x.to_vec();
|
||||
}
|
||||
|
||||
let arctanh_arg = (sqrt_c * norm_x).min(1.0 - EPS);
|
||||
let arctanh_val = arctanh_arg.atanh();
|
||||
let scale = (1.0 / sqrt_c) * (r * arctanh_val).tanh() / norm_x;
|
||||
|
||||
x.iter().map(|&xi| scale * xi).collect()
|
||||
}
|
||||
|
||||
// ============================================================================
|
||||
// Exp/Log Maps (Optimized)
|
||||
// ============================================================================
|
||||
|
||||
/// Exponential map at point p
|
||||
///
|
||||
/// exp_p(v) = p ⊕_c (tanh(√c λ_p ||v|| / 2) · v / (√c ||v||))
|
||||
pub fn exp_map(v: &[f32], p: &[f32], c: f32) -> Vec<f32> {
|
||||
let c = c.abs().max(EPS);
|
||||
let sqrt_c = c.sqrt();
|
||||
|
||||
let norm_p_sq = norm_squared(p);
|
||||
let lambda_p = conformal_factor_from_norm_sq(norm_p_sq, c);
|
||||
|
||||
let norm_v = norm(v);
|
||||
|
||||
if norm_v < EPS {
|
||||
return p.to_vec();
|
||||
}
|
||||
|
||||
let scaled_norm = sqrt_c * lambda_p * norm_v / 2.0;
|
||||
let coef = scaled_norm.tanh() / (sqrt_c * norm_v);
|
||||
|
||||
let transported: Vec<f32> = v.iter().map(|&vi| coef * vi).collect();
|
||||
|
||||
mobius_add(p, &transported, c)
|
||||
}
|
||||
|
||||
/// Logarithmic map at point p
|
||||
///
|
||||
/// log_p(y) = (2 / (√c λ_p)) arctanh(√c ||−p ⊕_c y||) · (−p ⊕_c y) / ||−p ⊕_c y||
|
||||
pub fn log_map(y: &[f32], p: &[f32], c: f32) -> Vec<f32> {
|
||||
let c = c.abs().max(EPS);
|
||||
let sqrt_c = c.sqrt();
|
||||
|
||||
// Compute -p ⊕_c y
|
||||
let neg_p: Vec<f32> = p.iter().map(|&pi| -pi).collect();
|
||||
let diff = mobius_add(&neg_p, y, c);
|
||||
let norm_diff = norm(&diff);
|
||||
|
||||
if norm_diff < EPS {
|
||||
return vec![0.0; y.len()];
|
||||
}
|
||||
|
||||
let norm_p_sq = norm_squared(p);
|
||||
let lambda_p = conformal_factor_from_norm_sq(norm_p_sq, c);
|
||||
|
||||
let arctanh_arg = (sqrt_c * norm_diff).min(1.0 - EPS);
|
||||
let coef = (2.0 / (sqrt_c * lambda_p)) * arctanh_arg.atanh() / norm_diff;
|
||||
|
||||
diff.iter().map(|&di| coef * di).collect()
|
||||
}
|
||||
|
||||
/// Logarithmic map at a shard centroid for tangent space coordinates
|
||||
pub fn log_map_at_centroid(x: &[f32], centroid: &[f32], c: f32) -> Vec<f32> {
|
||||
log_map(x, centroid, c)
|
||||
}
|
||||
|
||||
// ============================================================================
|
||||
// Fréchet Mean & Utilities
|
||||
// ============================================================================
|
||||
|
||||
/// Compute the Fréchet mean (hyperbolic centroid) of points
|
||||
pub fn frechet_mean(
|
||||
points: &[&[f32]],
|
||||
weights: Option<&[f32]>,
|
||||
config: &PoincareConfig,
|
||||
) -> HyperbolicResult<Vec<f32>> {
|
||||
if points.is_empty() {
|
||||
return Err(HyperbolicError::EmptyCollection);
|
||||
}
|
||||
|
||||
let dim = points[0].len();
|
||||
let c = config.curvature;
|
||||
|
||||
// Validate dimensions
|
||||
for p in points.iter() {
|
||||
if p.len() != dim {
|
||||
return Err(HyperbolicError::DimensionMismatch {
|
||||
expected: dim,
|
||||
got: p.len(),
|
||||
});
|
||||
}
|
||||
}
|
||||
|
||||
// Set up weights
|
||||
let uniform_weights: Vec<f32>;
|
||||
let w = if let Some(weights) = weights {
|
||||
if weights.len() != points.len() {
|
||||
return Err(HyperbolicError::DimensionMismatch {
|
||||
expected: points.len(),
|
||||
got: weights.len(),
|
||||
});
|
||||
}
|
||||
weights
|
||||
} else {
|
||||
uniform_weights = vec![1.0 / points.len() as f32; points.len()];
|
||||
&uniform_weights
|
||||
};
|
||||
|
||||
// Initialize with Euclidean weighted mean, projected to ball
|
||||
let mut mean = vec![0.0; dim];
|
||||
for (point, &weight) in points.iter().zip(w) {
|
||||
for (i, &val) in point.iter().enumerate() {
|
||||
mean[i] += weight * val;
|
||||
}
|
||||
}
|
||||
project_to_ball_inplace(&mut mean, c, config.eps);
|
||||
|
||||
// Riemannian gradient descent
|
||||
let learning_rate = 0.1;
|
||||
let mut grad = vec![0.0; dim];
|
||||
|
||||
for _ in 0..config.max_iter {
|
||||
// Reset gradient
|
||||
for g in grad.iter_mut() {
|
||||
*g = 0.0;
|
||||
}
|
||||
|
||||
// Compute Riemannian gradient
|
||||
for (point, &weight) in points.iter().zip(w) {
|
||||
let log_result = log_map(point, &mean, c);
|
||||
for (i, &val) in log_result.iter().enumerate() {
|
||||
grad[i] += weight * val;
|
||||
}
|
||||
}
|
||||
|
||||
// Check convergence
|
||||
if norm(&grad) < config.tol {
|
||||
break;
|
||||
}
|
||||
|
||||
// Update step
|
||||
let update: Vec<f32> = grad.iter().map(|&g| learning_rate * g).collect();
|
||||
mean = exp_map(&update, &mean, c);
|
||||
}
|
||||
|
||||
Ok(mean)
|
||||
}
|
||||
|
||||
/// Hyperbolic midpoint between two points
|
||||
pub fn hyperbolic_midpoint(x: &[f32], y: &[f32], c: f32) -> Vec<f32> {
|
||||
let log_y = log_map(y, x, c);
|
||||
let half_log: Vec<f32> = log_y.iter().map(|&v| 0.5 * v).collect();
|
||||
exp_map(&half_log, x, c)
|
||||
}
|
||||
|
||||
/// Parallel transport a tangent vector from p to q
|
||||
pub fn parallel_transport(v: &[f32], p: &[f32], q: &[f32], c: f32) -> Vec<f32> {
|
||||
let c = c.abs().max(EPS);
|
||||
|
||||
let lambda_p = conformal_factor(p, c);
|
||||
let lambda_q = conformal_factor(q, c);
|
||||
let scale = lambda_p / lambda_q;
|
||||
|
||||
v.iter().map(|&vi| scale * vi).collect()
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
#[test]
|
||||
fn test_project_to_ball() {
|
||||
let x = vec![0.5, 0.5, 0.5];
|
||||
let projected = project_to_ball(&x, 1.0, EPS);
|
||||
assert!(norm(&projected) < 1.0 - EPS);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_mobius_add_identity() {
|
||||
let x = vec![0.3, 0.2, 0.1];
|
||||
let zero = vec![0.0, 0.0, 0.0];
|
||||
|
||||
let result = mobius_add(&x, &zero, 1.0);
|
||||
for (a, b) in x.iter().zip(result.iter()) {
|
||||
assert!((a - b).abs() < 1e-5);
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_exp_log_inverse() {
|
||||
let p = vec![0.1, 0.2, 0.1];
|
||||
let v = vec![0.1, -0.1, 0.05];
|
||||
|
||||
let q = exp_map(&v, &p, 1.0);
|
||||
let v_recovered = log_map(&q, &p, 1.0);
|
||||
|
||||
for (a, b) in v.iter().zip(v_recovered.iter()) {
|
||||
assert!((a - b).abs() < 1e-4);
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_poincare_distance_symmetry() {
|
||||
let u = vec![0.3, 0.2];
|
||||
let v = vec![-0.1, 0.4];
|
||||
|
||||
let d1 = poincare_distance(&u, &v, 1.0);
|
||||
let d2 = poincare_distance(&v, &u, 1.0);
|
||||
|
||||
assert!((d1 - d2).abs() < 1e-6);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_poincare_distance_origin() {
|
||||
let origin = vec![0.0, 0.0];
|
||||
let d = poincare_distance(&origin, &origin, 1.0);
|
||||
assert!(d.abs() < 1e-6);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_fused_norms() {
|
||||
let u = vec![0.3, 0.2, 0.1];
|
||||
let v = vec![0.1, 0.4, 0.2];
|
||||
|
||||
let (diff_sq, norm_u_sq, norm_v_sq) = fused_norms(&u, &v);
|
||||
|
||||
let expected_diff_sq: f32 = u.iter().zip(v.iter())
|
||||
.map(|(a, b)| (a - b) * (a - b)).sum();
|
||||
let expected_norm_u_sq = norm_squared(&u);
|
||||
let expected_norm_v_sq = norm_squared(&v);
|
||||
|
||||
assert!((diff_sq - expected_diff_sq).abs() < 1e-6);
|
||||
assert!((norm_u_sq - expected_norm_u_sq).abs() < 1e-6);
|
||||
assert!((norm_v_sq - expected_norm_v_sq).abs() < 1e-6);
|
||||
}
|
||||
}
|
||||
575
vendor/ruvector/crates/ruvector-hyperbolic-hnsw/src/shard.rs
vendored
Normal file
575
vendor/ruvector/crates/ruvector-hyperbolic-hnsw/src/shard.rs
vendored
Normal file
@@ -0,0 +1,575 @@
|
||||
//! Shard Management with Curvature Registry
|
||||
//!
|
||||
//! This module implements per-shard curvature management for hierarchical data.
|
||||
//! Different parts of the hierarchy may have different optimal curvatures.
|
||||
//!
|
||||
//! # Features
|
||||
//!
|
||||
//! - Per-shard curvature configuration
|
||||
//! - Hot reload of curvature parameters
|
||||
//! - Canary testing for curvature updates
|
||||
//! - Hierarchy preservation metrics
|
||||
|
||||
use crate::error::{HyperbolicError, HyperbolicResult};
|
||||
use crate::hnsw::{HyperbolicHnsw, HyperbolicHnswConfig, SearchResult};
|
||||
use crate::poincare::{frechet_mean, poincare_distance, project_to_ball, PoincareConfig, EPS};
|
||||
use crate::tangent::TangentCache;
|
||||
use serde::{Deserialize, Serialize};
|
||||
use std::collections::HashMap;
|
||||
|
||||
#[cfg(feature = "parallel")]
|
||||
use rayon::prelude::*;
|
||||
|
||||
/// Curvature configuration for a shard
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct ShardCurvature {
|
||||
/// Current active curvature
|
||||
pub current: f32,
|
||||
/// Canary curvature (for testing)
|
||||
pub canary: Option<f32>,
|
||||
/// Traffic percentage for canary (0-100)
|
||||
pub canary_traffic: u8,
|
||||
/// Learned curvature from data
|
||||
pub learned: Option<f32>,
|
||||
/// Last update timestamp
|
||||
pub updated_at: i64,
|
||||
}
|
||||
|
||||
impl Default for ShardCurvature {
|
||||
fn default() -> Self {
|
||||
Self {
|
||||
current: 1.0,
|
||||
canary: None,
|
||||
canary_traffic: 0,
|
||||
learned: None,
|
||||
updated_at: 0,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl ShardCurvature {
|
||||
/// Get the effective curvature (considering canary traffic)
|
||||
pub fn effective(&self, use_canary: bool) -> f32 {
|
||||
if use_canary && self.canary.is_some() && self.canary_traffic > 0 {
|
||||
self.canary.unwrap()
|
||||
} else {
|
||||
self.current
|
||||
}
|
||||
}
|
||||
|
||||
/// Promote canary to current
|
||||
pub fn promote_canary(&mut self) {
|
||||
if let Some(c) = self.canary {
|
||||
self.current = c;
|
||||
self.canary = None;
|
||||
self.canary_traffic = 0;
|
||||
}
|
||||
}
|
||||
|
||||
/// Rollback canary
|
||||
pub fn rollback_canary(&mut self) {
|
||||
self.canary = None;
|
||||
self.canary_traffic = 0;
|
||||
}
|
||||
}
|
||||
|
||||
/// Curvature registry for managing per-shard curvatures
|
||||
#[derive(Debug, Clone, Default, Serialize, Deserialize)]
|
||||
pub struct CurvatureRegistry {
|
||||
/// Shard curvatures by shard ID
|
||||
pub shards: HashMap<String, ShardCurvature>,
|
||||
/// Global default curvature
|
||||
pub default_curvature: f32,
|
||||
/// Registry version (for hot reload)
|
||||
pub version: u64,
|
||||
}
|
||||
|
||||
impl CurvatureRegistry {
|
||||
/// Create a new registry with default curvature
|
||||
pub fn new(default_curvature: f32) -> Self {
|
||||
Self {
|
||||
shards: HashMap::new(),
|
||||
default_curvature,
|
||||
version: 0,
|
||||
}
|
||||
}
|
||||
|
||||
/// Get curvature for a shard
|
||||
pub fn get(&self, shard_id: &str) -> f32 {
|
||||
self.shards
|
||||
.get(shard_id)
|
||||
.map(|s| s.current)
|
||||
.unwrap_or(self.default_curvature)
|
||||
}
|
||||
|
||||
/// Get curvature with canary consideration
|
||||
pub fn get_effective(&self, shard_id: &str, use_canary: bool) -> f32 {
|
||||
self.shards
|
||||
.get(shard_id)
|
||||
.map(|s| s.effective(use_canary))
|
||||
.unwrap_or(self.default_curvature)
|
||||
}
|
||||
|
||||
/// Set curvature for a shard
|
||||
pub fn set(&mut self, shard_id: &str, curvature: f32) {
|
||||
let entry = self.shards.entry(shard_id.to_string()).or_default();
|
||||
entry.current = curvature;
|
||||
entry.updated_at = chrono_timestamp();
|
||||
self.version += 1;
|
||||
}
|
||||
|
||||
/// Set canary curvature
|
||||
pub fn set_canary(&mut self, shard_id: &str, curvature: f32, traffic: u8) {
|
||||
let entry = self.shards.entry(shard_id.to_string()).or_default();
|
||||
entry.canary = Some(curvature);
|
||||
entry.canary_traffic = traffic.min(100);
|
||||
entry.updated_at = chrono_timestamp();
|
||||
self.version += 1;
|
||||
}
|
||||
|
||||
/// Promote all canaries
|
||||
pub fn promote_all_canaries(&mut self) {
|
||||
for (_, shard) in self.shards.iter_mut() {
|
||||
shard.promote_canary();
|
||||
}
|
||||
self.version += 1;
|
||||
}
|
||||
|
||||
/// Rollback all canaries
|
||||
pub fn rollback_all_canaries(&mut self) {
|
||||
for (_, shard) in self.shards.iter_mut() {
|
||||
shard.rollback_canary();
|
||||
}
|
||||
self.version += 1;
|
||||
}
|
||||
|
||||
/// Record learned curvature
|
||||
pub fn set_learned(&mut self, shard_id: &str, curvature: f32) {
|
||||
let entry = self.shards.entry(shard_id.to_string()).or_default();
|
||||
entry.learned = Some(curvature);
|
||||
entry.updated_at = chrono_timestamp();
|
||||
}
|
||||
}
|
||||
|
||||
fn chrono_timestamp() -> i64 {
|
||||
std::time::SystemTime::now()
|
||||
.duration_since(std::time::UNIX_EPOCH)
|
||||
.map(|d| d.as_secs() as i64)
|
||||
.unwrap_or(0)
|
||||
}
|
||||
|
||||
/// A single shard in the sharded HNSW system
|
||||
#[derive(Debug)]
|
||||
pub struct HyperbolicShard {
|
||||
/// Shard ID
|
||||
pub id: String,
|
||||
/// HNSW index for this shard
|
||||
pub index: HyperbolicHnsw,
|
||||
/// Tangent cache
|
||||
pub tangent_cache: Option<TangentCache>,
|
||||
/// Shard centroid
|
||||
pub centroid: Vec<f32>,
|
||||
/// Hierarchy depth range (min, max)
|
||||
pub depth_range: (usize, usize),
|
||||
/// Number of vectors in shard
|
||||
pub count: usize,
|
||||
}
|
||||
|
||||
impl HyperbolicShard {
|
||||
/// Create a new shard
|
||||
pub fn new(id: String, curvature: f32) -> Self {
|
||||
let mut config = HyperbolicHnswConfig::default();
|
||||
config.curvature = curvature;
|
||||
|
||||
Self {
|
||||
id,
|
||||
index: HyperbolicHnsw::new(config),
|
||||
tangent_cache: None,
|
||||
centroid: Vec::new(),
|
||||
depth_range: (0, 0),
|
||||
count: 0,
|
||||
}
|
||||
}
|
||||
|
||||
/// Insert a vector
|
||||
pub fn insert(&mut self, vector: Vec<f32>) -> HyperbolicResult<usize> {
|
||||
let id = self.index.insert(vector)?;
|
||||
self.count += 1;
|
||||
// Invalidate tangent cache
|
||||
self.tangent_cache = None;
|
||||
Ok(id)
|
||||
}
|
||||
|
||||
/// Build tangent cache
|
||||
pub fn build_cache(&mut self) -> HyperbolicResult<()> {
|
||||
if self.count == 0 {
|
||||
return Ok(());
|
||||
}
|
||||
|
||||
let vectors: Vec<Vec<f32>> = self
|
||||
.index
|
||||
.vectors()
|
||||
.iter()
|
||||
.map(|v| v.to_vec())
|
||||
.collect();
|
||||
let indices: Vec<usize> = (0..vectors.len()).collect();
|
||||
|
||||
self.tangent_cache = Some(TangentCache::new(
|
||||
&vectors,
|
||||
&indices,
|
||||
self.index.config.curvature,
|
||||
)?);
|
||||
|
||||
if let Some(cache) = &self.tangent_cache {
|
||||
self.centroid = cache.centroid.clone();
|
||||
}
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// Search with tangent pruning
|
||||
pub fn search(&self, query: &[f32], k: usize) -> HyperbolicResult<Vec<SearchResult>> {
|
||||
self.index.search(query, k)
|
||||
}
|
||||
|
||||
/// Update curvature
|
||||
pub fn set_curvature(&mut self, curvature: f32) -> HyperbolicResult<()> {
|
||||
self.index.set_curvature(curvature)?;
|
||||
// Rebuild cache with new curvature
|
||||
if self.tangent_cache.is_some() {
|
||||
self.build_cache()?;
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
|
||||
/// Sharded hyperbolic HNSW manager
|
||||
#[derive(Debug)]
|
||||
pub struct ShardedHyperbolicHnsw {
|
||||
/// Shards by ID
|
||||
pub shards: HashMap<String, HyperbolicShard>,
|
||||
/// Curvature registry
|
||||
pub registry: CurvatureRegistry,
|
||||
/// Global ID to shard mapping
|
||||
pub id_to_shard: Vec<(String, usize)>,
|
||||
/// Shard assignment strategy
|
||||
pub strategy: ShardStrategy,
|
||||
}
|
||||
|
||||
/// Strategy for assigning vectors to shards
|
||||
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
|
||||
pub enum ShardStrategy {
|
||||
/// Assign by hash
|
||||
Hash,
|
||||
/// Assign by hierarchy depth
|
||||
Depth,
|
||||
/// Assign by radius (distance from origin)
|
||||
Radius,
|
||||
/// Round-robin
|
||||
RoundRobin,
|
||||
}
|
||||
|
||||
impl Default for ShardStrategy {
|
||||
fn default() -> Self {
|
||||
Self::Radius
|
||||
}
|
||||
}
|
||||
|
||||
impl ShardedHyperbolicHnsw {
|
||||
/// Create a new sharded manager
|
||||
pub fn new(default_curvature: f32) -> Self {
|
||||
Self {
|
||||
shards: HashMap::new(),
|
||||
registry: CurvatureRegistry::new(default_curvature),
|
||||
id_to_shard: Vec::new(),
|
||||
strategy: ShardStrategy::default(),
|
||||
}
|
||||
}
|
||||
|
||||
/// Create or get a shard
|
||||
pub fn get_or_create_shard(&mut self, shard_id: &str) -> &mut HyperbolicShard {
|
||||
let curvature = self.registry.get(shard_id);
|
||||
self.shards
|
||||
.entry(shard_id.to_string())
|
||||
.or_insert_with(|| HyperbolicShard::new(shard_id.to_string(), curvature))
|
||||
}
|
||||
|
||||
/// Determine shard for a vector
|
||||
pub fn assign_shard(&self, vector: &[f32], depth: Option<usize>) -> String {
|
||||
match self.strategy {
|
||||
ShardStrategy::Hash => {
|
||||
let hash: u64 = vector.iter().fold(0u64, |acc, &v| {
|
||||
acc.wrapping_add((v.to_bits() as u64).wrapping_mul(31))
|
||||
});
|
||||
format!("shard_{}", hash % (self.shards.len().max(1) as u64))
|
||||
}
|
||||
ShardStrategy::Depth => {
|
||||
let d = depth.unwrap_or(0);
|
||||
format!("depth_{}", d / 10) // Group by depth buckets
|
||||
}
|
||||
ShardStrategy::Radius => {
|
||||
let radius: f32 = vector.iter().map(|v| v * v).sum::<f32>().sqrt();
|
||||
let bucket = (radius * 10.0) as usize;
|
||||
format!("radius_{}", bucket)
|
||||
}
|
||||
ShardStrategy::RoundRobin => {
|
||||
let idx = self.id_to_shard.len() % self.shards.len().max(1);
|
||||
self.shards
|
||||
.keys()
|
||||
.nth(idx)
|
||||
.cloned()
|
||||
.unwrap_or_else(|| "default".to_string())
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Insert vector with automatic shard assignment
|
||||
pub fn insert(&mut self, vector: Vec<f32>, depth: Option<usize>) -> HyperbolicResult<usize> {
|
||||
let shard_id = self.assign_shard(&vector, depth);
|
||||
let shard = self.get_or_create_shard(&shard_id);
|
||||
let local_id = shard.insert(vector)?;
|
||||
|
||||
let global_id = self.id_to_shard.len();
|
||||
self.id_to_shard.push((shard_id, local_id));
|
||||
|
||||
Ok(global_id)
|
||||
}
|
||||
|
||||
/// Insert into specific shard
|
||||
pub fn insert_to_shard(
|
||||
&mut self,
|
||||
shard_id: &str,
|
||||
vector: Vec<f32>,
|
||||
) -> HyperbolicResult<usize> {
|
||||
let shard = self.get_or_create_shard(shard_id);
|
||||
let local_id = shard.insert(vector)?;
|
||||
|
||||
let global_id = self.id_to_shard.len();
|
||||
self.id_to_shard.push((shard_id.to_string(), local_id));
|
||||
|
||||
Ok(global_id)
|
||||
}
|
||||
|
||||
/// Search across all shards
|
||||
pub fn search(&self, query: &[f32], k: usize) -> HyperbolicResult<Vec<(usize, SearchResult)>> {
|
||||
let mut all_results: Vec<(usize, SearchResult)> = Vec::new();
|
||||
|
||||
for (shard_id, shard) in &self.shards {
|
||||
let results = shard.search(query, k)?;
|
||||
for result in results {
|
||||
// Map local ID to global ID
|
||||
if let Some((global_id, _)) = self.id_to_shard.iter().enumerate().find(|(_, (s, l))| s == shard_id && *l == result.id) {
|
||||
all_results.push((global_id, result));
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Sort by distance and take top k
|
||||
all_results.sort_by(|a, b| a.1.distance.partial_cmp(&b.1.distance).unwrap());
|
||||
all_results.truncate(k);
|
||||
|
||||
Ok(all_results)
|
||||
}
|
||||
|
||||
/// Build all tangent caches
|
||||
pub fn build_caches(&mut self) -> HyperbolicResult<()> {
|
||||
for shard in self.shards.values_mut() {
|
||||
shard.build_cache()?;
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// Update curvature for a shard
|
||||
pub fn update_curvature(&mut self, shard_id: &str, curvature: f32) -> HyperbolicResult<()> {
|
||||
self.registry.set(shard_id, curvature);
|
||||
if let Some(shard) = self.shards.get_mut(shard_id) {
|
||||
shard.set_curvature(curvature)?;
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// Hot reload curvatures from registry
|
||||
pub fn reload_curvatures(&mut self) -> HyperbolicResult<()> {
|
||||
for (shard_id, shard) in self.shards.iter_mut() {
|
||||
let curvature = self.registry.get(shard_id);
|
||||
shard.set_curvature(curvature)?;
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// Get total vector count
|
||||
pub fn len(&self) -> usize {
|
||||
self.id_to_shard.len()
|
||||
}
|
||||
|
||||
/// Check if empty
|
||||
pub fn is_empty(&self) -> bool {
|
||||
self.id_to_shard.is_empty()
|
||||
}
|
||||
|
||||
/// Get number of shards
|
||||
pub fn num_shards(&self) -> usize {
|
||||
self.shards.len()
|
||||
}
|
||||
}
|
||||
|
||||
/// Metrics for hierarchy preservation
|
||||
#[derive(Debug, Clone, Default, Serialize, Deserialize)]
|
||||
pub struct HierarchyMetrics {
|
||||
/// Spearman correlation between radius and depth
|
||||
pub radius_depth_correlation: f32,
|
||||
/// Average distance distortion
|
||||
pub distance_distortion: f32,
|
||||
/// Ancestor preservation (AUPRC)
|
||||
pub ancestor_auprc: f32,
|
||||
/// Mean rank
|
||||
pub mean_rank: f32,
|
||||
/// NDCG scores
|
||||
pub ndcg: HashMap<String, f32>,
|
||||
}
|
||||
|
||||
impl HierarchyMetrics {
|
||||
/// Compute hierarchy metrics
|
||||
pub fn compute(
|
||||
points: &[Vec<f32>],
|
||||
depths: &[usize],
|
||||
curvature: f32,
|
||||
) -> HyperbolicResult<Self> {
|
||||
if points.is_empty() || points.len() != depths.len() {
|
||||
return Err(HyperbolicError::EmptyCollection);
|
||||
}
|
||||
|
||||
// Compute radii
|
||||
let radii: Vec<f32> = points
|
||||
.iter()
|
||||
.map(|p| p.iter().map(|v| v * v).sum::<f32>().sqrt())
|
||||
.collect();
|
||||
|
||||
// Spearman correlation between radius and depth
|
||||
let radius_depth_correlation = spearman_correlation(&radii, depths);
|
||||
|
||||
// Distance distortion (sample-based for efficiency)
|
||||
let sample_size = points.len().min(100);
|
||||
let mut distortion_sum = 0.0;
|
||||
let mut distortion_count = 0;
|
||||
|
||||
for i in 0..sample_size {
|
||||
for j in (i + 1)..sample_size {
|
||||
let hyp_dist = poincare_distance(&points[i], &points[j], curvature);
|
||||
let depth_diff = (depths[i] as f32 - depths[j] as f32).abs();
|
||||
|
||||
if depth_diff > 0.0 {
|
||||
distortion_sum += (hyp_dist - depth_diff).abs() / depth_diff;
|
||||
distortion_count += 1;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
let distance_distortion = if distortion_count > 0 {
|
||||
distortion_sum / distortion_count as f32
|
||||
} else {
|
||||
0.0
|
||||
};
|
||||
|
||||
Ok(Self {
|
||||
radius_depth_correlation,
|
||||
distance_distortion,
|
||||
ancestor_auprc: 0.0, // Requires ground truth
|
||||
mean_rank: 0.0, // Requires ground truth
|
||||
ndcg: HashMap::new(),
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
/// Compute Spearman rank correlation
|
||||
fn spearman_correlation(x: &[f32], y: &[usize]) -> f32 {
|
||||
if x.len() != y.len() || x.is_empty() {
|
||||
return 0.0;
|
||||
}
|
||||
|
||||
let n = x.len();
|
||||
|
||||
// Compute ranks for x
|
||||
let mut x_indexed: Vec<(usize, f32)> = x.iter().cloned().enumerate().collect();
|
||||
x_indexed.sort_by(|a, b| a.1.partial_cmp(&b.1).unwrap());
|
||||
let mut x_ranks = vec![0.0; n];
|
||||
for (rank, (idx, _)) in x_indexed.iter().enumerate() {
|
||||
x_ranks[*idx] = rank as f32;
|
||||
}
|
||||
|
||||
// Compute ranks for y
|
||||
let mut y_indexed: Vec<(usize, usize)> = y.iter().cloned().enumerate().collect();
|
||||
y_indexed.sort_by_key(|a| a.1);
|
||||
let mut y_ranks = vec![0.0; n];
|
||||
for (rank, (idx, _)) in y_indexed.iter().enumerate() {
|
||||
y_ranks[*idx] = rank as f32;
|
||||
}
|
||||
|
||||
// Compute Spearman correlation
|
||||
let mean_x: f32 = x_ranks.iter().sum::<f32>() / n as f32;
|
||||
let mean_y: f32 = y_ranks.iter().sum::<f32>() / n as f32;
|
||||
|
||||
let mut cov = 0.0;
|
||||
let mut var_x = 0.0;
|
||||
let mut var_y = 0.0;
|
||||
|
||||
for i in 0..n {
|
||||
let dx = x_ranks[i] - mean_x;
|
||||
let dy = y_ranks[i] - mean_y;
|
||||
cov += dx * dy;
|
||||
var_x += dx * dx;
|
||||
var_y += dy * dy;
|
||||
}
|
||||
|
||||
if var_x == 0.0 || var_y == 0.0 {
|
||||
return 0.0;
|
||||
}
|
||||
|
||||
cov / (var_x * var_y).sqrt()
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
#[test]
|
||||
fn test_curvature_registry() {
|
||||
let mut registry = CurvatureRegistry::new(1.0);
|
||||
|
||||
registry.set("shard_1", 0.5);
|
||||
assert_eq!(registry.get("shard_1"), 0.5);
|
||||
assert_eq!(registry.get("shard_2"), 1.0); // Default
|
||||
|
||||
registry.set_canary("shard_1", 0.3, 50);
|
||||
assert_eq!(registry.get_effective("shard_1", false), 0.5);
|
||||
assert_eq!(registry.get_effective("shard_1", true), 0.3);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_sharded_hnsw() {
|
||||
let mut manager = ShardedHyperbolicHnsw::new(1.0);
|
||||
|
||||
for i in 0..20 {
|
||||
let v = vec![0.1 * i as f32, 0.05 * i as f32];
|
||||
manager.insert(v, Some(i / 5)).unwrap();
|
||||
}
|
||||
|
||||
assert_eq!(manager.len(), 20);
|
||||
|
||||
let query = vec![0.3, 0.15];
|
||||
let results = manager.search(&query, 5).unwrap();
|
||||
assert!(!results.is_empty());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_spearman() {
|
||||
let x = vec![1.0, 2.0, 3.0, 4.0, 5.0];
|
||||
let y = vec![1, 2, 3, 4, 5];
|
||||
let corr = spearman_correlation(&x, &y);
|
||||
assert!((corr - 1.0).abs() < 0.01);
|
||||
|
||||
let y_rev = vec![5, 4, 3, 2, 1];
|
||||
let corr_rev = spearman_correlation(&x, &y_rev);
|
||||
assert!((corr_rev + 1.0).abs() < 0.01);
|
||||
}
|
||||
}
|
||||
348
vendor/ruvector/crates/ruvector-hyperbolic-hnsw/src/tangent.rs
vendored
Normal file
348
vendor/ruvector/crates/ruvector-hyperbolic-hnsw/src/tangent.rs
vendored
Normal file
@@ -0,0 +1,348 @@
|
||||
//! Tangent Space Operations for HNSW Pruning Optimization
|
||||
//!
|
||||
//! This module implements the key optimization for hyperbolic HNSW:
|
||||
//! - Precompute tangent space coordinates at shard centroids
|
||||
//! - Use cheap Euclidean distance in tangent space for pruning
|
||||
//! - Only compute exact Poincaré distance for final ranking
|
||||
//!
|
||||
//! # HNSW Speed Trick
|
||||
//!
|
||||
//! The core insight is that for points near a centroid c:
|
||||
//! 1. Map points to tangent space: u = log_c(x)
|
||||
//! 2. Euclidean distance ||u_q - u_p|| approximates hyperbolic distance
|
||||
//! 3. Prune candidates using fast Euclidean comparisons
|
||||
//! 4. Rank final top-N candidates with exact Poincaré distance
|
||||
|
||||
use crate::error::{HyperbolicError, HyperbolicResult};
|
||||
use crate::poincare::{
|
||||
conformal_factor, frechet_mean, log_map, norm, norm_squared, poincare_distance,
|
||||
project_to_ball, PoincareConfig, EPS,
|
||||
};
|
||||
use serde::{Deserialize, Serialize};
|
||||
|
||||
/// Tangent space cache for a shard
|
||||
///
|
||||
/// Stores precomputed tangent coordinates for fast pruning.
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct TangentCache {
|
||||
/// Centroid point (base of tangent space)
|
||||
pub centroid: Vec<f32>,
|
||||
/// Precomputed tangent coordinates for all points in shard
|
||||
pub tangent_coords: Vec<Vec<f32>>,
|
||||
/// Original point indices
|
||||
pub point_indices: Vec<usize>,
|
||||
/// Curvature parameter
|
||||
pub curvature: f32,
|
||||
/// Cached conformal factor at centroid
|
||||
conformal: f32,
|
||||
}
|
||||
|
||||
impl TangentCache {
|
||||
/// Create a new tangent cache for a shard
|
||||
///
|
||||
/// # Arguments
|
||||
/// * `points` - Points in the shard (Poincaré ball coordinates)
|
||||
/// * `indices` - Original indices of the points
|
||||
/// * `curvature` - Curvature parameter
|
||||
pub fn new(points: &[Vec<f32>], indices: &[usize], curvature: f32) -> HyperbolicResult<Self> {
|
||||
if points.is_empty() {
|
||||
return Err(HyperbolicError::EmptyCollection);
|
||||
}
|
||||
|
||||
let config = PoincareConfig::with_curvature(curvature)?;
|
||||
|
||||
// Compute centroid as Fréchet mean
|
||||
let point_refs: Vec<&[f32]> = points.iter().map(|p| p.as_slice()).collect();
|
||||
let centroid = frechet_mean(&point_refs, None, &config)?;
|
||||
|
||||
// Precompute tangent coordinates
|
||||
let tangent_coords: Vec<Vec<f32>> = points
|
||||
.iter()
|
||||
.map(|p| log_map(p, ¢roid, curvature))
|
||||
.collect();
|
||||
|
||||
let conformal = conformal_factor(¢roid, curvature);
|
||||
|
||||
Ok(Self {
|
||||
centroid,
|
||||
tangent_coords,
|
||||
point_indices: indices.to_vec(),
|
||||
curvature,
|
||||
conformal,
|
||||
})
|
||||
}
|
||||
|
||||
/// Create from centroid directly (for incremental updates)
|
||||
pub fn from_centroid(
|
||||
centroid: Vec<f32>,
|
||||
points: &[Vec<f32>],
|
||||
indices: &[usize],
|
||||
curvature: f32,
|
||||
) -> HyperbolicResult<Self> {
|
||||
let tangent_coords: Vec<Vec<f32>> = points
|
||||
.iter()
|
||||
.map(|p| log_map(p, ¢roid, curvature))
|
||||
.collect();
|
||||
|
||||
let conformal = conformal_factor(¢roid, curvature);
|
||||
|
||||
Ok(Self {
|
||||
centroid,
|
||||
tangent_coords,
|
||||
point_indices: indices.to_vec(),
|
||||
curvature,
|
||||
conformal,
|
||||
})
|
||||
}
|
||||
|
||||
/// Get tangent coordinates for a query point
|
||||
pub fn query_tangent(&self, query: &[f32]) -> Vec<f32> {
|
||||
log_map(query, &self.centroid, self.curvature)
|
||||
}
|
||||
|
||||
/// Fast Euclidean distance in tangent space (for pruning)
|
||||
#[inline]
|
||||
pub fn tangent_distance_squared(&self, query_tangent: &[f32], idx: usize) -> f32 {
|
||||
if idx >= self.tangent_coords.len() {
|
||||
return f32::MAX;
|
||||
}
|
||||
|
||||
let p = &self.tangent_coords[idx];
|
||||
query_tangent
|
||||
.iter()
|
||||
.zip(p.iter())
|
||||
.map(|(&a, &b)| (a - b) * (a - b))
|
||||
.sum()
|
||||
}
|
||||
|
||||
/// Exact Poincaré distance for final ranking
|
||||
pub fn exact_distance(&self, query: &[f32], idx: usize, points: &[Vec<f32>]) -> f32 {
|
||||
if idx >= points.len() {
|
||||
return f32::MAX;
|
||||
}
|
||||
poincare_distance(query, &points[idx], self.curvature)
|
||||
}
|
||||
|
||||
/// Add a new point to the cache (for incremental updates)
|
||||
pub fn add_point(&mut self, point: &[f32], index: usize) {
|
||||
let tangent = log_map(point, &self.centroid, self.curvature);
|
||||
self.tangent_coords.push(tangent);
|
||||
self.point_indices.push(index);
|
||||
}
|
||||
|
||||
/// Update centroid and recompute all tangent coordinates
|
||||
pub fn recompute_centroid(&mut self, points: &[Vec<f32>]) -> HyperbolicResult<()> {
|
||||
if points.is_empty() {
|
||||
return Err(HyperbolicError::EmptyCollection);
|
||||
}
|
||||
|
||||
let config = PoincareConfig::with_curvature(self.curvature)?;
|
||||
let point_refs: Vec<&[f32]> = points.iter().map(|p| p.as_slice()).collect();
|
||||
self.centroid = frechet_mean(&point_refs, None, &config)?;
|
||||
|
||||
self.tangent_coords = points
|
||||
.iter()
|
||||
.map(|p| log_map(p, &self.centroid, self.curvature))
|
||||
.collect();
|
||||
|
||||
self.conformal = conformal_factor(&self.centroid, self.curvature);
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// Get number of points in cache
|
||||
pub fn len(&self) -> usize {
|
||||
self.tangent_coords.len()
|
||||
}
|
||||
|
||||
/// Check if cache is empty
|
||||
pub fn is_empty(&self) -> bool {
|
||||
self.tangent_coords.is_empty()
|
||||
}
|
||||
|
||||
/// Get the dimension of the tangent space
|
||||
pub fn dim(&self) -> usize {
|
||||
self.centroid.len()
|
||||
}
|
||||
}
|
||||
|
||||
/// Tangent space pruning result
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct PrunedCandidate {
|
||||
/// Original index
|
||||
pub index: usize,
|
||||
/// Tangent space distance (for initial ranking)
|
||||
pub tangent_dist: f32,
|
||||
/// Exact Poincaré distance (computed lazily)
|
||||
pub exact_dist: Option<f32>,
|
||||
}
|
||||
|
||||
/// Tangent space pruner for HNSW neighbor selection
|
||||
///
|
||||
/// Implements the two-phase search:
|
||||
/// 1. Fast pruning using Euclidean distance in tangent space
|
||||
/// 2. Exact ranking using Poincaré distance for top candidates
|
||||
pub struct TangentPruner {
|
||||
/// Tangent caches for each shard
|
||||
caches: Vec<TangentCache>,
|
||||
/// Number of candidates to consider in exact phase
|
||||
top_n: usize,
|
||||
/// Pruning factor (how many candidates to keep from tangent phase)
|
||||
prune_factor: usize,
|
||||
}
|
||||
|
||||
impl TangentPruner {
|
||||
/// Create a new pruner
|
||||
///
|
||||
/// # Arguments
|
||||
/// * `top_n` - Number of final results
|
||||
/// * `prune_factor` - Multiplier for candidates to consider (e.g., 10 means consider 10*top_n)
|
||||
pub fn new(top_n: usize, prune_factor: usize) -> Self {
|
||||
Self {
|
||||
caches: Vec::new(),
|
||||
top_n,
|
||||
prune_factor,
|
||||
}
|
||||
}
|
||||
|
||||
/// Add a shard cache
|
||||
pub fn add_cache(&mut self, cache: TangentCache) {
|
||||
self.caches.push(cache);
|
||||
}
|
||||
|
||||
/// Get shard caches
|
||||
pub fn caches(&self) -> &[TangentCache] {
|
||||
&self.caches
|
||||
}
|
||||
|
||||
/// Get mutable shard caches
|
||||
pub fn caches_mut(&mut self) -> &mut [TangentCache] {
|
||||
&mut self.caches
|
||||
}
|
||||
|
||||
/// Search across all shards with tangent pruning
|
||||
///
|
||||
/// Returns top_n candidates sorted by exact Poincaré distance.
|
||||
pub fn search(
|
||||
&self,
|
||||
query: &[f32],
|
||||
points: &[Vec<f32>],
|
||||
curvature: f32,
|
||||
) -> Vec<PrunedCandidate> {
|
||||
let num_prune = self.top_n * self.prune_factor;
|
||||
let mut candidates: Vec<PrunedCandidate> = Vec::with_capacity(num_prune);
|
||||
|
||||
// Phase 1: Tangent space pruning across all shards
|
||||
for cache in &self.caches {
|
||||
let query_tangent = cache.query_tangent(query);
|
||||
|
||||
for (local_idx, &global_idx) in cache.point_indices.iter().enumerate() {
|
||||
let tangent_dist = cache.tangent_distance_squared(&query_tangent, local_idx);
|
||||
candidates.push(PrunedCandidate {
|
||||
index: global_idx,
|
||||
tangent_dist,
|
||||
exact_dist: None,
|
||||
});
|
||||
}
|
||||
}
|
||||
|
||||
// Sort by tangent distance and keep top prune_factor * top_n
|
||||
candidates.sort_by(|a, b| a.tangent_dist.partial_cmp(&b.tangent_dist).unwrap());
|
||||
candidates.truncate(num_prune);
|
||||
|
||||
// Phase 2: Exact Poincaré distance for finalists
|
||||
for candidate in &mut candidates {
|
||||
if candidate.index < points.len() {
|
||||
candidate.exact_dist =
|
||||
Some(poincare_distance(query, &points[candidate.index], curvature));
|
||||
}
|
||||
}
|
||||
|
||||
// Sort by exact distance and return top_n
|
||||
candidates.sort_by(|a, b| {
|
||||
a.exact_dist
|
||||
.unwrap_or(f32::MAX)
|
||||
.partial_cmp(&b.exact_dist.unwrap_or(f32::MAX))
|
||||
.unwrap()
|
||||
});
|
||||
candidates.truncate(self.top_n);
|
||||
|
||||
candidates
|
||||
}
|
||||
}
|
||||
|
||||
/// Compute micro tangent update for incremental operations
|
||||
///
|
||||
/// For small updates (reflex loop), compute tangent-space delta
|
||||
/// that keeps the point inside the ball.
|
||||
pub fn tangent_micro_update(
|
||||
point: &[f32],
|
||||
delta: &[f32],
|
||||
centroid: &[f32],
|
||||
curvature: f32,
|
||||
max_step: f32,
|
||||
) -> Vec<f32> {
|
||||
// Get current tangent coordinates
|
||||
let tangent = log_map(point, centroid, curvature);
|
||||
|
||||
// Apply bounded delta in tangent space
|
||||
let delta_norm = norm(delta);
|
||||
let scale = if delta_norm > max_step {
|
||||
max_step / delta_norm
|
||||
} else {
|
||||
1.0
|
||||
};
|
||||
|
||||
let new_tangent: Vec<f32> = tangent
|
||||
.iter()
|
||||
.zip(delta.iter())
|
||||
.map(|(&t, &d)| t + scale * d)
|
||||
.collect();
|
||||
|
||||
// Map back to ball and project
|
||||
let new_point = crate::poincare::exp_map(&new_tangent, centroid, curvature);
|
||||
project_to_ball(&new_point, curvature, EPS)
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
#[test]
|
||||
fn test_tangent_cache_creation() {
|
||||
let points = vec![
|
||||
vec![0.1, 0.2, 0.1],
|
||||
vec![-0.1, 0.15, 0.05],
|
||||
vec![0.2, -0.1, 0.1],
|
||||
];
|
||||
let indices: Vec<usize> = (0..3).collect();
|
||||
|
||||
let cache = TangentCache::new(&points, &indices, 1.0).unwrap();
|
||||
|
||||
assert_eq!(cache.len(), 3);
|
||||
assert_eq!(cache.dim(), 3);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_tangent_pruning() {
|
||||
let points = vec![
|
||||
vec![0.1, 0.2],
|
||||
vec![-0.1, 0.15],
|
||||
vec![0.2, -0.1],
|
||||
vec![0.05, 0.05],
|
||||
];
|
||||
let indices: Vec<usize> = (0..4).collect();
|
||||
|
||||
let cache = TangentCache::new(&points, &indices, 1.0).unwrap();
|
||||
|
||||
let mut pruner = TangentPruner::new(2, 2);
|
||||
pruner.add_cache(cache);
|
||||
|
||||
let query = vec![0.08, 0.1];
|
||||
let results = pruner.search(&query, &points, 1.0);
|
||||
|
||||
assert_eq!(results.len(), 2);
|
||||
// Results should be sorted by exact distance
|
||||
assert!(results[0].exact_dist.unwrap() <= results[1].exact_dist.unwrap());
|
||||
}
|
||||
}
|
||||
Reference in New Issue
Block a user